Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
82d2ef6f40 | ||
|
|
44dc2ddc05 |
@@ -28,14 +28,5 @@ github:
|
|||||||
squash: true
|
squash: true
|
||||||
merge: false
|
merge: false
|
||||||
rebase: false
|
rebase: false
|
||||||
features:
|
|
||||||
# Enable wiki for documentation
|
|
||||||
wiki: true
|
|
||||||
# Enable issue management
|
|
||||||
issues: true
|
|
||||||
# Enable projects for project management boardS
|
|
||||||
projects: true
|
|
||||||
# Enable discussions
|
|
||||||
discussions: true
|
|
||||||
notifications:
|
notifications:
|
||||||
pullrequests_status: commits@doris.apache.org
|
pullrequests_status: commits@doris.apache.org
|
||||||
|
|||||||
91
.env.example
@@ -1,90 +1,71 @@
|
|||||||
# Doris MCP Server Configuration
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
# Copy this file to .env and modify the values according to your environment
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
# Doris MCP Server Environment Configuration
|
||||||
|
# Copy this file to .env and modify the values as needed
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Database Configuration
|
# Database Configuration
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
# Doris FE connection settings
|
|
||||||
DORIS_HOST=localhost
|
DORIS_HOST=localhost
|
||||||
DORIS_PORT=9030
|
DORIS_PORT=9030
|
||||||
DORIS_USER=root
|
DORIS_USER=root
|
||||||
DORIS_PASSWORD=
|
DORIS_PASSWORD=your_password_here
|
||||||
DORIS_DATABASE=information_schema
|
DORIS_DATABASE=your_database_name
|
||||||
|
|
||||||
# Doris FE HTTP API port
|
|
||||||
DORIS_FE_HTTP_PORT=8030
|
|
||||||
|
|
||||||
# BE nodes configuration for external access
|
|
||||||
# If DORIS_BE_HOSTS is empty, will use "show backends" to get BE nodes automatically
|
|
||||||
# Format: comma-separated list of BE host addresses
|
|
||||||
# Example: DORIS_BE_HOSTS=192.168.1.100,192.168.1.101,192.168.1.102
|
|
||||||
DORIS_BE_HOSTS=
|
|
||||||
|
|
||||||
# BE webserver port for HTTP APIs (memory tracker, metrics, etc.)
|
|
||||||
DORIS_BE_WEBSERVER_PORT=8040
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Connection Pool Configuration
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
# Connection Pool Settings
|
||||||
DORIS_MIN_CONNECTIONS=5
|
DORIS_MIN_CONNECTIONS=5
|
||||||
DORIS_MAX_CONNECTIONS=20
|
DORIS_MAX_CONNECTIONS=20
|
||||||
DORIS_CONNECTION_TIMEOUT=30
|
DORIS_CONNECTION_TIMEOUT=30
|
||||||
DORIS_HEALTH_CHECK_INTERVAL=60
|
DORIS_HEALTH_CHECK_INTERVAL=60
|
||||||
DORIS_MAX_CONNECTION_AGE=3600
|
DORIS_MAX_CONNECTION_AGE=3600
|
||||||
|
|
||||||
# =============================================================================
|
# Security Settings
|
||||||
# Profile And Explain Max Data Size
|
|
||||||
# =============================================================================
|
|
||||||
MAX_RESPONSE_CONTENT_SIZE=4096
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Security Configuration
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
ENABLE_SECURITY_CHECK=true
|
|
||||||
BLOCKED_KEYWORDS="DROP,TRUNCATE,DELETE,SHUTDOWN,INSERT,UPDATE,CREATE,ALTER,GRANT,REVOKE,KILL"
|
|
||||||
AUTH_TYPE=token
|
AUTH_TYPE=token
|
||||||
TOKEN_SECRET=your_secret_key_here
|
TOKEN_SECRET=your_256_bit_secret_key_here
|
||||||
TOKEN_EXPIRY=3600
|
TOKEN_EXPIRY=3600
|
||||||
MAX_RESULT_ROWS=10000
|
MAX_RESULT_ROWS=10000
|
||||||
MAX_QUERY_COMPLEXITY=100
|
|
||||||
ENABLE_MASKING=true
|
ENABLE_MASKING=true
|
||||||
|
|
||||||
# =============================================================================
|
# Performance Settings
|
||||||
# Performance Configuration
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
ENABLE_QUERY_CACHE=true
|
ENABLE_QUERY_CACHE=true
|
||||||
CACHE_TTL=300
|
CACHE_TTL=300
|
||||||
MAX_CACHE_SIZE=1000
|
MAX_CACHE_SIZE=1000
|
||||||
MAX_CONCURRENT_QUERIES=50
|
MAX_CONCURRENT_QUERIES=50
|
||||||
QUERY_TIMEOUT=300
|
QUERY_TIMEOUT=300
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Logging Configuration
|
# Logging Configuration
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
LOG_LEVEL=INFO
|
LOG_LEVEL=INFO
|
||||||
LOG_FILE_PATH=
|
LOG_FILE_PATH=./log/doris-mcp-server.log
|
||||||
ENABLE_AUDIT=true
|
ENABLE_AUDIT=true
|
||||||
AUDIT_FILE_PATH=
|
AUDIT_FILE_PATH=./log/doris-mcp-audit.log
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# Monitoring Configuration
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
|
# Monitoring Settings
|
||||||
ENABLE_METRICS=true
|
ENABLE_METRICS=true
|
||||||
METRICS_PORT=3001
|
METRICS_PORT=3001
|
||||||
|
METRICS_PATH=/metrics
|
||||||
HEALTH_CHECK_PORT=3002
|
HEALTH_CHECK_PORT=3002
|
||||||
|
HEALTH_CHECK_PATH=/health
|
||||||
ENABLE_ALERTS=false
|
ENABLE_ALERTS=false
|
||||||
ALERT_WEBHOOK_URL=
|
ALERT_WEBHOOK_URL=
|
||||||
|
|
||||||
# =============================================================================
|
# Server Settings
|
||||||
# Server Configuration
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
SERVER_NAME=doris-mcp-server
|
SERVER_NAME=doris-mcp-server
|
||||||
SERVER_VERSION=0.4.1
|
SERVER_VERSION=0.3.0
|
||||||
SERVER_PORT=3000
|
SERVER_PORT=3000
|
||||||
|
|
||||||
|
# Development Settings (for development environment only)
|
||||||
|
DEBUG=false
|
||||||
|
VERBOSE=false
|
||||||
65
Dockerfile
@@ -1,65 +0,0 @@
|
|||||||
# Licensed to the Apache Software Foundation (ASF) under one
|
|
||||||
# or more contributor license agreements. See the NOTICE file
|
|
||||||
# distributed with this work for additional information
|
|
||||||
# regarding copyright ownership. The ASF licenses this file
|
|
||||||
# to you under the Apache License, Version 2.0 (the
|
|
||||||
# "License"); you may not use this file except in compliance
|
|
||||||
# with the License. You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing,
|
|
||||||
# software distributed under the License is distributed on an
|
|
||||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
# KIND, either express or implied. See the License for the
|
|
||||||
# specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
# Use Python 3.12 as base image
|
|
||||||
FROM python:3.12-slim
|
|
||||||
|
|
||||||
# Set working directory
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
# Set environment variables
|
|
||||||
ENV PYTHONPATH=/app
|
|
||||||
ENV PYTHONUNBUFFERED=1
|
|
||||||
ENV DEBIAN_FRONTEND=noninteractive
|
|
||||||
|
|
||||||
# Install system dependencies
|
|
||||||
RUN apt-get update && apt-get install -y \
|
|
||||||
curl \
|
|
||||||
gcc \
|
|
||||||
g++ \
|
|
||||||
pkg-config \
|
|
||||||
default-libmysqlclient-dev \
|
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
|
||||||
|
|
||||||
# Copy requirements file
|
|
||||||
COPY requirements.txt .
|
|
||||||
|
|
||||||
# Install Python dependencies
|
|
||||||
RUN pip install --no-cache-dir -r requirements.txt
|
|
||||||
|
|
||||||
# Copy application code
|
|
||||||
COPY . .
|
|
||||||
|
|
||||||
# Create necessary directories
|
|
||||||
RUN mkdir -p /app/logs /app/config /app/data
|
|
||||||
|
|
||||||
# Set permissions
|
|
||||||
RUN chmod +x /app/start_server.sh
|
|
||||||
|
|
||||||
# Create non-root user
|
|
||||||
RUN groupadd -r doris && useradd -r -g doris doris
|
|
||||||
RUN chown -R doris:doris /app
|
|
||||||
USER doris
|
|
||||||
|
|
||||||
# Health check
|
|
||||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
|
|
||||||
CMD curl -f http://localhost:3000/health || exit 1
|
|
||||||
|
|
||||||
# Expose ports
|
|
||||||
EXPOSE 3000 3001 3002
|
|
||||||
|
|
||||||
# Start command
|
|
||||||
CMD ["/app/start_server.sh"]
|
|
||||||
135
Makefile
@@ -1,135 +0,0 @@
|
|||||||
# Licensed to the Apache Software Foundation (ASF) under one
|
|
||||||
# or more contributor license agreements. See the NOTICE file
|
|
||||||
# distributed with this work for additional information
|
|
||||||
# regarding copyright ownership. The ASF licenses this file
|
|
||||||
# to you under the Apache License, Version 2.0 (the
|
|
||||||
# "License"); you may not use this file except in compliance
|
|
||||||
# with the License. You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing,
|
|
||||||
# software distributed under the License is distributed on an
|
|
||||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
# KIND, either express or implied. See the License for the
|
|
||||||
# specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
# Doris MCP Server Makefile
|
|
||||||
# Provides convenient commands using UV
|
|
||||||
|
|
||||||
.PHONY: help install sync dev test lint format build clean check start-stdio start-sse
|
|
||||||
|
|
||||||
# Default target
|
|
||||||
help:
|
|
||||||
@echo "Available commands:"
|
|
||||||
@echo " install - Install dependencies using UV"
|
|
||||||
@echo " sync - Sync dependencies and create virtual environment"
|
|
||||||
@echo " dev - Install development dependencies"
|
|
||||||
@echo " test - Run tests"
|
|
||||||
@echo " lint - Run linting tools"
|
|
||||||
@echo " format - Format code with black and isort"
|
|
||||||
@echo " build - Build the package"
|
|
||||||
@echo " clean - Clean build artifacts"
|
|
||||||
@echo " check - Run all checks (format, lint, test)"
|
|
||||||
@echo " start-stdio - Start server in stdio mode"
|
|
||||||
@echo " start-sse - Start server in SSE mode"
|
|
||||||
|
|
||||||
# Install dependencies
|
|
||||||
install:
|
|
||||||
uv sync
|
|
||||||
|
|
||||||
# Sync dependencies with development extras
|
|
||||||
sync:
|
|
||||||
uv sync
|
|
||||||
|
|
||||||
# Install development dependencies
|
|
||||||
dev:
|
|
||||||
uv sync --dev
|
|
||||||
|
|
||||||
# Run tests
|
|
||||||
test:
|
|
||||||
uv run pytest
|
|
||||||
|
|
||||||
# Run linting tools
|
|
||||||
lint:
|
|
||||||
uv run ruff check doris_mcp_server/
|
|
||||||
uv run mypy doris_mcp_server/
|
|
||||||
|
|
||||||
# Format code
|
|
||||||
format:
|
|
||||||
uv run ruff format doris_mcp_server/
|
|
||||||
uv run ruff check --fix doris_mcp_server/
|
|
||||||
|
|
||||||
# Build the package
|
|
||||||
build:
|
|
||||||
uv build
|
|
||||||
|
|
||||||
# Clean build artifacts
|
|
||||||
clean:
|
|
||||||
rm -rf build/
|
|
||||||
rm -rf dist/
|
|
||||||
rm -rf *.egg-info/
|
|
||||||
find . -type d -name __pycache__ -exec rm -rf {} +
|
|
||||||
find . -type d -name .pytest_cache -exec rm -rf {} +
|
|
||||||
find . -type d -name .mypy_cache -exec rm -rf {} +
|
|
||||||
|
|
||||||
# Run all checks
|
|
||||||
check: format lint test
|
|
||||||
|
|
||||||
# Start server in stdio mode
|
|
||||||
start-stdio:
|
|
||||||
uv run python -m doris_mcp_server.main --transport stdio
|
|
||||||
|
|
||||||
# Start server in SSE mode
|
|
||||||
start-sse:
|
|
||||||
uv run python -m doris_mcp_server.main --transport sse --host 0.0.0.0 --port 8080
|
|
||||||
|
|
||||||
# Start server with custom database settings
|
|
||||||
start-dev:
|
|
||||||
uv run python -m doris_mcp_server.main \
|
|
||||||
--transport stdio \
|
|
||||||
--db-host localhost \
|
|
||||||
--db-port 9030 \
|
|
||||||
--db-user root \
|
|
||||||
--log-level DEBUG
|
|
||||||
|
|
||||||
# Run a single test file
|
|
||||||
test-file:
|
|
||||||
uv run pytest $(FILE) -v
|
|
||||||
|
|
||||||
# Install and run in one command
|
|
||||||
run: install start-stdio
|
|
||||||
|
|
||||||
# Development setup
|
|
||||||
setup: dev
|
|
||||||
@echo "✅ Development environment is ready!"
|
|
||||||
@echo "Run 'make start-stdio' to start the server"
|
|
||||||
|
|
||||||
# Add dependencies
|
|
||||||
add:
|
|
||||||
uv add $(PACKAGE)
|
|
||||||
|
|
||||||
# Add development dependencies
|
|
||||||
add-dev:
|
|
||||||
uv add --dev $(PACKAGE)
|
|
||||||
|
|
||||||
# Show dependency tree
|
|
||||||
deps:
|
|
||||||
uv tree
|
|
||||||
|
|
||||||
# Lock dependencies
|
|
||||||
lock:
|
|
||||||
uv lock
|
|
||||||
|
|
||||||
# Check for outdated dependencies
|
|
||||||
outdated:
|
|
||||||
uv tree --outdated
|
|
||||||
|
|
||||||
# Export requirements.txt
|
|
||||||
export-requirements:
|
|
||||||
uv export --no-hashes > requirements.txt
|
|
||||||
|
|
||||||
# Show UV version and info
|
|
||||||
info:
|
|
||||||
uv --version
|
|
||||||
uv python list
|
|
||||||
385
README.md
@@ -21,37 +21,37 @@ under the License.
|
|||||||
|
|
||||||
Doris MCP (Model Context Protocol) Server is a backend service built with Python and FastAPI. It implements the MCP, allowing clients to interact with it through defined "Tools". It's primarily designed to connect to Apache Doris databases, potentially leveraging Large Language Models (LLMs) for tasks like converting natural language queries to SQL (NL2SQL), executing queries, and performing metadata management and analysis.
|
Doris MCP (Model Context Protocol) Server is a backend service built with Python and FastAPI. It implements the MCP, allowing clients to interact with it through defined "Tools". It's primarily designed to connect to Apache Doris databases, potentially leveraging Large Language Models (LLMs) for tasks like converting natural language queries to SQL (NL2SQL), executing queries, and performing metadata management and analysis.
|
||||||
|
|
||||||
## 🚀 What's New in v0.4.2
|
## 🚀 What's New in v0.3.0
|
||||||
|
|
||||||
- **🔒 Enhanced Security Framework**: Comprehensive SQL security validation with configurable blocked keywords, SQL injection protection, and unified security configuration management
|
- **🔄 Streamlined Communication**: Completely migrated from SSE to Streamable HTTP for better performance and reliability
|
||||||
- **🛠️ Connection Stability Improvements**: Fixed critical `at_eof` connection errors with advanced connection health monitoring, automatic retry mechanisms, and proactive connection cleanup
|
- **🏗️ Unified Architecture**: Consolidated tools management with centralized registration and routing
|
||||||
- **⚙️ Flexible Security Configuration**: Environment variable support for security policies (`BLOCKED_KEYWORDS`, `ENABLE_SECURITY_CHECK`) with unified configuration architecture eliminating code duplication
|
- **⚡ Enhanced Performance**: Improved query execution with advanced caching and optimization
|
||||||
- **🎯 Centralized Configuration Management**: All security keywords now managed through single configuration source with consistent enforcement across all components
|
- **🔒 Enterprise Security**: Added comprehensive security management with SQL validation and data masking
|
||||||
- **🔧 MCP Version Compatibility**: Resolved MCP library version conflicts with intelligent compatibility layer supporting both MCP 1.8.x and 1.9.x versions
|
- **📊 Advanced Analytics**: New column analysis and performance monitoring tools
|
||||||
- **🚀 Production Reliability**: Enhanced error handling, connection diagnostics, and automatic recovery from database connection issues
|
- **🛠️ Simplified Development**: Streamlined tool development process with unified interfaces
|
||||||
- **🙏 Community Contribution**: Special thanks to Hailin Xie for supporting the doris-mcp-server project by graciously transferring the PyPI project to the community free of charge, contributing to open source. The mcp-doris-server repository will be retained but no longer maintained, with ongoing development continuing on the doris-mcp-server repository
|
|
||||||
|
|
||||||
> **🔧 Key Improvements**: Resolved connection stability issues, unified security keyword management, added comprehensive environment variable configuration for security policies, and fixed MCP library version compatibility conflicts.
|
> **⚠️ Breaking Changes**: SSE endpoints have been removed. Please update your client configurations to use Streamable HTTP (`/mcp` endpoint).
|
||||||
|
|
||||||
## Core Features
|
## Core Features
|
||||||
|
|
||||||
* **MCP Protocol Implementation**: Provides standard MCP interfaces, supporting tool calls, resource management, and prompt interactions.
|
* **MCP Protocol Implementation**: Provides standard MCP interfaces, supporting tool calls, resource management, and prompt interactions.
|
||||||
* **Streamable HTTP Communication**: Unified HTTP endpoint supporting both request/response and streaming communication for optimal performance and reliability.
|
* **Multiple Communication Modes** (Updated in v0.3.0):
|
||||||
* **Stdio Communication**: Standard input/output mode for direct integration with MCP clients like Cursor.
|
* **Stdio**: Standard input/output mode for direct integration with MCP clients like Cursor.
|
||||||
|
* **Streamable HTTP**: Unified HTTP endpoint supporting request/response and streaming (Primary mode since v0.3.0).
|
||||||
|
|
||||||
|
> **⚠️ Breaking Change in v0.3.0**: SSE (Server-Sent Events) mode has been completely removed in favor of the more robust Streamable HTTP implementation.
|
||||||
* **Enterprise-Grade Architecture**: Modular design with comprehensive functionality:
|
* **Enterprise-Grade Architecture**: Modular design with comprehensive functionality:
|
||||||
* **Tools Manager**: Centralized tool registration and routing with unified interfaces (`doris_mcp_server/tools/tools_manager.py`)
|
* **Tools Manager**: Centralized tool registration and routing (`doris_mcp_server/tools/tools_manager.py`)
|
||||||
* **Enhanced Monitoring Tools Module**: Advanced memory tracking, metrics collection, and flexible BE node discovery with modular, extensible design
|
|
||||||
* **Query Information Tools**: Enhanced SQL explain and profiling with configurable content truncation, file export for LLM attachments, and advanced query analytics
|
|
||||||
* **Resources Manager**: Resource management and metadata exposure (`doris_mcp_server/tools/resources_manager.py`)
|
* **Resources Manager**: Resource management and metadata exposure (`doris_mcp_server/tools/resources_manager.py`)
|
||||||
* **Prompts Manager**: Intelligent prompt templates for data analysis (`doris_mcp_server/tools/prompts_manager.py`)
|
* **Prompts Manager**: Intelligent prompt templates for data analysis (`doris_mcp_server/tools/prompts_manager.py`)
|
||||||
* **Advanced Database Features**:
|
* **Advanced Database Features**:
|
||||||
* **Query Execution**: High-performance SQL execution with advanced caching and optimization, enhanced connection stability and automatic retry mechanisms (`doris_mcp_server/utils/query_executor.py`)
|
* **Query Execution**: High-performance SQL execution with caching and optimization (`doris_mcp_server/utils/query_executor.py`)
|
||||||
* **Security Management**: Comprehensive SQL security validation with configurable blocked keywords, SQL injection protection, data masking, and unified security configuration management (`doris_mcp_server/utils/security.py`)
|
* **Security Management**: SQL security validation, data masking, and access control (`doris_mcp_server/utils/security.py`)
|
||||||
* **Metadata Extraction**: Comprehensive database metadata with catalog federation support (`doris_mcp_server/utils/schema_extractor.py`)
|
* **Metadata Extraction**: Comprehensive database metadata with catalog federation support (`doris_mcp_server/utils/schema_extractor.py`)
|
||||||
* **Performance Analysis**: Advanced column analysis, performance monitoring, and data analysis tools (`doris_mcp_server/utils/analysis_tools.py`)
|
* **Performance Analysis**: Column statistics, performance monitoring, and data analysis tools (`doris_mcp_server/utils/analysis_tools.py`)
|
||||||
* **Catalog Federation Support**: Full support for multi-catalog environments (internal Doris tables and external data sources like Hive, MySQL, etc.)
|
* **Catalog Federation Support**: Full support for multi-catalog environments (internal Doris tables and external data sources like Hive, MySQL, etc.)
|
||||||
* **Enterprise Security**: Comprehensive security framework with authentication, authorization, SQL injection protection, and data masking capabilities with environment variable configuration support
|
* **Enterprise Security**: Comprehensive security framework with authentication, authorization, SQL injection protection, and data masking (`doris_mcp_server/utils/security.py`)
|
||||||
* **Unified Configuration Framework**: Centralized configuration management through `config.py` with comprehensive validation, standardized parameter naming, and smart default database handling with automatic fallback to `information_schema`
|
* **Flexible Configuration**: Comprehensive configuration management with environment variables, file-based config, and validation (`doris_mcp_server/utils/config.py`)
|
||||||
|
|
||||||
## System Requirements
|
## System Requirements
|
||||||
|
|
||||||
@@ -64,18 +64,16 @@ Doris MCP (Model Context Protocol) Server is a backend service built with Python
|
|||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Install the latest version
|
# Install the latest version
|
||||||
pip install doris-mcp-server
|
pip install mcp-doris-server
|
||||||
|
|
||||||
# Install specific version
|
# Install specific version
|
||||||
pip install doris-mcp-server==0.4.2
|
pip install mcp-doris-server==0.3
|
||||||
```
|
```
|
||||||
|
|
||||||
> **💡 Command Compatibility**: After installation, both `doris-mcp-server` commands are available for backward compatibility. You can use either command interchangeably.
|
> **💡 Command Compatibility**: After installation, both `doris-mcp-server` and `mcp-doris-server` commands are available for backward compatibility. You can use either command interchangeably.
|
||||||
|
|
||||||
### Start Streamable HTTP Mode (Web Service)
|
### Start Streamable HTTP Mode (Web Service)
|
||||||
|
|
||||||
The primary communication mode offering optimal performance and reliability:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Full configuration with database connection
|
# Full configuration with database connection
|
||||||
doris-mcp-server \
|
doris-mcp-server \
|
||||||
@@ -90,8 +88,6 @@ doris-mcp-server \
|
|||||||
|
|
||||||
### Start Stdio Mode (for Cursor and other MCP clients)
|
### Start Stdio Mode (for Cursor and other MCP clients)
|
||||||
|
|
||||||
Standard input/output mode for direct integration with MCP clients:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# For direct integration with MCP clients like Cursor
|
# For direct integration with MCP clients like Cursor
|
||||||
doris-mcp-server --transport stdio
|
doris-mcp-server --transport stdio
|
||||||
@@ -155,10 +151,10 @@ pip install -r requirements.txt
|
|||||||
|
|
||||||
### 3. Configure Environment Variables
|
### 3. Configure Environment Variables
|
||||||
|
|
||||||
Copy the `.env.example` file to `.env` and modify the settings according to your environment:
|
Copy the `env.example` file to `.env` and modify the settings according to your environment:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cp .env.example .env
|
cp env.example .env
|
||||||
```
|
```
|
||||||
|
|
||||||
**Key Environment Variables:**
|
**Key Environment Variables:**
|
||||||
@@ -168,23 +164,18 @@ cp .env.example .env
|
|||||||
* `DORIS_PORT`: Database port (default: 9030)
|
* `DORIS_PORT`: Database port (default: 9030)
|
||||||
* `DORIS_USER`: Database username (default: root)
|
* `DORIS_USER`: Database username (default: root)
|
||||||
* `DORIS_PASSWORD`: Database password
|
* `DORIS_PASSWORD`: Database password
|
||||||
* `DORIS_DATABASE`: Default database name (default: information_schema)
|
* `DORIS_DATABASE`: Default database name (default: test)
|
||||||
* `DORIS_MIN_CONNECTIONS`: Minimum connection pool size (default: 5)
|
* `DORIS_MIN_CONNECTIONS`: Minimum connection pool size (default: 5)
|
||||||
* `DORIS_MAX_CONNECTIONS`: Maximum connection pool size (default: 20)
|
* `DORIS_MAX_CONNECTIONS`: Maximum connection pool size (default: 20)
|
||||||
* `DORIS_BE_HOSTS`: BE nodes for monitoring (comma-separated, optional - auto-discovery via SHOW BACKENDS if empty)
|
|
||||||
* `DORIS_BE_WEBSERVER_PORT`: BE webserver port for monitoring tools (default: 8040)
|
|
||||||
* **Security Configuration**:
|
* **Security Configuration**:
|
||||||
* `AUTH_TYPE`: Authentication type (token/basic/oauth, default: token)
|
* `AUTH_TYPE`: Authentication type (token/basic/oauth, default: token)
|
||||||
* `TOKEN_SECRET`: Token secret key
|
* `TOKEN_SECRET`: Token secret key
|
||||||
* `ENABLE_SECURITY_CHECK`: Enable/disable SQL security validation (default: true, New in v0.4.2)
|
|
||||||
* `BLOCKED_KEYWORDS`: Comma-separated list of blocked SQL keywords (New in v0.4.2)
|
|
||||||
* `ENABLE_MASKING`: Enable data masking (default: true)
|
* `ENABLE_MASKING`: Enable data masking (default: true)
|
||||||
* `MAX_RESULT_ROWS`: Maximum result rows (default: 10000)
|
* `MAX_RESULT_ROWS`: Maximum result rows (default: 10000)
|
||||||
* **Performance Configuration**:
|
* **Performance Configuration**:
|
||||||
* `ENABLE_QUERY_CACHE`: Enable query caching (default: true)
|
* `ENABLE_QUERY_CACHE`: Enable query caching (default: true)
|
||||||
* `CACHE_TTL`: Cache time-to-live in seconds (default: 300)
|
* `CACHE_TTL`: Cache time-to-live in seconds (default: 300)
|
||||||
* `MAX_CONCURRENT_QUERIES`: Maximum concurrent queries (default: 50)
|
* `MAX_CONCURRENT_QUERIES`: Maximum concurrent queries (default: 50)
|
||||||
* `MAX_RESPONSE_CONTENT_SIZE`: Maximum response content size for LLM compatibility (default: 4096, New in v0.4.0)
|
|
||||||
* **Logging Configuration**:
|
* **Logging Configuration**:
|
||||||
* `LOG_LEVEL`: Log level (DEBUG/INFO/WARNING/ERROR, default: INFO)
|
* `LOG_LEVEL`: Log level (DEBUG/INFO/WARNING/ERROR, default: INFO)
|
||||||
* `LOG_FILE_PATH`: Log file path
|
* `LOG_FILE_PATH`: Log file path
|
||||||
@@ -194,26 +185,21 @@ cp .env.example .env
|
|||||||
|
|
||||||
The following table lists the main tools currently available for invocation via an MCP client:
|
The following table lists the main tools currently available for invocation via an MCP client:
|
||||||
|
|
||||||
| Tool Name | Description | Parameters |
|
| Tool Name | Description | Parameters | Status |
|
||||||
|-----------------------------|--------------------------------------------------------------|--------------------------------------------------------------|
|
|:----------------------------| :---------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------- | :------- |
|
||||||
| `exec_query` | Execute SQL query and return results. | `sql` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional), `max_rows` (integer, Optional), `timeout` (integer, Optional) |
|
| `exec_query` | Execute SQL query with catalog federation support. | `sql` (string, Required - MUST use three-part naming), `db_name` (string, Optional), `catalog_name` (string, Optional), `max_rows` (integer, Optional, default 100), `timeout` (integer, Optional, default 30) | ✅ Active |
|
||||||
| `get_table_schema` | Get detailed table structure information. | `table_name` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional) |
|
| `get_catalog_list` | Get a list of all catalogs with detailed information. | `random_string` (string, Required) | ✅ Active |
|
||||||
| `get_db_table_list` | Get list of all table names in specified database. | `db_name` (string, Optional), `catalog_name` (string, Optional) |
|
| `get_db_list` | Get a list of all database names in the specified catalog. | `catalog_name` (string, Optional, defaults to internal catalog) | ✅ Active |
|
||||||
| `get_db_list` | Get list of all database names. | `catalog_name` (string, Optional) |
|
| `get_db_table_list` | Get a list of all table names in the specified database. | `db_name` (string, Optional), `catalog_name` (string, Optional) | ✅ Active |
|
||||||
| `get_table_comment` | Get table comment information. | `table_name` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional) |
|
| `get_table_schema` | Get detailed structure of the specified table. | `table_name` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional) | ✅ Active |
|
||||||
| `get_table_column_comments` | Get comment information for all columns in table. | `table_name` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional) |
|
| `get_table_comment` | Get the comment for the specified table. | `table_name` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional) | ✅ Active |
|
||||||
| `get_table_indexes` | Get index information for specified table. | `table_name` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional) |
|
| `get_table_column_comments` | Get comments for all columns in the specified table. | `table_name` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional) | ✅ Active |
|
||||||
| `get_recent_audit_logs` | Get audit log records for recent period. | `days` (integer, Optional), `limit` (integer, Optional) |
|
| `get_table_indexes` | Get index information for the specified table. | `table_name` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional) | ✅ Active |
|
||||||
| `get_catalog_list` | Get list of all catalog names. | `random_string` (string, Required) |
|
| `get_recent_audit_logs` | Get audit log records for a recent period. | `days` (integer, Optional, default 7), `limit` (integer, Optional, default 100) | ✅ Active |
|
||||||
| `get_sql_explain` | Get SQL execution plan with configurable content truncation and file export for LLM analysis. | `sql` (string, Required), `verbose` (boolean, Optional), `db_name` (string, Optional), `catalog_name` (string, Optional) |
|
| `column_analysis` | Analyze statistical information and data distribution. | `table_name` (string, Required), `column_name` (string, Required), `analysis_type` (string, Optional: basic/distribution/detailed) | ⚠️ Experimental |
|
||||||
| `get_sql_profile` | Get SQL execution profile with content management and file export for LLM optimization workflows. | `sql` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional), `timeout` (integer, Optional) |
|
| `performance_stats` | Get database performance statistics information. | `metric_type` (string, Optional: queries/connections/tables/system), `time_range` (string, Optional: 1h/6h/24h/7d) | ⚠️ Experimental |
|
||||||
| `get_table_data_size` | Get table data size information via FE HTTP API. | `db_name` (string, Optional), `table_name` (string, Optional), `single_replica` (boolean, Optional) |
|
|
||||||
| `get_monitoring_metrics_info` | Get Doris monitoring metrics definitions and descriptions. | `role` (string, Optional), `monitor_type` (string, Optional), `priority` (string, Optional) |
|
|
||||||
| `get_monitoring_metrics_data` | Get actual Doris monitoring metrics data from nodes with flexible BE discovery. | `role` (string, Optional), `monitor_type` (string, Optional), `priority` (string, Optional) |
|
|
||||||
| `get_realtime_memory_stats` | Get real-time memory statistics via BE Memory Tracker with auto/manual BE discovery. | `tracker_type` (string, Optional), `include_details` (boolean, Optional) |
|
|
||||||
| `get_historical_memory_stats` | Get historical memory statistics via BE Bvar interface with flexible BE configuration. | `tracker_names` (array, Optional), `time_range` (string, Optional) |
|
|
||||||
|
|
||||||
**Note:** All metadata tools support catalog federation for multi-catalog environments. The `get_catalog_list` tool requires a `random_string` parameter for compatibility reasons. Enhanced monitoring tools in v0.4.0 provide comprehensive memory tracking and metrics collection capabilities with flexible BE node discovery.
|
**Note:** All metadata tools support catalog federation for multi-catalog environments. The `get_catalog_list` tool requires a `random_string` parameter for compatibility reasons.
|
||||||
|
|
||||||
### 4. Run the Service
|
### 4. Run the Service
|
||||||
|
|
||||||
@@ -225,18 +211,19 @@ Execute the following command to start the server:
|
|||||||
|
|
||||||
This command starts the FastAPI application with Streamable HTTP MCP service.
|
This command starts the FastAPI application with Streamable HTTP MCP service.
|
||||||
|
|
||||||
**Service Endpoints:**
|
**Service Endpoints (v0.3.0+):**
|
||||||
|
|
||||||
* **Streamable HTTP**: `http://<host>:<port>/mcp` (Primary MCP endpoint - supports GET, POST, DELETE, OPTIONS)
|
* **Streamable HTTP**: `http://<host>:<port>/mcp` (Primary MCP endpoint - supports GET, POST, DELETE, OPTIONS)
|
||||||
* **Health Check**: `http://<host>:<port>/health`
|
* **Health Check**: `http://<host>:<port>/health`
|
||||||
|
* **Status Check**: `http://<host>:<port>/status`
|
||||||
|
|
||||||
> **Note**: The server uses Streamable HTTP for web-based communication, providing unified request/response and streaming capabilities.
|
> **Note**: Starting from v0.3.0, only Streamable HTTP mode is supported for web-based communication. SSE endpoints have been removed.
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
Interaction with the Doris MCP Server requires an **MCP Client**. The client connects to the server's Streamable HTTP endpoint and sends requests according to the MCP specification to invoke the server's tools.
|
Interaction with the Doris MCP Server requires an **MCP Client**. The client connects to the server's Streamable HTTP endpoint and sends requests according to the MCP specification to invoke the server's tools.
|
||||||
|
|
||||||
**Main Interaction Flow:**
|
**Main Interaction Flow (v0.3.0+):**
|
||||||
|
|
||||||
1. **Client Initialization**: Send an `initialize` method call to `/mcp` (Streamable HTTP).
|
1. **Client Initialization**: Send an `initialize` method call to `/mcp` (Streamable HTTP).
|
||||||
2. **(Optional) Discover Tools**: The client can call `tools/list` to get the list of supported tools, their descriptions, and parameter schemas.
|
2. **(Optional) Discover Tools**: The client can call `tools/list` to get the list of supported tools, their descriptions, and parameter schemas.
|
||||||
@@ -248,6 +235,8 @@ Interaction with the Doris MCP Server requires an **MCP Client**. The client con
|
|||||||
* **Non-streaming**: The client receives a response containing `content` or `isError`.
|
* **Non-streaming**: The client receives a response containing `content` or `isError`.
|
||||||
* **Streaming**: The client receives a series of progress notifications, followed by a final response.
|
* **Streaming**: The client receives a series of progress notifications, followed by a final response.
|
||||||
|
|
||||||
|
> **Migration Note**: If you're upgrading from v0.2.x, note that tool names have been simplified (removed `mcp_doris_` prefix) and the communication protocol has been updated to use Streamable HTTP exclusively.
|
||||||
|
|
||||||
### Catalog Federation Support
|
### Catalog Federation Support
|
||||||
|
|
||||||
The Doris MCP Server supports **catalog federation**, enabling interaction with multiple data catalogs (internal Doris tables and external data sources like Hive, MySQL, etc.) within a unified interface.
|
The Doris MCP Server supports **catalog federation**, enabling interaction with multiple data catalogs (internal Doris tables and external data sources like Hive, MySQL, etc.) within a unified interface.
|
||||||
@@ -316,7 +305,7 @@ The Doris MCP Server supports **catalog federation**, enabling interaction with
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## Security Configuration
|
## Security Configuration (v0.3.0+)
|
||||||
|
|
||||||
The Doris MCP Server includes a comprehensive security framework that provides enterprise-level protection through authentication, authorization, SQL security validation, and data masking capabilities.
|
The Doris MCP Server includes a comprehensive security framework that provides enterprise-level protection through authentication, authorization, SQL security validation, and data masking capabilities.
|
||||||
|
|
||||||
@@ -408,25 +397,16 @@ The system automatically validates SQL queries for security risks:
|
|||||||
|
|
||||||
#### Blocked Operations
|
#### Blocked Operations
|
||||||
|
|
||||||
Configure blocked SQL operations using environment variables (New in v0.4.2):
|
Configure blocked SQL operations:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Enable/disable SQL security check (New in v0.4.2)
|
# Environment variable
|
||||||
ENABLE_SECURITY_CHECK=true
|
BLOCKED_SQL_OPERATIONS=DROP,DELETE,TRUNCATE,ALTER,CREATE,INSERT,UPDATE,GRANT,REVOKE
|
||||||
|
|
||||||
# Customize blocked keywords via environment variable (New in v0.4.2)
|
|
||||||
BLOCKED_KEYWORDS="DROP,DELETE,TRUNCATE,ALTER,CREATE,INSERT,UPDATE,GRANT,REVOKE,EXEC,EXECUTE,SHUTDOWN,KILL"
|
|
||||||
|
|
||||||
# Maximum query complexity score
|
# Maximum query complexity score
|
||||||
MAX_QUERY_COMPLEXITY=100
|
MAX_QUERY_COMPLEXITY=100
|
||||||
```
|
```
|
||||||
|
|
||||||
**Default Blocked Keywords (Unified in v0.4.2):**
|
|
||||||
- **DDL Operations**: DROP, CREATE, ALTER, TRUNCATE
|
|
||||||
- **DML Operations**: DELETE, INSERT, UPDATE
|
|
||||||
- **DCL Operations**: GRANT, REVOKE
|
|
||||||
- **System Operations**: EXEC, EXECUTE, SHUTDOWN, KILL
|
|
||||||
|
|
||||||
#### SQL Injection Protection
|
#### SQL Injection Protection
|
||||||
|
|
||||||
The system automatically detects and blocks:
|
The system automatically detects and blocks:
|
||||||
@@ -632,7 +612,7 @@ uv run --project /path/to/doris-mcp-server doris-mcp-server
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
### Streamable HTTP Mode
|
### Streamable HTTP Mode (v0.3.0+)
|
||||||
|
|
||||||
Streamable HTTP mode requires you to run the MCP server independently first, and then configure Cursor to connect to it.
|
Streamable HTTP mode requires you to run the MCP server independently first, and then configure Cursor to connect to it.
|
||||||
|
|
||||||
@@ -654,10 +634,12 @@ Streamable HTTP mode requires you to run the MCP server independently first, and
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
> **Note**: Adjust the host/port if your server runs on a different address. The `/mcp` endpoint is the unified Streamable HTTP interface.
|
> **Note**: Adjust the host/port if your server runs on a different address. The `/mcp` endpoint is the unified Streamable HTTP interface introduced in v0.3.0.
|
||||||
|
|
||||||
After configuring either mode in Cursor, you should be able to select the server (e.g., `doris-stdio` or `doris-http`) and use its tools.
|
After configuring either mode in Cursor, you should be able to select the server (e.g., `doris-stdio` or `doris-http`) and use its tools.
|
||||||
|
|
||||||
|
> **⚠️ Migration from v0.2.x**: If you were using SSE mode (`/sse` endpoint), update your configuration to use the new Streamable HTTP endpoint (`/mcp`).
|
||||||
|
|
||||||
## Directory Structure
|
## Directory Structure
|
||||||
|
|
||||||
```
|
```
|
||||||
@@ -696,22 +678,22 @@ doris-mcp-server/
|
|||||||
|
|
||||||
## Developing New Tools
|
## Developing New Tools
|
||||||
|
|
||||||
This section outlines the process for adding new MCP tools to the Doris MCP Server, based on the unified modular architecture with centralized tool management.
|
This section outlines the process for adding new MCP tools to the Doris MCP Server, based on the current modular architecture.
|
||||||
|
|
||||||
### 1. Leverage Existing Utility Modules
|
### 1. Leverage Existing Utility Modules
|
||||||
|
|
||||||
The server provides comprehensive utility modules for common database operations:
|
The server provides comprehensive utility modules for common database operations:
|
||||||
|
|
||||||
* **`doris_mcp_server/utils/db.py`**: Database connection management with connection pooling and health monitoring.
|
* **`doris_mcp_server/utils/db.py`**: Database connection management with connection pooling and health monitoring.
|
||||||
* **`doris_mcp_server/utils/query_executor.py`**: High-performance SQL execution with advanced caching, optimization, and performance monitoring.
|
* **`doris_mcp_server/utils/query_executor.py`**: High-performance SQL execution with caching, optimization, and performance monitoring.
|
||||||
* **`doris_mcp_server/utils/schema_extractor.py`**: Metadata extraction with full catalog federation support.
|
* **`doris_mcp_server/utils/schema_extractor.py`**: Metadata extraction with full catalog federation support.
|
||||||
* **`doris_mcp_server/utils/security.py`**: Comprehensive security management, SQL validation, and data masking.
|
* **`doris_mcp_server/utils/security.py`**: Security management, SQL validation, and data masking.
|
||||||
* **`doris_mcp_server/utils/analysis_tools.py`**: Advanced data analysis and statistical tools.
|
* **`doris_mcp_server/utils/analysis_tools.py`**: Data analysis and statistical tools.
|
||||||
* **`doris_mcp_server/utils/config.py`**: Configuration management with validation.
|
* **`doris_mcp_server/utils/config.py`**: Configuration management with validation.
|
||||||
|
|
||||||
### 2. Implement Tool Logic
|
### 2. Implement Tool Logic
|
||||||
|
|
||||||
Add your new tool to the `DorisToolsManager` class in `doris_mcp_server/tools/tools_manager.py`. The tools manager provides a centralized approach to tool registration and execution with unified interfaces.
|
Add your new tool to the `DorisToolsManager` class in `doris_mcp_server/tools/tools_manager.py`. The tools manager provides a centralized approach to tool registration and execution.
|
||||||
|
|
||||||
**Example:** Adding a new analysis tool:
|
**Example:** Adding a new analysis tool:
|
||||||
|
|
||||||
@@ -780,13 +762,12 @@ async def your_new_analysis_tool_wrapper(arguments: Dict[str, Any]) -> List[Dict
|
|||||||
|
|
||||||
### 4. Advanced Features
|
### 4. Advanced Features
|
||||||
|
|
||||||
For more complex tools, you can leverage the comprehensive framework:
|
For more complex tools, you can leverage:
|
||||||
|
|
||||||
* **Advanced Caching**: Use the query executor's built-in caching for enhanced performance
|
* **Caching**: Use the query executor's built-in caching for performance
|
||||||
* **Enterprise Security**: Apply comprehensive SQL validation and data masking through the security manager
|
* **Security**: Apply SQL validation and data masking through the security manager
|
||||||
* **Intelligent Prompts**: Use the prompts manager for advanced query generation
|
* **Prompts**: Use the prompts manager for intelligent query generation
|
||||||
* **Resource Management**: Expose metadata through the resources manager
|
* **Resources**: Expose metadata through the resources manager
|
||||||
* **Performance Monitoring**: Integrate with the analysis tools for monitoring capabilities
|
|
||||||
|
|
||||||
### 5. Testing
|
### 5. Testing
|
||||||
|
|
||||||
@@ -817,242 +798,4 @@ Contributions are welcome via Issues or Pull Requests.
|
|||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
This project is licensed under the Apache 2.0 License. See the LICENSE file for details.
|
This project is licensed under the Apache 2.0 License. See the LICENSE file (if it exists) for details.
|
||||||
|
|
||||||
## FAQ
|
|
||||||
|
|
||||||
### Q: Why do Qwen3-32b and other small parameter models always fail when calling tools?
|
|
||||||
|
|
||||||
**A:** This is a common issue. The main reason is that these models need more explicit guidance to correctly use MCP tools. It's recommended to add the following instruction prompt for the model:
|
|
||||||
|
|
||||||
- Chinese version:
|
|
||||||
|
|
||||||
```xml
|
|
||||||
<instruction>
|
|
||||||
尽可能使用MCP工具完成任务,仔细阅读每个工具的注解、方法名、参数说明等内容。请按照以下步骤操作:
|
|
||||||
|
|
||||||
1. 仔细分析用户的问题,从已有的Tools列表中匹配最合适的工具。
|
|
||||||
2. 确保工具名称、方法名和参数完全按照工具注释中的定义使用,不要自行创造工具名称或参数。
|
|
||||||
3. 传入参数时,严格遵循工具注释中规定的参数格式和要求。
|
|
||||||
4. 调用工具时,根据需要直接调用工具,但参数请求参考以下请求格式:{"mcp_sse_call_tool": {"tool_name": "$tools_name", "arguments": "{}"}}
|
|
||||||
5. 输出结果时,不要包含任何XML标签,仅返回纯文本内容。
|
|
||||||
|
|
||||||
<input>
|
|
||||||
用户问题:user_query
|
|
||||||
</input>
|
|
||||||
|
|
||||||
<output>
|
|
||||||
返回工具调用结果或最终答案,以及对结果的分析。
|
|
||||||
</output>
|
|
||||||
</instruction>
|
|
||||||
```
|
|
||||||
- English version:
|
|
||||||
|
|
||||||
```xml
|
|
||||||
<instruction>
|
|
||||||
Use MCP tools to complete tasks as much as possible. Carefully read the annotations, method names, and parameter descriptions of each tool. Please follow these steps:
|
|
||||||
|
|
||||||
1. Carefully analyze the user's question and match the most appropriate tool from the existing Tools list.
|
|
||||||
2. Ensure tool names, method names, and parameters are used exactly as defined in the tool annotations. Do not create tool names or parameters on your own.
|
|
||||||
3. When passing parameters, strictly follow the parameter format and requirements specified in the tool annotations.
|
|
||||||
4. When calling tools, call them directly as needed, but refer to the following request format for parameters: {"mcp_sse_call_tool": {"tool_name": "$tools_name", "arguments": "{}"}}
|
|
||||||
5. When outputting results, do not include any XML tags, return plain text content only.
|
|
||||||
|
|
||||||
<input>
|
|
||||||
User question: user_query
|
|
||||||
</input>
|
|
||||||
|
|
||||||
<output>
|
|
||||||
Return tool call results or final answer, along with analysis of the results.
|
|
||||||
</output>
|
|
||||||
</instruction>
|
|
||||||
```
|
|
||||||
|
|
||||||
If you have further requirements for the returned results, you can describe the specific requirements in the `<output>` tag.
|
|
||||||
|
|
||||||
### Q: How to configure different database connections?
|
|
||||||
|
|
||||||
**A:** You can configure database connections in several ways:
|
|
||||||
|
|
||||||
1. **Environment Variables** (Recommended):
|
|
||||||
```bash
|
|
||||||
export DORIS_HOST="your_doris_host"
|
|
||||||
export DORIS_PORT="9030"
|
|
||||||
export DORIS_USER="root"
|
|
||||||
export DORIS_PASSWORD="your_password"
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **Command Line Arguments**:
|
|
||||||
```bash
|
|
||||||
doris-mcp-server --db-host your_host --db-port 9030 --db-user root --db-password your_password
|
|
||||||
```
|
|
||||||
|
|
||||||
3. **Configuration File**:
|
|
||||||
Modify the corresponding configuration items in the `.env` file.
|
|
||||||
|
|
||||||
### Q: How to configure BE nodes for monitoring tools?
|
|
||||||
|
|
||||||
**A:** Choose the appropriate configuration based on your deployment scenario:
|
|
||||||
|
|
||||||
**External Network (Manual Configuration):**
|
|
||||||
```bash
|
|
||||||
# Manually specify BE node addresses
|
|
||||||
DORIS_BE_HOSTS=10.1.1.100,10.1.1.101,10.1.1.102
|
|
||||||
DORIS_BE_WEBSERVER_PORT=8040
|
|
||||||
```
|
|
||||||
|
|
||||||
**Internal Network (Automatic Discovery):**
|
|
||||||
```bash
|
|
||||||
# Leave BE_HOSTS empty for auto-discovery
|
|
||||||
# DORIS_BE_HOSTS= # Not set or empty
|
|
||||||
# System will use 'SHOW BACKENDS' command to get internal IPs
|
|
||||||
```
|
|
||||||
|
|
||||||
### Q: How to use SQL Explain/Profile files with LLM for optimization?
|
|
||||||
|
|
||||||
**A:** The tools provide both truncated content and complete files for LLM analysis:
|
|
||||||
|
|
||||||
1. **Get Analysis Results:**
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"content": "Truncated plan for immediate review",
|
|
||||||
"file_path": "/tmp/explain_12345.txt",
|
|
||||||
"is_content_truncated": true
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **LLM Analysis Workflow:**
|
|
||||||
- Review truncated content for quick insights
|
|
||||||
- Upload the complete file to your LLM as an attachment
|
|
||||||
- Request optimization suggestions or performance analysis
|
|
||||||
- Implement recommended improvements
|
|
||||||
|
|
||||||
3. **Configure Content Size:**
|
|
||||||
```bash
|
|
||||||
MAX_RESPONSE_CONTENT_SIZE=4096 # Adjust as needed
|
|
||||||
```
|
|
||||||
|
|
||||||
### Q: How to enable data security and masking features?
|
|
||||||
|
|
||||||
**A:** Set the following configurations in your `.env` file:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Enable data masking
|
|
||||||
ENABLE_MASKING=true
|
|
||||||
# Set authentication type
|
|
||||||
AUTH_TYPE=token
|
|
||||||
# Configure token secret
|
|
||||||
TOKEN_SECRET=your_secret_key
|
|
||||||
# Set maximum result rows
|
|
||||||
MAX_RESULT_ROWS=10000
|
|
||||||
```
|
|
||||||
|
|
||||||
### Q: What's the difference between Stdio mode and HTTP mode?
|
|
||||||
|
|
||||||
**A:**
|
|
||||||
|
|
||||||
- **Stdio Mode**: Suitable for direct integration with MCP clients (like Cursor), where the client manages the server process
|
|
||||||
- **HTTP Mode**: Independent web service that supports multiple client connections, suitable for production environments
|
|
||||||
|
|
||||||
Recommendations:
|
|
||||||
- Development and personal use: Stdio mode
|
|
||||||
- Production and multi-user environments: HTTP mode
|
|
||||||
|
|
||||||
### Q: How to resolve connection timeout issues?
|
|
||||||
|
|
||||||
**A:** Try the following solutions:
|
|
||||||
|
|
||||||
1. **Increase timeout settings**:
|
|
||||||
```bash
|
|
||||||
# Set in .env file
|
|
||||||
QUERY_TIMEOUT=60
|
|
||||||
CONNECTION_TIMEOUT=30
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **Check network connectivity**:
|
|
||||||
```bash
|
|
||||||
# Test database connection
|
|
||||||
curl http://localhost:3000/health
|
|
||||||
```
|
|
||||||
|
|
||||||
3. **Optimize connection pool configuration**:
|
|
||||||
```bash
|
|
||||||
DORIS_MIN_CONNECTIONS=5
|
|
||||||
DORIS_MAX_CONNECTIONS=20
|
|
||||||
```
|
|
||||||
|
|
||||||
### Q: How to resolve `at_eof` connection errors? (Fixed in v0.4.2)
|
|
||||||
|
|
||||||
**A:** Version 0.4.2 has resolved the critical `at_eof` connection errors. The improvements include:
|
|
||||||
|
|
||||||
1. **Enhanced Connection Health Monitoring**: Strict connection state validation before operations
|
|
||||||
2. **Automatic Retry Mechanism**: Failed queries are automatically retried up to 2 times
|
|
||||||
3. **Proactive Connection Cleanup**: Automatic detection and cleanup of problematic connections
|
|
||||||
4. **Connection Diagnostics**: Comprehensive connection health analysis and reporting
|
|
||||||
|
|
||||||
If you still encounter connection issues after upgrading to v0.4.2:
|
|
||||||
```bash
|
|
||||||
# Check connection diagnostics
|
|
||||||
# The system now automatically handles connection recovery
|
|
||||||
# Monitor logs for connection health reports
|
|
||||||
tail -f logs/doris_mcp_server.log | grep "connection"
|
|
||||||
```
|
|
||||||
|
|
||||||
### Q: How to resolve MCP library version compatibility issues? (Fixed in v0.4.2)
|
|
||||||
|
|
||||||
**A:** Version 0.4.2 introduced an intelligent MCP compatibility layer that supports both MCP 1.8.x and 1.9.x versions:
|
|
||||||
|
|
||||||
**The Problem:**
|
|
||||||
- MCP 1.9.3 introduced breaking changes to the `RequestContext` class (changed from 2 to 3 generic parameters)
|
|
||||||
- This caused `TypeError: Too few arguments for RequestContext` errors
|
|
||||||
|
|
||||||
**The Solution (v0.4.2):**
|
|
||||||
- **Intelligent Version Detection**: Automatically detects the installed MCP version
|
|
||||||
- **Compatibility Layer**: Gracefully handles API differences between versions
|
|
||||||
- **Flexible Version Support**: `mcp>=1.8.0,<2.0.0` in dependencies
|
|
||||||
|
|
||||||
**Supported MCP Versions:**
|
|
||||||
```bash
|
|
||||||
# Both versions now work seamlessly
|
|
||||||
pip install mcp==1.8.0 # Stable version (recommended)
|
|
||||||
pip install mcp==1.9.3 # Latest version with new features
|
|
||||||
```
|
|
||||||
|
|
||||||
**Version Information:**
|
|
||||||
```bash
|
|
||||||
# Check which MCP version is being used
|
|
||||||
doris-mcp-server --transport stdio
|
|
||||||
# The server will log: "Using MCP version: x.x.x"
|
|
||||||
```
|
|
||||||
|
|
||||||
If you encounter MCP-related startup errors:
|
|
||||||
```bash
|
|
||||||
# Recommended: Use stable version
|
|
||||||
pip uninstall mcp
|
|
||||||
pip install mcp==1.8.0
|
|
||||||
|
|
||||||
# Or upgrade to latest compatible version
|
|
||||||
pip install --upgrade mcp-doris-server==0.4.2
|
|
||||||
```
|
|
||||||
|
|
||||||
### Q: How to view server logs?
|
|
||||||
|
|
||||||
**A:** Log files are located in the `logs/` directory. You can:
|
|
||||||
|
|
||||||
1. **View real-time logs**:
|
|
||||||
```bash
|
|
||||||
tail -f logs/doris_mcp_server.log
|
|
||||||
```
|
|
||||||
|
|
||||||
2. **Adjust log level**:
|
|
||||||
```bash
|
|
||||||
# Set in .env file
|
|
||||||
LOG_LEVEL=DEBUG
|
|
||||||
```
|
|
||||||
|
|
||||||
3. **Enable audit logging**:
|
|
||||||
```bash
|
|
||||||
ENABLE_AUDIT=true
|
|
||||||
```
|
|
||||||
|
|
||||||
For other issues, please check GitHub Issues or submit a new issue.
|
|
||||||
|
|||||||
@@ -1,218 +0,0 @@
|
|||||||
# Licensed to the Apache Software Foundation (ASF) under one
|
|
||||||
# or more contributor license agreements. See the NOTICE file
|
|
||||||
# distributed with this work for additional information
|
|
||||||
# regarding copyright ownership. The ASF licenses this file
|
|
||||||
# to you under the Apache License, Version 2.0 (the
|
|
||||||
# "License"); you may not use this file except in compliance
|
|
||||||
# with the License. You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing,
|
|
||||||
# software distributed under the License is distributed on an
|
|
||||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
# KIND, either express or implied. See the License for the
|
|
||||||
# specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
version: '3.8'
|
|
||||||
|
|
||||||
services:
|
|
||||||
# Doris MCP Server
|
|
||||||
doris-mcp-server:
|
|
||||||
build:
|
|
||||||
context: .
|
|
||||||
dockerfile: Dockerfile
|
|
||||||
container_name: doris-mcp-server
|
|
||||||
ports:
|
|
||||||
- "3000:3000" # MCP service port
|
|
||||||
- "3001:3001" # Monitoring metrics port
|
|
||||||
- "3002:3002" # Health check port
|
|
||||||
environment:
|
|
||||||
# Database configuration
|
|
||||||
- DORIS_HOST=doris-fe
|
|
||||||
- DORIS_PORT=9030
|
|
||||||
- DORIS_USER=root
|
|
||||||
- DORIS_PASSWORD=doris123
|
|
||||||
- DORIS_DATABASE=test_db
|
|
||||||
|
|
||||||
# Connection pool configuration
|
|
||||||
- DORIS_MIN_CONNECTIONS=5
|
|
||||||
- DORIS_MAX_CONNECTIONS=20
|
|
||||||
|
|
||||||
# Security configuration
|
|
||||||
- AUTH_TYPE=token
|
|
||||||
- TOKEN_SECRET=your_secret_key_here
|
|
||||||
- MAX_RESULT_ROWS=10000
|
|
||||||
|
|
||||||
# Performance configuration
|
|
||||||
- ENABLE_QUERY_CACHE=true
|
|
||||||
- MAX_CONCURRENT_QUERIES=50
|
|
||||||
|
|
||||||
# Logging configuration
|
|
||||||
- LOG_LEVEL=INFO
|
|
||||||
- LOG_FILE_PATH=/app/logs/doris-mcp-server.log
|
|
||||||
|
|
||||||
# Monitoring configuration
|
|
||||||
- ENABLE_METRICS=true
|
|
||||||
- METRICS_PORT=8081
|
|
||||||
volumes:
|
|
||||||
- ./logs:/app/logs
|
|
||||||
- ./config:/app/config
|
|
||||||
depends_on:
|
|
||||||
- doris-fe
|
|
||||||
- doris-be
|
|
||||||
networks:
|
|
||||||
- doris-network
|
|
||||||
restart: unless-stopped
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD", "curl", "-f", "http://localhost:8082/health"]
|
|
||||||
interval: 30s
|
|
||||||
timeout: 10s
|
|
||||||
retries: 3
|
|
||||||
start_period: 40s
|
|
||||||
|
|
||||||
# Apache Doris Frontend
|
|
||||||
doris-fe:
|
|
||||||
image: apache/doris:2.0.3-fe-x86_64
|
|
||||||
container_name: doris-fe
|
|
||||||
ports:
|
|
||||||
- "8030:8030" # FE HTTP port
|
|
||||||
- "9030:9030" # FE MySQL port
|
|
||||||
environment:
|
|
||||||
- FE_SERVERS=fe1:doris-fe:9010
|
|
||||||
- FE_ID=1
|
|
||||||
volumes:
|
|
||||||
- doris-fe-data:/opt/apache-doris/fe/doris-meta
|
|
||||||
- doris-fe-log:/opt/apache-doris/fe/log
|
|
||||||
- ./doris-config/fe.conf:/opt/apache-doris/fe/conf/fe.conf
|
|
||||||
networks:
|
|
||||||
- doris-network
|
|
||||||
restart: unless-stopped
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD", "curl", "-f", "http://localhost:8030/api/health"]
|
|
||||||
interval: 30s
|
|
||||||
timeout: 10s
|
|
||||||
retries: 5
|
|
||||||
start_period: 60s
|
|
||||||
|
|
||||||
# Apache Doris Backend
|
|
||||||
doris-be:
|
|
||||||
image: apache/doris:2.0.3-be-x86_64
|
|
||||||
container_name: doris-be
|
|
||||||
ports:
|
|
||||||
- "8040:8040" # BE HTTP port
|
|
||||||
- "9060:9060" # BE heartbeat port
|
|
||||||
environment:
|
|
||||||
- FE_SERVERS=doris-fe:9010
|
|
||||||
- BE_ADDR=doris-be:9050
|
|
||||||
volumes:
|
|
||||||
- doris-be-data:/opt/apache-doris/be/storage
|
|
||||||
- doris-be-log:/opt/apache-doris/be/log
|
|
||||||
- ./doris-config/be.conf:/opt/apache-doris/be/conf/be.conf
|
|
||||||
depends_on:
|
|
||||||
- doris-fe
|
|
||||||
networks:
|
|
||||||
- doris-network
|
|
||||||
restart: unless-stopped
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD", "curl", "-f", "http://localhost:8040/api/health"]
|
|
||||||
interval: 30s
|
|
||||||
timeout: 10s
|
|
||||||
retries: 5
|
|
||||||
start_period: 60s
|
|
||||||
|
|
||||||
# Redis cache (optional)
|
|
||||||
redis:
|
|
||||||
image: redis:7-alpine
|
|
||||||
container_name: doris-redis
|
|
||||||
ports:
|
|
||||||
- "6379:6379"
|
|
||||||
command: redis-server --appendonly yes --requirepass redis123
|
|
||||||
volumes:
|
|
||||||
- redis-data:/data
|
|
||||||
networks:
|
|
||||||
- doris-network
|
|
||||||
restart: unless-stopped
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD", "redis-cli", "--raw", "incr", "ping"]
|
|
||||||
interval: 30s
|
|
||||||
timeout: 10s
|
|
||||||
retries: 3
|
|
||||||
|
|
||||||
# Prometheus monitoring
|
|
||||||
prometheus:
|
|
||||||
image: prom/prometheus:latest
|
|
||||||
container_name: doris-prometheus
|
|
||||||
ports:
|
|
||||||
- "9090:9090"
|
|
||||||
volumes:
|
|
||||||
- ./monitoring/prometheus.yml:/etc/prometheus/prometheus.yml
|
|
||||||
- prometheus-data:/prometheus
|
|
||||||
command:
|
|
||||||
- '--config.file=/etc/prometheus/prometheus.yml'
|
|
||||||
- '--storage.tsdb.path=/prometheus'
|
|
||||||
- '--web.console.libraries=/etc/prometheus/console_libraries'
|
|
||||||
- '--web.console.templates=/etc/prometheus/consoles'
|
|
||||||
- '--storage.tsdb.retention.time=200h'
|
|
||||||
- '--web.enable-lifecycle'
|
|
||||||
networks:
|
|
||||||
- doris-network
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
# Grafana visualization
|
|
||||||
grafana:
|
|
||||||
image: grafana/grafana:latest
|
|
||||||
container_name: doris-grafana
|
|
||||||
ports:
|
|
||||||
- "3000:3000"
|
|
||||||
environment:
|
|
||||||
- GF_SECURITY_ADMIN_PASSWORD=admin123
|
|
||||||
volumes:
|
|
||||||
- grafana-data:/var/lib/grafana
|
|
||||||
- ./monitoring/grafana/dashboards:/etc/grafana/provisioning/dashboards
|
|
||||||
- ./monitoring/grafana/datasources:/etc/grafana/provisioning/datasources
|
|
||||||
depends_on:
|
|
||||||
- prometheus
|
|
||||||
networks:
|
|
||||||
- doris-network
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
# Nginx load balancer
|
|
||||||
nginx:
|
|
||||||
image: nginx:alpine
|
|
||||||
container_name: doris-nginx
|
|
||||||
ports:
|
|
||||||
- "80:80"
|
|
||||||
- "443:443"
|
|
||||||
volumes:
|
|
||||||
- ./nginx/nginx.conf:/etc/nginx/nginx.conf
|
|
||||||
- ./nginx/ssl:/etc/nginx/ssl
|
|
||||||
- ./nginx/logs:/var/log/nginx
|
|
||||||
depends_on:
|
|
||||||
- doris-mcp-server
|
|
||||||
networks:
|
|
||||||
- doris-network
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
volumes:
|
|
||||||
doris-fe-data:
|
|
||||||
driver: local
|
|
||||||
doris-fe-log:
|
|
||||||
driver: local
|
|
||||||
doris-be-data:
|
|
||||||
driver: local
|
|
||||||
doris-be-log:
|
|
||||||
driver: local
|
|
||||||
redis-data:
|
|
||||||
driver: local
|
|
||||||
prometheus-data:
|
|
||||||
driver: local
|
|
||||||
grafana-data:
|
|
||||||
driver: local
|
|
||||||
|
|
||||||
networks:
|
|
||||||
doris-network:
|
|
||||||
driver: bridge
|
|
||||||
ipam:
|
|
||||||
config:
|
|
||||||
- subnet: 172.20.0.0/16
|
|
||||||
@@ -1,328 +0,0 @@
|
|||||||
<!--
|
|
||||||
Licensed to the Apache Software Foundation (ASF) under one
|
|
||||||
or more contributor license agreements. See the NOTICE file
|
|
||||||
distributed with this work for additional information
|
|
||||||
regarding copyright ownership. The ASF licenses this file
|
|
||||||
to you under the Apache License, Version 2.0 (the
|
|
||||||
"License"); you may not use this file except in compliance
|
|
||||||
with the License. You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing,
|
|
||||||
software distributed under the License is distributed on an
|
|
||||||
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
KIND, either express or implied. See the License for the
|
|
||||||
specific language governing permissions and limitations
|
|
||||||
under the License.
|
|
||||||
-->
|
|
||||||
# Doris Unified MCP Client
|
|
||||||
|
|
||||||
This is a unified Doris MCP client that supports both **stdio** and **Streamable HTTP** transport modes, providing complete MCP protocol support.
|
|
||||||
|
|
||||||
## 🚀 Features
|
|
||||||
|
|
||||||
- ✅ **Dual Mode Support**: Both stdio and HTTP transport methods
|
|
||||||
- ✅ **Complete MCP Support**: Resources, Tools, and Prompts primitives
|
|
||||||
- ✅ **Unified API**: Same interface for different transport modes
|
|
||||||
- ✅ **Asynchronous Design**: High-performance async client based on asyncio
|
|
||||||
- ✅ **Enterprise Features**: Connection pooling, error handling, logging
|
|
||||||
- ✅ **Convenience Methods**: High-level wrappers for common database operations
|
|
||||||
|
|
||||||
## 📦 Install Dependencies
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install mcp
|
|
||||||
```
|
|
||||||
|
|
||||||
## 🎯 Quick Start
|
|
||||||
|
|
||||||
### 1. stdio Mode
|
|
||||||
|
|
||||||
```python
|
|
||||||
import asyncio
|
|
||||||
from client import create_stdio_client
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
# Create stdio client
|
|
||||||
client = await create_stdio_client(
|
|
||||||
"python",
|
|
||||||
["-m", "doris_mcp_server.main", "--transport", "stdio"]
|
|
||||||
)
|
|
||||||
|
|
||||||
async def test_client(client):
|
|
||||||
# Get database list
|
|
||||||
db_result = await client.get_database_list()
|
|
||||||
print(f"Databases: {db_result}")
|
|
||||||
|
|
||||||
# Execute SQL query
|
|
||||||
query_result = await client.execute_sql("SELECT 1 as test")
|
|
||||||
print(f"Query result: {query_result}")
|
|
||||||
|
|
||||||
await client.connect_and_run(test_client)
|
|
||||||
|
|
||||||
asyncio.run(main())
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. HTTP Mode
|
|
||||||
|
|
||||||
```python
|
|
||||||
import asyncio
|
|
||||||
from unified_client import create_http_client
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
# Create HTTP client
|
|
||||||
client = await create_http_client("http://localhost:3000/mcp")
|
|
||||||
|
|
||||||
async def test_client(client):
|
|
||||||
# Get all tools
|
|
||||||
tools = await client.list_all_tools()
|
|
||||||
print(f"Available tools: {len(tools)}")
|
|
||||||
|
|
||||||
# Execute query
|
|
||||||
result = await client.execute_sql(
|
|
||||||
"SELECT COUNT(*) FROM internal.ssb.lineorder LIMIT 1"
|
|
||||||
)
|
|
||||||
print(f"Query result: {result}")
|
|
||||||
|
|
||||||
await client.connect_and_run(test_client)
|
|
||||||
|
|
||||||
asyncio.run(main())
|
|
||||||
```
|
|
||||||
|
|
||||||
## 🔧 API Reference
|
|
||||||
|
|
||||||
### Client Creation
|
|
||||||
|
|
||||||
```python
|
|
||||||
# stdio mode
|
|
||||||
client = await create_stdio_client(command, args)
|
|
||||||
|
|
||||||
# HTTP mode
|
|
||||||
client = await create_http_client(server_url, timeout=60)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Basic Operations
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def test_client(client):
|
|
||||||
# Get server capabilities
|
|
||||||
tools = await client.list_all_tools()
|
|
||||||
resources = await client.list_all_resources()
|
|
||||||
prompts = await client.list_all_prompts()
|
|
||||||
|
|
||||||
# Call tool
|
|
||||||
result = await client.call_tool("tool_name", {"param": "value"})
|
|
||||||
|
|
||||||
# Read resource
|
|
||||||
content = await client.read_resource("resource://uri")
|
|
||||||
|
|
||||||
# Get prompt
|
|
||||||
prompt = await client.get_prompt("prompt_name", {"param": "value"})
|
|
||||||
```
|
|
||||||
|
|
||||||
### Advanced Database Operations
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def database_operations(client):
|
|
||||||
# Execute SQL query
|
|
||||||
result = await client.execute_sql("SELECT * FROM table LIMIT 10")
|
|
||||||
|
|
||||||
# Get database list
|
|
||||||
databases = await client.get_database_list()
|
|
||||||
|
|
||||||
# Get table schema
|
|
||||||
schema = await client.get_table_schema("table_name", "db_name")
|
|
||||||
```
|
|
||||||
|
|
||||||
## 🧪 Testing
|
|
||||||
|
|
||||||
### Run Test Suite
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Interactive testing
|
|
||||||
python test_unified_client.py
|
|
||||||
|
|
||||||
# Test stdio mode
|
|
||||||
python test_unified_client.py stdio
|
|
||||||
|
|
||||||
# Test HTTP mode
|
|
||||||
python test_unified_client.py http
|
|
||||||
|
|
||||||
# Test both modes
|
|
||||||
python test_unified_client.py both
|
|
||||||
|
|
||||||
# Performance benchmark
|
|
||||||
python test_unified_client.py benchmark
|
|
||||||
```
|
|
||||||
|
|
||||||
### Test Output Example
|
|
||||||
|
|
||||||
```
|
|
||||||
🎯 Doris Unified Client Test Suite
|
|
||||||
============================================================
|
|
||||||
|
|
||||||
🚀 Testing HTTP Mode
|
|
||||||
==================================================
|
|
||||||
📋 Getting server capabilities...
|
|
||||||
✅ Found 11 tools
|
|
||||||
✅ Found 0 resources
|
|
||||||
✅ Found 0 prompts
|
|
||||||
|
|
||||||
🔧 Available tools:
|
|
||||||
1. get_db_list: Get database list
|
|
||||||
2. get_table_list: Get table list for specified database
|
|
||||||
3. get_table_schema: Get table structure information
|
|
||||||
4. exec_query: Execute SQL query
|
|
||||||
...
|
|
||||||
|
|
||||||
🧪 Testing basic functionality...
|
|
||||||
1️⃣ Getting database list...
|
|
||||||
✅ Success: 3 databases
|
|
||||||
2️⃣ Executing simple query...
|
|
||||||
✅ Query successful
|
|
||||||
3️⃣ Executing SSB data query...
|
|
||||||
✅ SSB query successful
|
|
||||||
4️⃣ Getting table structure...
|
|
||||||
✅ Table structure retrieved successfully
|
|
||||||
|
|
||||||
✅ HTTP mode testing completed!
|
|
||||||
```
|
|
||||||
|
|
||||||
## 🏗️ Architecture Design
|
|
||||||
|
|
||||||
### Unified Client Architecture
|
|
||||||
|
|
||||||
```
|
|
||||||
DorisUnifiedClient
|
|
||||||
├── DorisResourceClient # Resource management
|
|
||||||
├── DorisToolsClient # Tool invocation
|
|
||||||
├── DorisPromptClient # Prompt management
|
|
||||||
└── Transport Layer
|
|
||||||
├── stdio mode # Standard input/output
|
|
||||||
└── HTTP mode # Streamable HTTP
|
|
||||||
```
|
|
||||||
|
|
||||||
### Key Features
|
|
||||||
|
|
||||||
1. **Unified Interface**: Same API for different transport modes
|
|
||||||
2. **Async Context**: Proper resource management and connection cleanup
|
|
||||||
3. **Error Handling**: Comprehensive exception handling and error recovery
|
|
||||||
4. **Performance Optimization**: Connection reuse and request caching
|
|
||||||
|
|
||||||
## 📚 Usage Examples
|
|
||||||
|
|
||||||
### Complete Example
|
|
||||||
|
|
||||||
```python
|
|
||||||
import asyncio
|
|
||||||
from client import DorisUnifiedClient, DorisClientConfig
|
|
||||||
|
|
||||||
async def comprehensive_example():
|
|
||||||
# Create configuration
|
|
||||||
config = DorisClientConfig.stdio(
|
|
||||||
"python",
|
|
||||||
["-m", "doris_mcp_server.main"]
|
|
||||||
)
|
|
||||||
|
|
||||||
client = DorisUnifiedClient(config)
|
|
||||||
|
|
||||||
async def demo_operations(client):
|
|
||||||
print("🔍 Discovering server capabilities...")
|
|
||||||
|
|
||||||
# List all available tools
|
|
||||||
tools = await client.list_all_tools()
|
|
||||||
print(f"Available tools: {[tool.name for tool in tools]}")
|
|
||||||
|
|
||||||
# Get database list
|
|
||||||
print("\n📊 Getting database information...")
|
|
||||||
db_result = await client.get_database_list()
|
|
||||||
print(f"Databases: {db_result}")
|
|
||||||
|
|
||||||
# Execute queries
|
|
||||||
print("\n🔍 Executing queries...")
|
|
||||||
|
|
||||||
# Simple query
|
|
||||||
result1 = await client.execute_sql("SELECT 1 as test_column")
|
|
||||||
print(f"Simple query result: {result1}")
|
|
||||||
|
|
||||||
# Get table schema
|
|
||||||
schema_result = await client.get_table_schema("lineorder", "ssb")
|
|
||||||
print(f"Table schema: {schema_result}")
|
|
||||||
|
|
||||||
await client.connect_and_run(demo_operations)
|
|
||||||
|
|
||||||
# Run the example
|
|
||||||
asyncio.run(comprehensive_example())
|
|
||||||
```
|
|
||||||
|
|
||||||
### Error Handling
|
|
||||||
|
|
||||||
```python
|
|
||||||
async def error_handling_example(client):
|
|
||||||
try:
|
|
||||||
# This might fail
|
|
||||||
result = await client.execute_sql("INVALID SQL")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"SQL execution failed: {e}")
|
|
||||||
|
|
||||||
# Check result status
|
|
||||||
result = await client.get_database_list()
|
|
||||||
if result.get("success", True):
|
|
||||||
print("Operation successful")
|
|
||||||
else:
|
|
||||||
print(f"Operation failed: {result.get('error')}")
|
|
||||||
```
|
|
||||||
|
|
||||||
## 🔧 Configuration
|
|
||||||
|
|
||||||
### Client Configuration Options
|
|
||||||
|
|
||||||
```python
|
|
||||||
# stdio mode with custom arguments
|
|
||||||
config = DorisClientConfig(
|
|
||||||
transport="stdio",
|
|
||||||
server_command="python",
|
|
||||||
server_args=["-m", "doris_mcp_server.main", "--debug"],
|
|
||||||
timeout=30
|
|
||||||
)
|
|
||||||
|
|
||||||
# HTTP mode with custom timeout
|
|
||||||
config = DorisClientConfig(
|
|
||||||
transport="http",
|
|
||||||
server_url="http://localhost:8080/mcp",
|
|
||||||
timeout=60
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Environment Variables
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Set default server URL
|
|
||||||
export DORIS_MCP_SERVER_URL="http://localhost:8080"
|
|
||||||
|
|
||||||
# Set default timeout
|
|
||||||
export DORIS_MCP_TIMEOUT=60
|
|
||||||
|
|
||||||
# Enable debug logging
|
|
||||||
export DORIS_MCP_DEBUG=true
|
|
||||||
```
|
|
||||||
|
|
||||||
## 🚀 Performance Tips
|
|
||||||
|
|
||||||
1. **Connection Reuse**: Use the same client instance for multiple operations
|
|
||||||
2. **Batch Operations**: Group related queries together
|
|
||||||
3. **Async Context**: Always use proper async context management
|
|
||||||
4. **Error Recovery**: Implement retry logic for transient failures
|
|
||||||
|
|
||||||
## 🤝 Contributing
|
|
||||||
|
|
||||||
1. Fork the repository
|
|
||||||
2. Create a feature branch
|
|
||||||
3. Make your changes
|
|
||||||
4. Add tests
|
|
||||||
5. Submit a pull request
|
|
||||||
|
|
||||||
## 📄 License
|
|
||||||
|
|
||||||
This project is licensed under the Apache 2.0 License.
|
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
# Licensed to the Apache Software Foundation (ASF) under one
|
|
||||||
# or more contributor license agreements. See the NOTICE file
|
|
||||||
# distributed with this work for additional information
|
|
||||||
# regarding copyright ownership. The ASF licenses this file
|
|
||||||
# to you under the Apache License, Version 2.0 (the
|
|
||||||
# "License"); you may not use this file except in compliance
|
|
||||||
# with the License. You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing,
|
|
||||||
# software distributed under the License is distributed on an
|
|
||||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
# KIND, either express or implied. See the License for the
|
|
||||||
# specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
"""
|
|
||||||
Doris MCP Client Package
|
|
||||||
|
|
||||||
Unified MCP client supporting both stdio and HTTP transport modes
|
|
||||||
"""
|
|
||||||
|
|
||||||
from .client import DorisUnifiedClient, DorisClientConfig
|
|
||||||
|
|
||||||
__all__ = ["DorisUnifiedClient", "DorisClientConfig"]
|
|
||||||
@@ -1,509 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Licensed to the Apache Software Foundation (ASF) under one
|
|
||||||
# or more contributor license agreements. See the NOTICE file
|
|
||||||
# distributed with this work for additional information
|
|
||||||
# regarding copyright ownership. The ASF licenses this file
|
|
||||||
# to you under the Apache License, Version 2.0 (the
|
|
||||||
# "License"); you may not use this file except in compliance
|
|
||||||
# with the License. You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing,
|
|
||||||
# software distributed under the License is distributed on an
|
|
||||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
# KIND, either express or implied. See the License for the
|
|
||||||
# specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
"""
|
|
||||||
Unified Doris MCP Client - Supports both stdio and Streamable HTTP modes
|
|
||||||
|
|
||||||
Combines the correct HTTP implementation from http_client.py and the complete architecture from client.py
|
|
||||||
Provides complete support for the three major primitives: Resources, Tools, and Prompts
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from typing import Any, Callable
|
|
||||||
from datetime import timedelta
|
|
||||||
|
|
||||||
from mcp.client.session import ClientSession
|
|
||||||
from mcp.client.stdio import stdio_client
|
|
||||||
from mcp.client.streamable_http import streamablehttp_client
|
|
||||||
from mcp import StdioServerParameters
|
|
||||||
from mcp.types import (
|
|
||||||
Prompt,
|
|
||||||
Resource,
|
|
||||||
Tool,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Configure logging
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class DorisClientConfig:
|
|
||||||
"""Doris client configuration class"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
transport: str = "stdio",
|
|
||||||
server_command: str | None = None,
|
|
||||||
server_args: list[str] | None = None,
|
|
||||||
server_url: str | None = None,
|
|
||||||
timeout: int = 60,
|
|
||||||
):
|
|
||||||
self.transport = transport
|
|
||||||
self.server_command = server_command
|
|
||||||
self.server_args = server_args or []
|
|
||||||
self.server_url = server_url
|
|
||||||
self.timeout = timeout
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def stdio(cls, command: str, args: list[str] = None) -> "DorisClientConfig":
|
|
||||||
"""Create stdio connection configuration"""
|
|
||||||
return cls(
|
|
||||||
transport="stdio",
|
|
||||||
server_command=command,
|
|
||||||
server_args=args or []
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def http(cls, url: str, timeout: int = 60) -> "DorisClientConfig":
|
|
||||||
"""Create HTTP connection configuration"""
|
|
||||||
return cls(
|
|
||||||
transport="http",
|
|
||||||
server_url=url,
|
|
||||||
timeout=timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DorisResourceClient:
|
|
||||||
"""Doris resource client - Handles Resources related operations"""
|
|
||||||
|
|
||||||
def __init__(self, session: ClientSession):
|
|
||||||
self.session = session
|
|
||||||
self.logger = logging.getLogger(f"{__name__}.DorisResourceClient")
|
|
||||||
self._resources_cache = None
|
|
||||||
|
|
||||||
async def list_resources(self) -> list[Resource]:
|
|
||||||
"""Get list of all available resources"""
|
|
||||||
try:
|
|
||||||
self.logger.info("Getting resource list")
|
|
||||||
response = await self.session.list_resources()
|
|
||||||
resources = response.resources if hasattr(response, "resources") else []
|
|
||||||
self._resources_cache = resources
|
|
||||||
self.logger.info(f"Retrieved {len(resources)} resources")
|
|
||||||
return resources
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Failed to get resource list: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def read_resource(self, uri: str) -> str | None:
|
|
||||||
"""Read specified resource content"""
|
|
||||||
try:
|
|
||||||
self.logger.info(f"Reading resource: {uri}")
|
|
||||||
response = await self.session.read_resource(uri)
|
|
||||||
|
|
||||||
if hasattr(response, "contents") and response.contents:
|
|
||||||
# Merge all content
|
|
||||||
content_parts = []
|
|
||||||
for content in response.contents:
|
|
||||||
if hasattr(content, "text"):
|
|
||||||
content_parts.append(content.text)
|
|
||||||
content = "\n".join(content_parts)
|
|
||||||
self.logger.info(f"Successfully read resource content: {len(content)} characters")
|
|
||||||
return content
|
|
||||||
elif hasattr(response, "content"):
|
|
||||||
return str(response.content)
|
|
||||||
else:
|
|
||||||
self.logger.warning(f"Resource {uri} returned no content")
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Failed to read resource {uri}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def filter_resources_by_type(self, resource_type: str) -> list[Resource]:
|
|
||||||
"""Filter resources by type"""
|
|
||||||
if not self._resources_cache:
|
|
||||||
await self.list_resources()
|
|
||||||
|
|
||||||
if resource_type == "table":
|
|
||||||
return [r for r in self._resources_cache if "table" in r.uri]
|
|
||||||
elif resource_type == "view":
|
|
||||||
return [r for r in self._resources_cache if "view" in r.uri]
|
|
||||||
elif resource_type == "database":
|
|
||||||
return [
|
|
||||||
r for r in self._resources_cache
|
|
||||||
if "database" in r.uri and "table" not in r.uri
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
return self._resources_cache
|
|
||||||
|
|
||||||
|
|
||||||
class DorisToolsClient:
|
|
||||||
"""Doris tools client - Handles Tools related operations"""
|
|
||||||
|
|
||||||
def __init__(self, session: ClientSession):
|
|
||||||
self.session = session
|
|
||||||
self.logger = logging.getLogger(f"{__name__}.DorisToolsClient")
|
|
||||||
self._tools_cache = None
|
|
||||||
|
|
||||||
async def list_tools(self) -> list[Tool]:
|
|
||||||
"""Get list of all available tools"""
|
|
||||||
try:
|
|
||||||
self.logger.info("Getting tool list")
|
|
||||||
response = await self.session.list_tools()
|
|
||||||
tools = response.tools if hasattr(response, "tools") else []
|
|
||||||
self._tools_cache = tools
|
|
||||||
self.logger.info(f"Retrieved {len(tools)} tools")
|
|
||||||
return tools
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Failed to get tool list: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def call_tool(self, name: str, arguments: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Call specified tool"""
|
|
||||||
try:
|
|
||||||
self.logger.info(f"Calling tool: {name}")
|
|
||||||
self.logger.debug(f"Tool arguments: {arguments}")
|
|
||||||
|
|
||||||
response = await self.session.call_tool(name, arguments)
|
|
||||||
|
|
||||||
if hasattr(response, "content") and response.content:
|
|
||||||
# Parse response content
|
|
||||||
result_text = ""
|
|
||||||
for content in response.content:
|
|
||||||
if hasattr(content, "text"):
|
|
||||||
result_text += content.text
|
|
||||||
|
|
||||||
# Try to parse as JSON
|
|
||||||
try:
|
|
||||||
result = json.loads(result_text)
|
|
||||||
self.logger.info(f"Tool call successful: {name}")
|
|
||||||
return result
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
# If not JSON format, return text directly
|
|
||||||
return {"success": True, "data": result_text}
|
|
||||||
|
|
||||||
self.logger.warning(f"Tool {name} returned no content")
|
|
||||||
return {"success": False, "error": "No response content"}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Tool call failed {name}: {e}")
|
|
||||||
return {"success": False, "error": str(e)}
|
|
||||||
|
|
||||||
async def get_tool_by_name(self, name: str) -> Tool | None:
|
|
||||||
"""Get tool definition by name"""
|
|
||||||
if not self._tools_cache:
|
|
||||||
await self.list_tools()
|
|
||||||
|
|
||||||
for tool in self._tools_cache:
|
|
||||||
if tool.name == name:
|
|
||||||
return tool
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def get_tools_by_category(self, category: str) -> list[Tool]:
|
|
||||||
"""Filter tools by category"""
|
|
||||||
if not self._tools_cache:
|
|
||||||
await self.list_tools()
|
|
||||||
|
|
||||||
category_lower = category.lower()
|
|
||||||
return [
|
|
||||||
tool for tool in self._tools_cache
|
|
||||||
if category_lower in tool.description.lower()
|
|
||||||
or category_lower in tool.name.lower()
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class DorisPromptClient:
|
|
||||||
"""Doris prompt client - Handles Prompts related operations"""
|
|
||||||
|
|
||||||
def __init__(self, session: ClientSession):
|
|
||||||
self.session = session
|
|
||||||
self.logger = logging.getLogger(f"{__name__}.DorisPromptClient")
|
|
||||||
self._prompts_cache = None
|
|
||||||
|
|
||||||
async def list_prompts(self) -> list[Prompt]:
|
|
||||||
"""Get list of all available prompts"""
|
|
||||||
try:
|
|
||||||
self.logger.info("Getting prompt list")
|
|
||||||
response = await self.session.list_prompts()
|
|
||||||
prompts = response.prompts if hasattr(response, "prompts") else []
|
|
||||||
self._prompts_cache = prompts
|
|
||||||
self.logger.info(f"Retrieved {len(prompts)} prompts")
|
|
||||||
return prompts
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Failed to get prompt list: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def get_prompt(self, name: str, arguments: dict[str, Any]) -> str | None:
|
|
||||||
"""Get specified prompt content"""
|
|
||||||
try:
|
|
||||||
self.logger.info(f"Getting prompt: {name}")
|
|
||||||
self.logger.debug(f"Prompt arguments: {arguments}")
|
|
||||||
|
|
||||||
response = await self.session.get_prompt(name, arguments)
|
|
||||||
|
|
||||||
if hasattr(response, "messages") and response.messages:
|
|
||||||
# Merge all message content
|
|
||||||
content_parts = []
|
|
||||||
for message in response.messages:
|
|
||||||
if hasattr(message, "content"):
|
|
||||||
if hasattr(message.content, "text"):
|
|
||||||
content_parts.append(message.content.text)
|
|
||||||
else:
|
|
||||||
content_parts.append(str(message.content))
|
|
||||||
|
|
||||||
content = "\n".join(content_parts)
|
|
||||||
self.logger.info(f"Successfully retrieved prompt content: {len(content)} characters")
|
|
||||||
return content
|
|
||||||
|
|
||||||
self.logger.warning(f"Prompt {name} returned no content")
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Failed to get prompt {name}: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
class DorisUnifiedClient:
|
|
||||||
"""Unified Doris MCP client - Provides complete MCP functionality"""
|
|
||||||
|
|
||||||
def __init__(self, config: DorisClientConfig):
|
|
||||||
self.config = config
|
|
||||||
self.logger = logging.getLogger(f"{__name__}.DorisUnifiedClient")
|
|
||||||
self.session = None
|
|
||||||
self.resources = None
|
|
||||||
self.tools = None
|
|
||||||
self.prompts = None
|
|
||||||
|
|
||||||
async def connect_and_run(self, callback_func: Callable):
|
|
||||||
"""Connect to server and execute callback function"""
|
|
||||||
if self.config.transport == "stdio":
|
|
||||||
await self._run_stdio_mode(callback_func)
|
|
||||||
elif self.config.transport == "http":
|
|
||||||
await self._run_http_mode(callback_func)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported transport type: {self.config.transport}")
|
|
||||||
|
|
||||||
async def _run_stdio_mode(self, callback_func: Callable):
|
|
||||||
"""Run in stdio mode"""
|
|
||||||
try:
|
|
||||||
self.logger.info(f"Starting stdio client: {self.config.server_command}")
|
|
||||||
|
|
||||||
server_params = StdioServerParameters(
|
|
||||||
command=self.config.server_command,
|
|
||||||
args=self.config.server_args,
|
|
||||||
)
|
|
||||||
|
|
||||||
async with stdio_client(server_params) as (read, write):
|
|
||||||
async with ClientSession(read, write) as session:
|
|
||||||
self.session = session
|
|
||||||
self._init_sub_clients()
|
|
||||||
|
|
||||||
# Initialize server
|
|
||||||
await session.initialize()
|
|
||||||
self.logger.info("Server initialized successfully")
|
|
||||||
|
|
||||||
# Execute callback function
|
|
||||||
await callback_func(self)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"stdio mode execution failed: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
async def _run_http_mode(self, callback_func: Callable):
|
|
||||||
"""Run in HTTP mode"""
|
|
||||||
try:
|
|
||||||
self.logger.info(f"Starting HTTP client: {self.config.server_url}")
|
|
||||||
|
|
||||||
async with streamablehttp_client(
|
|
||||||
self.config.server_url,
|
|
||||||
timeout=timedelta(seconds=self.config.timeout)
|
|
||||||
) as (read, write):
|
|
||||||
async with ClientSession(read, write) as session:
|
|
||||||
self.session = session
|
|
||||||
self._init_sub_clients()
|
|
||||||
|
|
||||||
# Initialize server
|
|
||||||
await session.initialize()
|
|
||||||
self.logger.info("Server initialized successfully")
|
|
||||||
|
|
||||||
# Execute callback function
|
|
||||||
await callback_func(self)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"HTTP mode execution failed: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _init_sub_clients(self):
|
|
||||||
"""Initialize sub-clients"""
|
|
||||||
self.resources = DorisResourceClient(self.session)
|
|
||||||
self.tools = DorisToolsClient(self.session)
|
|
||||||
self.prompts = DorisPromptClient(self.session)
|
|
||||||
|
|
||||||
# Convenience methods
|
|
||||||
async def list_all_resources(self) -> list[Resource]:
|
|
||||||
"""Get all resources"""
|
|
||||||
return await self.resources.list_resources()
|
|
||||||
|
|
||||||
async def list_all_tools(self) -> list[Tool]:
|
|
||||||
"""Get all tools"""
|
|
||||||
return await self.tools.list_tools()
|
|
||||||
|
|
||||||
async def list_all_prompts(self) -> list[Prompt]:
|
|
||||||
"""Get all prompts"""
|
|
||||||
return await self.prompts.list_prompts()
|
|
||||||
|
|
||||||
async def call_tool(self, name: str, arguments: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Call tool"""
|
|
||||||
return await self.tools.call_tool(name, arguments)
|
|
||||||
|
|
||||||
async def read_resource(self, uri: str) -> str | None:
|
|
||||||
"""Read resource"""
|
|
||||||
return await self.resources.read_resource(uri)
|
|
||||||
|
|
||||||
async def get_prompt(self, name: str, arguments: dict[str, Any]) -> str | None:
|
|
||||||
"""Get prompt"""
|
|
||||||
return await self.prompts.get_prompt(name, arguments)
|
|
||||||
|
|
||||||
# Smart tool finding methods
|
|
||||||
async def _find_tool_by_pattern(self, patterns: list[str]) -> str | None:
|
|
||||||
"""Find tool by name pattern"""
|
|
||||||
tools = await self.list_all_tools()
|
|
||||||
for pattern in patterns:
|
|
||||||
for tool in tools:
|
|
||||||
if pattern in tool.name:
|
|
||||||
return tool.name
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _find_tool_by_function(self, function_keywords: list[str]) -> str | None:
|
|
||||||
"""Find tool by function keywords"""
|
|
||||||
tools = await self.list_all_tools()
|
|
||||||
for tool in tools:
|
|
||||||
tool_desc = tool.description.lower()
|
|
||||||
tool_name = tool.name.lower()
|
|
||||||
for keyword in function_keywords:
|
|
||||||
if keyword.lower() in tool_desc or keyword.lower() in tool_name:
|
|
||||||
return tool.name
|
|
||||||
return None
|
|
||||||
|
|
||||||
# High-level business methods
|
|
||||||
async def execute_sql(self, sql: str, **kwargs) -> dict[str, Any]:
|
|
||||||
"""Execute SQL query"""
|
|
||||||
tool_name = await self._find_tool_by_pattern(["exec_query", "execute", "query"])
|
|
||||||
if not tool_name:
|
|
||||||
return {"success": False, "error": "SQL execution tool not found"}
|
|
||||||
|
|
||||||
arguments = {"sql": sql, **kwargs}
|
|
||||||
return await self.call_tool(tool_name, arguments)
|
|
||||||
|
|
||||||
async def get_table_schema(self, table_name: str, db_name: str = None, **kwargs) -> dict[str, Any]:
|
|
||||||
"""Get table schema"""
|
|
||||||
tool_name = await self._find_tool_by_pattern(["get_table_schema", "table_schema", "schema"])
|
|
||||||
if not tool_name:
|
|
||||||
return {"success": False, "error": "Table schema tool not found"}
|
|
||||||
|
|
||||||
arguments = {"table_name": table_name}
|
|
||||||
if db_name:
|
|
||||||
arguments["db_name"] = db_name
|
|
||||||
arguments.update(kwargs)
|
|
||||||
|
|
||||||
return await self.call_tool(tool_name, arguments)
|
|
||||||
|
|
||||||
async def get_database_list(self, **kwargs) -> dict[str, Any]:
|
|
||||||
"""Get database list"""
|
|
||||||
tool_name = await self._find_tool_by_pattern(["get_db_list", "database_list", "db_list"])
|
|
||||||
if not tool_name:
|
|
||||||
return {"success": False, "error": "Database list tool not found"}
|
|
||||||
|
|
||||||
return await self.call_tool(tool_name, kwargs)
|
|
||||||
|
|
||||||
async def get_memory_stats(self, tracker_type: str = "overview", include_details: bool = True, **kwargs) -> dict[str, Any]:
|
|
||||||
"""Get memory statistics"""
|
|
||||||
tool_name = await self._find_tool_by_pattern(["memory", "realtime_memory"])
|
|
||||||
if not tool_name:
|
|
||||||
return {"success": False, "error": "Memory stats tool not found"}
|
|
||||||
|
|
||||||
arguments = {"tracker_type": tracker_type, "include_details": include_details}
|
|
||||||
arguments.update(kwargs)
|
|
||||||
return await self.call_tool(tool_name, arguments)
|
|
||||||
|
|
||||||
async def call_tool_by_function(self, function_description: str, arguments: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Call tool by function description"""
|
|
||||||
# Try to find appropriate tool based on function description
|
|
||||||
function_keywords = function_description.lower().split()
|
|
||||||
tool_name = await self._find_tool_by_function(function_keywords)
|
|
||||||
|
|
||||||
if not tool_name:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": f"No tool found for function: {function_description}"
|
|
||||||
}
|
|
||||||
|
|
||||||
return await self.call_tool(tool_name, arguments)
|
|
||||||
|
|
||||||
|
|
||||||
# Convenience factory functions
|
|
||||||
async def create_stdio_client(command: str, args: list[str] = None) -> DorisUnifiedClient:
|
|
||||||
"""Create stdio client"""
|
|
||||||
config = DorisClientConfig.stdio(command, args)
|
|
||||||
return DorisUnifiedClient(config)
|
|
||||||
|
|
||||||
|
|
||||||
async def create_http_client(server_url: str, timeout: int = 60) -> DorisUnifiedClient:
|
|
||||||
"""Create HTTP client"""
|
|
||||||
config = DorisClientConfig.http(server_url, timeout)
|
|
||||||
return DorisUnifiedClient(config)
|
|
||||||
|
|
||||||
|
|
||||||
# Example usage
|
|
||||||
async def example_stdio():
|
|
||||||
"""stdio mode example"""
|
|
||||||
client = await create_stdio_client("python", ["doris_mcp_server/main.py"])
|
|
||||||
|
|
||||||
async def test_client(client: DorisUnifiedClient):
|
|
||||||
# Get server capabilities
|
|
||||||
resources = await client.list_all_resources()
|
|
||||||
tools = await client.list_all_tools()
|
|
||||||
prompts = await client.list_all_prompts()
|
|
||||||
|
|
||||||
print(f"Resources: {len(resources)}")
|
|
||||||
print(f"Tools: {len(tools)}")
|
|
||||||
print(f"Prompts: {len(prompts)}")
|
|
||||||
|
|
||||||
# Test SQL execution
|
|
||||||
result = await client.execute_sql("SELECT 1 as test")
|
|
||||||
print(f"SQL execution result: {result}")
|
|
||||||
|
|
||||||
await client.connect_and_run(test_client)
|
|
||||||
|
|
||||||
|
|
||||||
async def example_http():
|
|
||||||
"""HTTP mode example"""
|
|
||||||
client = await create_http_client("http://localhost:8080")
|
|
||||||
|
|
||||||
async def test_client(client: DorisUnifiedClient):
|
|
||||||
# Get server capabilities
|
|
||||||
resources = await client.list_all_resources()
|
|
||||||
tools = await client.list_all_tools()
|
|
||||||
|
|
||||||
print(f"Resources: {len(resources)}")
|
|
||||||
print(f"Tools: {len(tools)}")
|
|
||||||
|
|
||||||
# Test database list
|
|
||||||
result = await client.get_database_list()
|
|
||||||
print(f"Database list: {result}")
|
|
||||||
|
|
||||||
await client.connect_and_run(test_client)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Run stdio example
|
|
||||||
asyncio.run(example_stdio())
|
|
||||||
|
|
||||||
# Run HTTP example
|
|
||||||
# asyncio.run(example_http())
|
|
||||||
@@ -28,17 +28,6 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
# MCP version compatibility check
|
|
||||||
try:
|
|
||||||
import mcp
|
|
||||||
MCP_VERSION = getattr(mcp, '__version__', 'unknown')
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
logger.info(f"Using MCP version: {MCP_VERSION}")
|
|
||||||
except Exception as e:
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
logger.warning(f"Could not determine MCP version: {e}")
|
|
||||||
MCP_VERSION = 'unknown'
|
|
||||||
|
|
||||||
from mcp.server import Server
|
from mcp.server import Server
|
||||||
from mcp.server.models import InitializationOptions
|
from mcp.server.models import InitializationOptions
|
||||||
|
|
||||||
@@ -55,15 +44,11 @@ from .tools.resources_manager import DorisResourcesManager
|
|||||||
from .utils.config import DorisConfig
|
from .utils.config import DorisConfig
|
||||||
from .utils.db import DorisConnectionManager
|
from .utils.db import DorisConnectionManager
|
||||||
from .utils.security import DorisSecurityManager
|
from .utils.security import DorisSecurityManager
|
||||||
import os
|
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Create a default config instance for getting default values
|
|
||||||
_default_config = DorisConfig()
|
|
||||||
|
|
||||||
|
|
||||||
class DorisServer:
|
class DorisServer:
|
||||||
"""Apache Doris MCP Server main class"""
|
"""Apache Doris MCP Server main class"""
|
||||||
@@ -86,47 +71,6 @@ class DorisServer:
|
|||||||
self.logger = logging.getLogger(f"{__name__}.DorisServer")
|
self.logger = logging.getLogger(f"{__name__}.DorisServer")
|
||||||
self._setup_handlers()
|
self._setup_handlers()
|
||||||
|
|
||||||
def _get_mcp_capabilities(self):
|
|
||||||
"""Get MCP capabilities with version compatibility"""
|
|
||||||
try:
|
|
||||||
# For MCP 1.9.x and newer
|
|
||||||
from mcp.server.lowlevel.server import NotificationOptions
|
|
||||||
|
|
||||||
return self.server.get_capabilities(
|
|
||||||
notification_options=NotificationOptions(
|
|
||||||
prompts_changed=True,
|
|
||||||
resources_changed=True,
|
|
||||||
tools_changed=True
|
|
||||||
),
|
|
||||||
experimental_capabilities={}
|
|
||||||
)
|
|
||||||
except TypeError:
|
|
||||||
try:
|
|
||||||
# For MCP 1.8.x
|
|
||||||
from mcp.server.lowlevel.server import NotificationOptions
|
|
||||||
|
|
||||||
return self.server.get_capabilities(
|
|
||||||
notification_options=NotificationOptions(
|
|
||||||
prompts_changed=True,
|
|
||||||
resources_changed=True,
|
|
||||||
tools_changed=True
|
|
||||||
),
|
|
||||||
experimental_capabilities={}
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.warning(f"Could not get capabilities with NotificationOptions: {e}")
|
|
||||||
# Fallback for older versions
|
|
||||||
try:
|
|
||||||
return self.server.get_capabilities()
|
|
||||||
except Exception as fallback_e:
|
|
||||||
self.logger.error(f"Failed to get capabilities: {fallback_e}")
|
|
||||||
# Return minimal capabilities
|
|
||||||
return {
|
|
||||||
"resources": {},
|
|
||||||
"tools": {},
|
|
||||||
"prompts": {}
|
|
||||||
}
|
|
||||||
|
|
||||||
def _setup_handlers(self):
|
def _setup_handlers(self):
|
||||||
"""Setup MCP protocol handlers"""
|
"""Setup MCP protocol handlers"""
|
||||||
|
|
||||||
@@ -245,12 +189,22 @@ class DorisServer:
|
|||||||
read_stream, write_stream = streams
|
read_stream, write_stream = streams
|
||||||
self.logger.info("stdio_server streams created successfully")
|
self.logger.info("stdio_server streams created successfully")
|
||||||
|
|
||||||
# Create initialization options with version compatibility
|
# Create initialization options
|
||||||
capabilities = self._get_mcp_capabilities()
|
# MCP 1.8.0 requires parameters for get_capabilities
|
||||||
|
from mcp.server.lowlevel.server import NotificationOptions
|
||||||
|
|
||||||
|
capabilities = self.server.get_capabilities(
|
||||||
|
notification_options=NotificationOptions(
|
||||||
|
prompts_changed=True,
|
||||||
|
resources_changed=True,
|
||||||
|
tools_changed=True
|
||||||
|
),
|
||||||
|
experimental_capabilities={}
|
||||||
|
)
|
||||||
|
|
||||||
init_options = InitializationOptions(
|
init_options = InitializationOptions(
|
||||||
server_name="doris-mcp-server",
|
server_name="doris-mcp-server",
|
||||||
server_version=os.getenv("SERVER_VERSION", _default_config.server_version),
|
server_version="1.0.0",
|
||||||
capabilities=capabilities,
|
capabilities=capabilities,
|
||||||
)
|
)
|
||||||
self.logger.info("Initialization options created successfully")
|
self.logger.info("Initialization options created successfully")
|
||||||
@@ -283,7 +237,7 @@ class DorisServer:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def start_http(self, host: str = os.getenv("SERVER_HOST", _default_config.database.host), port: int = os.getenv("SERVER_PORT", _default_config.server_port)):
|
async def start_http(self, host: str = "localhost", port: int = 3000):
|
||||||
"""Start Streamable HTTP transport mode"""
|
"""Start Streamable HTTP transport mode"""
|
||||||
self.logger.info(f"Starting Doris MCP Server (Streamable HTTP mode) - {host}:{port}")
|
self.logger.info(f"Starting Doris MCP Server (Streamable HTTP mode) - {host}:{port}")
|
||||||
|
|
||||||
@@ -297,9 +251,9 @@ class DorisServer:
|
|||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
||||||
from starlette.applications import Starlette
|
from starlette.applications import Starlette
|
||||||
from starlette.routing import Route
|
from starlette.routing import Mount, Route
|
||||||
from starlette.responses import JSONResponse, Response
|
from starlette.responses import JSONResponse, Response
|
||||||
from starlette.types import Scope
|
from starlette.types import Receive, Scope, Send
|
||||||
|
|
||||||
# Create session manager
|
# Create session manager
|
||||||
session_manager = StreamableHTTPSessionManager(
|
session_manager = StreamableHTTPSessionManager(
|
||||||
@@ -459,34 +413,34 @@ Examples:
|
|||||||
"--transport",
|
"--transport",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["stdio", "http"],
|
choices=["stdio", "http"],
|
||||||
default=os.getenv("TRANSPORT", _default_config.transport),
|
default="stdio",
|
||||||
help=f"Transport protocol type: stdio (local), http (Streamable HTTP) (default: {_default_config.transport})",
|
help="Transport protocol type: stdio (local), http (Streamable HTTP)",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--host",
|
"--host",
|
||||||
type=str,
|
type=str,
|
||||||
default=os.getenv("SERVER_HOST", _default_config.database.host),
|
default="localhost",
|
||||||
help=f"Host address for HTTP mode (default: {_default_config.database.host})",
|
help="Host address for HTTP mode (default: localhost)",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--port", type=int, default=os.getenv("SERVER_PORT", _default_config.server_port), help=f"Port number for HTTP mode (default: {_default_config.server_port})"
|
"--port", type=int, default=3000, help="Port number for HTTP mode (default: 3000)"
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--db-host",
|
"--db-host",
|
||||||
type=str,
|
type=str,
|
||||||
default=os.getenv("DB_HOST", _default_config.database.host),
|
default="localhost",
|
||||||
help=f"Doris database host address (default: {_default_config.database.host})",
|
help="Doris database host address (default: localhost)",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--db-port", type=int, default=os.getenv("DB_PORT", _default_config.database.port), help=f"Doris database port number (default: {_default_config.database.port})"
|
"--db-port", type=int, default=9030, help="Doris database port number (default: 9030)"
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--db-user", type=str, default=os.getenv("DB_USER", _default_config.database.user), help=f"Doris database username (default: {_default_config.database.user})"
|
"--db-user", type=str, default="root", help="Doris database username (default: root)"
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("--db-password", type=str, default="", help="Doris database password")
|
parser.add_argument("--db-password", type=str, default="", help="Doris database password")
|
||||||
@@ -494,16 +448,16 @@ Examples:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--db-database",
|
"--db-database",
|
||||||
type=str,
|
type=str,
|
||||||
default=os.getenv("DB_DATABASE", _default_config.database.database),
|
default="information_schema",
|
||||||
help=f"Doris database name (default: {_default_config.database.database})",
|
help="Doris database name (default: information_schema)",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--log-level",
|
"--log-level",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||||
default=os.getenv("LOG_LEVEL", _default_config.logging.level),
|
default="INFO",
|
||||||
help=f"Log level (default: {_default_config.logging.level})",
|
help="Log level (default: INFO)",
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@@ -521,17 +475,17 @@ async def main():
|
|||||||
config = DorisConfig.from_env() # First load from .env file and environment variables
|
config = DorisConfig.from_env() # First load from .env file and environment variables
|
||||||
|
|
||||||
# Command line arguments override configuration (if provided)
|
# Command line arguments override configuration (if provided)
|
||||||
if args.db_host != _default_config.database.host: # If not default value, use command line argument
|
if args.db_host != "localhost": # If not default value, use command line argument
|
||||||
config.database.host = args.db_host
|
config.database.host = args.db_host
|
||||||
if args.db_port != _default_config.database.port:
|
if args.db_port != 9030:
|
||||||
config.database.port = args.db_port
|
config.database.port = args.db_port
|
||||||
if args.db_user != _default_config.database.user:
|
if args.db_user != "root":
|
||||||
config.database.user = args.db_user
|
config.database.user = args.db_user
|
||||||
if args.db_password: # Use password if provided
|
if args.db_password: # Use password if provided
|
||||||
config.database.password = args.db_password
|
config.database.password = args.db_password
|
||||||
if args.db_database != _default_config.database.database:
|
if args.db_database != "information_schema":
|
||||||
config.database.database = args.db_database
|
config.database.database = args.db_database
|
||||||
if args.log_level != _default_config.logging.level:
|
if args.log_level != "INFO":
|
||||||
config.logging.level = args.log_level
|
config.logging.level = args.log_level
|
||||||
|
|
||||||
# Create server instance
|
# Create server instance
|
||||||
|
|||||||
@@ -28,8 +28,7 @@ from mcp.types import Tool
|
|||||||
|
|
||||||
from ..utils.db import DorisConnectionManager
|
from ..utils.db import DorisConnectionManager
|
||||||
from ..utils.query_executor import DorisQueryExecutor
|
from ..utils.query_executor import DorisQueryExecutor
|
||||||
from ..utils.analysis_tools import TableAnalyzer, SQLAnalyzer, MemoryTracker
|
from ..utils.analysis_tools import TableAnalyzer, PerformanceMonitor
|
||||||
from ..utils.monitoring_tools import DorisMonitoringTools
|
|
||||||
from ..utils.schema_extractor import MetadataExtractor
|
from ..utils.schema_extractor import MetadataExtractor
|
||||||
from ..utils.logger import get_logger
|
from ..utils.logger import get_logger
|
||||||
|
|
||||||
@@ -46,10 +45,8 @@ class DorisToolsManager:
|
|||||||
# Initialize business logic processors
|
# Initialize business logic processors
|
||||||
self.query_executor = DorisQueryExecutor(connection_manager)
|
self.query_executor = DorisQueryExecutor(connection_manager)
|
||||||
self.table_analyzer = TableAnalyzer(connection_manager)
|
self.table_analyzer = TableAnalyzer(connection_manager)
|
||||||
self.sql_analyzer = SQLAnalyzer(connection_manager)
|
self.performance_monitor = PerformanceMonitor(connection_manager)
|
||||||
self.metadata_extractor = MetadataExtractor(connection_manager=connection_manager)
|
self.metadata_extractor = MetadataExtractor(connection_manager=connection_manager)
|
||||||
self.monitoring_tools = DorisMonitoringTools(connection_manager)
|
|
||||||
self.memory_tracker = MemoryTracker(connection_manager)
|
|
||||||
|
|
||||||
logger.info("DorisToolsManager initialized with business logic processors")
|
logger.info("DorisToolsManager initialized with business logic processors")
|
||||||
|
|
||||||
@@ -57,6 +54,99 @@ class DorisToolsManager:
|
|||||||
"""Register all tools to MCP server"""
|
"""Register all tools to MCP server"""
|
||||||
logger.info("Starting to register MCP tools")
|
logger.info("Starting to register MCP tools")
|
||||||
|
|
||||||
|
# Column statistical analysis tool
|
||||||
|
@mcp.tool(
|
||||||
|
"column_analysis",
|
||||||
|
description="""[Function Description]: Analyze statistical information and data distribution of the specified column.
|
||||||
|
|
||||||
|
[Parameter Content]:
|
||||||
|
|
||||||
|
- table_name (string) [Required] - Name of the table to analyze
|
||||||
|
|
||||||
|
- column_name (string) [Required] - Name of the column to analyze
|
||||||
|
|
||||||
|
- analysis_type (string) [Optional] - Type of analysis to perform, default is "basic"
|
||||||
|
* "basic": Basic statistics (count, null values, distinct values)
|
||||||
|
* "distribution": Data distribution analysis (frequency, percentiles)
|
||||||
|
* "detailed": Comprehensive analysis including all above plus patterns and outliers
|
||||||
|
""",
|
||||||
|
inputSchema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"table_name": {"type": "string", "description": "Table name"},
|
||||||
|
"column_name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Column name to analyze",
|
||||||
|
},
|
||||||
|
"analysis_type": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["basic", "distribution", "detailed"],
|
||||||
|
"description": "Analysis type",
|
||||||
|
"default": "basic",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["table_name", "column_name"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
async def column_analysis_tool(
|
||||||
|
table_name: str,
|
||||||
|
column_name: str,
|
||||||
|
analysis_type: str = "basic"
|
||||||
|
) -> str:
|
||||||
|
"""Column statistical analysis tool"""
|
||||||
|
return await self.call_tool("column_analysis", {
|
||||||
|
"table_name": table_name,
|
||||||
|
"column_name": column_name,
|
||||||
|
"analysis_type": analysis_type
|
||||||
|
})
|
||||||
|
|
||||||
|
# Database performance monitoring tool
|
||||||
|
@mcp.tool(
|
||||||
|
"performance_stats[Experimental]",
|
||||||
|
description="""[Important]: This tool is experimental and may not be fully functional!
|
||||||
|
[Function Description]: Get database performance statistics information.
|
||||||
|
|
||||||
|
[Parameter Content]:
|
||||||
|
|
||||||
|
- metric_type (string) [Optional] - Type of performance metrics to retrieve, default is "queries"
|
||||||
|
* "queries": Query performance metrics (execution time, frequency, etc.)
|
||||||
|
* "connections": Connection statistics (active connections, connection pool status)
|
||||||
|
* "tables": Table-level statistics (size, row count, access patterns)
|
||||||
|
* "system": System-level metrics (CPU, memory, disk usage)
|
||||||
|
|
||||||
|
- time_range (string) [Optional] - Time range for statistics, default is "1h"
|
||||||
|
* "1h": Last 1 hour
|
||||||
|
* "6h": Last 6 hours
|
||||||
|
* "24h": Last 24 hours
|
||||||
|
* "7d": Last 7 days
|
||||||
|
""",
|
||||||
|
inputSchema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"metric_type": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["queries", "connections", "tables", "system"],
|
||||||
|
"description": "Performance metric type",
|
||||||
|
"default": "queries",
|
||||||
|
},
|
||||||
|
"time_range": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["1h", "6h", "24h", "7d"],
|
||||||
|
"description": "Time range",
|
||||||
|
"default": "1h",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
async def performance_stats_tool(
|
||||||
|
metric_type: str = "queries",
|
||||||
|
time_range: str = "1h"
|
||||||
|
) -> str:
|
||||||
|
"""Database performance monitoring tool"""
|
||||||
|
return await self.call_tool("performance_stats", {
|
||||||
|
"metric_type": metric_type,
|
||||||
|
"time_range": time_range
|
||||||
|
})
|
||||||
|
|
||||||
# SQL query execution tool (supports catalog federation queries)
|
# SQL query execution tool (supports catalog federation queries)
|
||||||
@mcp.tool(
|
@mcp.tool(
|
||||||
@@ -262,227 +352,81 @@ class DorisToolsManager:
|
|||||||
"random_string": random_string
|
"random_string": random_string
|
||||||
})
|
})
|
||||||
|
|
||||||
# SQL Explain tool
|
logger.info("Successfully registered 11 tools to MCP server (2 core tools + 9 migrated tools)")
|
||||||
@mcp.tool(
|
|
||||||
"get_sql_explain",
|
|
||||||
description="""[Function Description]: Get SQL execution plan using EXPLAIN command based on Doris syntax.
|
|
||||||
|
|
||||||
[Parameter Content]:
|
|
||||||
|
|
||||||
- sql (string) [Required] - SQL statement to explain
|
|
||||||
|
|
||||||
- verbose (boolean) [Optional] - Whether to show verbose information, default is false
|
|
||||||
|
|
||||||
- db_name (string) [Optional] - Target database name, defaults to the current database
|
|
||||||
|
|
||||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
async def get_sql_explain_tool(
|
|
||||||
sql: str,
|
|
||||||
verbose: bool = False,
|
|
||||||
db_name: str = None,
|
|
||||||
catalog_name: str = None
|
|
||||||
) -> str:
|
|
||||||
"""Get SQL execution plan"""
|
|
||||||
return await self.call_tool("get_sql_explain", {
|
|
||||||
"sql": sql,
|
|
||||||
"verbose": verbose,
|
|
||||||
"db_name": db_name,
|
|
||||||
"catalog_name": catalog_name
|
|
||||||
})
|
|
||||||
|
|
||||||
# SQL Profile tool
|
|
||||||
@mcp.tool(
|
|
||||||
"get_sql_profile",
|
|
||||||
description="""[Function Description]: Get SQL execution profile by setting trace ID and fetching profile via FE HTTP API.
|
|
||||||
|
|
||||||
[Parameter Content]:
|
|
||||||
|
|
||||||
- sql (string) [Required] - SQL statement to profile
|
|
||||||
|
|
||||||
- db_name (string) [Optional] - Target database name, defaults to the current database
|
|
||||||
|
|
||||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
|
|
||||||
|
|
||||||
- timeout (integer) [Optional] - Query timeout in seconds, default is 30
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
async def get_sql_profile_tool(
|
|
||||||
sql: str,
|
|
||||||
db_name: str = None,
|
|
||||||
catalog_name: str = None,
|
|
||||||
timeout: int = 30
|
|
||||||
) -> str:
|
|
||||||
"""Get SQL execution profile"""
|
|
||||||
return await self.call_tool("get_sql_profile", {
|
|
||||||
"sql": sql,
|
|
||||||
"db_name": db_name,
|
|
||||||
"catalog_name": catalog_name,
|
|
||||||
"timeout": timeout
|
|
||||||
})
|
|
||||||
|
|
||||||
# Table data size tool
|
|
||||||
@mcp.tool(
|
|
||||||
"get_table_data_size",
|
|
||||||
description="""[Function Description]: Get table data size information via FE HTTP API.
|
|
||||||
|
|
||||||
[Parameter Content]:
|
|
||||||
|
|
||||||
- db_name (string) [Optional] - Database name, if not specified returns all databases
|
|
||||||
|
|
||||||
- table_name (string) [Optional] - Table name, if not specified returns all tables in the database
|
|
||||||
|
|
||||||
- single_replica (boolean) [Optional] - Whether to get single replica data size, default is false
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
async def get_table_data_size_tool(
|
|
||||||
db_name: str = None,
|
|
||||||
table_name: str = None,
|
|
||||||
single_replica: bool = False
|
|
||||||
) -> str:
|
|
||||||
"""Get table data size information"""
|
|
||||||
return await self.call_tool("get_table_data_size", {
|
|
||||||
"db_name": db_name,
|
|
||||||
"table_name": table_name,
|
|
||||||
"single_replica": single_replica
|
|
||||||
})
|
|
||||||
|
|
||||||
# Monitoring metrics definition tool
|
|
||||||
@mcp.tool(
|
|
||||||
"get_monitoring_metrics_info",
|
|
||||||
description="""[Function Description]: Get Doris monitoring metrics definitions and descriptions without executing queries.
|
|
||||||
|
|
||||||
[Parameter Content]:
|
|
||||||
|
|
||||||
- role (string) [Optional] - Node role to get metric definitions for, default is "all"
|
|
||||||
* "fe": Only FE metrics definitions
|
|
||||||
* "be": Only BE metrics definitions
|
|
||||||
* "all": Both FE and BE metrics definitions
|
|
||||||
|
|
||||||
- monitor_type (string) [Optional] - Type of monitoring metrics, default is "all"
|
|
||||||
* "process": Process monitoring metrics
|
|
||||||
* "jvm": JVM monitoring metrics (FE only)
|
|
||||||
* "machine": Machine monitoring metrics
|
|
||||||
* "all": All monitoring types
|
|
||||||
|
|
||||||
- priority (string) [Optional] - Metric priority level, default is "core"
|
|
||||||
* "core": Only core essential metrics (10-12 items for production use)
|
|
||||||
* "p0": Only P0 (highest priority) metrics definitions
|
|
||||||
* "all": All metrics definitions (P0 and non-P0)
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
async def get_monitoring_metrics_info_tool(
|
|
||||||
role: str = "all",
|
|
||||||
monitor_type: str = "all",
|
|
||||||
priority: str = "core"
|
|
||||||
) -> str:
|
|
||||||
"""Get Doris monitoring metrics definitions"""
|
|
||||||
return await self.call_tool("get_monitoring_metrics_info", {
|
|
||||||
"role": role,
|
|
||||||
"monitor_type": monitor_type,
|
|
||||||
"priority": priority
|
|
||||||
})
|
|
||||||
|
|
||||||
# Monitoring metrics data tool
|
|
||||||
@mcp.tool(
|
|
||||||
"get_monitoring_metrics_data",
|
|
||||||
description="""[Function Description]: Get actual Doris monitoring metrics data from FE and BE nodes via HTTP API.
|
|
||||||
|
|
||||||
[Parameter Content]:
|
|
||||||
|
|
||||||
- role (string) [Optional] - Node role to monitor, default is "all"
|
|
||||||
* "fe": Only FE nodes
|
|
||||||
* "be": Only BE nodes
|
|
||||||
* "all": Both FE and BE nodes
|
|
||||||
|
|
||||||
- monitor_type (string) [Optional] - Type of monitoring metrics, default is "all"
|
|
||||||
* "process": Process monitoring metrics
|
|
||||||
* "jvm": JVM monitoring metrics (FE only)
|
|
||||||
* "machine": Machine monitoring metrics
|
|
||||||
* "all": All monitoring types
|
|
||||||
|
|
||||||
- priority (string) [Optional] - Metric priority level, default is "core"
|
|
||||||
* "core": Only core essential metrics (10-12 items for production use)
|
|
||||||
* "p0": Only P0 (highest priority) metrics
|
|
||||||
* "all": All metrics (P0 and non-P0)
|
|
||||||
|
|
||||||
- include_raw_metrics (boolean) [Optional] - Whether to include raw detailed metrics data (can be very large)
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
async def get_monitoring_metrics_data_tool(
|
|
||||||
role: str = "all",
|
|
||||||
monitor_type: str = "all",
|
|
||||||
priority: str = "core",
|
|
||||||
include_raw_metrics: bool = False
|
|
||||||
) -> str:
|
|
||||||
"""Get Doris monitoring metrics data"""
|
|
||||||
return await self.call_tool("get_monitoring_metrics_data", {
|
|
||||||
"role": role,
|
|
||||||
"monitor_type": monitor_type,
|
|
||||||
"priority": priority,
|
|
||||||
"include_raw_metrics": include_raw_metrics
|
|
||||||
})
|
|
||||||
|
|
||||||
# Real-time memory tracker tool
|
|
||||||
@mcp.tool(
|
|
||||||
"get_realtime_memory_stats",
|
|
||||||
description="""[Function Description]: Get real-time memory statistics via Doris BE Memory Tracker web interface.
|
|
||||||
|
|
||||||
[Parameter Content]:
|
|
||||||
|
|
||||||
- tracker_type (string) [Optional] - Type of memory trackers to retrieve, default is "overview"
|
|
||||||
* "overview": Overview type trackers (process memory, tracked memory summary)
|
|
||||||
* "global": Global shared memory trackers (cache, metadata)
|
|
||||||
* "query": Query-related memory trackers
|
|
||||||
* "load": Load-related memory trackers
|
|
||||||
* "compaction": Compaction-related memory trackers
|
|
||||||
* "all": All memory tracker types
|
|
||||||
|
|
||||||
- include_details (boolean) [Optional] - Whether to include detailed tracker information and definitions, default is true
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
async def get_realtime_memory_stats_tool(
|
|
||||||
tracker_type: str = "overview",
|
|
||||||
include_details: bool = True
|
|
||||||
) -> str:
|
|
||||||
"""Get real-time memory statistics tool"""
|
|
||||||
return await self.call_tool("get_realtime_memory_stats", {
|
|
||||||
"tracker_type": tracker_type,
|
|
||||||
"include_details": include_details
|
|
||||||
})
|
|
||||||
|
|
||||||
# Historical memory tracker tool
|
|
||||||
@mcp.tool(
|
|
||||||
"get_historical_memory_stats",
|
|
||||||
description="""[Function Description]: Get historical memory statistics via Doris BE Bvar interface.
|
|
||||||
|
|
||||||
[Parameter Content]:
|
|
||||||
|
|
||||||
- tracker_names (array) [Optional] - List of specific tracker names to query, if not specified will get common trackers
|
|
||||||
* Example: ["process_resident_memory", "global", "query", "load", "compaction"]
|
|
||||||
|
|
||||||
- time_range (string) [Optional] - Time range for historical data, default is "1h"
|
|
||||||
* "1h": Last 1 hour
|
|
||||||
* "6h": Last 6 hours
|
|
||||||
* "24h": Last 24 hours
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
async def get_historical_memory_stats_tool(
|
|
||||||
tracker_names: List[str] = None,
|
|
||||||
time_range: str = "1h"
|
|
||||||
) -> str:
|
|
||||||
"""Get historical memory statistics tool"""
|
|
||||||
return await self.call_tool("get_historical_memory_stats", {
|
|
||||||
"tracker_names": tracker_names,
|
|
||||||
"time_range": time_range
|
|
||||||
})
|
|
||||||
|
|
||||||
logger.info("Successfully registered 16 tools to MCP server")
|
|
||||||
|
|
||||||
async def list_tools(self) -> List[Tool]:
|
async def list_tools(self) -> List[Tool]:
|
||||||
"""List all available query tools (for stdio mode)"""
|
"""List all available query tools (for stdio mode)"""
|
||||||
tools = [
|
tools = [
|
||||||
|
Tool(
|
||||||
|
name="column_analysis[Experimental]",
|
||||||
|
description="""[Important]: This tool is experimental and may not be fully functional!
|
||||||
|
[Function Description]: Analyze statistical information and data distribution of the specified column.
|
||||||
|
|
||||||
|
[Parameter Content]:
|
||||||
|
|
||||||
|
- table_name (string) [Required] - Name of the table to analyze
|
||||||
|
|
||||||
|
- column_name (string) [Required] - Name of the column to analyze
|
||||||
|
|
||||||
|
- analysis_type (string) [Optional] - Type of analysis to perform, default is "basic"
|
||||||
|
* "basic": Basic statistics (count, null values, distinct values)
|
||||||
|
* "distribution": Data distribution analysis (frequency, percentiles)
|
||||||
|
* "detailed": Comprehensive analysis including all above plus patterns and outliers
|
||||||
|
""",
|
||||||
|
inputSchema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"table_name": {"type": "string", "description": "Table name"},
|
||||||
|
"column_name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Column name to analyze",
|
||||||
|
},
|
||||||
|
"analysis_type": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["basic", "distribution", "detailed"],
|
||||||
|
"description": "Analysis type",
|
||||||
|
"default": "basic",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["table_name", "column_name"],
|
||||||
|
},
|
||||||
|
),
|
||||||
|
Tool(
|
||||||
|
name="performance_stats",
|
||||||
|
description="""[Function Description]: Get database performance statistics information.
|
||||||
|
|
||||||
|
[Parameter Content]:
|
||||||
|
|
||||||
|
- metric_type (string) [Optional] - Type of performance metrics to retrieve, default is "queries"
|
||||||
|
* "queries": Query performance metrics (execution time, frequency, etc.)
|
||||||
|
* "connections": Connection statistics (active connections, connection pool status)
|
||||||
|
* "tables": Table-level statistics (size, row count, access patterns)
|
||||||
|
* "system": System-level metrics (CPU, memory, disk usage)
|
||||||
|
|
||||||
|
- time_range (string) [Optional] - Time range for statistics, default is "1h"
|
||||||
|
* "1h": Last 1 hour
|
||||||
|
* "6h": Last 6 hours
|
||||||
|
* "24h": Last 24 hours
|
||||||
|
* "7d": Last 7 days
|
||||||
|
""",
|
||||||
|
inputSchema={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"metric_type": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["queries", "connections", "tables", "system"],
|
||||||
|
"description": "Performance metric type",
|
||||||
|
"default": "queries",
|
||||||
|
},
|
||||||
|
"time_range": {
|
||||||
|
"type": "string",
|
||||||
|
"enum": ["1h", "6h", "24h", "7d"],
|
||||||
|
"description": "Time range",
|
||||||
|
"default": "1h",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
),
|
||||||
Tool(
|
Tool(
|
||||||
name="exec_query",
|
name="exec_query",
|
||||||
description="""[Function Description]: Execute SQL query and return result command with catalog federation support.
|
description="""[Function Description]: Execute SQL query and return result command with catalog federation support.
|
||||||
@@ -666,188 +610,6 @@ class DorisToolsManager:
|
|||||||
"required": ["random_string"],
|
"required": ["random_string"],
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
Tool(
|
|
||||||
name="get_sql_explain",
|
|
||||||
description="""[Function Description]: Get SQL execution plan using EXPLAIN command based on Doris syntax.
|
|
||||||
|
|
||||||
[Parameter Content]:
|
|
||||||
|
|
||||||
- sql (string) [Required] - SQL statement to explain
|
|
||||||
|
|
||||||
- verbose (boolean) [Optional] - Whether to show verbose information, default is false
|
|
||||||
|
|
||||||
- db_name (string) [Optional] - Target database name, defaults to the current database
|
|
||||||
|
|
||||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
|
|
||||||
""",
|
|
||||||
inputSchema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"sql": {"type": "string", "description": "SQL statement to explain"},
|
|
||||||
"verbose": {"type": "boolean", "description": "Whether to show verbose information", "default": False},
|
|
||||||
"db_name": {"type": "string", "description": "Database name"},
|
|
||||||
"catalog_name": {"type": "string", "description": "Catalog name"},
|
|
||||||
},
|
|
||||||
"required": ["sql"],
|
|
||||||
},
|
|
||||||
),
|
|
||||||
Tool(
|
|
||||||
name="get_sql_profile",
|
|
||||||
description="""[Function Description]: Get SQL execution profile by setting trace ID and fetching profile via FE HTTP API.
|
|
||||||
|
|
||||||
[Parameter Content]:
|
|
||||||
|
|
||||||
- sql (string) [Required] - SQL statement to profile
|
|
||||||
|
|
||||||
- db_name (string) [Optional] - Target database name, defaults to the current database
|
|
||||||
|
|
||||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
|
|
||||||
|
|
||||||
- timeout (integer) [Optional] - Query timeout in seconds, default is 30
|
|
||||||
""",
|
|
||||||
inputSchema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"sql": {"type": "string", "description": "SQL statement to profile"},
|
|
||||||
"db_name": {"type": "string", "description": "Database name"},
|
|
||||||
"catalog_name": {"type": "string", "description": "Catalog name"},
|
|
||||||
"timeout": {"type": "integer", "description": "Query timeout in seconds", "default": 30},
|
|
||||||
},
|
|
||||||
"required": ["sql"],
|
|
||||||
},
|
|
||||||
),
|
|
||||||
Tool(
|
|
||||||
name="get_table_data_size",
|
|
||||||
description="""[Function Description]: Get table data size information via FE HTTP API.
|
|
||||||
|
|
||||||
[Parameter Content]:
|
|
||||||
|
|
||||||
- db_name (string) [Optional] - Database name, if not specified returns all databases
|
|
||||||
|
|
||||||
- table_name (string) [Optional] - Table name, if not specified returns all tables in the database
|
|
||||||
|
|
||||||
- single_replica (boolean) [Optional] - Whether to get single replica data size, default is false
|
|
||||||
""",
|
|
||||||
inputSchema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"db_name": {"type": "string", "description": "Database name"},
|
|
||||||
"table_name": {"type": "string", "description": "Table name"},
|
|
||||||
"single_replica": {"type": "boolean", "description": "Whether to get single replica data size", "default": False},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
|
||||||
Tool(
|
|
||||||
name="get_monitoring_metrics_info",
|
|
||||||
description="""[Function Description]: Get Doris monitoring metrics definitions and descriptions without executing queries.
|
|
||||||
|
|
||||||
[Parameter Content]:
|
|
||||||
|
|
||||||
- role (string) [Optional] - Node role to get metric definitions for, default is "all"
|
|
||||||
* "fe": Only FE metrics definitions
|
|
||||||
* "be": Only BE metrics definitions
|
|
||||||
* "all": Both FE and BE metrics definitions
|
|
||||||
|
|
||||||
- monitor_type (string) [Optional] - Type of monitoring metrics, default is "all"
|
|
||||||
* "process": Process monitoring metrics
|
|
||||||
* "jvm": JVM monitoring metrics (FE only)
|
|
||||||
* "machine": Machine monitoring metrics
|
|
||||||
* "all": All monitoring types
|
|
||||||
|
|
||||||
- priority (string) [Optional] - Metric priority level, default is "core"
|
|
||||||
* "core": Only core essential metrics (10-12 items for production use)
|
|
||||||
* "p0": Only P0 (highest priority) metrics definitions
|
|
||||||
* "all": All metrics definitions (P0 and non-P0)
|
|
||||||
""",
|
|
||||||
inputSchema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"role": {"type": "string", "enum": ["fe", "be", "all"], "description": "Node role to get metric definitions for", "default": "all"},
|
|
||||||
"monitor_type": {"type": "string", "enum": ["process", "jvm", "machine", "all"], "description": "Type of monitoring metrics", "default": "all"},
|
|
||||||
"priority": {"type": "string", "enum": ["core", "p0", "all"], "description": "Metric priority level", "default": "core"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
|
||||||
Tool(
|
|
||||||
name="get_monitoring_metrics_data",
|
|
||||||
description="""[Function Description]: Get actual Doris monitoring metrics data from FE and BE nodes via HTTP API.
|
|
||||||
|
|
||||||
[Parameter Content]:
|
|
||||||
|
|
||||||
- role (string) [Optional] - Node role to monitor, default is "all"
|
|
||||||
* "fe": Only FE nodes
|
|
||||||
* "be": Only BE nodes
|
|
||||||
* "all": Both FE and BE nodes
|
|
||||||
|
|
||||||
- monitor_type (string) [Optional] - Type of monitoring metrics, default is "all"
|
|
||||||
* "process": Process monitoring metrics
|
|
||||||
* "jvm": JVM monitoring metrics (FE only)
|
|
||||||
* "machine": Machine monitoring metrics
|
|
||||||
* "all": All monitoring types
|
|
||||||
|
|
||||||
- priority (string) [Optional] - Metric priority level, default is "core"
|
|
||||||
* "core": Only core essential metrics (10-12 items for production use)
|
|
||||||
* "p0": Only P0 (highest priority) metrics
|
|
||||||
* "all": All metrics (P0 and non-P0)
|
|
||||||
|
|
||||||
- include_raw_metrics (boolean) [Optional] - Whether to include raw detailed metrics data (can be very large)
|
|
||||||
""",
|
|
||||||
inputSchema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"role": {"type": "string", "enum": ["fe", "be", "all"], "description": "Node role to monitor", "default": "all"},
|
|
||||||
"monitor_type": {"type": "string", "enum": ["process", "jvm", "machine", "all"], "description": "Type of monitoring metrics", "default": "all"},
|
|
||||||
"priority": {"type": "string", "enum": ["core", "p0", "all"], "description": "Metric priority level", "default": "core"},
|
|
||||||
"include_raw_metrics": {"type": "boolean", "description": "Whether to include raw detailed metrics data (can be very large)", "default": False},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
|
||||||
Tool(
|
|
||||||
name="get_realtime_memory_stats",
|
|
||||||
description="""[Function Description]: Get real-time memory statistics via Doris BE Memory Tracker web interface.
|
|
||||||
|
|
||||||
[Parameter Content]:
|
|
||||||
|
|
||||||
- tracker_type (string) [Optional] - Type of memory trackers to retrieve, default is "overview"
|
|
||||||
* "overview": Overview type trackers (process memory, tracked memory summary)
|
|
||||||
* "global": Global shared memory trackers (cache, metadata)
|
|
||||||
* "query": Query-related memory trackers
|
|
||||||
* "load": Load-related memory trackers
|
|
||||||
* "compaction": Compaction-related memory trackers
|
|
||||||
* "all": All memory tracker types
|
|
||||||
|
|
||||||
- include_details (boolean) [Optional] - Whether to include detailed tracker information and definitions, default is true
|
|
||||||
""",
|
|
||||||
inputSchema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"tracker_type": {"type": "string", "enum": ["overview", "global", "query", "load", "compaction", "all"], "description": "Type of memory trackers to retrieve", "default": "overview"},
|
|
||||||
"include_details": {"type": "boolean", "description": "Whether to include detailed tracker information and definitions", "default": True},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
|
||||||
Tool(
|
|
||||||
name="get_historical_memory_stats",
|
|
||||||
description="""[Function Description]: Get historical memory statistics via Doris BE Bvar interface.
|
|
||||||
|
|
||||||
[Parameter Content]:
|
|
||||||
|
|
||||||
- tracker_names (array) [Optional] - List of specific tracker names to query, if not specified will get common trackers
|
|
||||||
* Example: ["process_resident_memory", "global", "query", "load", "compaction"]
|
|
||||||
|
|
||||||
- time_range (string) [Optional] - Time range for historical data, default is "1h"
|
|
||||||
* "1h": Last 1 hour
|
|
||||||
* "6h": Last 6 hours
|
|
||||||
* "24h": Last 24 hours
|
|
||||||
""",
|
|
||||||
inputSchema={
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"tracker_names": {"type": "array", "items": {"type": "string"}, "description": "List of specific tracker names to query"},
|
|
||||||
"time_range": {"type": "string", "enum": ["1h", "6h", "24h"], "description": "Time range for historical data", "default": "1h"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
|
||||||
]
|
]
|
||||||
|
|
||||||
return tools
|
return tools
|
||||||
@@ -860,7 +622,12 @@ class DorisToolsManager:
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# Tool routing - dispatch requests to corresponding business logic processors
|
# Tool routing - dispatch requests to corresponding business logic processors
|
||||||
if name == "exec_query":
|
if name == "column_analysis":
|
||||||
|
result = await self._column_analysis_tool(arguments)
|
||||||
|
elif name == "performance_stats":
|
||||||
|
result = await self._performance_stats_tool(arguments)
|
||||||
|
# ===== 9 tool routes migrated from source project =====
|
||||||
|
elif name == "exec_query":
|
||||||
result = await self._exec_query_tool(arguments)
|
result = await self._exec_query_tool(arguments)
|
||||||
elif name == "get_table_schema":
|
elif name == "get_table_schema":
|
||||||
result = await self._get_table_schema_tool(arguments)
|
result = await self._get_table_schema_tool(arguments)
|
||||||
@@ -878,20 +645,6 @@ class DorisToolsManager:
|
|||||||
result = await self._get_recent_audit_logs_tool(arguments)
|
result = await self._get_recent_audit_logs_tool(arguments)
|
||||||
elif name == "get_catalog_list":
|
elif name == "get_catalog_list":
|
||||||
result = await self._get_catalog_list_tool(arguments)
|
result = await self._get_catalog_list_tool(arguments)
|
||||||
elif name == "get_sql_explain":
|
|
||||||
result = await self._get_sql_explain_tool(arguments)
|
|
||||||
elif name == "get_sql_profile":
|
|
||||||
result = await self._get_sql_profile_tool(arguments)
|
|
||||||
elif name == "get_table_data_size":
|
|
||||||
result = await self._get_table_data_size_tool(arguments)
|
|
||||||
elif name == "get_monitoring_metrics_info":
|
|
||||||
result = await self._get_monitoring_metrics_info_tool(arguments)
|
|
||||||
elif name == "get_monitoring_metrics_data":
|
|
||||||
result = await self._get_monitoring_metrics_data_tool(arguments)
|
|
||||||
elif name == "get_realtime_memory_stats":
|
|
||||||
result = await self._get_realtime_memory_stats_tool(arguments)
|
|
||||||
elif name == "get_historical_memory_stats":
|
|
||||||
result = await self._get_historical_memory_stats_tool(arguments)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown tool: {name}")
|
raise ValueError(f"Unknown tool: {name}")
|
||||||
|
|
||||||
@@ -917,6 +670,28 @@ class DorisToolsManager:
|
|||||||
}
|
}
|
||||||
return json.dumps(error_result, ensure_ascii=False, indent=2)
|
return json.dumps(error_result, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
# The following are tool routing methods, responsible for calling corresponding business logic processors
|
||||||
|
|
||||||
|
async def _column_analysis_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Column statistical analysis tool routing"""
|
||||||
|
table_name = arguments.get("table_name")
|
||||||
|
column_name = arguments.get("column_name")
|
||||||
|
analysis_type = arguments.get("analysis_type", "basic")
|
||||||
|
|
||||||
|
# Delegate to table analyzer for processing
|
||||||
|
return await self.table_analyzer.analyze_column(
|
||||||
|
table_name, column_name, analysis_type
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _performance_stats_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Database performance statistics tool routing"""
|
||||||
|
metric_type = arguments.get("metric_type", "queries")
|
||||||
|
time_range = arguments.get("time_range", "1h")
|
||||||
|
|
||||||
|
# Delegate to performance monitor for processing
|
||||||
|
return await self.performance_monitor.get_performance_stats(
|
||||||
|
metric_type, time_range
|
||||||
|
)
|
||||||
|
|
||||||
async def _exec_query_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
async def _exec_query_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""SQL query execution tool routing (supports federation queries)"""
|
"""SQL query execution tool routing (supports federation queries)"""
|
||||||
@@ -1004,82 +779,4 @@ class DorisToolsManager:
|
|||||||
# Here we ignore it and directly call business logic
|
# Here we ignore it and directly call business logic
|
||||||
|
|
||||||
# Delegate to metadata extractor for processing
|
# Delegate to metadata extractor for processing
|
||||||
return await self.metadata_extractor.get_catalog_list_for_mcp()
|
return await self.metadata_extractor.get_catalog_list_for_mcp()
|
||||||
|
|
||||||
async def _get_sql_explain_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""SQL Explain tool routing"""
|
|
||||||
sql = arguments.get("sql")
|
|
||||||
verbose = arguments.get("verbose", False)
|
|
||||||
db_name = arguments.get("db_name")
|
|
||||||
catalog_name = arguments.get("catalog_name")
|
|
||||||
|
|
||||||
# Delegate to SQL analyzer for processing
|
|
||||||
return await self.sql_analyzer.get_sql_explain(
|
|
||||||
sql, verbose, db_name, catalog_name
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _get_sql_profile_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""SQL Profile tool routing"""
|
|
||||||
sql = arguments.get("sql")
|
|
||||||
db_name = arguments.get("db_name")
|
|
||||||
catalog_name = arguments.get("catalog_name")
|
|
||||||
timeout = arguments.get("timeout", 30)
|
|
||||||
|
|
||||||
# Delegate to SQL analyzer for processing
|
|
||||||
return await self.sql_analyzer.get_sql_profile(
|
|
||||||
sql, db_name, catalog_name, timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _get_table_data_size_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""Table data size tool routing"""
|
|
||||||
db_name = arguments.get("db_name")
|
|
||||||
table_name = arguments.get("table_name")
|
|
||||||
single_replica = arguments.get("single_replica", False)
|
|
||||||
|
|
||||||
# Delegate to SQL analyzer for processing
|
|
||||||
return await self.sql_analyzer.get_table_data_size(
|
|
||||||
db_name, table_name, single_replica
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _get_monitoring_metrics_info_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""Monitoring metrics info tool routing"""
|
|
||||||
role = arguments.get("role", "all")
|
|
||||||
monitor_type = arguments.get("monitor_type", "all")
|
|
||||||
priority = arguments.get("priority", "p0")
|
|
||||||
|
|
||||||
# Delegate to monitoring tools for processing (info_only=True)
|
|
||||||
return await self.monitoring_tools.get_monitoring_metrics(
|
|
||||||
role, monitor_type, priority, info_only=True, format_type="prometheus"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _get_monitoring_metrics_data_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""Monitoring metrics data tool routing"""
|
|
||||||
role = arguments.get("role", "all")
|
|
||||||
monitor_type = arguments.get("monitor_type", "all")
|
|
||||||
priority = arguments.get("priority", "p0")
|
|
||||||
include_raw_metrics = arguments.get("include_raw_metrics", False)
|
|
||||||
|
|
||||||
# Delegate to monitoring tools for processing (info_only=False)
|
|
||||||
return await self.monitoring_tools.get_monitoring_metrics(
|
|
||||||
role, monitor_type, priority, info_only=False, format_type="prometheus", include_raw_metrics=include_raw_metrics
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _get_realtime_memory_stats_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""Real-time memory statistics tool routing"""
|
|
||||||
tracker_type = arguments.get("tracker_type", "overview")
|
|
||||||
include_details = arguments.get("include_details", True)
|
|
||||||
|
|
||||||
# Delegate to memory tracker for processing
|
|
||||||
return await self.memory_tracker.get_realtime_memory_stats(
|
|
||||||
tracker_type, include_details
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _get_historical_memory_stats_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""Historical memory statistics tool routing"""
|
|
||||||
tracker_names = arguments.get("tracker_names")
|
|
||||||
time_range = arguments.get("time_range", "1h")
|
|
||||||
|
|
||||||
# Delegate to memory tracker for processing
|
|
||||||
return await self.memory_tracker.get_historical_memory_stats(
|
|
||||||
tracker_names, time_range
|
|
||||||
)
|
|
||||||
@@ -22,10 +22,6 @@ Provides data analysis functions including table analysis, column statistics, pe
|
|||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
import uuid
|
|
||||||
import aiohttp
|
|
||||||
import hashlib
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from .db import DorisConnectionManager
|
from .db import DorisConnectionManager
|
||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
@@ -335,906 +331,4 @@ class PerformanceMonitor:
|
|||||||
"limit": limit,
|
"limit": limit,
|
||||||
"order_by": order_by,
|
"order_by": order_by,
|
||||||
"note": "Query history feature requires audit log configuration"
|
"note": "Query history feature requires audit log configuration"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class SQLAnalyzer:
|
|
||||||
"""SQL analyzer for EXPLAIN and PROFILE operations"""
|
|
||||||
|
|
||||||
def __init__(self, connection_manager: DorisConnectionManager):
|
|
||||||
self.connection_manager = connection_manager
|
|
||||||
|
|
||||||
async def get_sql_explain(
|
|
||||||
self,
|
|
||||||
sql: str,
|
|
||||||
verbose: bool = False,
|
|
||||||
db_name: str = None,
|
|
||||||
catalog_name: str = None
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Get SQL execution plan using EXPLAIN command based on Doris syntax
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sql: SQL statement to explain
|
|
||||||
verbose: Whether to show verbose information
|
|
||||||
db_name: Target database name
|
|
||||||
catalog_name: Target catalog name
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict containing explain plan file path, content, and basic info
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Generate unique query ID for file naming
|
|
||||||
import time
|
|
||||||
query_hash = hashlib.md5(sql.encode()).hexdigest()[:8]
|
|
||||||
timestamp = int(time.time())
|
|
||||||
query_id = f"{timestamp}_{query_hash}"
|
|
||||||
|
|
||||||
# Ensure temp directory exists
|
|
||||||
temp_dir = Path(self.connection_manager.config.temp_files_dir)
|
|
||||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Create explain file path
|
|
||||||
explain_file = temp_dir / f"explain_{query_id}.txt"
|
|
||||||
|
|
||||||
logger.info(f"Generating SQL explain for query ID: {query_id}")
|
|
||||||
|
|
||||||
# Switch database if specified
|
|
||||||
if db_name:
|
|
||||||
await self.connection_manager.execute_query("explain_session", f"USE {db_name}")
|
|
||||||
|
|
||||||
# Construct EXPLAIN query
|
|
||||||
explain_type = "EXPLAIN VERBOSE" if verbose else "EXPLAIN"
|
|
||||||
explain_sql = f"{explain_type} {sql.strip().rstrip(';')}"
|
|
||||||
|
|
||||||
logger.info(f"Executing explain query: {explain_sql}")
|
|
||||||
|
|
||||||
# Execute explain query
|
|
||||||
result = await self.connection_manager.execute_query("explain_session", explain_sql)
|
|
||||||
|
|
||||||
# Format explain output
|
|
||||||
explain_content = []
|
|
||||||
explain_content.append(f"=== SQL EXPLAIN PLAN ===")
|
|
||||||
explain_content.append(f"Query ID: {query_id}")
|
|
||||||
explain_content.append(f"Timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}")
|
|
||||||
explain_content.append(f"Database: {db_name or 'current'}")
|
|
||||||
explain_content.append(f"Verbose: {verbose}")
|
|
||||||
explain_content.append("")
|
|
||||||
explain_content.append("=== ORIGINAL SQL ===")
|
|
||||||
explain_content.append(sql)
|
|
||||||
explain_content.append("")
|
|
||||||
explain_content.append("=== EXPLAIN QUERY ===")
|
|
||||||
explain_content.append(explain_sql)
|
|
||||||
explain_content.append("")
|
|
||||||
explain_content.append("=== EXECUTION PLAN ===")
|
|
||||||
|
|
||||||
if result and result.data:
|
|
||||||
for row in result.data:
|
|
||||||
if isinstance(row, dict):
|
|
||||||
# Handle dict format
|
|
||||||
for key, value in row.items():
|
|
||||||
explain_content.append(f"{key}: {value}")
|
|
||||||
elif isinstance(row, (list, tuple)):
|
|
||||||
# Handle tuple/list format
|
|
||||||
explain_content.append(" | ".join(str(col) for col in row))
|
|
||||||
else:
|
|
||||||
# Handle string format
|
|
||||||
explain_content.append(str(row))
|
|
||||||
else:
|
|
||||||
explain_content.append("No execution plan data returned")
|
|
||||||
|
|
||||||
explain_content.append("")
|
|
||||||
explain_content.append("=== METADATA ===")
|
|
||||||
explain_content.append(f"Execution time: {result.execution_time if result else 'N/A'} seconds")
|
|
||||||
explain_content.append(f"Rows returned: {len(result.data) if result and result.data else 0}")
|
|
||||||
|
|
||||||
# Get full content
|
|
||||||
full_content = '\n'.join(explain_content)
|
|
||||||
|
|
||||||
# Write to file
|
|
||||||
with open(explain_file, 'w', encoding='utf-8') as f:
|
|
||||||
f.write(full_content)
|
|
||||||
|
|
||||||
logger.info(f"Explain plan saved to: {explain_file.absolute()}")
|
|
||||||
|
|
||||||
# Get max response size from config
|
|
||||||
max_size = self.connection_manager.config.performance.max_response_content_size
|
|
||||||
|
|
||||||
# Truncate content if needed
|
|
||||||
truncated_content = full_content
|
|
||||||
is_truncated = False
|
|
||||||
if len(full_content) > max_size:
|
|
||||||
truncated_content = full_content[:max_size] + "\n\n=== CONTENT TRUNCATED ===\n[Content is truncated due to size limit. Full content is saved to file.]"
|
|
||||||
is_truncated = True
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"query_id": query_id,
|
|
||||||
"explain_file_path": str(explain_file.absolute()),
|
|
||||||
"file_size_bytes": explain_file.stat().st_size,
|
|
||||||
"content": truncated_content,
|
|
||||||
"content_size": len(truncated_content),
|
|
||||||
"is_content_truncated": is_truncated,
|
|
||||||
"original_content_size": len(full_content),
|
|
||||||
"sql_preview": sql[:100] + "..." if len(sql) > 100 else sql,
|
|
||||||
"verbose": verbose,
|
|
||||||
"database": db_name,
|
|
||||||
"catalog": catalog_name,
|
|
||||||
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S'),
|
|
||||||
"execution_time": result.execution_time if result else None,
|
|
||||||
"plan_lines_count": len(result.data) if result and result.data else 0
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to get SQL explain: {str(e)}")
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": f"Failed to get SQL explain: {str(e)}",
|
|
||||||
"sql_preview": sql[:100] + "..." if len(sql) > 100 else sql,
|
|
||||||
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S')
|
|
||||||
}
|
|
||||||
|
|
||||||
async def get_sql_profile(
|
|
||||||
self,
|
|
||||||
sql: str,
|
|
||||||
db_name: str = None,
|
|
||||||
catalog_name: str = None,
|
|
||||||
timeout: int = 30
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Get SQL execution profile by setting trace ID and fetching profile via HTTP API
|
|
||||||
|
|
||||||
Args:
|
|
||||||
sql: SQL statement to profile
|
|
||||||
db_name: Target database name
|
|
||||||
catalog_name: Target catalog name
|
|
||||||
timeout: Query timeout in seconds
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict containing profile file path, content, and basic info
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Generate unique trace ID and query ID for file naming
|
|
||||||
trace_id = str(uuid.uuid4())
|
|
||||||
import time
|
|
||||||
query_hash = hashlib.md5(sql.encode()).hexdigest()[:8]
|
|
||||||
timestamp = int(time.time())
|
|
||||||
file_query_id = f"{timestamp}_{query_hash}"
|
|
||||||
|
|
||||||
# Ensure temp directory exists
|
|
||||||
temp_dir = Path(self.connection_manager.config.temp_files_dir)
|
|
||||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Create profile file path
|
|
||||||
profile_file = temp_dir / f"profile_{file_query_id}.txt"
|
|
||||||
|
|
||||||
logger.info(f"Generated trace ID for SQL profiling: {trace_id}")
|
|
||||||
logger.info(f"Profile will be saved to: {profile_file}")
|
|
||||||
|
|
||||||
connection = await self.connection_manager.get_connection("query")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Switch to specified database/catalog if provided
|
|
||||||
if catalog_name:
|
|
||||||
await connection.execute(f"USE `{catalog_name}`")
|
|
||||||
if db_name:
|
|
||||||
await connection.execute(f"USE `{db_name}`")
|
|
||||||
|
|
||||||
# Set trace ID for the session using session variable
|
|
||||||
# According to official docs: set session_context="trace_id:your_trace_id"
|
|
||||||
await connection.execute(f'set session_context="trace_id:{trace_id}"')
|
|
||||||
logger.info(f"Set trace ID: {trace_id}")
|
|
||||||
|
|
||||||
# Enable profile
|
|
||||||
await connection.execute(f'set enable_profile=true')
|
|
||||||
logger.info(f"Enabled profile")
|
|
||||||
|
|
||||||
# Execute the SQL statement
|
|
||||||
logger.info(f"Executing SQL with trace ID: {sql}")
|
|
||||||
start_time = time.time()
|
|
||||||
sql_result = await connection.execute(sql)
|
|
||||||
execution_time = time.time() - start_time
|
|
||||||
logger.info(f"SQL execution completed in {execution_time:.3f}s")
|
|
||||||
|
|
||||||
# Get query ID from trace ID via HTTP API
|
|
||||||
query_id = await self._get_query_id_by_trace_id(trace_id)
|
|
||||||
if not query_id:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": "Failed to get query ID from trace ID",
|
|
||||||
"trace_id": trace_id,
|
|
||||||
"sql": sql,
|
|
||||||
"execution_time": execution_time
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(f"Retrieved query ID: {query_id}")
|
|
||||||
|
|
||||||
# Get profile data
|
|
||||||
profile_data = await self._get_profile_by_query_id(query_id)
|
|
||||||
|
|
||||||
if not profile_data:
|
|
||||||
# Save error info to file
|
|
||||||
profile_content = [
|
|
||||||
f"=== SQL PROFILE RESULT ===",
|
|
||||||
f"File Query ID: {file_query_id}",
|
|
||||||
f"Trace ID: {trace_id}",
|
|
||||||
f"Query ID: {query_id}",
|
|
||||||
f"Timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}",
|
|
||||||
f"Database: {db_name or 'current'}",
|
|
||||||
f"Status: FAILED",
|
|
||||||
"",
|
|
||||||
"=== ORIGINAL SQL ===",
|
|
||||||
sql,
|
|
||||||
"",
|
|
||||||
"=== ERROR INFO ===",
|
|
||||||
"Failed to get profile data. This may be due to:",
|
|
||||||
"1) Profile data not generated yet",
|
|
||||||
"2) Query ID expired",
|
|
||||||
"3) Insufficient permissions to access profile data",
|
|
||||||
"",
|
|
||||||
"=== EXECUTION INFO ===",
|
|
||||||
f"Query execution: SUCCESSFUL",
|
|
||||||
f"Execution time: {execution_time:.3f} seconds",
|
|
||||||
f"Note: Query execution was successful, but profile data is not available"
|
|
||||||
]
|
|
||||||
|
|
||||||
# Get full content
|
|
||||||
full_profile_content = '\n'.join(profile_content)
|
|
||||||
|
|
||||||
with open(profile_file, 'w', encoding='utf-8') as f:
|
|
||||||
f.write(full_profile_content)
|
|
||||||
|
|
||||||
# Get max response size from config
|
|
||||||
max_size = self.connection_manager.config.performance.max_response_content_size
|
|
||||||
|
|
||||||
# Truncate content if needed
|
|
||||||
truncated_content = full_profile_content
|
|
||||||
is_truncated = False
|
|
||||||
if len(full_profile_content) > max_size:
|
|
||||||
truncated_content = full_profile_content[:max_size] + "\n\n=== CONTENT TRUNCATED ===\n[Content is truncated due to size limit. Full content is saved to file.]"
|
|
||||||
is_truncated = True
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"file_query_id": file_query_id,
|
|
||||||
"trace_id": trace_id,
|
|
||||||
"query_id": query_id,
|
|
||||||
"profile_file_path": str(profile_file.absolute()),
|
|
||||||
"file_size_bytes": profile_file.stat().st_size,
|
|
||||||
"content": truncated_content,
|
|
||||||
"content_size": len(truncated_content),
|
|
||||||
"is_content_truncated": is_truncated,
|
|
||||||
"original_content_size": len(full_profile_content),
|
|
||||||
"sql_preview": sql[:100] + "..." if len(sql) > 100 else sql,
|
|
||||||
"execution_time": execution_time,
|
|
||||||
"error": "Failed to get profile data",
|
|
||||||
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S')
|
|
||||||
}
|
|
||||||
|
|
||||||
# Format profile output
|
|
||||||
profile_content = []
|
|
||||||
profile_content.append(f"=== SQL PROFILE RESULT ===")
|
|
||||||
profile_content.append(f"File Query ID: {file_query_id}")
|
|
||||||
profile_content.append(f"Trace ID: {trace_id}")
|
|
||||||
profile_content.append(f"Query ID: {query_id}")
|
|
||||||
profile_content.append(f"Timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}")
|
|
||||||
profile_content.append(f"Database: {db_name or 'current'}")
|
|
||||||
profile_content.append(f"Status: SUCCESS")
|
|
||||||
profile_content.append("")
|
|
||||||
profile_content.append("=== ORIGINAL SQL ===")
|
|
||||||
profile_content.append(sql)
|
|
||||||
profile_content.append("")
|
|
||||||
profile_content.append("=== EXECUTION INFO ===")
|
|
||||||
profile_content.append(f"Execution time: {execution_time:.3f} seconds")
|
|
||||||
if hasattr(sql_result, 'data') and sql_result.data:
|
|
||||||
profile_content.append(f"Result rows: {len(sql_result.data)}")
|
|
||||||
if sql_result.data and sql_result.data[0]:
|
|
||||||
profile_content.append(f"Result columns: {list(sql_result.data[0].keys())}")
|
|
||||||
profile_content.append("")
|
|
||||||
profile_content.append("=== PROFILE DATA ===")
|
|
||||||
|
|
||||||
if isinstance(profile_data, dict):
|
|
||||||
import json
|
|
||||||
profile_content.append(json.dumps(profile_data, indent=2, ensure_ascii=False))
|
|
||||||
else:
|
|
||||||
profile_content.append(str(profile_data))
|
|
||||||
|
|
||||||
# Get full content
|
|
||||||
full_profile_content = '\n'.join(profile_content)
|
|
||||||
|
|
||||||
# Write to file
|
|
||||||
with open(profile_file, 'w', encoding='utf-8') as f:
|
|
||||||
f.write(full_profile_content)
|
|
||||||
|
|
||||||
logger.info(f"Profile data saved to: {profile_file.absolute()}")
|
|
||||||
|
|
||||||
# Get max response size from config
|
|
||||||
max_size = self.connection_manager.config.performance.max_response_content_size
|
|
||||||
|
|
||||||
# Truncate content if needed
|
|
||||||
truncated_content = full_profile_content
|
|
||||||
is_truncated = False
|
|
||||||
if len(full_profile_content) > max_size:
|
|
||||||
truncated_content = full_profile_content[:max_size] + "\n\n=== CONTENT TRUNCATED ===\n[Content is truncated due to size limit. Full content is saved to file.]"
|
|
||||||
is_truncated = True
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"file_query_id": file_query_id,
|
|
||||||
"trace_id": trace_id,
|
|
||||||
"query_id": query_id,
|
|
||||||
"profile_file_path": str(profile_file.absolute()),
|
|
||||||
"file_size_bytes": profile_file.stat().st_size,
|
|
||||||
"content": truncated_content,
|
|
||||||
"content_size": len(truncated_content),
|
|
||||||
"is_content_truncated": is_truncated,
|
|
||||||
"original_content_size": len(full_profile_content),
|
|
||||||
"sql_preview": sql[:100] + "..." if len(sql) > 100 else sql,
|
|
||||||
"database": db_name,
|
|
||||||
"catalog": catalog_name,
|
|
||||||
"execution_time": execution_time,
|
|
||||||
"sql_result_summary": {
|
|
||||||
"row_count": len(sql_result.data) if hasattr(sql_result, 'data') and sql_result.data else 0,
|
|
||||||
"columns": list(sql_result.data[0].keys()) if hasattr(sql_result, 'data') and sql_result.data and sql_result.data[0] else []
|
|
||||||
},
|
|
||||||
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S')
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error during SQL execution or profile retrieval: {str(e)}")
|
|
||||||
# Save error info to file
|
|
||||||
profile_content = [
|
|
||||||
f"=== SQL PROFILE RESULT ===",
|
|
||||||
f"File Query ID: {file_query_id}",
|
|
||||||
f"Trace ID: {trace_id}",
|
|
||||||
f"Timestamp: {time.strftime('%Y-%m-%d %H:%M:%S')}",
|
|
||||||
f"Database: {db_name or 'current'}",
|
|
||||||
f"Status: ERROR",
|
|
||||||
"",
|
|
||||||
"=== ORIGINAL SQL ===",
|
|
||||||
sql,
|
|
||||||
"",
|
|
||||||
"=== ERROR INFO ===",
|
|
||||||
f"SQL execution or profile retrieval failed: {str(e)}",
|
|
||||||
"",
|
|
||||||
"=== EXECUTION INFO ===",
|
|
||||||
"Query execution failed during profiling process"
|
|
||||||
]
|
|
||||||
|
|
||||||
# Get full content
|
|
||||||
full_profile_content = '\n'.join(profile_content)
|
|
||||||
|
|
||||||
with open(profile_file, 'w', encoding='utf-8') as f:
|
|
||||||
f.write(full_profile_content)
|
|
||||||
|
|
||||||
# Get max response size from config
|
|
||||||
max_size = self.connection_manager.config.performance.max_response_content_size
|
|
||||||
|
|
||||||
# Truncate content if needed
|
|
||||||
truncated_content = full_profile_content
|
|
||||||
is_truncated = False
|
|
||||||
if len(full_profile_content) > max_size:
|
|
||||||
truncated_content = full_profile_content[:max_size] + "\n\n=== CONTENT TRUNCATED ===\n[Content is truncated due to size limit. Full content is saved to file.]"
|
|
||||||
is_truncated = True
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"file_query_id": file_query_id,
|
|
||||||
"trace_id": trace_id,
|
|
||||||
"profile_file_path": str(profile_file.absolute()),
|
|
||||||
"file_size_bytes": profile_file.stat().st_size,
|
|
||||||
"content": truncated_content,
|
|
||||||
"content_size": len(truncated_content),
|
|
||||||
"is_content_truncated": is_truncated,
|
|
||||||
"original_content_size": len(full_profile_content),
|
|
||||||
"sql_preview": sql[:100] + "..." if len(sql) > 100 else sql,
|
|
||||||
"error": f"SQL execution or profile retrieval failed: {str(e)}",
|
|
||||||
"database": db_name,
|
|
||||||
"catalog": catalog_name,
|
|
||||||
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S')
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"SQL PROFILE failed: {str(e)}")
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": f"SQL PROFILE failed: {str(e)}",
|
|
||||||
"sql_preview": sql[:100] + "..." if len(sql) > 100 else sql,
|
|
||||||
"database": db_name,
|
|
||||||
"catalog": catalog_name,
|
|
||||||
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S')
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _get_query_id_by_trace_id(self, trace_id: str) -> str:
|
|
||||||
"""
|
|
||||||
Get query ID by trace ID via FE HTTP API
|
|
||||||
|
|
||||||
Args:
|
|
||||||
trace_id: The trace ID set during query execution
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Query ID string or None if not found
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Get database config
|
|
||||||
db_config = self.connection_manager.config.database
|
|
||||||
|
|
||||||
# Build HTTP API URL according to official documentation
|
|
||||||
# Reference: https://doris.apache.org/zh-CN/docs/admin-manual/open-api/fe-http/query-profile-action#通过-trace-id-获取-query-id
|
|
||||||
url = f"http://{db_config.host}:{db_config.fe_http_port}/rest/v2/manager/query/trace_id/{trace_id}"
|
|
||||||
|
|
||||||
# HTTP Basic Auth
|
|
||||||
auth = aiohttp.BasicAuth(db_config.user, db_config.password)
|
|
||||||
|
|
||||||
logger.info(f"Requesting query ID from: {url}")
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(url, auth=auth, timeout=10) as response:
|
|
||||||
if response.status == 200:
|
|
||||||
# Check content type first
|
|
||||||
content_type = response.headers.get('content-type', '')
|
|
||||||
response_text = await response.text()
|
|
||||||
logger.info(f"Response content type: {content_type}")
|
|
||||||
logger.info(f"Response body: {response_text}")
|
|
||||||
|
|
||||||
# Parse JSON response (regardless of content-type)
|
|
||||||
if response_text.strip():
|
|
||||||
try:
|
|
||||||
import json
|
|
||||||
result = json.loads(response_text)
|
|
||||||
logger.info(f"Query ID API response: {result}")
|
|
||||||
|
|
||||||
# Parse response according to Doris API format
|
|
||||||
if result.get("code") == 0 and result.get("data"):
|
|
||||||
data = result["data"]
|
|
||||||
# Data can be either a string (query_id) or object with query_ids
|
|
||||||
if isinstance(data, str):
|
|
||||||
logger.info(f"Found query ID: {data}")
|
|
||||||
return data
|
|
||||||
elif isinstance(data, dict) and "query_ids" in data:
|
|
||||||
query_ids = data["query_ids"]
|
|
||||||
if query_ids:
|
|
||||||
query_id = query_ids[0] # Take the first query ID
|
|
||||||
logger.info(f"Found query ID: {query_id}")
|
|
||||||
return query_id
|
|
||||||
else:
|
|
||||||
logger.warning("No query IDs found in response")
|
|
||||||
else:
|
|
||||||
logger.error(f"API returned error: {result}")
|
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.error(f"Failed to parse JSON response: {e}")
|
|
||||||
# Fallback: try to extract query ID using regex
|
|
||||||
import re
|
|
||||||
query_id_pattern = r'[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}'
|
|
||||||
matches = re.findall(query_id_pattern, response_text)
|
|
||||||
if matches:
|
|
||||||
query_id = matches[0]
|
|
||||||
logger.info(f"Extracted query ID from text: {query_id}")
|
|
||||||
return query_id
|
|
||||||
else:
|
|
||||||
logger.error(f"HTTP request failed with status {response.status}")
|
|
||||||
response_text = await response.text()
|
|
||||||
logger.error(f"Response body: {response_text}")
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to get query ID by trace ID: {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _get_profile_by_query_id(self, query_id: str) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Get profile data by query ID via FE HTTP API
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query_id: The query ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Profile data dict or None if failed
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Get database config
|
|
||||||
db_config = self.connection_manager.config.database
|
|
||||||
|
|
||||||
# Try both API endpoints according to official documentation
|
|
||||||
urls = [
|
|
||||||
f"http://{db_config.host}:{db_config.fe_http_port}/rest/v2/manager/query/profile/text/{query_id}",
|
|
||||||
f"http://{db_config.host}:{db_config.fe_http_port}/api/profile/text?query_id={query_id}"
|
|
||||||
]
|
|
||||||
|
|
||||||
# HTTP Basic Auth
|
|
||||||
auth = aiohttp.BasicAuth(db_config.user, db_config.password)
|
|
||||||
|
|
||||||
for i, url in enumerate(urls):
|
|
||||||
logger.info(f"Requesting profile from URL {i+1}: {url}")
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(url, auth=auth, timeout=60) as response:
|
|
||||||
if response.status == 200:
|
|
||||||
content_type = response.headers.get('content-type', '')
|
|
||||||
response_text = await response.text()
|
|
||||||
logger.info(f"Profile response content type: {content_type}")
|
|
||||||
logger.info(f"Profile response length: {len(response_text)}")
|
|
||||||
|
|
||||||
# Handle JSON response
|
|
||||||
if 'application/json' in content_type:
|
|
||||||
try:
|
|
||||||
result = await response.json()
|
|
||||||
logger.info(f"Profile JSON response: {result}")
|
|
||||||
|
|
||||||
if result.get("code") == 0 and result.get("data"):
|
|
||||||
profile_text = result["data"].get("profile", "")
|
|
||||||
return {
|
|
||||||
"query_id": query_id,
|
|
||||||
"profile_text": profile_text,
|
|
||||||
"profile_size": len(profile_text),
|
|
||||||
"retrieved_at": datetime.now().isoformat(),
|
|
||||||
"api_endpoint": url
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
logger.warning(f"Profile API returned error: {result}")
|
|
||||||
continue # Try next URL
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to parse profile JSON: {e}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Handle plain text response
|
|
||||||
else:
|
|
||||||
if response_text.strip() and "not found" not in response_text.lower():
|
|
||||||
return {
|
|
||||||
"query_id": query_id,
|
|
||||||
"profile_text": response_text,
|
|
||||||
"profile_size": len(response_text),
|
|
||||||
"retrieved_at": datetime.now().isoformat(),
|
|
||||||
"api_endpoint": url
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
logger.warning(f"Profile not found or empty: {response_text}")
|
|
||||||
continue # Try next URL
|
|
||||||
|
|
||||||
elif response.status == 404:
|
|
||||||
logger.warning(f"Profile not found (404) at {url}")
|
|
||||||
continue # Try next URL
|
|
||||||
else:
|
|
||||||
logger.error(f"Profile HTTP request failed with status {response.status} at {url}")
|
|
||||||
response_text = await response.text()
|
|
||||||
logger.error(f"Response body: {response_text}")
|
|
||||||
continue # Try next URL
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to get profile by query ID: {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def get_table_data_size(
|
|
||||||
self,
|
|
||||||
db_name: str = None,
|
|
||||||
table_name: str = None,
|
|
||||||
single_replica: bool = False
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Get table data size information via FE HTTP API
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db_name: Database name, if not specified returns all databases
|
|
||||||
table_name: Table name, if not specified returns all tables in the database
|
|
||||||
single_replica: Whether to get single replica data size
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict containing table data size information
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Get database config
|
|
||||||
db_config = self.connection_manager.config.database
|
|
||||||
|
|
||||||
# Build HTTP API URL according to official documentation
|
|
||||||
# Reference: https://doris.apache.org/zh-CN/docs/admin-manual/open-api/fe-http/show-table-data-action
|
|
||||||
url = f"http://{db_config.host}:{db_config.fe_http_port}/api/show_table_data"
|
|
||||||
|
|
||||||
# Build query parameters
|
|
||||||
params = {}
|
|
||||||
if db_name:
|
|
||||||
params["db"] = db_name
|
|
||||||
if table_name:
|
|
||||||
params["table"] = table_name
|
|
||||||
if single_replica:
|
|
||||||
params["single_replica"] = "true"
|
|
||||||
|
|
||||||
# HTTP Basic Auth
|
|
||||||
auth = aiohttp.BasicAuth(db_config.user, db_config.password)
|
|
||||||
|
|
||||||
logger.info(f"Requesting table data size from: {url} with params: {params}")
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.get(url, auth=auth, params=params, timeout=30) as response:
|
|
||||||
if response.status == 200:
|
|
||||||
response_text = await response.text()
|
|
||||||
logger.info(f"Table data size response length: {len(response_text)}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Parse JSON response
|
|
||||||
import json
|
|
||||||
result = json.loads(response_text)
|
|
||||||
|
|
||||||
if result.get("code") == 0 and result.get("data"):
|
|
||||||
data = result["data"]
|
|
||||||
|
|
||||||
# Process and format the data
|
|
||||||
formatted_data = self._format_table_data_size(data, db_name, table_name, single_replica)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"db_name": db_name,
|
|
||||||
"table_name": table_name,
|
|
||||||
"single_replica": single_replica,
|
|
||||||
"timestamp": datetime.now().isoformat(),
|
|
||||||
"data": formatted_data,
|
|
||||||
"url": url,
|
|
||||||
"note": "Table data size information from Doris FE HTTP API"
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": f"API returned error: {result}",
|
|
||||||
"db_name": db_name,
|
|
||||||
"table_name": table_name,
|
|
||||||
"url": url,
|
|
||||||
"timestamp": datetime.now().isoformat()
|
|
||||||
}
|
|
||||||
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.error(f"Failed to parse JSON response: {e}")
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": f"Failed to parse JSON response: {e}",
|
|
||||||
"response_text": response_text[:500], # First 500 chars for debugging
|
|
||||||
"url": url,
|
|
||||||
"timestamp": datetime.now().isoformat()
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
logger.error(f"HTTP request failed with status {response.status}")
|
|
||||||
response_text = await response.text()
|
|
||||||
logger.error(f"Response body: {response_text}")
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": f"HTTP request failed with status {response.status}",
|
|
||||||
"response_text": response_text[:500], # First 500 chars for debugging
|
|
||||||
"url": url,
|
|
||||||
"timestamp": datetime.now().isoformat()
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Table data size request failed: {str(e)}")
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": f"Table data size request failed: {str(e)}",
|
|
||||||
"db_name": db_name,
|
|
||||||
"table_name": table_name,
|
|
||||||
"timestamp": datetime.now().isoformat()
|
|
||||||
}
|
|
||||||
|
|
||||||
def _format_table_data_size(self, data: Dict[str, Any], db_name: str, table_name: str, single_replica: bool) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Format table data size response data
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: Raw response data from API
|
|
||||||
db_name: Database name filter
|
|
||||||
table_name: Table name filter
|
|
||||||
single_replica: Single replica flag
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted data structure
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
formatted = {
|
|
||||||
"summary": {
|
|
||||||
"total_databases": 0,
|
|
||||||
"total_tables": 0,
|
|
||||||
"total_size_bytes": 0,
|
|
||||||
"total_size_formatted": "0 B",
|
|
||||||
"single_replica": single_replica,
|
|
||||||
"query_filters": {
|
|
||||||
"db_name": db_name,
|
|
||||||
"table_name": table_name
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"databases": {}
|
|
||||||
}
|
|
||||||
|
|
||||||
# Process the data based on its structure
|
|
||||||
if isinstance(data, list):
|
|
||||||
# Data is a list of table records
|
|
||||||
for record in data:
|
|
||||||
db = record.get("database", "unknown")
|
|
||||||
table = record.get("table", "unknown")
|
|
||||||
size_bytes = int(record.get("size", 0))
|
|
||||||
|
|
||||||
if db not in formatted["databases"]:
|
|
||||||
formatted["databases"][db] = {
|
|
||||||
"database_name": db,
|
|
||||||
"table_count": 0,
|
|
||||||
"total_size_bytes": 0,
|
|
||||||
"total_size_formatted": "0 B",
|
|
||||||
"tables": {}
|
|
||||||
}
|
|
||||||
|
|
||||||
formatted["databases"][db]["tables"][table] = {
|
|
||||||
"table_name": table,
|
|
||||||
"size_bytes": size_bytes,
|
|
||||||
"size_formatted": self._format_bytes(size_bytes),
|
|
||||||
"replica_count": record.get("replica_count", 1),
|
|
||||||
"details": record
|
|
||||||
}
|
|
||||||
|
|
||||||
formatted["databases"][db]["table_count"] += 1
|
|
||||||
formatted["databases"][db]["total_size_bytes"] += size_bytes
|
|
||||||
formatted["summary"]["total_size_bytes"] += size_bytes
|
|
||||||
|
|
||||||
elif isinstance(data, dict):
|
|
||||||
# Data is a dict with database structure
|
|
||||||
for db, db_info in data.items():
|
|
||||||
if isinstance(db_info, dict) and "tables" in db_info:
|
|
||||||
formatted["databases"][db] = {
|
|
||||||
"database_name": db,
|
|
||||||
"table_count": len(db_info["tables"]),
|
|
||||||
"total_size_bytes": 0,
|
|
||||||
"total_size_formatted": "0 B",
|
|
||||||
"tables": {}
|
|
||||||
}
|
|
||||||
|
|
||||||
for table, table_info in db_info["tables"].items():
|
|
||||||
size_bytes = int(table_info.get("size", 0))
|
|
||||||
formatted["databases"][db]["tables"][table] = {
|
|
||||||
"table_name": table,
|
|
||||||
"size_bytes": size_bytes,
|
|
||||||
"size_formatted": self._format_bytes(size_bytes),
|
|
||||||
"replica_count": table_info.get("replica_count", 1),
|
|
||||||
"details": table_info
|
|
||||||
}
|
|
||||||
formatted["databases"][db]["total_size_bytes"] += size_bytes
|
|
||||||
formatted["summary"]["total_size_bytes"] += size_bytes
|
|
||||||
|
|
||||||
# Update summary
|
|
||||||
formatted["summary"]["total_databases"] = len(formatted["databases"])
|
|
||||||
formatted["summary"]["total_tables"] = sum(db["table_count"] for db in formatted["databases"].values())
|
|
||||||
formatted["summary"]["total_size_formatted"] = self._format_bytes(formatted["summary"]["total_size_bytes"])
|
|
||||||
|
|
||||||
# Update database totals formatting
|
|
||||||
for db_info in formatted["databases"].values():
|
|
||||||
db_info["total_size_formatted"] = self._format_bytes(db_info["total_size_bytes"])
|
|
||||||
|
|
||||||
return formatted
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to format table data size: {str(e)}")
|
|
||||||
return {
|
|
||||||
"error": f"Failed to format data: {str(e)}",
|
|
||||||
"raw_data": data
|
|
||||||
}
|
|
||||||
|
|
||||||
def _format_bytes(self, bytes_value: int) -> str:
|
|
||||||
"""
|
|
||||||
Format bytes value to human readable string
|
|
||||||
|
|
||||||
Args:
|
|
||||||
bytes_value: Bytes value
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted string like "1.23 GB"
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
bytes_value = int(bytes_value)
|
|
||||||
if bytes_value == 0:
|
|
||||||
return "0 B"
|
|
||||||
|
|
||||||
units = ["B", "KB", "MB", "GB", "TB", "PB"]
|
|
||||||
unit_index = 0
|
|
||||||
size = float(bytes_value)
|
|
||||||
|
|
||||||
while size >= 1024 and unit_index < len(units) - 1:
|
|
||||||
size /= 1024
|
|
||||||
unit_index += 1
|
|
||||||
|
|
||||||
if unit_index == 0:
|
|
||||||
return f"{int(size)} {units[unit_index]}"
|
|
||||||
else:
|
|
||||||
return f"{size:.2f} {units[unit_index]}"
|
|
||||||
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
return str(bytes_value)
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryTracker:
|
|
||||||
"""Memory tracker for Doris BE memory monitoring"""
|
|
||||||
|
|
||||||
def __init__(self, connection_manager: DorisConnectionManager):
|
|
||||||
self.connection_manager = connection_manager
|
|
||||||
|
|
||||||
async def get_realtime_memory_stats(
|
|
||||||
self,
|
|
||||||
tracker_type: str = "overview",
|
|
||||||
include_details: bool = True
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Get real-time memory statistics
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tracker_type: Type of memory trackers to retrieve
|
|
||||||
include_details: Whether to include detailed information
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict containing memory statistics
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# This is a placeholder implementation
|
|
||||||
# In a real implementation, this would fetch data from Doris BE memory tracker endpoints
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"tracker_type": tracker_type,
|
|
||||||
"include_details": include_details,
|
|
||||||
"timestamp": datetime.now().isoformat(),
|
|
||||||
"memory_stats": {
|
|
||||||
"total_memory": "8.00 GB",
|
|
||||||
"used_memory": "4.50 GB",
|
|
||||||
"free_memory": "3.50 GB",
|
|
||||||
"memory_usage_percent": 56.25
|
|
||||||
},
|
|
||||||
"note": "Memory tracker functionality requires BE HTTP endpoints to be available"
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to get realtime memory stats: {str(e)}")
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": f"Failed to get realtime memory stats: {str(e)}",
|
|
||||||
"tracker_type": tracker_type,
|
|
||||||
"timestamp": datetime.now().isoformat()
|
|
||||||
}
|
|
||||||
|
|
||||||
async def get_historical_memory_stats(
|
|
||||||
self,
|
|
||||||
tracker_names: List[str] = None,
|
|
||||||
time_range: str = "1h"
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Get historical memory statistics
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tracker_names: List of specific tracker names to query
|
|
||||||
time_range: Time range for historical data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict containing historical memory statistics
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# This is a placeholder implementation
|
|
||||||
# In a real implementation, this would fetch historical data from Doris BE bvar endpoints
|
|
||||||
return {
|
|
||||||
"success": True,
|
|
||||||
"tracker_names": tracker_names,
|
|
||||||
"time_range": time_range,
|
|
||||||
"timestamp": datetime.now().isoformat(),
|
|
||||||
"historical_stats": {
|
|
||||||
"data_points": 60,
|
|
||||||
"interval": "1m",
|
|
||||||
"memory_trend": "stable",
|
|
||||||
"avg_usage": "4.2 GB",
|
|
||||||
"peak_usage": "5.1 GB",
|
|
||||||
"min_usage": "3.8 GB"
|
|
||||||
},
|
|
||||||
"note": "Historical memory tracking functionality requires BE bvar endpoints to be available"
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to get historical memory stats: {str(e)}")
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": f"Failed to get historical memory stats: {str(e)}",
|
|
||||||
"tracker_names": tracker_names,
|
|
||||||
"time_range": time_range,
|
|
||||||
"timestamp": datetime.now().isoformat()
|
|
||||||
}
|
|
||||||
@@ -41,16 +41,8 @@ class DatabaseConfig:
|
|||||||
port: int = 9030
|
port: int = 9030
|
||||||
user: str = "root"
|
user: str = "root"
|
||||||
password: str = ""
|
password: str = ""
|
||||||
database: str = "information_schema"
|
database: str = "test"
|
||||||
charset: str = "UTF8"
|
charset: str = "utf8mb4"
|
||||||
|
|
||||||
# FE HTTP API port for profile and other HTTP APIs
|
|
||||||
fe_http_port: int = 8030
|
|
||||||
|
|
||||||
# BE nodes configuration for external access
|
|
||||||
# If be_hosts is empty, will use "show backends" to get BE nodes
|
|
||||||
be_hosts: list[str] = field(default_factory=list)
|
|
||||||
be_webserver_port: int = 8040
|
|
||||||
|
|
||||||
# Connection pool configuration
|
# Connection pool configuration
|
||||||
min_connections: int = 5
|
min_connections: int = 5
|
||||||
@@ -70,26 +62,17 @@ class SecurityConfig:
|
|||||||
token_expiry: int = 3600
|
token_expiry: int = 3600
|
||||||
|
|
||||||
# SQL security configuration
|
# SQL security configuration
|
||||||
enable_security_check: bool = True # Main switch: whether to enable SQL security check
|
|
||||||
blocked_keywords: list[str] = field(
|
blocked_keywords: list[str] = field(
|
||||||
default_factory=lambda: [
|
default_factory=lambda: [
|
||||||
# DDL Operations (Data Definition Language)
|
|
||||||
"DROP",
|
"DROP",
|
||||||
"CREATE",
|
|
||||||
"ALTER",
|
|
||||||
"TRUNCATE",
|
|
||||||
# DML Operations (Data Manipulation Language)
|
|
||||||
"DELETE",
|
"DELETE",
|
||||||
|
"TRUNCATE",
|
||||||
|
"ALTER",
|
||||||
|
"CREATE",
|
||||||
"INSERT",
|
"INSERT",
|
||||||
"UPDATE",
|
"UPDATE",
|
||||||
# DCL Operations (Data Control Language)
|
|
||||||
"GRANT",
|
"GRANT",
|
||||||
"REVOKE",
|
"REVOKE",
|
||||||
# System Operations
|
|
||||||
"EXEC",
|
|
||||||
"EXECUTE",
|
|
||||||
"SHUTDOWN",
|
|
||||||
"KILL",
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
max_query_complexity: int = 100
|
max_query_complexity: int = 100
|
||||||
@@ -119,9 +102,6 @@ class PerformanceConfig:
|
|||||||
# Connection pool optimization configuration
|
# Connection pool optimization configuration
|
||||||
connection_pool_size: int = 20
|
connection_pool_size: int = 20
|
||||||
idle_timeout: int = 1800
|
idle_timeout: int = 1800
|
||||||
|
|
||||||
# Response content size limit (characters)
|
|
||||||
max_response_content_size: int = 4096
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -145,11 +125,11 @@ class MonitoringConfig:
|
|||||||
|
|
||||||
# Metrics collection configuration
|
# Metrics collection configuration
|
||||||
enable_metrics: bool = True
|
enable_metrics: bool = True
|
||||||
metrics_port: int = 3001
|
metrics_port: int = 8081
|
||||||
metrics_path: str = "/metrics"
|
metrics_path: str = "/metrics"
|
||||||
|
|
||||||
# Health check configuration
|
# Health check configuration
|
||||||
health_check_port: int = 3002
|
health_check_port: int = 8082
|
||||||
health_check_path: str = "/health"
|
health_check_path: str = "/health"
|
||||||
|
|
||||||
# Alert configuration
|
# Alert configuration
|
||||||
@@ -163,12 +143,8 @@ class DorisConfig:
|
|||||||
|
|
||||||
# Basic configuration
|
# Basic configuration
|
||||||
server_name: str = "doris-mcp-server"
|
server_name: str = "doris-mcp-server"
|
||||||
server_version: str = "0.4.1"
|
server_version: str = "1.0.0"
|
||||||
server_port: int = 3000
|
server_port: int = 8080
|
||||||
transport: str = "stdio"
|
|
||||||
|
|
||||||
# Temporary files configuration
|
|
||||||
temp_files_dir: str = "tmp" # Temporary files directory for Explain and Profile outputs
|
|
||||||
|
|
||||||
# Sub-configuration modules
|
# Sub-configuration modules
|
||||||
database: DatabaseConfig = field(default_factory=DatabaseConfig)
|
database: DatabaseConfig = field(default_factory=DatabaseConfig)
|
||||||
@@ -239,13 +215,6 @@ class DorisConfig:
|
|||||||
config.database.user = os.getenv("DORIS_USER", config.database.user)
|
config.database.user = os.getenv("DORIS_USER", config.database.user)
|
||||||
config.database.password = os.getenv("DORIS_PASSWORD", config.database.password)
|
config.database.password = os.getenv("DORIS_PASSWORD", config.database.password)
|
||||||
config.database.database = os.getenv("DORIS_DATABASE", config.database.database)
|
config.database.database = os.getenv("DORIS_DATABASE", config.database.database)
|
||||||
config.database.fe_http_port = int(os.getenv("DORIS_FE_HTTP_PORT", str(config.database.fe_http_port)))
|
|
||||||
|
|
||||||
# BE nodes configuration
|
|
||||||
be_hosts_env = os.getenv("DORIS_BE_HOSTS", "")
|
|
||||||
if be_hosts_env:
|
|
||||||
config.database.be_hosts = [host.strip() for host in be_hosts_env.split(",") if host.strip()]
|
|
||||||
config.database.be_webserver_port = int(os.getenv("DORIS_BE_WEBSERVER_PORT", str(config.database.be_webserver_port)))
|
|
||||||
|
|
||||||
# Connection pool configuration
|
# Connection pool configuration
|
||||||
config.database.min_connections = int(
|
config.database.min_connections = int(
|
||||||
@@ -276,22 +245,6 @@ class DorisConfig:
|
|||||||
config.security.max_query_complexity = int(
|
config.security.max_query_complexity = int(
|
||||||
os.getenv("MAX_QUERY_COMPLEXITY", str(config.security.max_query_complexity))
|
os.getenv("MAX_QUERY_COMPLEXITY", str(config.security.max_query_complexity))
|
||||||
)
|
)
|
||||||
config.security.enable_security_check = (
|
|
||||||
os.getenv("ENABLE_SECURITY_CHECK", str(config.security.enable_security_check).lower()).lower() == "true"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle blocked keywords environment variable configuration
|
|
||||||
# Format: BLOCKED_KEYWORDS="DROP,DELETE,TRUNCATE,ALTER,CREATE,INSERT,UPDATE,GRANT,REVOKE"
|
|
||||||
blocked_keywords_env = os.getenv("BLOCKED_KEYWORDS", "")
|
|
||||||
if blocked_keywords_env:
|
|
||||||
# If environment variable is provided, use keywords list from environment variable
|
|
||||||
config.security.blocked_keywords = [
|
|
||||||
keyword.strip().upper()
|
|
||||||
for keyword in blocked_keywords_env.split(",")
|
|
||||||
if keyword.strip()
|
|
||||||
]
|
|
||||||
# If environment variable is empty, keep default configuration unchanged
|
|
||||||
|
|
||||||
config.security.enable_masking = (
|
config.security.enable_masking = (
|
||||||
os.getenv("ENABLE_MASKING", str(config.security.enable_masking).lower()).lower() == "true"
|
os.getenv("ENABLE_MASKING", str(config.security.enable_masking).lower()).lower() == "true"
|
||||||
)
|
)
|
||||||
@@ -312,9 +265,6 @@ class DorisConfig:
|
|||||||
config.performance.query_timeout = int(
|
config.performance.query_timeout = int(
|
||||||
os.getenv("QUERY_TIMEOUT", str(config.performance.query_timeout))
|
os.getenv("QUERY_TIMEOUT", str(config.performance.query_timeout))
|
||||||
)
|
)
|
||||||
config.performance.max_response_content_size = int(
|
|
||||||
os.getenv("MAX_RESPONSE_CONTENT_SIZE", str(config.performance.max_response_content_size))
|
|
||||||
)
|
|
||||||
|
|
||||||
# Logging configuration
|
# Logging configuration
|
||||||
config.logging.level = os.getenv("LOG_LEVEL", config.logging.level)
|
config.logging.level = os.getenv("LOG_LEVEL", config.logging.level)
|
||||||
@@ -343,7 +293,6 @@ class DorisConfig:
|
|||||||
config.server_name = os.getenv("SERVER_NAME", config.server_name)
|
config.server_name = os.getenv("SERVER_NAME", config.server_name)
|
||||||
config.server_version = os.getenv("SERVER_VERSION", config.server_version)
|
config.server_version = os.getenv("SERVER_VERSION", config.server_version)
|
||||||
config.server_port = int(os.getenv("SERVER_PORT", str(config.server_port)))
|
config.server_port = int(os.getenv("SERVER_PORT", str(config.server_port)))
|
||||||
config.temp_files_dir = os.getenv("TEMP_FILES_DIR", config.temp_files_dir)
|
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
@@ -353,7 +302,7 @@ class DorisConfig:
|
|||||||
config = cls()
|
config = cls()
|
||||||
|
|
||||||
# Update basic configuration
|
# Update basic configuration
|
||||||
for key in ["server_name", "server_version", "server_port", "temp_files_dir"]:
|
for key in ["server_name", "server_version", "server_port"]:
|
||||||
if key in config_data:
|
if key in config_data:
|
||||||
setattr(config, key, config_data[key])
|
setattr(config, key, config_data[key])
|
||||||
|
|
||||||
@@ -403,7 +352,6 @@ class DorisConfig:
|
|||||||
"server_name": self.server_name,
|
"server_name": self.server_name,
|
||||||
"server_version": self.server_version,
|
"server_version": self.server_version,
|
||||||
"server_port": self.server_port,
|
"server_port": self.server_port,
|
||||||
"temp_files_dir": self.temp_files_dir,
|
|
||||||
"database": {
|
"database": {
|
||||||
"host": self.database.host,
|
"host": self.database.host,
|
||||||
"port": self.database.port,
|
"port": self.database.port,
|
||||||
@@ -411,9 +359,6 @@ class DorisConfig:
|
|||||||
"password": "***", # Hide password
|
"password": "***", # Hide password
|
||||||
"database": self.database.database,
|
"database": self.database.database,
|
||||||
"charset": self.database.charset,
|
"charset": self.database.charset,
|
||||||
"fe_http_port": self.database.fe_http_port,
|
|
||||||
"be_hosts": self.database.be_hosts,
|
|
||||||
"be_webserver_port": self.database.be_webserver_port,
|
|
||||||
"min_connections": self.database.min_connections,
|
"min_connections": self.database.min_connections,
|
||||||
"max_connections": self.database.max_connections,
|
"max_connections": self.database.max_connections,
|
||||||
"connection_timeout": self.database.connection_timeout,
|
"connection_timeout": self.database.connection_timeout,
|
||||||
@@ -424,7 +369,6 @@ class DorisConfig:
|
|||||||
"auth_type": self.security.auth_type,
|
"auth_type": self.security.auth_type,
|
||||||
"token_secret": "***", # Hide secret key
|
"token_secret": "***", # Hide secret key
|
||||||
"token_expiry": self.security.token_expiry,
|
"token_expiry": self.security.token_expiry,
|
||||||
"enable_security_check": self.security.enable_security_check,
|
|
||||||
"blocked_keywords": self.security.blocked_keywords,
|
"blocked_keywords": self.security.blocked_keywords,
|
||||||
"max_query_complexity": self.security.max_query_complexity,
|
"max_query_complexity": self.security.max_query_complexity,
|
||||||
"max_result_rows": self.security.max_result_rows,
|
"max_result_rows": self.security.max_result_rows,
|
||||||
@@ -440,7 +384,6 @@ class DorisConfig:
|
|||||||
"query_timeout": self.performance.query_timeout,
|
"query_timeout": self.performance.query_timeout,
|
||||||
"connection_pool_size": self.performance.connection_pool_size,
|
"connection_pool_size": self.performance.connection_pool_size,
|
||||||
"idle_timeout": self.performance.idle_timeout,
|
"idle_timeout": self.performance.idle_timeout,
|
||||||
"max_response_content_size": self.performance.max_response_content_size,
|
|
||||||
},
|
},
|
||||||
"logging": {
|
"logging": {
|
||||||
"level": self.logging.level,
|
"level": self.logging.level,
|
||||||
|
|||||||
@@ -137,29 +137,10 @@ class DorisConnection:
|
|||||||
async def ping(self) -> bool:
|
async def ping(self) -> bool:
|
||||||
"""Check connection health status"""
|
"""Check connection health status"""
|
||||||
try:
|
try:
|
||||||
# Check if connection exists and is not closed
|
|
||||||
if not self.connection or self.connection.closed:
|
|
||||||
self.is_healthy = False
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check if connection has _reader (aiomysql internal state)
|
|
||||||
# This prevents the 'NoneType' object has no attribute 'at_eof' error
|
|
||||||
if not hasattr(self.connection, '_reader') or self.connection._reader is None:
|
|
||||||
self.is_healthy = False
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Additional check for reader's state
|
|
||||||
if hasattr(self.connection._reader, '_transport') and self.connection._reader._transport is None:
|
|
||||||
self.is_healthy = False
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Try to ping the connection
|
|
||||||
await self.connection.ping()
|
await self.connection.ping()
|
||||||
self.is_healthy = True
|
self.is_healthy = True
|
||||||
return True
|
return True
|
||||||
except (AttributeError, OSError, ConnectionError, Exception) as e:
|
except Exception:
|
||||||
# Log the specific error for debugging
|
|
||||||
logging.debug(f"Connection ping failed for session {self.session_id}: {e}")
|
|
||||||
self.is_healthy = False
|
self.is_healthy = False
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -200,17 +181,7 @@ class DorisConnectionManager:
|
|||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""Initialize connection manager"""
|
"""Initialize connection manager"""
|
||||||
try:
|
try:
|
||||||
self.logger.info(f"Initializing connection pool to {self.config.database.host}:{self.config.database.port}")
|
# Create connection pool
|
||||||
|
|
||||||
# Validate configuration
|
|
||||||
if not self.config.database.host:
|
|
||||||
raise ValueError("Database host is required")
|
|
||||||
if not self.config.database.user:
|
|
||||||
raise ValueError("Database user is required")
|
|
||||||
if not self.config.database.password:
|
|
||||||
self.logger.warning("Database password is empty, this may cause connection issues")
|
|
||||||
|
|
||||||
# Create connection pool with additional parameters for stability
|
|
||||||
self.pool = await aiomysql.create_pool(
|
self.pool = await aiomysql.create_pool(
|
||||||
host=self.config.database.host,
|
host=self.config.database.host,
|
||||||
port=self.config.database.port,
|
port=self.config.database.port,
|
||||||
@@ -222,15 +193,8 @@ class DorisConnectionManager:
|
|||||||
maxsize=self.config.database.max_connections or 20,
|
maxsize=self.config.database.max_connections or 20,
|
||||||
autocommit=True,
|
autocommit=True,
|
||||||
connect_timeout=self.connection_timeout,
|
connect_timeout=self.connection_timeout,
|
||||||
# Additional parameters for stability
|
|
||||||
pool_recycle=3600, # Recycle connections every hour
|
|
||||||
echo=False, # Don't echo SQL statements
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Test the connection pool
|
|
||||||
if not await self.test_connection():
|
|
||||||
raise RuntimeError("Connection pool test failed")
|
|
||||||
|
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f"Connection pool initialized successfully, min connections: {self.config.database.min_connections}, "
|
f"Connection pool initialized successfully, min connections: {self.config.database.min_connections}, "
|
||||||
f"max connections: {self.config.database.max_connections}"
|
f"max connections: {self.config.database.max_connections}"
|
||||||
@@ -242,14 +206,6 @@ class DorisConnectionManager:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"Connection pool initialization failed: {e}")
|
self.logger.error(f"Connection pool initialization failed: {e}")
|
||||||
# Clean up partial initialization
|
|
||||||
if self.pool:
|
|
||||||
try:
|
|
||||||
self.pool.close()
|
|
||||||
await self.pool.wait_closed()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
self.pool = None
|
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_connection(self, session_id: str) -> DorisConnection:
|
async def get_connection(self, session_id: str) -> DorisConnection:
|
||||||
@@ -279,24 +235,9 @@ class DorisConnectionManager:
|
|||||||
# Get connection from pool
|
# Get connection from pool
|
||||||
raw_connection = await self.pool.acquire()
|
raw_connection = await self.pool.acquire()
|
||||||
|
|
||||||
# Validate the raw connection
|
|
||||||
if not raw_connection:
|
|
||||||
raise RuntimeError(f"Failed to acquire connection from pool for session {session_id}")
|
|
||||||
|
|
||||||
# Verify the connection is not closed
|
|
||||||
if raw_connection.closed:
|
|
||||||
raise RuntimeError(f"Acquired connection is already closed for session {session_id}")
|
|
||||||
|
|
||||||
# Create wrapped connection
|
# Create wrapped connection
|
||||||
doris_conn = DorisConnection(raw_connection, session_id, self.security_manager)
|
doris_conn = DorisConnection(raw_connection, session_id, self.security_manager)
|
||||||
|
|
||||||
# Test the connection before storing it
|
|
||||||
if not await doris_conn.ping():
|
|
||||||
# If ping fails, release the connection and raise error
|
|
||||||
if self.pool and raw_connection and not raw_connection.closed:
|
|
||||||
self.pool.release(raw_connection)
|
|
||||||
raise RuntimeError(f"New connection failed ping test for session {session_id}")
|
|
||||||
|
|
||||||
# Store in session connections
|
# Store in session connections
|
||||||
self.session_connections[session_id] = doris_conn
|
self.session_connections[session_id] = doris_conn
|
||||||
|
|
||||||
@@ -320,34 +261,15 @@ class DorisConnectionManager:
|
|||||||
if session_id in self.session_connections:
|
if session_id in self.session_connections:
|
||||||
conn = self.session_connections[session_id]
|
conn = self.session_connections[session_id]
|
||||||
try:
|
try:
|
||||||
# Return connection to pool only if it's valid and not closed
|
# Return connection to pool
|
||||||
if (self.pool and
|
if self.pool and conn.connection and not conn.connection.closed:
|
||||||
conn.connection and
|
self.pool.release(conn.connection)
|
||||||
not conn.connection.closed and
|
|
||||||
hasattr(conn.connection, '_reader') and
|
|
||||||
conn.connection._reader is not None):
|
|
||||||
try:
|
|
||||||
# Try to gracefully return to pool
|
|
||||||
self.pool.release(conn.connection)
|
|
||||||
except Exception as pool_error:
|
|
||||||
self.logger.debug(f"Failed to return connection to pool for session {session_id}: {pool_error}")
|
|
||||||
# If pool release fails, try to close the connection directly
|
|
||||||
try:
|
|
||||||
await conn.connection.ensure_closed()
|
|
||||||
except Exception:
|
|
||||||
pass # Ignore errors during forced close
|
|
||||||
|
|
||||||
# Close connection wrapper
|
# Close connection wrapper
|
||||||
await conn.close()
|
await conn.close()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"Error cleaning up connection for session {session_id}: {e}")
|
self.logger.error(f"Error cleaning up connection for session {session_id}: {e}")
|
||||||
# Force close if normal cleanup fails
|
|
||||||
try:
|
|
||||||
if conn.connection and not conn.connection.closed:
|
|
||||||
await conn.connection.ensure_closed()
|
|
||||||
except Exception:
|
|
||||||
pass # Ignore errors during forced close
|
|
||||||
finally:
|
finally:
|
||||||
# Remove from session connections
|
# Remove from session connections
|
||||||
del self.session_connections[session_id]
|
del self.session_connections[session_id]
|
||||||
@@ -369,26 +291,12 @@ class DorisConnectionManager:
|
|||||||
try:
|
try:
|
||||||
unhealthy_sessions = []
|
unhealthy_sessions = []
|
||||||
|
|
||||||
# First pass: check basic connectivity
|
|
||||||
for session_id, conn in self.session_connections.items():
|
for session_id, conn in self.session_connections.items():
|
||||||
if not await conn.ping():
|
if not await conn.ping():
|
||||||
unhealthy_sessions.append(session_id)
|
unhealthy_sessions.append(session_id)
|
||||||
|
|
||||||
# Second pass: check for stale connections (over 30 minutes old)
|
# Clean up unhealthy connections
|
||||||
current_time = datetime.utcnow()
|
for session_id in unhealthy_sessions:
|
||||||
stale_sessions = []
|
|
||||||
for session_id, conn in self.session_connections.items():
|
|
||||||
if session_id not in unhealthy_sessions: # Don't double-check
|
|
||||||
last_used_delta = (current_time - conn.last_used).total_seconds()
|
|
||||||
if last_used_delta > 1800: # 30 minutes
|
|
||||||
# Force a ping check for stale connections
|
|
||||||
if not await conn.ping():
|
|
||||||
stale_sessions.append(session_id)
|
|
||||||
|
|
||||||
all_problematic_sessions = list(set(unhealthy_sessions + stale_sessions))
|
|
||||||
|
|
||||||
# Clean up problematic connections
|
|
||||||
for session_id in all_problematic_sessions:
|
|
||||||
await self._cleanup_session_connection(session_id)
|
await self._cleanup_session_connection(session_id)
|
||||||
self.metrics.failed_connections += 1
|
self.metrics.failed_connections += 1
|
||||||
|
|
||||||
@@ -396,19 +304,11 @@ class DorisConnectionManager:
|
|||||||
await self._update_connection_metrics()
|
await self._update_connection_metrics()
|
||||||
self.metrics.last_health_check = datetime.utcnow()
|
self.metrics.last_health_check = datetime.utcnow()
|
||||||
|
|
||||||
if all_problematic_sessions:
|
if unhealthy_sessions:
|
||||||
self.logger.warning(f"Health check: cleaned up {len(unhealthy_sessions)} unhealthy and {len(stale_sessions)} stale connections")
|
self.logger.warning(f"Cleaned up {len(unhealthy_sessions)} unhealthy connections")
|
||||||
else:
|
|
||||||
self.logger.debug(f"Health check: all {len(self.session_connections)} connections healthy")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"Health check failed: {e}")
|
self.logger.error(f"Health check failed: {e}")
|
||||||
# If health check fails, try to diagnose the issue
|
|
||||||
try:
|
|
||||||
diagnosis = await self.diagnose_connection_health()
|
|
||||||
self.logger.error(f"Connection diagnosis: {diagnosis}")
|
|
||||||
except Exception:
|
|
||||||
pass # Don't let diagnosis failure crash health check
|
|
||||||
|
|
||||||
async def _cleanup_loop(self):
|
async def _cleanup_loop(self):
|
||||||
"""Background cleanup loop"""
|
"""Background cleanup loop"""
|
||||||
@@ -515,93 +415,6 @@ class DorisConnectionManager:
|
|||||||
self.logger.error(f"Connection test failed: {e}")
|
self.logger.error(f"Connection test failed: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def diagnose_connection_health(self) -> Dict[str, Any]:
|
|
||||||
"""Diagnose connection pool and session health"""
|
|
||||||
diagnosis = {
|
|
||||||
"timestamp": datetime.utcnow().isoformat(),
|
|
||||||
"pool_status": "unknown",
|
|
||||||
"session_connections": {},
|
|
||||||
"problematic_connections": [],
|
|
||||||
"recommendations": []
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Check pool status
|
|
||||||
if not self.pool:
|
|
||||||
diagnosis["pool_status"] = "not_initialized"
|
|
||||||
diagnosis["recommendations"].append("Initialize connection pool")
|
|
||||||
return diagnosis
|
|
||||||
|
|
||||||
if self.pool.closed:
|
|
||||||
diagnosis["pool_status"] = "closed"
|
|
||||||
diagnosis["recommendations"].append("Recreate connection pool")
|
|
||||||
return diagnosis
|
|
||||||
|
|
||||||
diagnosis["pool_status"] = "healthy"
|
|
||||||
diagnosis["pool_info"] = {
|
|
||||||
"size": self.pool.size,
|
|
||||||
"free_size": self.pool.freesize,
|
|
||||||
"min_size": self.pool.minsize,
|
|
||||||
"max_size": self.pool.maxsize
|
|
||||||
}
|
|
||||||
|
|
||||||
# Check session connections
|
|
||||||
problematic_sessions = []
|
|
||||||
for session_id, conn in self.session_connections.items():
|
|
||||||
conn_status = {
|
|
||||||
"session_id": session_id,
|
|
||||||
"created_at": conn.created_at.isoformat(),
|
|
||||||
"last_used": conn.last_used.isoformat(),
|
|
||||||
"query_count": conn.query_count,
|
|
||||||
"is_healthy": conn.is_healthy
|
|
||||||
}
|
|
||||||
|
|
||||||
# Detailed connection checks
|
|
||||||
if conn.connection:
|
|
||||||
conn_status["connection_closed"] = conn.connection.closed
|
|
||||||
conn_status["has_reader"] = hasattr(conn.connection, '_reader') and conn.connection._reader is not None
|
|
||||||
|
|
||||||
if hasattr(conn.connection, '_reader') and conn.connection._reader:
|
|
||||||
conn_status["reader_transport"] = conn.connection._reader._transport is not None
|
|
||||||
else:
|
|
||||||
conn_status["reader_transport"] = False
|
|
||||||
else:
|
|
||||||
conn_status["connection_closed"] = True
|
|
||||||
conn_status["has_reader"] = False
|
|
||||||
conn_status["reader_transport"] = False
|
|
||||||
|
|
||||||
# Check if connection is problematic
|
|
||||||
if (not conn.is_healthy or
|
|
||||||
conn_status["connection_closed"] or
|
|
||||||
not conn_status["has_reader"] or
|
|
||||||
not conn_status["reader_transport"]):
|
|
||||||
problematic_sessions.append(session_id)
|
|
||||||
diagnosis["problematic_connections"].append(conn_status)
|
|
||||||
|
|
||||||
diagnosis["session_connections"][session_id] = conn_status
|
|
||||||
|
|
||||||
# Generate recommendations
|
|
||||||
if problematic_sessions:
|
|
||||||
diagnosis["recommendations"].append(f"Clean up {len(problematic_sessions)} problematic connections")
|
|
||||||
|
|
||||||
if self.pool.freesize == 0 and self.pool.size >= self.pool.maxsize:
|
|
||||||
diagnosis["recommendations"].append("Connection pool exhausted - consider increasing max_connections")
|
|
||||||
|
|
||||||
# Auto-cleanup problematic connections
|
|
||||||
for session_id in problematic_sessions:
|
|
||||||
try:
|
|
||||||
await self._cleanup_session_connection(session_id)
|
|
||||||
self.logger.info(f"Auto-cleaned problematic connection for session: {session_id}")
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Failed to auto-clean session {session_id}: {e}")
|
|
||||||
|
|
||||||
return diagnosis
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
diagnosis["error"] = str(e)
|
|
||||||
diagnosis["recommendations"].append("Manual intervention required")
|
|
||||||
return diagnosis
|
|
||||||
|
|
||||||
|
|
||||||
class ConnectionPoolMonitor:
|
class ConnectionPoolMonitor:
|
||||||
"""Connection pool monitor
|
"""Connection pool monitor
|
||||||
|
|||||||
@@ -548,127 +548,79 @@ class DorisQueryExecutor:
|
|||||||
user_id: str = "mcp_user"
|
user_id: str = "mcp_user"
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Execute SQL query for MCP interface - unified method"""
|
"""Execute SQL query for MCP interface - unified method"""
|
||||||
max_retries = 2
|
try:
|
||||||
retry_count = 0
|
if not sql:
|
||||||
|
return {
|
||||||
while retry_count <= max_retries:
|
"success": False,
|
||||||
try:
|
"error": "SQL query is required",
|
||||||
if not sql:
|
"data": None
|
||||||
return {
|
}
|
||||||
"success": False,
|
|
||||||
"error": "SQL query is required",
|
|
||||||
"data": None
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add LIMIT if not present and it's a SELECT query
|
# Add LIMIT if not present and it's a SELECT query
|
||||||
if sql.upper().startswith("SELECT") and "LIMIT" not in sql.upper():
|
if sql.upper().startswith("SELECT") and "LIMIT" not in sql.upper():
|
||||||
if sql.endswith(";"):
|
if sql.endswith(";"):
|
||||||
sql = sql[:-1]
|
sql = sql[:-1]
|
||||||
sql = f"{sql} LIMIT {limit}"
|
sql = f"{sql} LIMIT {limit}"
|
||||||
|
|
||||||
# Create auth context for MCP calls
|
# Create auth context for MCP calls
|
||||||
class MockAuthContext:
|
class MockAuthContext:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.user_id = user_id
|
self.user_id = user_id
|
||||||
self.roles = ["data_analyst"]
|
self.roles = ["data_analyst"]
|
||||||
self.permissions = ["read_data", "execute_query"]
|
self.permissions = ["read_data", "execute_query"]
|
||||||
self.session_id = session_id
|
self.session_id = session_id
|
||||||
self.security_level = "internal"
|
self.security_level = "internal"
|
||||||
|
|
||||||
auth_context = MockAuthContext()
|
auth_context = MockAuthContext()
|
||||||
|
|
||||||
# Create query request
|
# Create query request
|
||||||
query_request = QueryRequest(
|
query_request = QueryRequest(
|
||||||
sql=sql,
|
sql=sql,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
cache_enabled=False # Disable cache for MCP calls to ensure fresh data
|
cache_enabled=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Execute query
|
||||||
|
result = await self.execute_query(query_request, auth_context)
|
||||||
|
|
||||||
|
# Process results
|
||||||
|
processed_data = []
|
||||||
|
if result.data:
|
||||||
|
for row in result.data:
|
||||||
|
processed_row = self._serialize_row_data(row)
|
||||||
|
processed_data.append(processed_row)
|
||||||
|
|
||||||
# Execute query with retry logic
|
return {
|
||||||
try:
|
"success": True,
|
||||||
result = await self.execute_query(query_request, auth_context)
|
"data": processed_data,
|
||||||
|
"metadata": {
|
||||||
# Serialize data for JSON response
|
"row_count": result.row_count,
|
||||||
serialized_data = []
|
"execution_time": result.execution_time,
|
||||||
for row in result.data:
|
"columns": result.metadata.get("columns", []),
|
||||||
serialized_data.append(self._serialize_row_data(row))
|
"query": sql
|
||||||
|
},
|
||||||
|
"error": None
|
||||||
|
}
|
||||||
|
|
||||||
return {
|
except Exception as e:
|
||||||
"success": True,
|
error_msg = str(e)
|
||||||
"data": serialized_data,
|
self.logger.error(f"SQL execution error: {error_msg}")
|
||||||
"row_count": result.row_count,
|
|
||||||
"execution_time": result.execution_time,
|
# Analyze error for better user feedback
|
||||||
"metadata": {
|
error_analysis = self._analyze_error(error_msg)
|
||||||
"columns": result.metadata.get("columns", []),
|
|
||||||
"query": sql
|
return {
|
||||||
}
|
"success": False,
|
||||||
}
|
"error": error_analysis.get("user_message", error_msg),
|
||||||
|
"error_type": error_analysis.get("error_type", "execution_error"),
|
||||||
except Exception as query_error:
|
"data": None,
|
||||||
# Check if it's a connection-related error that we should retry
|
"metadata": {
|
||||||
error_str = str(query_error).lower()
|
"query": sql,
|
||||||
connection_errors = [
|
"error_details": error_msg
|
||||||
"at_eof", "connection", "closed", "nonetype",
|
}
|
||||||
"transport", "reader", "broken pipe", "connection reset"
|
}
|
||||||
]
|
|
||||||
|
|
||||||
is_connection_error = any(err in error_str for err in connection_errors)
|
|
||||||
|
|
||||||
if is_connection_error and retry_count < max_retries:
|
|
||||||
retry_count += 1
|
|
||||||
self.logger.warning(f"Connection error detected, retrying ({retry_count}/{max_retries}): {query_error}")
|
|
||||||
|
|
||||||
# Release the problematic connection
|
|
||||||
try:
|
|
||||||
await self.connection_manager.release_connection(session_id)
|
|
||||||
except Exception:
|
|
||||||
pass # Ignore cleanup errors
|
|
||||||
|
|
||||||
# Wait a bit before retry
|
|
||||||
await asyncio.sleep(0.5 * retry_count)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
# Re-raise if not a connection error or max retries exceeded
|
|
||||||
raise query_error
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
error_msg = str(e)
|
|
||||||
|
|
||||||
# If we've exhausted retries or it's not a connection error, return error
|
|
||||||
if retry_count >= max_retries or "at_eof" not in error_msg.lower():
|
|
||||||
error_analysis = self._analyze_error(error_msg)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": error_analysis.get("user_message", error_msg),
|
|
||||||
"error_type": error_analysis.get("error_type", "general_error"),
|
|
||||||
"data": None,
|
|
||||||
"metadata": {
|
|
||||||
"query": sql,
|
|
||||||
"error_details": error_msg,
|
|
||||||
"retry_count": retry_count
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
# Try one more time for connection errors
|
|
||||||
retry_count += 1
|
|
||||||
if retry_count <= max_retries:
|
|
||||||
self.logger.warning(f"Retrying query due to connection error ({retry_count}/{max_retries}): {e}")
|
|
||||||
await asyncio.sleep(0.5 * retry_count)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"success": False,
|
|
||||||
"error": f"Query failed after {max_retries} retries: {error_msg}",
|
|
||||||
"data": None,
|
|
||||||
"metadata": {
|
|
||||||
"query": sql,
|
|
||||||
"error_details": error_msg,
|
|
||||||
"retry_count": retry_count
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
def _serialize_row_data(self, row_data: Dict[str, Any]) -> Dict[str, Any]:
|
def _serialize_row_data(self, row_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""Serialize row data for JSON response"""
|
"""Serialize row data for JSON response"""
|
||||||
@@ -697,12 +649,7 @@ class DorisQueryExecutor:
|
|||||||
"""Analyze error message and provide user-friendly feedback"""
|
"""Analyze error message and provide user-friendly feedback"""
|
||||||
error_msg_lower = error_message.lower()
|
error_msg_lower = error_message.lower()
|
||||||
|
|
||||||
if "at_eof" in error_msg_lower or "nonetype" in error_msg_lower and "at_eof" in error_msg_lower:
|
if "table" in error_msg_lower and "doesn't exist" in error_msg_lower:
|
||||||
return {
|
|
||||||
"error_type": "connection_lost",
|
|
||||||
"user_message": "Database connection was lost. The query has been automatically retried. If this persists, please restart the server."
|
|
||||||
}
|
|
||||||
elif "table" in error_msg_lower and "doesn't exist" in error_msg_lower:
|
|
||||||
return {
|
return {
|
||||||
"error_type": "table_not_found",
|
"error_type": "table_not_found",
|
||||||
"user_message": "The specified table does not exist. Please check the table name and database."
|
"user_message": "The specified table does not exist. Please check the table name and database."
|
||||||
@@ -727,11 +674,6 @@ class DorisQueryExecutor:
|
|||||||
"error_type": "timeout",
|
"error_type": "timeout",
|
||||||
"user_message": "Query execution timed out. Try simplifying your query or adding more specific filters."
|
"user_message": "Query execution timed out. Try simplifying your query or adding more specific filters."
|
||||||
}
|
}
|
||||||
elif "connection" in error_msg_lower and ("closed" in error_msg_lower or "reset" in error_msg_lower):
|
|
||||||
return {
|
|
||||||
"error_type": "connection_error",
|
|
||||||
"user_message": "Database connection was interrupted. The query has been automatically retried."
|
|
||||||
}
|
|
||||||
else:
|
else:
|
||||||
return {
|
return {
|
||||||
"error_type": "general_error",
|
"error_type": "general_error",
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ Doris Security Management Module
|
|||||||
Implements enterprise-level authentication, authorization, SQL security validation and data masking functionality
|
Implements enterprise-level authentication, authorization, SQL security validation and data masking functionality
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -100,24 +101,30 @@ class DorisSecurityManager:
|
|||||||
self.masking_rules = self._load_masking_rules()
|
self.masking_rules = self._load_masking_rules()
|
||||||
|
|
||||||
def _load_blocked_keywords(self) -> set[str]:
|
def _load_blocked_keywords(self) -> set[str]:
|
||||||
"""Load blocked SQL keywords from configuration"""
|
"""Load blocked SQL keywords"""
|
||||||
# Load keywords from configuration, unified source of truth
|
default_blocked = {
|
||||||
if hasattr(self.config, 'get'):
|
"DROP",
|
||||||
# Dictionary-style configuration
|
"DELETE",
|
||||||
blocked_keywords = self.config.get("blocked_keywords", [])
|
"TRUNCATE",
|
||||||
elif hasattr(self.config, 'security') and hasattr(self.config.security, 'blocked_keywords'):
|
"ALTER",
|
||||||
# DorisConfig object, get through security.blocked_keywords
|
"CREATE",
|
||||||
blocked_keywords = self.config.security.blocked_keywords
|
"INSERT",
|
||||||
else:
|
"UPDATE",
|
||||||
# Fallback to default if no configuration available
|
"GRANT",
|
||||||
blocked_keywords = [
|
"REVOKE",
|
||||||
"DROP", "CREATE", "ALTER", "TRUNCATE",
|
"EXEC",
|
||||||
"DELETE", "INSERT", "UPDATE",
|
"EXECUTE",
|
||||||
"GRANT", "REVOKE",
|
"SHUTDOWN",
|
||||||
"EXEC", "EXECUTE", "SHUTDOWN", "KILL"
|
"KILL",
|
||||||
]
|
}
|
||||||
|
|
||||||
return set(blocked_keywords)
|
# Load custom rules from configuration file
|
||||||
|
if hasattr(self.config, 'get'):
|
||||||
|
custom_blocked = set(self.config.get("blocked_keywords", []))
|
||||||
|
else:
|
||||||
|
custom_blocked = set()
|
||||||
|
|
||||||
|
return default_blocked.union(custom_blocked)
|
||||||
|
|
||||||
def _load_sensitive_tables(self) -> dict[str, SecurityLevel]:
|
def _load_sensitive_tables(self) -> dict[str, SecurityLevel]:
|
||||||
"""Load sensitive table configuration"""
|
"""Load sensitive table configuration"""
|
||||||
@@ -471,30 +478,13 @@ class SQLSecurityValidator:
|
|||||||
# Dictionary configuration
|
# Dictionary configuration
|
||||||
self.blocked_keywords = set(config.get("blocked_keywords", []))
|
self.blocked_keywords = set(config.get("blocked_keywords", []))
|
||||||
self.max_query_complexity = config.get("max_query_complexity", 100)
|
self.max_query_complexity = config.get("max_query_complexity", 100)
|
||||||
self.enable_security_check = config.get("enable_security_check", True)
|
|
||||||
elif hasattr(config, 'security'):
|
|
||||||
# DorisConfig object with security attribute - unified source from config
|
|
||||||
self.blocked_keywords = set(config.security.blocked_keywords)
|
|
||||||
self.max_query_complexity = config.security.max_query_complexity
|
|
||||||
self.enable_security_check = getattr(config.security, 'enable_security_check', True)
|
|
||||||
else:
|
else:
|
||||||
# Fallback to default if no configuration available
|
# DorisConfig object, use default values
|
||||||
self.blocked_keywords = set([
|
self.blocked_keywords = set(["DROP", "DELETE", "TRUNCATE", "ALTER", "CREATE", "INSERT", "UPDATE"])
|
||||||
"DROP", "CREATE", "ALTER", "TRUNCATE",
|
|
||||||
"DELETE", "INSERT", "UPDATE",
|
|
||||||
"GRANT", "REVOKE",
|
|
||||||
"EXEC", "EXECUTE", "SHUTDOWN", "KILL"
|
|
||||||
])
|
|
||||||
self.max_query_complexity = 100
|
self.max_query_complexity = 100
|
||||||
self.enable_security_check = True
|
|
||||||
|
|
||||||
async def validate(self, sql: str, auth_context: AuthContext) -> ValidationResult:
|
async def validate(self, sql: str, auth_context: AuthContext) -> ValidationResult:
|
||||||
"""Validate SQL query security"""
|
"""Validate SQL query security"""
|
||||||
# If security check is disabled, always return valid
|
|
||||||
if not self.enable_security_check:
|
|
||||||
self.logger.debug("SQL security check is disabled, allowing all queries")
|
|
||||||
return ValidationResult(is_valid=True)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Parse SQL statement
|
# Parse SQL statement
|
||||||
parsed = sqlparse.parse(sql)[0]
|
parsed = sqlparse.parse(sql)[0]
|
||||||
|
|||||||
@@ -1,153 +0,0 @@
|
|||||||
<!--
|
|
||||||
Licensed to the Apache Software Foundation (ASF) under one
|
|
||||||
or more contributor license agreements. See the NOTICE file
|
|
||||||
distributed with this work for additional information
|
|
||||||
regarding copyright ownership. The ASF licenses this file
|
|
||||||
to you under the Apache License, Version 2.0 (the
|
|
||||||
"License"); you may not use this file except in compliance
|
|
||||||
with the License. You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing,
|
|
||||||
software distributed under the License is distributed on an
|
|
||||||
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
KIND, either express or implied. See the License for the
|
|
||||||
specific language governing permissions and limitations
|
|
||||||
under the License.
|
|
||||||
-->
|
|
||||||
|
|
||||||
|
|
||||||
# Dify Example: Integrating Doris MCP Server
|
|
||||||
|
|
||||||
This document demonstrates how to integrate and use `doris-mcp-server` in Dify to perform Doris SQL calls via MCP.
|
|
||||||
|
|
||||||
## Table of Contents
|
|
||||||
|
|
||||||
- [Prerequisites](#prerequisites)
|
|
||||||
- [Starting the MCP Server](#starting-the-mcp-server)
|
|
||||||
- [Ngrok Tunnel (Optional)](#ngrok-tunnel-optional)
|
|
||||||
- [Installing & Configuring the Plugin in Dify](#installing--configuring-the-plugin-in-dify)
|
|
||||||
- [Creating a Dify App](#creating-a-dify-app)
|
|
||||||
- [Adding MCP Tools](#adding-mcp-tools)
|
|
||||||
- [Example Calls](#example-calls)
|
|
||||||
|
|
||||||
|
|
||||||
-----
|
|
||||||
|
|
||||||
### Prerequisites
|
|
||||||
|
|
||||||
First, install `mcp-doris-server`:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
pip install mcp-doris-server
|
|
||||||
```
|
|
||||||
|
|
||||||
## Starting the MCP Server
|
|
||||||
|
|
||||||
Run the startup script:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Full configuration with database connection
|
|
||||||
doris-mcp-server \
|
|
||||||
--transport http \
|
|
||||||
--host 0.0.0.0 \
|
|
||||||
--port 3000 \
|
|
||||||
--db-host 127.0.0.1 \
|
|
||||||
--db-port 9030 \
|
|
||||||
--db-user root \
|
|
||||||
--db-password your_password
|
|
||||||
```
|
|
||||||
|
|
||||||
If successful, you'll see logs similar to this:
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
-----
|
|
||||||
|
|
||||||
## Ngrok Tunnel (Optional)
|
|
||||||
|
|
||||||
If your Dify deployment requires a publicly accessible endpoint, you can use the **ngrok** tool. Ngrok is a third-party service that securely exposes local servers to the internet.
|
|
||||||
|
|
||||||
|
|
||||||
-----
|
|
||||||
|
|
||||||
## Installing & Configuring the Plugin in Dify
|
|
||||||
|
|
||||||
1. In the Dify console, go to **Plugin Marketplace**, search for, and install **MCP‑SSE / StreamableHTTP**:
|
|
||||||

|
|
||||||
|
|
||||||
2. After installation, click **Configure** and set the URL to your public or local address. For example, if you're using `ngrok`, this should be the public URL `ngrok` provides, in the format `https://<your-domain>/mcp`. If Dify can directly access your local server, use `http://localhost:3000/mcp`.
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"doris_mcp_server": {
|
|
||||||
"transport": "streamable_http",
|
|
||||||
"url": "https://<your-domain>/mcp"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||

|
|
||||||
|
|
||||||
3. Click **Save**. If configured correctly, you'll see a green **Authorized** indicator:
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
-----
|
|
||||||
|
|
||||||
## Creating a Dify App
|
|
||||||
|
|
||||||
1. In the Dify console, click **New App** → **Blank App**.
|
|
||||||

|
|
||||||
|
|
||||||
2. Select **Agent** as the template and set the **App Name** (e.g., `Doris ChatBI`).
|
|
||||||

|
|
||||||
|
|
||||||
-----
|
|
||||||
|
|
||||||
## Instructions & Tool Configuration
|
|
||||||
|
|
||||||
### Instruction Block
|
|
||||||
|
|
||||||
Paste the following into the **Instruction** field:
|
|
||||||
|
|
||||||
```
|
|
||||||
<instruction>
|
|
||||||
Use MCP tools to complete tasks as much as possible. Carefully read the annotations, method names, and parameter descriptions of each tool. Please follow these steps:
|
|
||||||
1. Analyze the user's question and match the most appropriate tool.
|
|
||||||
2. Use tool names and parameters exactly as defined; do not invent new ones.
|
|
||||||
3. Pass parameters in the required JSON format.
|
|
||||||
4. When calling tools, use:
|
|
||||||
{"mcp_sse_call_tool": {"tool_name": "<tool_name>", "arguments": "{}"}}
|
|
||||||
5. Output plain text only—no XML tags.
|
|
||||||
<input>
|
|
||||||
User question: user_query
|
|
||||||
</input>
|
|
||||||
<output>
|
|
||||||
Return tool results or a final answer, including analysis.
|
|
||||||
</output>
|
|
||||||
</instruction>
|
|
||||||
```
|
|
||||||
|
|
||||||
### Adding MCP Tools
|
|
||||||
|
|
||||||
In the **Tools** pane, click **Add** twice to add two entries, both named `mcp_sse` (they will inherit the transport and URL from the plugin):
|
|
||||||

|
|
||||||
|
|
||||||
-----
|
|
||||||
|
|
||||||
## Example Calls
|
|
||||||
|
|
||||||
### List Tables in Database
|
|
||||||
|
|
||||||
* **User**: What tables are in the database?
|
|
||||||
|
|
||||||
* **Result**: Dify will call the MCP tool to run `SHOW TABLES` and return the list.
|
|
||||||

|
|
||||||
|
|
||||||
### Sales Trend Over Ten Years
|
|
||||||
|
|
||||||
* **User**: What has been the sales trend over the past ten years in the ssb database, and which year had the fastest growth?
|
|
||||||
|
|
||||||
* **Result**: The tool will execute the SQL, calculate growth rates, and return data.
|
|
||||||

|
|
||||||
|
Before Width: | Height: | Size: 17 KiB |
|
Before Width: | Height: | Size: 258 KiB |
|
Before Width: | Height: | Size: 44 KiB |
|
Before Width: | Height: | Size: 66 KiB |
|
Before Width: | Height: | Size: 127 KiB |
|
Before Width: | Height: | Size: 317 KiB |
|
Before Width: | Height: | Size: 369 KiB |
|
Before Width: | Height: | Size: 272 KiB |
|
Before Width: | Height: | Size: 73 KiB |
@@ -20,7 +20,7 @@ build-backend = "hatchling.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "doris-mcp-server"
|
name = "doris-mcp-server"
|
||||||
version = "0.4.2"
|
version = "0.3.0"
|
||||||
description = "Enterprise-grade Model Context Protocol (MCP) server implementation for Apache Doris"
|
description = "Enterprise-grade Model Context Protocol (MCP) server implementation for Apache Doris"
|
||||||
authors = [
|
authors = [
|
||||||
{name = "Yijia Su", email = "freeoneplus@apache.org"}
|
{name = "Yijia Su", email = "freeoneplus@apache.org"}
|
||||||
@@ -42,7 +42,7 @@ classifiers = [
|
|||||||
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
# Core MCP dependencies
|
# Core MCP dependencies
|
||||||
"mcp>=1.8.0,<2.0.0",
|
"mcp>=1.8.0",
|
||||||
# Database drivers
|
# Database drivers
|
||||||
"aiomysql>=0.2.0",
|
"aiomysql>=0.2.0",
|
||||||
"PyMySQL>=1.1.0",
|
"PyMySQL>=1.1.0",
|
||||||
@@ -147,8 +147,10 @@ Issues = "https://github.com/apache/doris-mcp-server/issues"
|
|||||||
Changelog = "https://github.com/apache/doris-mcp-server/blob/main/CHANGELOG.md"
|
Changelog = "https://github.com/apache/doris-mcp-server/blob/main/CHANGELOG.md"
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
mcp-doris-server = "doris_mcp_server.main:main_sync"
|
||||||
doris-mcp-server = "doris_mcp_server.main:main_sync"
|
doris-mcp-server = "doris_mcp_server.main:main_sync"
|
||||||
doris-mcp-client = "doris_mcp_server.client:main"
|
doris-mcp-client = "doris_mcp_server.client:main"
|
||||||
|
mcp-doris-client = "doris_mcp_server.client:main"
|
||||||
|
|
||||||
[tool.hatch.build.targets.wheel]
|
[tool.hatch.build.targets.wheel]
|
||||||
packages = ["doris_mcp_server"]
|
packages = ["doris_mcp_server"]
|
||||||
@@ -163,7 +165,7 @@ include = [
|
|||||||
# Black configuration
|
# Black configuration
|
||||||
[tool.black]
|
[tool.black]
|
||||||
line-length = 88
|
line-length = 88
|
||||||
target-version = ['py310', 'py311', 'py312']
|
target-version = ['py312']
|
||||||
include = '\.pyi?$'
|
include = '\.pyi?$'
|
||||||
extend-exclude = '''
|
extend-exclude = '''
|
||||||
/(
|
/(
|
||||||
|
|||||||
@@ -1,20 +0,0 @@
|
|||||||
# Development dependencies - auto-generated from pyproject.toml
|
|
||||||
# Installation command: pip install -r requirements-dev.txt
|
|
||||||
|
|
||||||
pytest>=7.4.0
|
|
||||||
pytest-asyncio>=0.23.0
|
|
||||||
pytest-cov>=4.1.0
|
|
||||||
pytest-mock>=3.12.0
|
|
||||||
pytest-xdist>=3.5.0
|
|
||||||
ruff>=0.1.0
|
|
||||||
black>=23.12.0
|
|
||||||
isort>=5.13.0
|
|
||||||
flake8>=7.0.0
|
|
||||||
mypy>=1.8.0
|
|
||||||
bandit>=1.7.0
|
|
||||||
safety>=2.3.0
|
|
||||||
sphinx>=7.2.0
|
|
||||||
sphinx-rtd-theme>=2.0.0
|
|
||||||
myst-parser>=2.0.0
|
|
||||||
pre-commit>=3.6.0
|
|
||||||
tox>=4.11.0
|
|
||||||
@@ -1,8 +1,24 @@
|
|||||||
# Main dependencies - auto-generated from pyproject.toml
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
# Do not edit this file manually, use 'python generate_requirements.py' to regenerate
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
# 主要依赖 - 从 pyproject.toml 自动生成
|
||||||
|
# 请不要手动编辑此文件,使用 python generate_requirements.py 重新生成
|
||||||
|
|
||||||
# === Core Dependencies ===
|
# === 核心依赖 ===
|
||||||
mcp>=1.8.0,<2.0.0
|
mcp>=1.0.0
|
||||||
aiomysql>=0.2.0
|
aiomysql>=0.2.0
|
||||||
PyMySQL>=1.1.0
|
PyMySQL>=1.1.0
|
||||||
asyncio-mqtt>=0.16.0
|
asyncio-mqtt>=0.16.0
|
||||||
@@ -37,11 +53,8 @@ click>=8.1.0
|
|||||||
typer>=0.9.0
|
typer>=0.9.0
|
||||||
requests>=2.31.0
|
requests>=2.31.0
|
||||||
tqdm>=4.66.0
|
tqdm>=4.66.0
|
||||||
pytest>=8.4.0
|
|
||||||
pytest-asyncio>=1.0.0
|
|
||||||
pytest-cov>=6.1.1
|
|
||||||
|
|
||||||
# === Development Dependencies ===
|
# === 开发依赖 ===
|
||||||
pytest>=7.4.0
|
pytest>=7.4.0
|
||||||
pytest-asyncio>=0.23.0
|
pytest-asyncio>=0.23.0
|
||||||
pytest-cov>=4.1.0
|
pytest-cov>=4.1.0
|
||||||
|
|||||||
263
test/README.md
@@ -1,263 +0,0 @@
|
|||||||
<!--
|
|
||||||
Licensed to the Apache Software Foundation (ASF) under one
|
|
||||||
or more contributor license agreements. See the NOTICE file
|
|
||||||
distributed with this work for additional information
|
|
||||||
regarding copyright ownership. The ASF licenses this file
|
|
||||||
to you under the Apache License, Version 2.0 (the
|
|
||||||
"License"); you may not use this file except in compliance
|
|
||||||
with the License. You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing,
|
|
||||||
software distributed under the License is distributed on an
|
|
||||||
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
KIND, either express or implied. See the License for the
|
|
||||||
specific language governing permissions and limitations
|
|
||||||
under the License.
|
|
||||||
-->
|
|
||||||
# Doris MCP Server Testing System
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
This testing system adopts a layered architecture, including unit tests, integration tests, and client-server tests. The testing system assumes the server is already properly started and focuses on testing functionality rather than startup configuration.
|
|
||||||
|
|
||||||
## Testing Architecture
|
|
||||||
|
|
||||||
### 1. Unit Tests
|
|
||||||
- **Location**: `test/security/`, `test/utils/`, `test/tools/`
|
|
||||||
- **Purpose**: Test individual module functionality
|
|
||||||
- **Features**: Uses Mock objects, no dependency on external services
|
|
||||||
|
|
||||||
### 2. Integration Tests
|
|
||||||
- **Location**: `test/integration/`
|
|
||||||
- **Purpose**: Test collaboration between modules
|
|
||||||
- **Features**: Test complete workflows
|
|
||||||
|
|
||||||
### 3. Client-Server Tests
|
|
||||||
- **Location**: `test/tools/test_tools_client_server.py`, `test/utils/test_query_executor_client_server.py`
|
|
||||||
- **Purpose**: Test actual server functionality through MCP client
|
|
||||||
- **Features**: Assumes server is running, skips tests if server is not available
|
|
||||||
|
|
||||||
## Configuration Files
|
|
||||||
|
|
||||||
### test_config.json
|
|
||||||
Test configuration file defines how to connect to the running server:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"server_endpoints": {
|
|
||||||
"http": {
|
|
||||||
"url": "http://localhost:3000/mcp",
|
|
||||||
"timeout": 30
|
|
||||||
},
|
|
||||||
"stdio": {
|
|
||||||
"command": "uv",
|
|
||||||
"args": ["run", "python", "-m", "doris_mcp_server.main", "--transport", "stdio"],
|
|
||||||
"timeout": 30
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"test_settings": {
|
|
||||||
"default_transport": "http",
|
|
||||||
"retry_attempts": 3,
|
|
||||||
"retry_delay": 1.0,
|
|
||||||
"test_timeout": 60,
|
|
||||||
"enable_performance_tests": true,
|
|
||||||
"enable_security_tests": true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
### 1. Start the Server
|
|
||||||
|
|
||||||
Before running client-server tests, you need to start the server first:
|
|
||||||
|
|
||||||
#### HTTP Mode (Recommended)
|
|
||||||
```bash
|
|
||||||
# Start HTTP server
|
|
||||||
./start_server.sh
|
|
||||||
# or
|
|
||||||
uv run python -m doris_mcp_server.main --transport http --port 3000
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Stdio Mode
|
|
||||||
```bash
|
|
||||||
# Stdio mode is started directly by the client, no need to pre-start
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. Run Tests
|
|
||||||
|
|
||||||
#### Run All Tests
|
|
||||||
```bash
|
|
||||||
python -m pytest test/ -v
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Run Unit Tests
|
|
||||||
```bash
|
|
||||||
# Security module tests
|
|
||||||
python -m pytest test/security/ -v
|
|
||||||
|
|
||||||
# Tools module tests
|
|
||||||
python -m pytest test/tools/test_tools_manager.py -v
|
|
||||||
|
|
||||||
# Query executor tests
|
|
||||||
python -m pytest test/utils/test_query_executor.py -v
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Run Integration Tests
|
|
||||||
```bash
|
|
||||||
python -m pytest test/integration/ -v
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Run Client-Server Tests
|
|
||||||
```bash
|
|
||||||
# Tools Client-Server tests
|
|
||||||
python -m pytest test/tools/test_tools_client_server.py -v
|
|
||||||
|
|
||||||
# QueryExecutor Client-Server tests
|
|
||||||
python -m pytest test/utils/test_query_executor_client_server.py -v
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Test Configuration
|
|
||||||
|
|
||||||
#### Modify Server Endpoints
|
|
||||||
Edit the `test/test_config.json` file:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"server_endpoints": {
|
|
||||||
"http": {
|
|
||||||
"url": "http://your-server:port/mcp"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Enable/Disable Specific Tests
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"test_settings": {
|
|
||||||
"enable_performance_tests": false, // Disable performance tests
|
|
||||||
"enable_security_tests": true // Enable security tests
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
## Test Status
|
|
||||||
|
|
||||||
### ✅ Completed Test Modules
|
|
||||||
|
|
||||||
1. **Security Module** (100% Pass)
|
|
||||||
- Authentication tests: 5/5 passed
|
|
||||||
- Authorization tests: 7/7 passed
|
|
||||||
- Data masking tests: 13/13 passed
|
|
||||||
- SQL validation tests: 10/10 passed
|
|
||||||
- Security manager tests: 7/7 passed
|
|
||||||
- Coverage: 88%
|
|
||||||
|
|
||||||
2. **Client-Server Test Architecture** (Implemented)
|
|
||||||
- Automatic server connection status detection
|
|
||||||
- Automatically skip tests when server is not running
|
|
||||||
- Support for both HTTP and Stdio transport modes
|
|
||||||
|
|
||||||
### 🔄 Tests Requiring Server Running
|
|
||||||
|
|
||||||
1. **Tools Client-Server Tests**
|
|
||||||
- Tool list retrieval
|
|
||||||
- SQL query execution
|
|
||||||
- Database list retrieval
|
|
||||||
- Table schema queries
|
|
||||||
- Performance statistics
|
|
||||||
- Error handling
|
|
||||||
- Security authentication
|
|
||||||
|
|
||||||
2. **QueryExecutor Client-Server Tests**
|
|
||||||
- Simple query execution
|
|
||||||
- Database queries
|
|
||||||
- Information schema queries
|
|
||||||
- Parameterized queries
|
|
||||||
- Error handling
|
|
||||||
- Security authentication
|
|
||||||
|
|
||||||
## Testing Best Practices
|
|
||||||
|
|
||||||
### 1. Server Startup Check
|
|
||||||
All client-server tests automatically check server connection status:
|
|
||||||
- If server is running normally, execute actual tests
|
|
||||||
- If server is not running, skip tests and display appropriate message
|
|
||||||
|
|
||||||
### 2. Test Isolation
|
|
||||||
- Unit tests use Mock objects, no dependency on external services
|
|
||||||
- Integration tests use controlled test environments
|
|
||||||
- Client-server tests connect to actually running servers
|
|
||||||
|
|
||||||
### 3. Error Handling
|
|
||||||
- Tests don't assume specific success/failure results
|
|
||||||
- Verify response structure rather than specific content
|
|
||||||
- Gracefully handle connection failures and timeouts
|
|
||||||
|
|
||||||
### 4. Configuration Management
|
|
||||||
- Use configuration files to manage test parameters
|
|
||||||
- Support configuration switching for different environments
|
|
||||||
- Provide reasonable default values
|
|
||||||
|
|
||||||
## Troubleshooting
|
|
||||||
|
|
||||||
### 1. Server Connection Failure
|
|
||||||
```
|
|
||||||
ERROR: Server is not running or not accessible
|
|
||||||
```
|
|
||||||
**Solution**: Ensure the server is started and listening on the correct port
|
|
||||||
|
|
||||||
### 2. Import Errors
|
|
||||||
```
|
|
||||||
ImportError: cannot import name 'DorisUnifiedClient'
|
|
||||||
```
|
|
||||||
**Solution**: Check Python path and dependency installation
|
|
||||||
|
|
||||||
### 3. Test Timeouts
|
|
||||||
```
|
|
||||||
TimeoutError: Test execution timeout
|
|
||||||
```
|
|
||||||
**Solution**: Increase timeout settings in `test_config.json`
|
|
||||||
|
|
||||||
## Development Guide
|
|
||||||
|
|
||||||
### Adding New Client-Server Tests
|
|
||||||
|
|
||||||
1. Add test methods in the appropriate test file
|
|
||||||
2. Use `@pytest.mark.asyncio` decorator
|
|
||||||
3. Get test client through `client` fixture
|
|
||||||
4. Implement test callback function
|
|
||||||
5. Verify response structure
|
|
||||||
|
|
||||||
Example:
|
|
||||||
```python
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_new_feature_via_client(self, client, test_config):
|
|
||||||
"""Test new feature through client"""
|
|
||||||
async def test_callback(client_instance):
|
|
||||||
result = await client_instance.call_tool("new_tool", {
|
|
||||||
"param": "value"
|
|
||||||
})
|
|
||||||
|
|
||||||
assert "success" in result
|
|
||||||
return result
|
|
||||||
|
|
||||||
result = await client.connect_and_run(test_callback)
|
|
||||||
assert "success" in result
|
|
||||||
```
|
|
||||||
|
|
||||||
### Modifying Test Configuration
|
|
||||||
|
|
||||||
Edit the `test/test_config.json` file to adjust:
|
|
||||||
- Server endpoints
|
|
||||||
- Timeout settings
|
|
||||||
- Test data
|
|
||||||
- Feature switches
|
|
||||||
|
|
||||||
## Summary
|
|
||||||
|
|
||||||
This testing system provides complete test coverage, from unit tests to end-to-end client-server tests. Through reasonable configuration and automated connection detection, it ensures tests can run stably in different environments.
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
# Licensed to the Apache Software Foundation (ASF) under one
|
|
||||||
# or more contributor license agreements. See the NOTICE file
|
|
||||||
# distributed with this work for additional information
|
|
||||||
# regarding copyright ownership. The ASF licenses this file
|
|
||||||
# to you under the Apache License, Version 2.0 (the
|
|
||||||
# "License"); you may not use this file except in compliance
|
|
||||||
# with the License. You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing,
|
|
||||||
# software distributed under the License is distributed on an
|
|
||||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
# KIND, either express or implied. See the License for the
|
|
||||||
# specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
115
test/conftest.py
@@ -1,115 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Licensed to the Apache Software Foundation (ASF) under one
|
|
||||||
# or more contributor license agreements. See the NOTICE file
|
|
||||||
# distributed with this work for additional information
|
|
||||||
# regarding copyright ownership. The ASF licenses this file
|
|
||||||
# to you under the Apache License, Version 2.0 (the
|
|
||||||
# "License"); you may not use this file except in compliance
|
|
||||||
# with the License. You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing,
|
|
||||||
# software distributed under the License is distributed on an
|
|
||||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
# KIND, either express or implied. See the License for the
|
|
||||||
# specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
"""
|
|
||||||
Pytest configuration and fixtures for Doris MCP Server tests
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
# Add project root to Python path
|
|
||||||
project_root = Path(__file__).parent.parent
|
|
||||||
sys.path.insert(0, str(project_root))
|
|
||||||
|
|
||||||
# Configure logging for tests
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
|
||||||
def event_loop():
|
|
||||||
"""Create an instance of the default event loop for the test session."""
|
|
||||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
|
||||||
yield loop
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def test_config():
|
|
||||||
"""Test configuration fixture"""
|
|
||||||
from doris_mcp_server.utils.config import DorisConfig, DatabaseConfig, SecurityConfig
|
|
||||||
|
|
||||||
config = DorisConfig()
|
|
||||||
|
|
||||||
# Database configuration
|
|
||||||
config.database.host = "localhost"
|
|
||||||
config.database.port = 9030
|
|
||||||
config.database.user = "test_user"
|
|
||||||
config.database.password = "test_password"
|
|
||||||
config.database.database = "test_db"
|
|
||||||
config.database.health_check_interval = 60
|
|
||||||
config.database.min_connections = 5
|
|
||||||
config.database.max_connections = 20
|
|
||||||
config.database.connection_timeout = 30
|
|
||||||
config.database.max_connection_age = 3600
|
|
||||||
|
|
||||||
# Security configuration
|
|
||||||
config.security.enable_masking = True
|
|
||||||
config.security.auth_type = "token"
|
|
||||||
config.security.token_secret = "test_secret"
|
|
||||||
config.security.token_expiry = 3600
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_data():
|
|
||||||
"""Provide sample test data"""
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"id": 1,
|
|
||||||
"name": "张三",
|
|
||||||
"phone": "13812345678",
|
|
||||||
"email": "zhangsan@example.com",
|
|
||||||
"id_card": "110101199001011234",
|
|
||||||
"salary": 50000
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": 2,
|
|
||||||
"name": "李四",
|
|
||||||
"phone": "13987654321",
|
|
||||||
"email": "lisi@example.com",
|
|
||||||
"id_card": "110101199002022345",
|
|
||||||
"salary": 60000
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def test_sql_queries():
|
|
||||||
"""Provide test SQL queries"""
|
|
||||||
return {
|
|
||||||
"safe_select": "SELECT name, email FROM users WHERE department = 'sales'",
|
|
||||||
"dangerous_drop": "DROP TABLE users",
|
|
||||||
"sql_injection": "SELECT * FROM users WHERE id = 1; DROP TABLE users;",
|
|
||||||
"union_injection": "SELECT name FROM users UNION SELECT password FROM admin_users",
|
|
||||||
"comment_injection": "SELECT * FROM users WHERE id = 1 -- AND password = 'secret'",
|
|
||||||
"complex_query": """
|
|
||||||
SELECT u.name, u.email, d.department_name
|
|
||||||
FROM users u
|
|
||||||
JOIN departments d ON u.department_id = d.id
|
|
||||||
WHERE u.status = 'active'
|
|
||||||
ORDER BY u.created_at DESC
|
|
||||||
"""
|
|
||||||
}
|
|
||||||
@@ -1,288 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Licensed to the Apache Software Foundation (ASF) under one
|
|
||||||
# or more contributor license agreements. See the NOTICE file
|
|
||||||
# distributed with this work for additional information
|
|
||||||
# regarding copyright ownership. The ASF licenses this file
|
|
||||||
# to you under the Apache License, Version 2.0 (the
|
|
||||||
# "License"); you may not use this file except in compliance
|
|
||||||
# with the License. You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing,
|
|
||||||
# software distributed under the License is distributed on an
|
|
||||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
# KIND, either express or implied. See the License for the
|
|
||||||
# specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
"""
|
|
||||||
End-to-end integration tests
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import pytest
|
|
||||||
from unittest.mock import Mock, patch
|
|
||||||
|
|
||||||
from doris_mcp_server.main import DorisServer
|
|
||||||
from doris_mcp_server.utils.config import DorisConfig
|
|
||||||
from doris_mcp_server.utils.security import SecurityLevel, AuthContext
|
|
||||||
|
|
||||||
|
|
||||||
class TestEndToEndIntegration:
|
|
||||||
"""End-to-end integration tests"""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_config(self):
|
|
||||||
"""Create mock configuration"""
|
|
||||||
from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig
|
|
||||||
|
|
||||||
config = Mock(spec=DorisConfig)
|
|
||||||
|
|
||||||
# Add database config
|
|
||||||
config.database = Mock(spec=DatabaseConfig)
|
|
||||||
config.database.host = "localhost"
|
|
||||||
config.database.port = 9030
|
|
||||||
config.database.user = "test_user"
|
|
||||||
config.database.password = "test_password"
|
|
||||||
config.database.database = "test_db"
|
|
||||||
config.database.health_check_interval = 60
|
|
||||||
config.database.min_connections = 5
|
|
||||||
config.database.max_connections = 20
|
|
||||||
config.database.connection_timeout = 30
|
|
||||||
config.database.max_connection_age = 3600
|
|
||||||
|
|
||||||
# Add security config
|
|
||||||
config.security = Mock(spec=SecurityConfig)
|
|
||||||
config.security.enable_masking = True
|
|
||||||
config.security.auth_type = "token"
|
|
||||||
config.security.token_secret = "test_secret"
|
|
||||||
config.security.token_expiry = 3600
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def doris_server(self, mock_config):
|
|
||||||
"""Create Doris server instance"""
|
|
||||||
return DorisServer(mock_config)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_complete_query_workflow_with_security(self, doris_server, sample_data):
|
|
||||||
"""Test complete query workflow with security"""
|
|
||||||
with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute:
|
|
||||||
mock_execute.return_value = sample_data
|
|
||||||
|
|
||||||
# Mock authentication
|
|
||||||
with patch.object(doris_server.security_manager, 'authenticate_request') as mock_auth:
|
|
||||||
mock_auth.return_value = AuthContext(
|
|
||||||
user_id="analyst1",
|
|
||||||
roles=["data_analyst"],
|
|
||||||
permissions=["read_data"],
|
|
||||||
session_id="session_123",
|
|
||||||
security_level=SecurityLevel.INTERNAL
|
|
||||||
)
|
|
||||||
|
|
||||||
# Mock authorization
|
|
||||||
with patch.object(doris_server.security_manager, 'authorize_resource_access') as mock_authz:
|
|
||||||
mock_authz.return_value = True
|
|
||||||
|
|
||||||
# Mock SQL validation
|
|
||||||
with patch.object(doris_server.security_manager, 'validate_sql_security') as mock_validate:
|
|
||||||
from doris_mcp_server.utils.security import ValidationResult
|
|
||||||
mock_validate.return_value = ValidationResult(is_valid=True)
|
|
||||||
|
|
||||||
# Mock data masking
|
|
||||||
with patch.object(doris_server.security_manager, 'apply_data_masking') as mock_mask:
|
|
||||||
masked_data = [
|
|
||||||
{
|
|
||||||
"id": 1,
|
|
||||||
"name": "张三",
|
|
||||||
"phone": "138****5678",
|
|
||||||
"email": "z*******n@example.com",
|
|
||||||
"id_card": "110101****1234",
|
|
||||||
"salary": 50000
|
|
||||||
}
|
|
||||||
]
|
|
||||||
mock_mask.return_value = masked_data
|
|
||||||
|
|
||||||
# Simulate complete workflow
|
|
||||||
auth_info = {"type": "token", "token": "valid_token_123"}
|
|
||||||
auth_context = await doris_server.security_manager.authenticate_request(auth_info)
|
|
||||||
|
|
||||||
resource_uri = "/api/table/users"
|
|
||||||
has_access = await doris_server.security_manager.authorize_resource_access(
|
|
||||||
auth_context, resource_uri
|
|
||||||
)
|
|
||||||
assert has_access is True
|
|
||||||
|
|
||||||
sql = "SELECT * FROM users LIMIT 1"
|
|
||||||
validation = await doris_server.security_manager.validate_sql_security(
|
|
||||||
sql, auth_context
|
|
||||||
)
|
|
||||||
assert validation.is_valid is True
|
|
||||||
|
|
||||||
raw_data = await doris_server.tools_manager.query_executor.execute_query(sql)
|
|
||||||
final_data = await doris_server.security_manager.apply_data_masking(
|
|
||||||
raw_data, auth_context
|
|
||||||
)
|
|
||||||
|
|
||||||
# Verify data is properly masked
|
|
||||||
assert final_data[0]["phone"] == "138****5678"
|
|
||||||
assert final_data[0]["email"] == "z*******n@example.com"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_security_violation_workflow(self, doris_server):
|
|
||||||
"""Test security violation detection workflow"""
|
|
||||||
with patch.object(doris_server.security_manager, 'authenticate_request') as mock_auth:
|
|
||||||
mock_auth.return_value = AuthContext(
|
|
||||||
user_id="analyst1",
|
|
||||||
roles=["data_analyst"],
|
|
||||||
permissions=["read_data"],
|
|
||||||
session_id="session_123",
|
|
||||||
security_level=SecurityLevel.INTERNAL
|
|
||||||
)
|
|
||||||
|
|
||||||
# Test unauthorized resource access
|
|
||||||
with patch.object(doris_server.security_manager, 'authorize_resource_access') as mock_authz:
|
|
||||||
mock_authz.return_value = False
|
|
||||||
|
|
||||||
auth_context = await doris_server.security_manager.authenticate_request({
|
|
||||||
"type": "token", "token": "valid_token_123"
|
|
||||||
})
|
|
||||||
|
|
||||||
# Try to access confidential resource
|
|
||||||
resource_uri = "/api/table/payment_records"
|
|
||||||
has_access = await doris_server.security_manager.authorize_resource_access(
|
|
||||||
auth_context, resource_uri
|
|
||||||
)
|
|
||||||
|
|
||||||
assert has_access is False
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_sql_injection_prevention_workflow(self, doris_server):
|
|
||||||
"""Test SQL injection prevention workflow"""
|
|
||||||
with patch.object(doris_server.security_manager, 'authenticate_request') as mock_auth:
|
|
||||||
mock_auth.return_value = AuthContext(
|
|
||||||
user_id="analyst1",
|
|
||||||
roles=["data_analyst"],
|
|
||||||
permissions=["read_data"],
|
|
||||||
session_id="session_123",
|
|
||||||
security_level=SecurityLevel.INTERNAL
|
|
||||||
)
|
|
||||||
|
|
||||||
auth_context = await doris_server.security_manager.authenticate_request({
|
|
||||||
"type": "token", "token": "valid_token_123"
|
|
||||||
})
|
|
||||||
|
|
||||||
# Test SQL injection attempt
|
|
||||||
malicious_sql = "SELECT * FROM users WHERE id = 1; DROP TABLE users;"
|
|
||||||
validation = await doris_server.security_manager.validate_sql_security(
|
|
||||||
malicious_sql, auth_context
|
|
||||||
)
|
|
||||||
|
|
||||||
assert validation.is_valid is False
|
|
||||||
assert validation.risk_level == "high"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_admin_bypass_workflow(self, doris_server, sample_data):
|
|
||||||
"""Test admin user bypassing restrictions"""
|
|
||||||
with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute:
|
|
||||||
mock_execute.return_value = sample_data
|
|
||||||
|
|
||||||
with patch.object(doris_server.security_manager, 'authenticate_request') as mock_auth:
|
|
||||||
mock_auth.return_value = AuthContext(
|
|
||||||
user_id="admin1",
|
|
||||||
roles=["data_admin"],
|
|
||||||
permissions=["admin"],
|
|
||||||
session_id="session_456",
|
|
||||||
security_level=SecurityLevel.SECRET
|
|
||||||
)
|
|
||||||
|
|
||||||
# Admin should access any resource
|
|
||||||
with patch.object(doris_server.security_manager, 'authorize_resource_access') as mock_authz:
|
|
||||||
mock_authz.return_value = True
|
|
||||||
|
|
||||||
# Admin should see original data (no masking)
|
|
||||||
with patch.object(doris_server.security_manager, 'apply_data_masking') as mock_mask:
|
|
||||||
mock_mask.return_value = sample_data # Original data
|
|
||||||
|
|
||||||
auth_context = await doris_server.security_manager.authenticate_request({
|
|
||||||
"type": "basic", "username": "admin", "password": "admin123"
|
|
||||||
})
|
|
||||||
|
|
||||||
# Admin accesses secret resource
|
|
||||||
resource_uri = "/api/table/payment_records"
|
|
||||||
has_access = await doris_server.security_manager.authorize_resource_access(
|
|
||||||
auth_context, resource_uri
|
|
||||||
)
|
|
||||||
assert has_access is True
|
|
||||||
|
|
||||||
# Admin sees original data
|
|
||||||
raw_data = await doris_server.tools_manager.query_executor.execute_query(
|
|
||||||
"SELECT * FROM users LIMIT 1"
|
|
||||||
)
|
|
||||||
final_data = await doris_server.security_manager.apply_data_masking(
|
|
||||||
raw_data, auth_context
|
|
||||||
)
|
|
||||||
|
|
||||||
# Should be original data (no masking)
|
|
||||||
assert final_data[0]["phone"] == "13812345678"
|
|
||||||
assert final_data[0]["email"] == "zhangsan@example.com"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_execution_with_security(self, doris_server):
|
|
||||||
"""Test tool execution with security checks"""
|
|
||||||
with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute:
|
|
||||||
mock_execute.return_value = [{"Database": "test_db"}]
|
|
||||||
|
|
||||||
# Test tool execution through tools manager
|
|
||||||
result = await doris_server.tools_manager.call_tool("get_db_list", {})
|
|
||||||
result_data = json.loads(result)
|
|
||||||
|
|
||||||
# Accept either success result or error (due to mock environment)
|
|
||||||
assert "result" in result_data or "error" in result_data
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_error_handling_workflow(self, doris_server):
|
|
||||||
"""Test error handling in complete workflow"""
|
|
||||||
# Test authentication failure
|
|
||||||
with patch.object(doris_server.security_manager, 'authenticate_request') as mock_auth:
|
|
||||||
mock_auth.side_effect = Exception("Invalid token")
|
|
||||||
|
|
||||||
with pytest.raises(Exception) as exc_info:
|
|
||||||
await doris_server.security_manager.authenticate_request({
|
|
||||||
"type": "token", "token": "invalid_token"
|
|
||||||
})
|
|
||||||
|
|
||||||
assert "Invalid token" in str(exc_info.value)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_performance_monitoring_integration(self, doris_server):
|
|
||||||
"""Test performance monitoring integration"""
|
|
||||||
with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute:
|
|
||||||
mock_execute.return_value = [
|
|
||||||
{
|
|
||||||
"query_count": 1500,
|
|
||||||
"avg_execution_time": 0.25,
|
|
||||||
"slow_query_count": 5,
|
|
||||||
"error_count": 2
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
# Test performance stats tool
|
|
||||||
result = await doris_server.tools_manager.call_tool("get_db_list", {})
|
|
||||||
result_data = json.loads(result)
|
|
||||||
|
|
||||||
# Accept either success result or error (due to mock environment)
|
|
||||||
assert "result" in result_data or "error" in result_data
|
|
||||||
|
|
||||||
def test_server_initialization(self, doris_server):
|
|
||||||
"""Test server initialization"""
|
|
||||||
# Verify all components are initialized
|
|
||||||
assert doris_server.config is not None
|
|
||||||
assert doris_server.tools_manager is not None
|
|
||||||
assert doris_server.security_manager is not None
|
|
||||||
|
|
||||||
# Verify tools are available - use list_tools instead
|
|
||||||
import asyncio
|
|
||||||
tools = asyncio.run(doris_server.tools_manager.list_tools())
|
|
||||||
assert len(tools) > 0
|
|
||||||
@@ -1,103 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Licensed to the Apache Software Foundation (ASF) under one
|
|
||||||
# or more contributor license agreements. See the NOTICE file
|
|
||||||
# distributed with this work for additional information
|
|
||||||
# regarding copyright ownership. The ASF licenses this file
|
|
||||||
# to you under the Apache License, Version 2.0 (the
|
|
||||||
# "License"); you may not use this file except in compliance
|
|
||||||
# with the License. You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing,
|
|
||||||
# software distributed under the License is distributed on an
|
|
||||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
# KIND, either express or implied. See the License for the
|
|
||||||
# specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
"""
|
|
||||||
Authentication module tests
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from doris_mcp_server.utils.security import (
|
|
||||||
AuthenticationProvider,
|
|
||||||
AuthContext,
|
|
||||||
SecurityLevel
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestAuthenticationProvider:
|
|
||||||
"""Authentication provider tests"""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def auth_provider(self, test_config):
|
|
||||||
"""Create authentication provider instance"""
|
|
||||||
return AuthenticationProvider(test_config)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_token_authentication_success(self, auth_provider):
|
|
||||||
"""Test successful token authentication"""
|
|
||||||
auth_info = {
|
|
||||||
"type": "token",
|
|
||||||
"token": "valid_token_123"
|
|
||||||
}
|
|
||||||
|
|
||||||
result = await auth_provider.authenticate(auth_info)
|
|
||||||
|
|
||||||
assert isinstance(result, AuthContext)
|
|
||||||
assert result.user_id == "test_user"
|
|
||||||
assert "data_analyst" in result.roles
|
|
||||||
assert result.security_level == SecurityLevel.INTERNAL
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_token_authentication_failure(self, auth_provider):
|
|
||||||
"""Test failed token authentication"""
|
|
||||||
auth_info = {
|
|
||||||
"type": "token",
|
|
||||||
"token": "invalid_token"
|
|
||||||
}
|
|
||||||
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
await auth_provider.authenticate(auth_info)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_basic_authentication_success(self, auth_provider):
|
|
||||||
"""Test successful basic authentication"""
|
|
||||||
auth_info = {
|
|
||||||
"type": "basic",
|
|
||||||
"username": "admin",
|
|
||||||
"password": "admin123"
|
|
||||||
}
|
|
||||||
|
|
||||||
result = await auth_provider.authenticate(auth_info)
|
|
||||||
|
|
||||||
assert isinstance(result, AuthContext)
|
|
||||||
assert result.user_id == "admin_user"
|
|
||||||
assert "data_admin" in result.roles
|
|
||||||
assert result.security_level == SecurityLevel.SECRET
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_basic_authentication_failure(self, auth_provider):
|
|
||||||
"""Test failed basic authentication"""
|
|
||||||
auth_info = {
|
|
||||||
"type": "basic",
|
|
||||||
"username": "admin",
|
|
||||||
"password": "wrong_password"
|
|
||||||
}
|
|
||||||
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
await auth_provider.authenticate(auth_info)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_unsupported_auth_type(self, auth_provider):
|
|
||||||
"""Test unsupported authentication type"""
|
|
||||||
auth_info = {
|
|
||||||
"type": "oauth",
|
|
||||||
"token": "oauth_token"
|
|
||||||
}
|
|
||||||
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
await auth_provider.authenticate(auth_info)
|
|
||||||
@@ -1,147 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Licensed to the Apache Software Foundation (ASF) under one
|
|
||||||
# or more contributor license agreements. See the NOTICE file
|
|
||||||
# distributed with this work for additional information
|
|
||||||
# regarding copyright ownership. The ASF licenses this file
|
|
||||||
# to you under the Apache License, Version 2.0 (the
|
|
||||||
# "License"); you may not use this file except in compliance
|
|
||||||
# with the License. You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing,
|
|
||||||
# software distributed under the License is distributed on an
|
|
||||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
# KIND, either express or implied. See the License for the
|
|
||||||
# specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
"""
|
|
||||||
Authorization module tests
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from doris_mcp_server.utils.security import (
|
|
||||||
AuthorizationProvider,
|
|
||||||
AuthContext,
|
|
||||||
SecurityLevel
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestAuthorizationProvider:
|
|
||||||
"""Authorization provider tests"""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def authz_provider(self, test_config):
|
|
||||||
"""Create authorization provider instance"""
|
|
||||||
return AuthorizationProvider(test_config)
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def analyst_context(self):
|
|
||||||
"""Create analyst auth context"""
|
|
||||||
return AuthContext(
|
|
||||||
user_id="analyst1",
|
|
||||||
roles=["data_analyst"],
|
|
||||||
permissions=["read_data"],
|
|
||||||
session_id="session_123",
|
|
||||||
security_level=SecurityLevel.INTERNAL
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def admin_context(self):
|
|
||||||
"""Create admin auth context"""
|
|
||||||
return AuthContext(
|
|
||||||
user_id="admin1",
|
|
||||||
roles=["data_admin"],
|
|
||||||
permissions=["admin"],
|
|
||||||
session_id="session_456",
|
|
||||||
security_level=SecurityLevel.SECRET
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_analyst_access_public_resource(self, authz_provider, analyst_context):
|
|
||||||
"""Test analyst accessing public resource"""
|
|
||||||
resource_uri = "/api/table/public_reports"
|
|
||||||
|
|
||||||
result = await authz_provider.check_permission(analyst_context, resource_uri, "read")
|
|
||||||
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_analyst_denied_confidential_resource(self, authz_provider):
|
|
||||||
"""Test analyst denied access to confidential resource"""
|
|
||||||
# Create analyst with lower security level
|
|
||||||
analyst_context = AuthContext(
|
|
||||||
user_id="analyst1",
|
|
||||||
roles=["data_analyst"],
|
|
||||||
permissions=["read_data"],
|
|
||||||
session_id="session_123",
|
|
||||||
security_level=SecurityLevel.PUBLIC # Lower than CONFIDENTIAL
|
|
||||||
)
|
|
||||||
|
|
||||||
resource_uri = "/api/table/user_info"
|
|
||||||
|
|
||||||
result = await authz_provider.check_permission(analyst_context, resource_uri, "read")
|
|
||||||
|
|
||||||
assert result is False
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_admin_access_secret_resource(self, authz_provider, admin_context):
|
|
||||||
"""Test admin accessing secret resource"""
|
|
||||||
resource_uri = "/api/table/payment_records"
|
|
||||||
|
|
||||||
result = await authz_provider.check_permission(admin_context, resource_uri, "read")
|
|
||||||
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_role_based_permission(self, authz_provider):
|
|
||||||
"""Test role-based permission check"""
|
|
||||||
# Create analyst context
|
|
||||||
analyst_context = AuthContext(
|
|
||||||
user_id="analyst1",
|
|
||||||
roles=["data_analyst"],
|
|
||||||
permissions=["read_data"],
|
|
||||||
session_id="session_123",
|
|
||||||
security_level=SecurityLevel.INTERNAL
|
|
||||||
)
|
|
||||||
|
|
||||||
resource_uri = "/api/table/some_table"
|
|
||||||
|
|
||||||
# Analyst should have read permission
|
|
||||||
result = await authz_provider.check_permission(analyst_context, resource_uri, "read")
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
# Analyst should not have write permission
|
|
||||||
result = await authz_provider.check_permission(analyst_context, resource_uri, "write")
|
|
||||||
assert result is False
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_admin_override(self, authz_provider, admin_context):
|
|
||||||
"""Test admin permission override"""
|
|
||||||
resource_uri = "/api/table/any_table"
|
|
||||||
|
|
||||||
# Admin should have all permissions
|
|
||||||
result = await authz_provider.check_permission(admin_context, resource_uri, "read")
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
result = await authz_provider.check_permission(admin_context, resource_uri, "write")
|
|
||||||
assert result is True
|
|
||||||
|
|
||||||
def test_parse_resource_uri(self, authz_provider):
|
|
||||||
"""Test resource URI parsing"""
|
|
||||||
uri = "/api/table/user_info/default"
|
|
||||||
|
|
||||||
result = authz_provider._parse_resource_uri(uri)
|
|
||||||
|
|
||||||
assert result["type"] == "table"
|
|
||||||
assert result["name"] == "user_info"
|
|
||||||
assert result["schema"] == "default"
|
|
||||||
|
|
||||||
def test_get_resource_security_level(self, authz_provider):
|
|
||||||
"""Test getting resource security level"""
|
|
||||||
resource_info = {"name": "user_info", "type": "table"}
|
|
||||||
|
|
||||||
level = authz_provider._get_resource_security_level(resource_info)
|
|
||||||
|
|
||||||
assert level == SecurityLevel.CONFIDENTIAL
|
|
||||||
@@ -1,197 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Licensed to the Apache Software Foundation (ASF) under one
|
|
||||||
# or more contributor license agreements. See the NOTICE file
|
|
||||||
# distributed with this work for additional information
|
|
||||||
# regarding copyright ownership. The ASF licenses this file
|
|
||||||
# to you under the Apache License, Version 2.0 (the
|
|
||||||
# "License"); you may not use this file except in compliance
|
|
||||||
# with the License. You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing,
|
|
||||||
# software distributed under the License is distributed on an
|
|
||||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
# KIND, either express or implied. See the License for the
|
|
||||||
# specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
"""
|
|
||||||
Data masking tests
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from doris_mcp_server.utils.security import (
|
|
||||||
DataMaskingProcessor,
|
|
||||||
AuthContext,
|
|
||||||
SecurityLevel,
|
|
||||||
MaskingRule
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestDataMaskingProcessor:
|
|
||||||
"""Data masking processor tests"""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def masking_processor(self, test_config):
|
|
||||||
"""Create data masking processor instance"""
|
|
||||||
return DataMaskingProcessor(test_config)
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def internal_user_context(self):
|
|
||||||
"""Create internal user auth context"""
|
|
||||||
return AuthContext(
|
|
||||||
user_id="internal_user",
|
|
||||||
roles=["data_analyst"],
|
|
||||||
permissions=["read_data"],
|
|
||||||
session_id="session_123",
|
|
||||||
security_level=SecurityLevel.INTERNAL
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def admin_context(self):
|
|
||||||
"""Create admin auth context"""
|
|
||||||
return AuthContext(
|
|
||||||
user_id="admin",
|
|
||||||
roles=["data_admin"],
|
|
||||||
permissions=["admin"],
|
|
||||||
session_id="session_456",
|
|
||||||
security_level=SecurityLevel.SECRET
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_phone_masking_for_internal_user(self, masking_processor, internal_user_context, sample_data):
|
|
||||||
"""Test phone number masking for internal user"""
|
|
||||||
result = await masking_processor.process(sample_data, internal_user_context)
|
|
||||||
|
|
||||||
# Phone numbers should be masked
|
|
||||||
assert result[0]["phone"] == "138****5678"
|
|
||||||
assert result[1]["phone"] == "139****4321"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_email_masking_for_internal_user(self, masking_processor, internal_user_context, sample_data):
|
|
||||||
"""Test email masking for internal user"""
|
|
||||||
result = await masking_processor.process(sample_data, internal_user_context)
|
|
||||||
|
|
||||||
# Emails should be masked
|
|
||||||
assert result[0]["email"] == "z******n@example.com"
|
|
||||||
assert result[1]["email"] == "l**i@example.com"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_no_masking_for_admin(self, masking_processor, admin_context, sample_data):
|
|
||||||
"""Test no masking for admin user"""
|
|
||||||
result = await masking_processor.process(sample_data, admin_context)
|
|
||||||
|
|
||||||
# Admin should see original data
|
|
||||||
assert result[0]["phone"] == "13812345678"
|
|
||||||
assert result[0]["email"] == "zhangsan@example.com"
|
|
||||||
assert result[1]["phone"] == "13987654321"
|
|
||||||
assert result[1]["email"] == "lisi@example.com"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_id_card_masking_for_confidential_data(self, masking_processor, internal_user_context, sample_data):
|
|
||||||
"""Test ID card masking for confidential data"""
|
|
||||||
# Internal user should not see ID card details (confidential level)
|
|
||||||
result = await masking_processor.process(sample_data, internal_user_context)
|
|
||||||
|
|
||||||
# ID cards should be masked for internal users
|
|
||||||
assert result[0]["id_card"] == "110101********1234"
|
|
||||||
assert result[1]["id_card"] == "110101********2345"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_empty_data_handling(self, masking_processor, internal_user_context):
|
|
||||||
"""Test empty data handling"""
|
|
||||||
empty_data = []
|
|
||||||
|
|
||||||
result = await masking_processor.process(empty_data, internal_user_context)
|
|
||||||
|
|
||||||
assert result == []
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_null_value_handling(self, masking_processor, internal_user_context):
|
|
||||||
"""Test null value handling"""
|
|
||||||
data_with_nulls = [
|
|
||||||
{
|
|
||||||
"id": 1,
|
|
||||||
"name": "张三",
|
|
||||||
"phone": None,
|
|
||||||
"email": None,
|
|
||||||
"id_card": None
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
result = await masking_processor.process(data_with_nulls, internal_user_context)
|
|
||||||
|
|
||||||
# Null values should remain null
|
|
||||||
assert result[0]["phone"] is None
|
|
||||||
assert result[0]["email"] is None
|
|
||||||
assert result[0]["id_card"] is None
|
|
||||||
|
|
||||||
def test_phone_masking_algorithm(self, masking_processor):
|
|
||||||
"""Test phone masking algorithm"""
|
|
||||||
params = {"mask_char": "*", "keep_prefix": 3, "keep_suffix": 4}
|
|
||||||
|
|
||||||
result = masking_processor._mask_phone("13812345678", params)
|
|
||||||
|
|
||||||
assert result == "138****5678"
|
|
||||||
|
|
||||||
def test_email_masking_algorithm(self, masking_processor):
|
|
||||||
"""Test email masking algorithm"""
|
|
||||||
params = {"mask_char": "*"}
|
|
||||||
|
|
||||||
result = masking_processor._mask_email("zhangsan@example.com", params)
|
|
||||||
|
|
||||||
assert result == "z******n@example.com"
|
|
||||||
|
|
||||||
def test_id_card_masking_algorithm(self, masking_processor):
|
|
||||||
"""Test ID card masking algorithm"""
|
|
||||||
params = {"mask_char": "*", "keep_prefix": 6, "keep_suffix": 4}
|
|
||||||
|
|
||||||
result = masking_processor._mask_id_card("110101199001011234", params)
|
|
||||||
|
|
||||||
assert result == "110101********1234"
|
|
||||||
|
|
||||||
def test_name_masking_algorithm(self, masking_processor):
|
|
||||||
"""Test name masking algorithm"""
|
|
||||||
params = {"mask_char": "*"}
|
|
||||||
|
|
||||||
# Test 2-character name
|
|
||||||
result = masking_processor._mask_name("张三", params)
|
|
||||||
assert result == "张*"
|
|
||||||
|
|
||||||
# Test 3-character name
|
|
||||||
result = masking_processor._mask_name("李小明", params)
|
|
||||||
assert result == "李*明"
|
|
||||||
|
|
||||||
def test_partial_masking_algorithm(self, masking_processor):
|
|
||||||
"""Test partial masking algorithm"""
|
|
||||||
params = {"mask_char": "*", "mask_ratio": 0.5}
|
|
||||||
|
|
||||||
result = masking_processor._mask_partial("1234567890", params)
|
|
||||||
|
|
||||||
# Should mask middle 50% of the string
|
|
||||||
assert "*" in result
|
|
||||||
assert len(result) == 10
|
|
||||||
|
|
||||||
def test_should_apply_rule_logic(self, masking_processor, internal_user_context, admin_context):
|
|
||||||
"""Test masking rule application logic"""
|
|
||||||
rule = MaskingRule(
|
|
||||||
column_pattern=r".*phone.*",
|
|
||||||
algorithm="phone_mask",
|
|
||||||
parameters={"mask_char": "*", "keep_prefix": 3, "keep_suffix": 4},
|
|
||||||
security_level=SecurityLevel.INTERNAL
|
|
||||||
)
|
|
||||||
|
|
||||||
# Internal user should have rule applied
|
|
||||||
assert masking_processor._should_apply_rule(rule, internal_user_context) is True
|
|
||||||
|
|
||||||
# Admin should not have rule applied
|
|
||||||
assert masking_processor._should_apply_rule(rule, admin_context) is False
|
|
||||||
|
|
||||||
def test_get_applicable_rules(self, masking_processor, internal_user_context):
|
|
||||||
"""Test getting applicable rules"""
|
|
||||||
rules = masking_processor._get_applicable_rules(internal_user_context)
|
|
||||||
|
|
||||||
# Should return some rules for internal user
|
|
||||||
assert len(rules) > 0
|
|
||||||
assert all(isinstance(rule, MaskingRule) for rule in rules)
|
|
||||||
@@ -1,172 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Licensed to the Apache Software Foundation (ASF) under one
|
|
||||||
# or more contributor license agreements. See the NOTICE file
|
|
||||||
# distributed with this work for additional information
|
|
||||||
# regarding copyright ownership. The ASF licenses this file
|
|
||||||
# to you under the Apache License, Version 2.0 (the
|
|
||||||
# "License"); you may not use this file except in compliance
|
|
||||||
# with the License. You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing,
|
|
||||||
# software distributed under the License is distributed on an
|
|
||||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
# KIND, either express or implied. See the License for the
|
|
||||||
# specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
"""
|
|
||||||
Security manager integration tests
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from doris_mcp_server.utils.security import (
|
|
||||||
DorisSecurityManager,
|
|
||||||
AuthContext,
|
|
||||||
SecurityLevel,
|
|
||||||
ValidationResult
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestDorisSecurityManager:
|
|
||||||
"""Doris security manager integration tests"""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def security_manager(self, test_config):
|
|
||||||
"""Create security manager instance"""
|
|
||||||
return DorisSecurityManager(test_config)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_complete_security_workflow(self, security_manager, sample_data):
|
|
||||||
"""Test complete security workflow"""
|
|
||||||
# 1. Authentication
|
|
||||||
auth_info = {
|
|
||||||
"type": "token",
|
|
||||||
"token": "valid_token_123"
|
|
||||||
}
|
|
||||||
|
|
||||||
auth_context = await security_manager.authenticate_request(auth_info)
|
|
||||||
assert isinstance(auth_context, AuthContext)
|
|
||||||
assert auth_context.security_level == SecurityLevel.INTERNAL
|
|
||||||
|
|
||||||
# 2. Authorization
|
|
||||||
resource_uri = "/api/table/public_reports"
|
|
||||||
has_access = await security_manager.authorize_resource_access(auth_context, resource_uri)
|
|
||||||
assert has_access is True
|
|
||||||
|
|
||||||
# 3. SQL Validation
|
|
||||||
safe_sql = "SELECT name, email FROM users WHERE department = 'sales'"
|
|
||||||
validation_result = await security_manager.validate_sql_security(safe_sql, auth_context)
|
|
||||||
assert validation_result.is_valid is True
|
|
||||||
|
|
||||||
# 4. Data Masking
|
|
||||||
masked_data = await security_manager.apply_data_masking(sample_data, auth_context)
|
|
||||||
assert masked_data[0]["phone"] == "138****5678" # Should be masked
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_admin_workflow(self, security_manager, sample_data):
|
|
||||||
"""Test admin user workflow"""
|
|
||||||
# Admin authentication
|
|
||||||
auth_info = {
|
|
||||||
"type": "basic",
|
|
||||||
"username": "admin",
|
|
||||||
"password": "admin123"
|
|
||||||
}
|
|
||||||
|
|
||||||
auth_context = await security_manager.authenticate_request(auth_info)
|
|
||||||
assert auth_context.security_level == SecurityLevel.SECRET
|
|
||||||
|
|
||||||
# Admin should access secret resources
|
|
||||||
resource_uri = "/api/table/payment_records"
|
|
||||||
has_access = await security_manager.authorize_resource_access(auth_context, resource_uri)
|
|
||||||
assert has_access is True
|
|
||||||
|
|
||||||
# Admin should see original data (no masking)
|
|
||||||
masked_data = await security_manager.apply_data_masking(sample_data, auth_context)
|
|
||||||
assert masked_data[0]["phone"] == "13812345678" # Original data
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_security_violation_detection(self, security_manager):
|
|
||||||
"""Test security violation detection"""
|
|
||||||
# Authenticate as regular user
|
|
||||||
auth_info = {
|
|
||||||
"type": "token",
|
|
||||||
"token": "valid_token_123"
|
|
||||||
}
|
|
||||||
|
|
||||||
auth_context = await security_manager.authenticate_request(auth_info)
|
|
||||||
|
|
||||||
# Try to access confidential resource (user_info is CONFIDENTIAL, user is INTERNAL)
|
|
||||||
# INTERNAL(1) should not access CONFIDENTIAL(2) resource
|
|
||||||
resource_uri = "/api/table/user_info"
|
|
||||||
has_access = await security_manager.authorize_resource_access(auth_context, resource_uri)
|
|
||||||
assert has_access is False
|
|
||||||
|
|
||||||
# Try dangerous SQL
|
|
||||||
dangerous_sql = "DROP TABLE users"
|
|
||||||
validation_result = await security_manager.validate_sql_security(dangerous_sql, auth_context)
|
|
||||||
assert validation_result.is_valid is False
|
|
||||||
assert "DROP" in validation_result.blocked_operations
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_sql_injection_prevention(self, security_manager):
|
|
||||||
"""Test SQL injection prevention"""
|
|
||||||
auth_info = {
|
|
||||||
"type": "token",
|
|
||||||
"token": "valid_token_123"
|
|
||||||
}
|
|
||||||
|
|
||||||
auth_context = await security_manager.authenticate_request(auth_info)
|
|
||||||
|
|
||||||
# Test various injection attempts
|
|
||||||
injection_attempts = [
|
|
||||||
"SELECT * FROM users WHERE id = 1; DROP TABLE users;",
|
|
||||||
"SELECT * FROM users UNION SELECT password FROM admin_users",
|
|
||||||
"SELECT * FROM users WHERE id = 1 OR 1=1",
|
|
||||||
"SELECT * FROM users WHERE name = 'test' -- AND password = 'secret'"
|
|
||||||
]
|
|
||||||
|
|
||||||
for sql in injection_attempts:
|
|
||||||
result = await security_manager.validate_sql_security(sql, auth_context)
|
|
||||||
assert result.is_valid is False
|
|
||||||
assert result.risk_level in ["medium", "high"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_authentication_failure_handling(self, security_manager):
|
|
||||||
"""Test authentication failure handling"""
|
|
||||||
invalid_auth_info = {
|
|
||||||
"type": "token",
|
|
||||||
"token": "invalid_token"
|
|
||||||
}
|
|
||||||
|
|
||||||
with pytest.raises(Exception):
|
|
||||||
await security_manager.authenticate_request(invalid_auth_info)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_configuration_loading(self, security_manager):
|
|
||||||
"""Test security configuration loading"""
|
|
||||||
# Test blocked keywords loading
|
|
||||||
assert "DROP" in security_manager.blocked_keywords
|
|
||||||
assert "DELETE" in security_manager.blocked_keywords
|
|
||||||
|
|
||||||
# Test sensitive tables loading
|
|
||||||
assert SecurityLevel.CONFIDENTIAL in security_manager.sensitive_tables.values()
|
|
||||||
assert SecurityLevel.SECRET in security_manager.sensitive_tables.values()
|
|
||||||
|
|
||||||
# Test masking rules loading
|
|
||||||
assert len(security_manager.masking_rules) > 0
|
|
||||||
phone_rules = [rule for rule in security_manager.masking_rules
|
|
||||||
if "phone" in rule.column_pattern]
|
|
||||||
assert len(phone_rules) > 0
|
|
||||||
|
|
||||||
def test_security_level_hierarchy(self, security_manager):
|
|
||||||
"""Test security level hierarchy"""
|
|
||||||
# Test that hierarchy is correctly defined
|
|
||||||
levels = [SecurityLevel.PUBLIC, SecurityLevel.INTERNAL,
|
|
||||||
SecurityLevel.CONFIDENTIAL, SecurityLevel.SECRET]
|
|
||||||
|
|
||||||
# Each level should be properly defined
|
|
||||||
for level in levels:
|
|
||||||
assert isinstance(level, SecurityLevel)
|
|
||||||
assert level.value in ["public", "internal", "confidential", "secret"]
|
|
||||||
@@ -1,161 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Licensed to the Apache Software Foundation (ASF) under one
|
|
||||||
# or more contributor license agreements. See the NOTICE file
|
|
||||||
# distributed with this work for additional information
|
|
||||||
# regarding copyright ownership. The ASF licenses this file
|
|
||||||
# to you under the Apache License, Version 2.0 (the
|
|
||||||
# "License"); you may not use this file except in compliance
|
|
||||||
# with the License. You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing,
|
|
||||||
# software distributed under the License is distributed on an
|
|
||||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
# KIND, either express or implied. See the License for the
|
|
||||||
# specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
"""
|
|
||||||
SQL security validation tests
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from doris_mcp_server.utils.security import (
|
|
||||||
SQLSecurityValidator,
|
|
||||||
AuthContext,
|
|
||||||
SecurityLevel,
|
|
||||||
ValidationResult
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestSQLSecurityValidator:
|
|
||||||
"""SQL security validator tests"""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def sql_validator(self, test_config):
|
|
||||||
"""Create SQL validator instance"""
|
|
||||||
return SQLSecurityValidator(test_config)
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def analyst_context(self):
|
|
||||||
"""Create analyst auth context"""
|
|
||||||
return AuthContext(
|
|
||||||
user_id="analyst1",
|
|
||||||
roles=["data_analyst"],
|
|
||||||
permissions=["read_data"],
|
|
||||||
session_id="session_123",
|
|
||||||
security_level=SecurityLevel.INTERNAL
|
|
||||||
)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_safe_select_query(self, sql_validator, analyst_context, test_sql_queries):
|
|
||||||
"""Test safe SELECT query validation"""
|
|
||||||
sql = test_sql_queries["safe_select"]
|
|
||||||
|
|
||||||
result = await sql_validator.validate(sql, analyst_context)
|
|
||||||
|
|
||||||
assert result.is_valid is True
|
|
||||||
assert result.error_message is None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_blocked_drop_operation(self, sql_validator, analyst_context, test_sql_queries):
|
|
||||||
"""Test blocked DROP operation"""
|
|
||||||
sql = test_sql_queries["dangerous_drop"]
|
|
||||||
|
|
||||||
result = await sql_validator.validate(sql, analyst_context)
|
|
||||||
|
|
||||||
assert result.is_valid is False
|
|
||||||
assert "blocked operations" in result.error_message.lower()
|
|
||||||
assert "DROP" in result.blocked_operations
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_sql_injection_detection(self, sql_validator, analyst_context, test_sql_queries):
|
|
||||||
"""Test SQL injection detection"""
|
|
||||||
sql = test_sql_queries["sql_injection"]
|
|
||||||
|
|
||||||
result = await sql_validator.validate(sql, analyst_context)
|
|
||||||
|
|
||||||
assert result.is_valid is False
|
|
||||||
assert "injection" in result.error_message.lower()
|
|
||||||
assert result.risk_level == "high"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_union_injection_detection(self, sql_validator, analyst_context, test_sql_queries):
|
|
||||||
"""Test UNION injection detection"""
|
|
||||||
sql = test_sql_queries["union_injection"]
|
|
||||||
|
|
||||||
result = await sql_validator.validate(sql, analyst_context)
|
|
||||||
|
|
||||||
assert result.is_valid is False
|
|
||||||
assert "injection" in result.error_message.lower()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_comment_injection_detection(self, sql_validator, analyst_context, test_sql_queries):
|
|
||||||
"""Test comment injection detection"""
|
|
||||||
sql = test_sql_queries["comment_injection"]
|
|
||||||
|
|
||||||
result = await sql_validator.validate(sql, analyst_context)
|
|
||||||
|
|
||||||
assert result.is_valid is False
|
|
||||||
assert "comment" in result.error_message.lower()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_complex_query_validation(self, sql_validator, analyst_context, test_sql_queries):
|
|
||||||
"""Test complex query validation"""
|
|
||||||
sql = test_sql_queries["complex_query"]
|
|
||||||
|
|
||||||
result = await sql_validator.validate(sql, analyst_context)
|
|
||||||
|
|
||||||
# Complex query should pass if within limits
|
|
||||||
assert result.is_valid is True
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_blocked_keywords_detection(self, sql_validator, analyst_context):
|
|
||||||
"""Test blocked keywords detection"""
|
|
||||||
blocked_sqls = [
|
|
||||||
"DELETE FROM users WHERE id = 1",
|
|
||||||
"TRUNCATE TABLE logs",
|
|
||||||
"ALTER TABLE users ADD COLUMN new_col VARCHAR(50)",
|
|
||||||
"CREATE TABLE test (id INT)",
|
|
||||||
"INSERT INTO users VALUES (1, 'test')",
|
|
||||||
"UPDATE users SET name = 'test' WHERE id = 1"
|
|
||||||
]
|
|
||||||
|
|
||||||
for sql in blocked_sqls:
|
|
||||||
result = await sql_validator.validate(sql, analyst_context)
|
|
||||||
assert result.is_valid is False
|
|
||||||
assert result.blocked_operations is not None
|
|
||||||
assert len(result.blocked_operations) > 0
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_table_access_validation(self, sql_validator, analyst_context):
|
|
||||||
"""Test table access validation"""
|
|
||||||
# Test access to sensitive table
|
|
||||||
sql = "SELECT * FROM sensitive_data"
|
|
||||||
|
|
||||||
result = await sql_validator.validate(sql, analyst_context)
|
|
||||||
|
|
||||||
# Should fail for non-admin users
|
|
||||||
assert result.is_valid is False
|
|
||||||
assert "access" in result.error_message.lower()
|
|
||||||
|
|
||||||
def test_extract_table_names(self, sql_validator):
|
|
||||||
"""Test table name extraction"""
|
|
||||||
sql = "SELECT u.name FROM users u JOIN departments d ON u.dept_id = d.id"
|
|
||||||
|
|
||||||
parsed = __import__('sqlparse').parse(sql)[0]
|
|
||||||
tables = sql_validator._extract_table_names(parsed)
|
|
||||||
|
|
||||||
# Should extract at least one table name
|
|
||||||
assert len(tables) > 0
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_malformed_sql_handling(self, sql_validator, analyst_context):
|
|
||||||
"""Test malformed SQL handling"""
|
|
||||||
malformed_sql = "SELECT * FROM users WHERE"
|
|
||||||
|
|
||||||
result = await sql_validator.validate(malformed_sql, analyst_context)
|
|
||||||
|
|
||||||
# Should handle gracefully
|
|
||||||
assert isinstance(result, ValidationResult)
|
|
||||||
@@ -1,74 +0,0 @@
|
|||||||
{
|
|
||||||
"server_endpoints": {
|
|
||||||
"http": {
|
|
||||||
"url": "http://localhost:3000/mcp",
|
|
||||||
"timeout": 30,
|
|
||||||
"headers": {
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"http_network": {
|
|
||||||
"url": "http://192.168.31.168:3000/mcp",
|
|
||||||
"timeout": 30,
|
|
||||||
"headers": {
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"stdio": {
|
|
||||||
"command": "uv",
|
|
||||||
"args": ["run", "python", "-m", "doris_mcp_server.main", "--transport", "stdio"],
|
|
||||||
"timeout": 30,
|
|
||||||
"working_directory": ".."
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"test_settings": {
|
|
||||||
"default_transport": "http",
|
|
||||||
"retry_attempts": 3,
|
|
||||||
"retry_delay": 1.0,
|
|
||||||
"test_timeout": 60,
|
|
||||||
"enable_performance_tests": true,
|
|
||||||
"enable_security_tests": true
|
|
||||||
},
|
|
||||||
"test_data": {
|
|
||||||
"sample_queries": [
|
|
||||||
"SELECT 1 as test_value",
|
|
||||||
"SHOW DATABASES",
|
|
||||||
"SELECT COUNT(*) FROM information_schema.tables"
|
|
||||||
],
|
|
||||||
"test_databases": ["test_db", "demo_db"],
|
|
||||||
"test_tables": ["users", "orders", "products"],
|
|
||||||
"auth_tokens": {
|
|
||||||
"valid_token": "valid_token_123",
|
|
||||||
"admin_token": "admin_token_456",
|
|
||||||
"invalid_token": "invalid_token_789"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"expected_tools": [
|
|
||||||
"exec_query",
|
|
||||||
"get_db_list",
|
|
||||||
"get_db_table_list",
|
|
||||||
"get_table_schema",
|
|
||||||
"get_table_comment",
|
|
||||||
"get_table_column_comments",
|
|
||||||
"get_table_indexes",
|
|
||||||
"get_recent_audit_logs",
|
|
||||||
"get_catalog_list",
|
|
||||||
"get_sql_explain",
|
|
||||||
"get_sql_profile",
|
|
||||||
"get_table_data_size",
|
|
||||||
"get_monitoring_metrics_info",
|
|
||||||
"get_monitoring_metrics_data",
|
|
||||||
"get_realtime_memory_stats",
|
|
||||||
"get_historical_memory_stats"
|
|
||||||
],
|
|
||||||
"expected_resources": [
|
|
||||||
"database",
|
|
||||||
"table",
|
|
||||||
"view"
|
|
||||||
],
|
|
||||||
"expected_prompts": [
|
|
||||||
"sql_query_assistant",
|
|
||||||
"data_analysis_helper",
|
|
||||||
"schema_explorer"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
@@ -1,214 +0,0 @@
|
|||||||
# Licensed to the Apache Software Foundation (ASF) under one
|
|
||||||
# or more contributor license agreements. See the NOTICE file
|
|
||||||
# distributed with this work for additional information
|
|
||||||
# regarding copyright ownership. The ASF licenses this file
|
|
||||||
# to you under the Apache License, Version 2.0 (the
|
|
||||||
# "License"); you may not use this file except in compliance
|
|
||||||
# with the License. You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing,
|
|
||||||
# software distributed under the License is distributed on an
|
|
||||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
# KIND, either express or implied. See the License for the
|
|
||||||
# specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
"""
|
|
||||||
Test Configuration Loader
|
|
||||||
|
|
||||||
Loads test configuration and provides methods to connect to running servers
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, Any, Optional
|
|
||||||
import logging
|
|
||||||
|
|
||||||
# Add project root to path
|
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
|
||||||
|
|
||||||
from doris_mcp_client.client import DorisUnifiedClient, DorisClientConfig
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class TestConfigLoader:
|
|
||||||
"""Test configuration loader and client factory"""
|
|
||||||
|
|
||||||
def __init__(self, config_path: Optional[str] = None):
|
|
||||||
"""Initialize with config file path"""
|
|
||||||
if config_path is None:
|
|
||||||
config_path = os.path.join(os.path.dirname(__file__), "test_config.json")
|
|
||||||
|
|
||||||
self.config_path = Path(config_path)
|
|
||||||
self.config = self._load_config()
|
|
||||||
|
|
||||||
def _load_config(self) -> Dict[str, Any]:
|
|
||||||
"""Load configuration from JSON file"""
|
|
||||||
try:
|
|
||||||
with open(self.config_path, 'r', encoding='utf-8') as f:
|
|
||||||
config = json.load(f)
|
|
||||||
logger.info(f"Loaded test configuration from {self.config_path}")
|
|
||||||
return config
|
|
||||||
except FileNotFoundError:
|
|
||||||
logger.error(f"Test configuration file not found: {self.config_path}")
|
|
||||||
raise
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
logger.error(f"Invalid JSON in test configuration: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def get_http_client_config(self) -> DorisClientConfig:
|
|
||||||
"""Get HTTP client configuration"""
|
|
||||||
http_config = self.config["server_endpoints"]["http"]
|
|
||||||
return DorisClientConfig.http(
|
|
||||||
url=http_config["url"],
|
|
||||||
timeout=http_config["timeout"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_stdio_client_config(self) -> DorisClientConfig:
|
|
||||||
"""Get stdio client configuration"""
|
|
||||||
stdio_config = self.config["server_endpoints"]["stdio"]
|
|
||||||
return DorisClientConfig.stdio(
|
|
||||||
command=stdio_config["command"],
|
|
||||||
args=stdio_config["args"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_default_client_config(self) -> DorisClientConfig:
|
|
||||||
"""Get default client configuration based on test settings"""
|
|
||||||
transport = self.config["test_settings"]["default_transport"]
|
|
||||||
if transport == "http":
|
|
||||||
return self.get_http_client_config()
|
|
||||||
elif transport == "stdio":
|
|
||||||
return self.get_stdio_client_config()
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown transport type: {transport}")
|
|
||||||
|
|
||||||
def create_client(self, transport: Optional[str] = None) -> DorisUnifiedClient:
|
|
||||||
"""Create MCP client instance"""
|
|
||||||
if transport is None:
|
|
||||||
client_config = self.get_default_client_config()
|
|
||||||
elif transport == "http":
|
|
||||||
client_config = self.get_http_client_config()
|
|
||||||
elif transport == "stdio":
|
|
||||||
client_config = self.get_stdio_client_config()
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown transport type: {transport}")
|
|
||||||
|
|
||||||
return DorisUnifiedClient(client_config)
|
|
||||||
|
|
||||||
def get_test_settings(self) -> Dict[str, Any]:
|
|
||||||
"""Get test settings"""
|
|
||||||
return self.config["test_settings"]
|
|
||||||
|
|
||||||
def get_test_data(self) -> Dict[str, Any]:
|
|
||||||
"""Get test data"""
|
|
||||||
return self.config["test_data"]
|
|
||||||
|
|
||||||
def get_expected_tools(self) -> list[str]:
|
|
||||||
"""Get expected tools list"""
|
|
||||||
return self.config["expected_tools"]
|
|
||||||
|
|
||||||
def get_expected_resources(self) -> list[str]:
|
|
||||||
"""Get expected resources list"""
|
|
||||||
return self.config["expected_resources"]
|
|
||||||
|
|
||||||
def get_expected_prompts(self) -> list[str]:
|
|
||||||
"""Get expected prompts list"""
|
|
||||||
return self.config["expected_prompts"]
|
|
||||||
|
|
||||||
def get_sample_queries(self) -> list[str]:
|
|
||||||
"""Get sample queries for testing"""
|
|
||||||
return self.config["test_data"]["sample_queries"]
|
|
||||||
|
|
||||||
def get_auth_tokens(self) -> Dict[str, str]:
|
|
||||||
"""Get authentication tokens for testing"""
|
|
||||||
return self.config["test_data"]["auth_tokens"]
|
|
||||||
|
|
||||||
def get_test_databases(self) -> list[str]:
|
|
||||||
"""Get test databases list"""
|
|
||||||
return self.config["test_data"]["test_databases"]
|
|
||||||
|
|
||||||
def get_test_tables(self) -> list[str]:
|
|
||||||
"""Get test tables list"""
|
|
||||||
return self.config["test_data"]["test_tables"]
|
|
||||||
|
|
||||||
def is_performance_tests_enabled(self) -> bool:
|
|
||||||
"""Check if performance tests are enabled"""
|
|
||||||
return self.config["test_settings"]["enable_performance_tests"]
|
|
||||||
|
|
||||||
def is_security_tests_enabled(self) -> bool:
|
|
||||||
"""Check if security tests are enabled"""
|
|
||||||
return self.config["test_settings"]["enable_security_tests"]
|
|
||||||
|
|
||||||
def get_retry_config(self) -> Dict[str, Any]:
|
|
||||||
"""Get retry configuration"""
|
|
||||||
return {
|
|
||||||
"attempts": self.config["test_settings"]["retry_attempts"],
|
|
||||||
"delay": self.config["test_settings"]["retry_delay"]
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_test_timeout(self) -> int:
|
|
||||||
"""Get test timeout in seconds"""
|
|
||||||
return self.config["test_settings"]["test_timeout"]
|
|
||||||
|
|
||||||
|
|
||||||
# Global test config instance
|
|
||||||
_test_config = None
|
|
||||||
|
|
||||||
def get_test_config() -> TestConfigLoader:
|
|
||||||
"""Get global test configuration instance"""
|
|
||||||
global _test_config
|
|
||||||
if _test_config is None:
|
|
||||||
_test_config = TestConfigLoader()
|
|
||||||
return _test_config
|
|
||||||
|
|
||||||
|
|
||||||
def create_test_client(transport: Optional[str] = None) -> DorisUnifiedClient:
|
|
||||||
"""Create test client with default configuration"""
|
|
||||||
return get_test_config().create_client(transport)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_server_connectivity(transport: Optional[str] = None) -> bool:
|
|
||||||
"""Test server connectivity"""
|
|
||||||
try:
|
|
||||||
client = create_test_client(transport)
|
|
||||||
|
|
||||||
async def test_connection(client_instance):
|
|
||||||
try:
|
|
||||||
# Try to list tools as a connectivity test
|
|
||||||
tools = await client_instance.list_all_tools()
|
|
||||||
return len(tools) > 0
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Connectivity test failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
result = await client.connect_and_run(test_connection)
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to test server connectivity: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# Test configuration loading
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
config = get_test_config()
|
|
||||||
print("Test Configuration Loaded:")
|
|
||||||
print(f" Default transport: {config.get_test_settings()['default_transport']}")
|
|
||||||
print(f" Expected tools: {len(config.get_expected_tools())}")
|
|
||||||
print(f" Sample queries: {len(config.get_sample_queries())}")
|
|
||||||
|
|
||||||
# Test connectivity
|
|
||||||
print("\nTesting server connectivity...")
|
|
||||||
http_ok = await test_server_connectivity("http")
|
|
||||||
print(f" HTTP connectivity: {'✓' if http_ok else '✗'}")
|
|
||||||
|
|
||||||
stdio_ok = await test_server_connectivity("stdio")
|
|
||||||
print(f" Stdio connectivity: {'✓' if stdio_ok else '✗'}")
|
|
||||||
|
|
||||||
asyncio.run(main())
|
|
||||||
@@ -1,175 +0,0 @@
|
|||||||
# Licensed to the Apache Software Foundation (ASF) under one
|
|
||||||
# or more contributor license agreements. See the NOTICE file
|
|
||||||
# distributed with this work for additional information
|
|
||||||
# regarding copyright ownership. The ASF licenses this file
|
|
||||||
# to you under the Apache License, Version 2.0 (the
|
|
||||||
# "License"); you may not use this file except in compliance
|
|
||||||
# with the License. You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing,
|
|
||||||
# software distributed under the License is distributed on an
|
|
||||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
# KIND, either express or implied. See the License for the
|
|
||||||
# specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
"""
|
|
||||||
Tools Manager Client-Server Integration Tests
|
|
||||||
|
|
||||||
Tests the tools functionality through actual MCP client-server communication
|
|
||||||
Assumes the server is already running and configured properly
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import pytest
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from typing import Dict, Any
|
|
||||||
|
|
||||||
# Add project root to path
|
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
|
|
||||||
|
|
||||||
from test.test_config_loader import get_test_config, create_test_client, test_server_connectivity
|
|
||||||
|
|
||||||
|
|
||||||
class TestToolsClientServer:
|
|
||||||
"""Test tools functionality through client-server communication"""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def test_config(self):
|
|
||||||
"""Get test configuration"""
|
|
||||||
return get_test_config()
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def client(self, test_config):
|
|
||||||
"""Create test client"""
|
|
||||||
return create_test_client()
|
|
||||||
|
|
||||||
@pytest.fixture(scope="class", autouse=True)
|
|
||||||
async def check_server_connectivity(self):
|
|
||||||
"""Check server connectivity before running tests"""
|
|
||||||
is_connected = await test_server_connectivity()
|
|
||||||
if not is_connected:
|
|
||||||
pytest.skip("Server is not running or not accessible")
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_tools_via_client(self, client, test_config):
|
|
||||||
"""Test listing tools through client-server communication"""
|
|
||||||
expected_tools = test_config.get_expected_tools()
|
|
||||||
|
|
||||||
async def test_callback(client_instance):
|
|
||||||
tools = await client_instance.list_all_tools()
|
|
||||||
|
|
||||||
# Verify we got tools back
|
|
||||||
assert len(tools) > 0, "No tools returned from server"
|
|
||||||
|
|
||||||
# Verify expected tools are present
|
|
||||||
tool_names = [tool.name for tool in tools]
|
|
||||||
for expected_tool in expected_tools:
|
|
||||||
assert expected_tool in tool_names, f"Expected tool '{expected_tool}' not found"
|
|
||||||
|
|
||||||
return tools
|
|
||||||
|
|
||||||
result = await client.connect_and_run(test_callback)
|
|
||||||
assert len(result) > 0
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_tool_exec_query_via_client(self, client, test_config):
|
|
||||||
"""Test calling exec_query tool through client"""
|
|
||||||
sample_queries = test_config.get_sample_queries()
|
|
||||||
|
|
||||||
async def test_callback(client_instance):
|
|
||||||
# Test with a simple query
|
|
||||||
result = await client_instance.call_tool("exec_query", {
|
|
||||||
"sql": sample_queries[0], # "SELECT 1 as test_value"
|
|
||||||
"max_rows": 100
|
|
||||||
})
|
|
||||||
|
|
||||||
# Verify result structure
|
|
||||||
assert "success" in result, "Result should contain 'success' field"
|
|
||||||
|
|
||||||
if result["success"]:
|
|
||||||
assert "result" in result, "Successful result should contain 'result' field"
|
|
||||||
else:
|
|
||||||
assert "error" in result, "Failed result should contain 'error' field"
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
result = await client.connect_and_run(test_callback)
|
|
||||||
# Don't assert success=True as it depends on actual server state
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_tool_get_db_list_via_client(self, client, test_config):
|
|
||||||
"""Test calling get_db_list tool through client"""
|
|
||||||
async def test_callback(client_instance):
|
|
||||||
result = await client_instance.call_tool("get_db_list", {})
|
|
||||||
|
|
||||||
# Verify result structure
|
|
||||||
assert "success" in result, "Result should contain 'success' field"
|
|
||||||
|
|
||||||
if result["success"]:
|
|
||||||
assert "result" in result, "Successful result should contain 'result' field"
|
|
||||||
assert isinstance(result["result"], list), "Database list should be a list"
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
result = await client.connect_and_run(test_callback)
|
|
||||||
assert "success" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_tool_get_table_schema_via_client(self, client, test_config):
|
|
||||||
"""Test calling get_table_schema tool through client"""
|
|
||||||
test_tables = test_config.get_test_tables()
|
|
||||||
|
|
||||||
async def test_callback(client_instance):
|
|
||||||
result = await client_instance.call_tool("get_table_schema", {
|
|
||||||
"table_name": test_tables[0], # "users"
|
|
||||||
"db_name": "information_schema" # Use a database that should exist
|
|
||||||
})
|
|
||||||
|
|
||||||
# Verify result structure
|
|
||||||
assert "success" in result, "Result should contain 'success' field"
|
|
||||||
return result
|
|
||||||
|
|
||||||
result = await client.connect_and_run(test_callback)
|
|
||||||
assert "success" in result
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_error_handling_via_client(self, client, test_config):
|
|
||||||
"""Test tool error handling through client"""
|
|
||||||
async def test_callback(client_instance):
|
|
||||||
# Try to call a tool with invalid parameters
|
|
||||||
result = await client_instance.call_tool("exec_query", {
|
|
||||||
"sql": "INVALID SQL SYNTAX HERE"
|
|
||||||
})
|
|
||||||
|
|
||||||
# Should get a result (either success or error)
|
|
||||||
assert "success" in result, "Result should contain 'success' field"
|
|
||||||
return result
|
|
||||||
|
|
||||||
result = await client.connect_and_run(test_callback)
|
|
||||||
assert "success" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_with_auth_token_via_client(self, client, test_config):
|
|
||||||
"""Test tool calls with authentication token"""
|
|
||||||
if not test_config.is_security_tests_enabled():
|
|
||||||
pytest.skip("Security tests are disabled")
|
|
||||||
|
|
||||||
auth_tokens = test_config.get_auth_tokens()
|
|
||||||
|
|
||||||
async def test_callback(client_instance):
|
|
||||||
result = await client_instance.call_tool("get_db_list", {
|
|
||||||
"auth_token": auth_tokens["valid_token"]
|
|
||||||
})
|
|
||||||
|
|
||||||
# Verify result structure
|
|
||||||
assert "success" in result, "Result should contain 'success' field"
|
|
||||||
return result
|
|
||||||
|
|
||||||
result = await client.connect_and_run(test_callback)
|
|
||||||
assert "success" in result
|
|
||||||
@@ -1,271 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Licensed to the Apache Software Foundation (ASF) under one
|
|
||||||
# or more contributor license agreements. See the NOTICE file
|
|
||||||
# distributed with this work for additional information
|
|
||||||
# regarding copyright ownership. The ASF licenses this file
|
|
||||||
# to you under the Apache License, Version 2.0 (the
|
|
||||||
# "License"); you may not use this file except in compliance
|
|
||||||
# with the License. You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing,
|
|
||||||
# software distributed under the License is distributed on an
|
|
||||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
# KIND, either express or implied. See the License for the
|
|
||||||
# specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
"""
|
|
||||||
Tools manager tests
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import pytest
|
|
||||||
from unittest.mock import Mock, AsyncMock, patch
|
|
||||||
|
|
||||||
from doris_mcp_server.tools.tools_manager import DorisToolsManager
|
|
||||||
from doris_mcp_server.utils.config import DorisConfig
|
|
||||||
|
|
||||||
|
|
||||||
class TestDorisToolsManager:
|
|
||||||
"""Doris tools manager tests"""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_config(self):
|
|
||||||
"""Create mock configuration"""
|
|
||||||
from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig
|
|
||||||
|
|
||||||
config = Mock(spec=DorisConfig)
|
|
||||||
|
|
||||||
# Add database config
|
|
||||||
config.database = Mock(spec=DatabaseConfig)
|
|
||||||
config.database.host = "localhost"
|
|
||||||
config.database.port = 9030
|
|
||||||
config.database.user = "test_user"
|
|
||||||
config.database.password = "test_password"
|
|
||||||
config.database.database = "test_db"
|
|
||||||
config.database.health_check_interval = 60
|
|
||||||
config.database.min_connections = 5
|
|
||||||
config.database.max_connections = 20
|
|
||||||
config.database.connection_timeout = 30
|
|
||||||
config.database.max_connection_age = 3600
|
|
||||||
|
|
||||||
# Add security config
|
|
||||||
config.security = Mock(spec=SecurityConfig)
|
|
||||||
config.security.enable_masking = True
|
|
||||||
config.security.auth_type = "token"
|
|
||||||
config.security.token_secret = "test_secret"
|
|
||||||
config.security.token_expiry = 3600
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def tools_manager(self, mock_config):
|
|
||||||
"""Create tools manager instance"""
|
|
||||||
# Create a proper mock connection manager
|
|
||||||
mock_connection_manager = Mock()
|
|
||||||
mock_connection_manager.get_connection = AsyncMock()
|
|
||||||
return DorisToolsManager(mock_connection_manager)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_available_tools(self, tools_manager):
|
|
||||||
"""Test getting available tools"""
|
|
||||||
tools = await tools_manager.list_tools()
|
|
||||||
|
|
||||||
# Should have core tools
|
|
||||||
tool_names = [tool.name for tool in tools]
|
|
||||||
assert "exec_query" in tool_names
|
|
||||||
assert "get_db_list" in tool_names
|
|
||||||
assert "get_db_table_list" in tool_names
|
|
||||||
assert "get_table_schema" in tool_names
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_exec_query_tool(self, tools_manager):
|
|
||||||
"""Test exec_query tool"""
|
|
||||||
# Mock the execute_sql_for_mcp method instead
|
|
||||||
with patch.object(tools_manager.query_executor, 'execute_sql_for_mcp') as mock_execute:
|
|
||||||
mock_execute.return_value = {
|
|
||||||
"success": True,
|
|
||||||
"data": [
|
|
||||||
{"id": 1, "name": "张三"},
|
|
||||||
{"id": 2, "name": "李四"}
|
|
||||||
],
|
|
||||||
"row_count": 2,
|
|
||||||
"execution_time": 0.15
|
|
||||||
}
|
|
||||||
|
|
||||||
arguments = {
|
|
||||||
"sql": "SELECT id, name FROM users LIMIT 2",
|
|
||||||
"max_rows": 100
|
|
||||||
}
|
|
||||||
|
|
||||||
result = await tools_manager.call_tool("exec_query", arguments)
|
|
||||||
result_data = json.loads(result) if isinstance(result, str) else result
|
|
||||||
|
|
||||||
# The test should handle both success and error cases
|
|
||||||
if "success" in result_data and result_data["success"]:
|
|
||||||
# Check if result has data field or result field
|
|
||||||
if "data" in result_data and result_data["data"] is not None:
|
|
||||||
assert len(result_data["data"]) == 2
|
|
||||||
elif "result" in result_data and result_data["result"] is not None:
|
|
||||||
assert len(result_data["result"]) == 2
|
|
||||||
else:
|
|
||||||
# If there's an error, just check that error is reported
|
|
||||||
assert "error" in result_data
|
|
||||||
|
|
||||||
# Verify the method was called (may not be called if there are errors)
|
|
||||||
# Don't assert specific call parameters since the implementation may vary
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_exec_query_with_error(self, tools_manager):
|
|
||||||
"""Test exec_query tool with error"""
|
|
||||||
with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute:
|
|
||||||
mock_execute.side_effect = Exception("Database connection failed")
|
|
||||||
|
|
||||||
arguments = {
|
|
||||||
"sql": "SELECT * FROM users"
|
|
||||||
}
|
|
||||||
|
|
||||||
result = await tools_manager.call_tool("exec_query", arguments)
|
|
||||||
result_data = json.loads(result) if isinstance(result, str) else result
|
|
||||||
|
|
||||||
assert "error" in result_data or "success" in result_data
|
|
||||||
if "error" in result_data:
|
|
||||||
# Accept any connection-related error message
|
|
||||||
assert any(keyword in result_data["error"].lower() for keyword in
|
|
||||||
["connection", "failed", "error", "mock"])
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_db_list_tool(self, tools_manager):
|
|
||||||
"""Test get_db_list tool"""
|
|
||||||
with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute:
|
|
||||||
mock_execute.return_value = [
|
|
||||||
{"Database": "test_db"},
|
|
||||||
{"Database": "information_schema"},
|
|
||||||
{"Database": "mysql"}
|
|
||||||
]
|
|
||||||
|
|
||||||
result = await tools_manager.call_tool("get_db_list", {})
|
|
||||||
result_data = json.loads(result) if isinstance(result, str) else result
|
|
||||||
|
|
||||||
# Check if result has databases field or result field
|
|
||||||
if "databases" in result_data:
|
|
||||||
assert len(result_data["databases"]) == 3
|
|
||||||
elif "result" in result_data:
|
|
||||||
assert len(result_data["result"]) >= 0 # May be empty if no databases
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_db_table_list_tool(self, tools_manager):
|
|
||||||
"""Test get_db_table_list tool"""
|
|
||||||
with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute:
|
|
||||||
mock_execute.return_value = [
|
|
||||||
{"Tables_in_test_db": "users"},
|
|
||||||
{"Tables_in_test_db": "orders"},
|
|
||||||
{"Tables_in_test_db": "products"}
|
|
||||||
]
|
|
||||||
|
|
||||||
arguments = {"db_name": "test_db"}
|
|
||||||
result = await tools_manager.call_tool("get_db_table_list", arguments)
|
|
||||||
result_data = json.loads(result) if isinstance(result, str) else result
|
|
||||||
|
|
||||||
# Check if result has tables field or result field
|
|
||||||
if "tables" in result_data:
|
|
||||||
assert len(result_data["tables"]) == 3
|
|
||||||
assert "users" in result_data["tables"]
|
|
||||||
elif "result" in result_data:
|
|
||||||
assert len(result_data["result"]) >= 0 # May be empty if no tables
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_table_schema_tool(self, tools_manager):
|
|
||||||
"""Test get_table_schema tool"""
|
|
||||||
with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute:
|
|
||||||
mock_execute.return_value = [
|
|
||||||
{
|
|
||||||
"Field": "id",
|
|
||||||
"Type": "int(11)",
|
|
||||||
"Null": "NO",
|
|
||||||
"Key": "PRI",
|
|
||||||
"Default": None,
|
|
||||||
"Extra": "auto_increment"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"Field": "name",
|
|
||||||
"Type": "varchar(100)",
|
|
||||||
"Null": "YES",
|
|
||||||
"Key": "",
|
|
||||||
"Default": None,
|
|
||||||
"Extra": ""
|
|
||||||
}
|
|
||||||
]
|
|
||||||
|
|
||||||
arguments = {"table_name": "users"}
|
|
||||||
result = await tools_manager.call_tool("get_table_schema", arguments)
|
|
||||||
result_data = json.loads(result) if isinstance(result, str) else result
|
|
||||||
|
|
||||||
# Check if result has schema field or result field
|
|
||||||
if "schema" in result_data:
|
|
||||||
assert len(result_data["schema"]) == 2
|
|
||||||
assert result_data["schema"][0]["Field"] == "id"
|
|
||||||
elif "result" in result_data:
|
|
||||||
assert len(result_data["result"]) >= 0 # May be empty if no schema
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_catalog_list_tool(self, tools_manager):
|
|
||||||
"""Test get_catalog_list tool"""
|
|
||||||
with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute:
|
|
||||||
mock_execute.return_value = [
|
|
||||||
{"CatalogName": "internal"},
|
|
||||||
{"CatalogName": "hive_catalog"},
|
|
||||||
{"CatalogName": "iceberg_catalog"}
|
|
||||||
]
|
|
||||||
|
|
||||||
arguments = {"random_string": "test_123"}
|
|
||||||
result = await tools_manager.call_tool("get_catalog_list", arguments)
|
|
||||||
result_data = json.loads(result) if isinstance(result, str) else result
|
|
||||||
|
|
||||||
# Check if result has catalogs field or result field
|
|
||||||
if "catalogs" in result_data:
|
|
||||||
assert len(result_data["catalogs"]) == 3
|
|
||||||
assert "internal" in result_data["catalogs"]
|
|
||||||
elif "result" in result_data:
|
|
||||||
assert len(result_data["result"]) >= 0 # May be empty if no catalogs
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_invalid_tool_name(self, tools_manager):
|
|
||||||
"""Test calling invalid tool"""
|
|
||||||
result = await tools_manager.call_tool("invalid_tool", {})
|
|
||||||
result_data = json.loads(result) if isinstance(result, str) else result
|
|
||||||
|
|
||||||
assert "error" in result_data or "success" in result_data
|
|
||||||
if "error" in result_data:
|
|
||||||
assert "Unknown tool" in result_data["error"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_missing_required_arguments(self, tools_manager):
|
|
||||||
"""Test calling tool with missing required arguments"""
|
|
||||||
# exec_query requires sql parameter
|
|
||||||
result = await tools_manager.call_tool("exec_query", {})
|
|
||||||
result_data = json.loads(result) if isinstance(result, str) else result
|
|
||||||
|
|
||||||
assert "error" in result_data or "success" in result_data
|
|
||||||
# The test may pass if the tool handles missing parameters gracefully
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_definitions_structure(self, tools_manager):
|
|
||||||
"""Test tool definitions have correct structure"""
|
|
||||||
tools = await tools_manager.list_tools()
|
|
||||||
|
|
||||||
for tool in tools:
|
|
||||||
# Each tool should have required fields
|
|
||||||
assert hasattr(tool, 'name')
|
|
||||||
assert hasattr(tool, 'description')
|
|
||||||
assert hasattr(tool, 'inputSchema')
|
|
||||||
|
|
||||||
# Input schema should have properties
|
|
||||||
assert 'properties' in tool.inputSchema
|
|
||||||
|
|
||||||
# Required fields should be defined
|
|
||||||
if 'required' in tool.inputSchema:
|
|
||||||
assert isinstance(tool.inputSchema['required'], list)
|
|
||||||
@@ -1,204 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
# Licensed to the Apache Software Foundation (ASF) under one
|
|
||||||
# or more contributor license agreements. See the NOTICE file
|
|
||||||
# distributed with this work for additional information
|
|
||||||
# regarding copyright ownership. The ASF licenses this file
|
|
||||||
# to you under the Apache License, Version 2.0 (the
|
|
||||||
# "License"); you may not use this file except in compliance
|
|
||||||
# with the License. You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing,
|
|
||||||
# software distributed under the License is distributed on an
|
|
||||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
# KIND, either express or implied. See the License for the
|
|
||||||
# specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
"""
|
|
||||||
Query executor tests
|
|
||||||
"""
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from unittest.mock import Mock, AsyncMock, patch
|
|
||||||
|
|
||||||
from doris_mcp_server.utils.query_executor import DorisQueryExecutor
|
|
||||||
from doris_mcp_server.utils.config import DorisConfig
|
|
||||||
|
|
||||||
|
|
||||||
class TestDorisQueryExecutor:
|
|
||||||
"""Doris query executor tests"""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def mock_config(self):
|
|
||||||
"""Create mock configuration"""
|
|
||||||
from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig
|
|
||||||
|
|
||||||
config = Mock(spec=DorisConfig)
|
|
||||||
|
|
||||||
# Add database config
|
|
||||||
config.database = Mock(spec=DatabaseConfig)
|
|
||||||
config.database.host = "localhost"
|
|
||||||
config.database.port = 9030
|
|
||||||
config.database.user = "test_user"
|
|
||||||
config.database.password = "test_password"
|
|
||||||
config.database.database = "test_db"
|
|
||||||
config.database.health_check_interval = 60
|
|
||||||
config.database.min_connections = 5
|
|
||||||
config.database.max_connections = 20
|
|
||||||
config.database.connection_timeout = 30
|
|
||||||
config.database.max_connection_age = 3600
|
|
||||||
|
|
||||||
# Add security config
|
|
||||||
config.security = Mock(spec=SecurityConfig)
|
|
||||||
config.security.enable_masking = True
|
|
||||||
config.security.auth_type = "token"
|
|
||||||
config.security.token_secret = "test_secret"
|
|
||||||
config.security.token_expiry = 3600
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def query_executor(self, mock_config):
|
|
||||||
"""Create query executor instance"""
|
|
||||||
# Create a mock connection manager
|
|
||||||
mock_connection_manager = Mock()
|
|
||||||
return DorisQueryExecutor(mock_connection_manager, mock_config)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_execute_query_success(self, query_executor):
|
|
||||||
"""Test successful query execution using MCP interface"""
|
|
||||||
with patch.object(query_executor, 'execute_sql_for_mcp') as mock_execute:
|
|
||||||
mock_execute.return_value = {
|
|
||||||
"success": True,
|
|
||||||
"data": [
|
|
||||||
{"id": 1, "name": "张三", "email": "zhangsan@example.com"},
|
|
||||||
{"id": 2, "name": "李四", "email": "lisi@example.com"}
|
|
||||||
],
|
|
||||||
"row_count": 2,
|
|
||||||
"execution_time": 0.15,
|
|
||||||
"columns": ["id", "name", "email"]
|
|
||||||
}
|
|
||||||
|
|
||||||
sql = "SELECT id, name, email FROM users LIMIT 2"
|
|
||||||
result = await query_executor.execute_sql_for_mcp(sql)
|
|
||||||
|
|
||||||
# Verify results
|
|
||||||
assert result["success"] is True
|
|
||||||
assert result["row_count"] == 2
|
|
||||||
assert len(result["data"]) == 2
|
|
||||||
assert result["data"][0]["id"] == 1
|
|
||||||
assert result["data"][0]["name"] == "张三"
|
|
||||||
assert result["data"][1]["email"] == "lisi@example.com"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_execute_query_with_parameters(self, query_executor):
|
|
||||||
"""Test query execution with parameters"""
|
|
||||||
with patch.object(query_executor, 'execute_sql_for_mcp') as mock_execute:
|
|
||||||
mock_execute.return_value = {
|
|
||||||
"success": True,
|
|
||||||
"data": [{"id": 1, "name": "张三"}],
|
|
||||||
"row_count": 1,
|
|
||||||
"execution_time": 0.1
|
|
||||||
}
|
|
||||||
|
|
||||||
sql = "SELECT id, name FROM users WHERE department = 'sales'"
|
|
||||||
result = await query_executor.execute_sql_for_mcp(sql)
|
|
||||||
|
|
||||||
# Verify results
|
|
||||||
assert result["success"] is True
|
|
||||||
assert result["row_count"] == 1
|
|
||||||
assert len(result["data"]) == 1
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_execute_query_connection_error(self, query_executor):
|
|
||||||
"""Test query execution with connection error"""
|
|
||||||
with patch.object(query_executor, 'execute_sql_for_mcp') as mock_execute:
|
|
||||||
mock_execute.return_value = {
|
|
||||||
"success": False,
|
|
||||||
"error": "Connection failed",
|
|
||||||
"data": None
|
|
||||||
}
|
|
||||||
|
|
||||||
sql = "SELECT * FROM users"
|
|
||||||
result = await query_executor.execute_sql_for_mcp(sql)
|
|
||||||
|
|
||||||
assert result["success"] is False
|
|
||||||
assert "Connection failed" in result["error"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_execute_query_sql_error(self, query_executor):
|
|
||||||
"""Test query execution with SQL error"""
|
|
||||||
with patch.object(query_executor, 'execute_sql_for_mcp') as mock_execute:
|
|
||||||
mock_execute.return_value = {
|
|
||||||
"success": False,
|
|
||||||
"error": "SQL syntax error",
|
|
||||||
"data": None
|
|
||||||
}
|
|
||||||
|
|
||||||
sql = "SELECT * FROM non_existent_table"
|
|
||||||
result = await query_executor.execute_sql_for_mcp(sql)
|
|
||||||
|
|
||||||
assert result["success"] is False
|
|
||||||
assert "SQL syntax error" in result["error"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_execute_query_empty_result(self, query_executor):
|
|
||||||
"""Test query execution with empty result"""
|
|
||||||
with patch.object(query_executor, 'execute_sql_for_mcp') as mock_execute:
|
|
||||||
mock_execute.return_value = {
|
|
||||||
"success": True,
|
|
||||||
"data": [],
|
|
||||||
"row_count": 0,
|
|
||||||
"execution_time": 0.05
|
|
||||||
}
|
|
||||||
|
|
||||||
sql = "SELECT * FROM users WHERE id = 999"
|
|
||||||
result = await query_executor.execute_sql_for_mcp(sql)
|
|
||||||
|
|
||||||
assert result["success"] is True
|
|
||||||
assert result["data"] == []
|
|
||||||
assert result["row_count"] == 0
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_execute_query_max_rows_limit(self, query_executor):
|
|
||||||
"""Test query execution with max rows limit"""
|
|
||||||
with patch.object(query_executor, 'execute_sql_for_mcp') as mock_execute:
|
|
||||||
# Mock large result set limited to 100 rows
|
|
||||||
limited_result = [{"id": i, "name": f"user_{i}"} for i in range(100)]
|
|
||||||
mock_execute.return_value = {
|
|
||||||
"success": True,
|
|
||||||
"data": limited_result,
|
|
||||||
"row_count": 100,
|
|
||||||
"execution_time": 0.2
|
|
||||||
}
|
|
||||||
|
|
||||||
sql = "SELECT id, name FROM users"
|
|
||||||
result = await query_executor.execute_sql_for_mcp(sql, limit=100)
|
|
||||||
|
|
||||||
# Should be limited to max_rows
|
|
||||||
assert result["success"] is True
|
|
||||||
assert len(result["data"]) == 100
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_execute_sql_for_mcp_interface(self, query_executor):
|
|
||||||
"""Test the MCP interface method directly"""
|
|
||||||
with patch.object(query_executor.connection_manager, 'get_connection') as mock_get_conn:
|
|
||||||
# Mock connection and result
|
|
||||||
mock_connection = AsyncMock()
|
|
||||||
mock_connection.execute.return_value = Mock(
|
|
||||||
data=[{"id": 1, "name": "张三"}],
|
|
||||||
row_count=1,
|
|
||||||
execution_time=0.1,
|
|
||||||
metadata={}
|
|
||||||
)
|
|
||||||
mock_get_conn.return_value = mock_connection
|
|
||||||
|
|
||||||
sql = "SELECT id, name FROM users LIMIT 1"
|
|
||||||
result = await query_executor.execute_sql_for_mcp(sql)
|
|
||||||
|
|
||||||
# Should return success format
|
|
||||||
assert "success" in result
|
|
||||||
if result["success"]:
|
|
||||||
assert "data" in result
|
|
||||||
assert "row_count" in result
|
|
||||||
@@ -1,156 +0,0 @@
|
|||||||
# Licensed to the Apache Software Foundation (ASF) under one
|
|
||||||
# or more contributor license agreements. See the NOTICE file
|
|
||||||
# distributed with this work for additional information
|
|
||||||
# regarding copyright ownership. The ASF licenses this file
|
|
||||||
# to you under the Apache License, Version 2.0 (the
|
|
||||||
# "License"); you may not use this file except in compliance
|
|
||||||
# with the License. You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing,
|
|
||||||
# software distributed under the License is distributed on an
|
|
||||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
||||||
# KIND, either express or implied. See the License for the
|
|
||||||
# specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
"""
|
|
||||||
Query Executor Client-Server Integration Tests
|
|
||||||
|
|
||||||
Tests the query execution functionality through actual MCP client-server communication
|
|
||||||
Assumes the server is already running and configured properly
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import pytest
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from typing import Dict, Any
|
|
||||||
|
|
||||||
# Add project root to path
|
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
|
|
||||||
|
|
||||||
from test.test_config_loader import get_test_config, create_test_client, test_server_connectivity
|
|
||||||
|
|
||||||
|
|
||||||
class TestQueryExecutorClientServer:
|
|
||||||
"""Test query execution functionality through client-server communication"""
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def test_config(self):
|
|
||||||
"""Get test configuration"""
|
|
||||||
return get_test_config()
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def client(self, test_config):
|
|
||||||
"""Create test client"""
|
|
||||||
return create_test_client()
|
|
||||||
|
|
||||||
@pytest.fixture(scope="class", autouse=True)
|
|
||||||
async def check_server_connectivity(self):
|
|
||||||
"""Check server connectivity before running tests"""
|
|
||||||
is_connected = await test_server_connectivity()
|
|
||||||
if not is_connected:
|
|
||||||
pytest.skip("Server is not running or not accessible")
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_simple_select_query_via_client(self, client, test_config):
|
|
||||||
"""Test simple SELECT query through client"""
|
|
||||||
sample_queries = test_config.get_sample_queries()
|
|
||||||
|
|
||||||
async def test_callback(client_instance):
|
|
||||||
result = await client_instance.execute_sql(sample_queries[0]) # "SELECT 1 as test_value"
|
|
||||||
|
|
||||||
# Verify result structure
|
|
||||||
assert "success" in result, "Result should contain 'success' field"
|
|
||||||
|
|
||||||
if result["success"]:
|
|
||||||
assert "result" in result, "Successful result should contain 'result' field"
|
|
||||||
else:
|
|
||||||
assert "error" in result, "Failed result should contain 'error' field"
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
result = await client.connect_and_run(test_callback)
|
|
||||||
assert "success" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_show_databases_query_via_client(self, client, test_config):
|
|
||||||
"""Test SHOW DATABASES query through client"""
|
|
||||||
sample_queries = test_config.get_sample_queries()
|
|
||||||
|
|
||||||
async def test_callback(client_instance):
|
|
||||||
result = await client_instance.execute_sql(sample_queries[1]) # "SHOW DATABASES"
|
|
||||||
|
|
||||||
# Verify result structure
|
|
||||||
assert "success" in result, "Result should contain 'success' field"
|
|
||||||
return result
|
|
||||||
|
|
||||||
result = await client.connect_and_run(test_callback)
|
|
||||||
assert "success" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_information_schema_query_via_client(self, client, test_config):
|
|
||||||
"""Test information_schema query through client"""
|
|
||||||
sample_queries = test_config.get_sample_queries()
|
|
||||||
|
|
||||||
async def test_callback(client_instance):
|
|
||||||
result = await client_instance.execute_sql(sample_queries[2]) # "SELECT COUNT(*) FROM information_schema.tables"
|
|
||||||
|
|
||||||
# Verify result structure
|
|
||||||
assert "success" in result, "Result should contain 'success' field"
|
|
||||||
return result
|
|
||||||
|
|
||||||
result = await client.connect_and_run(test_callback)
|
|
||||||
assert "success" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_with_max_rows_parameter_via_client(self, client, test_config):
|
|
||||||
"""Test query with max_rows parameter through client"""
|
|
||||||
async def test_callback(client_instance):
|
|
||||||
result = await client_instance.call_tool("exec_query", {
|
|
||||||
"sql": "SELECT 1 as test_value",
|
|
||||||
"max_rows": 10
|
|
||||||
})
|
|
||||||
|
|
||||||
# Verify result structure
|
|
||||||
assert "success" in result, "Result should contain 'success' field"
|
|
||||||
return result
|
|
||||||
|
|
||||||
result = await client.connect_and_run(test_callback)
|
|
||||||
assert "success" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_error_handling_via_client(self, client, test_config):
|
|
||||||
"""Test query error handling through client"""
|
|
||||||
async def test_callback(client_instance):
|
|
||||||
result = await client_instance.execute_sql("INVALID SQL SYNTAX")
|
|
||||||
|
|
||||||
# Should get a result (either success or error)
|
|
||||||
assert "success" in result, "Result should contain 'success' field"
|
|
||||||
return result
|
|
||||||
|
|
||||||
result = await client.connect_and_run(test_callback)
|
|
||||||
assert "success" in result
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_query_with_auth_token_via_client(self, client, test_config):
|
|
||||||
"""Test query with authentication token"""
|
|
||||||
if not test_config.is_security_tests_enabled():
|
|
||||||
pytest.skip("Security tests are disabled")
|
|
||||||
|
|
||||||
auth_tokens = test_config.get_auth_tokens()
|
|
||||||
|
|
||||||
async def test_callback(client_instance):
|
|
||||||
result = await client_instance.call_tool("exec_query", {
|
|
||||||
"sql": "SELECT 1 as test_value",
|
|
||||||
"auth_token": auth_tokens["valid_token"]
|
|
||||||
})
|
|
||||||
|
|
||||||
# Verify result structure
|
|
||||||
assert "success" in result, "Result should contain 'success' field"
|
|
||||||
return result
|
|
||||||
|
|
||||||
result = await client.connect_and_run(test_callback)
|
|
||||||
assert "success" in result
|
|
||||||
4
uv.lock
generated
@@ -948,7 +948,7 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mcp-doris-server"
|
name = "mcp-doris-server"
|
||||||
version = "0.4.2"
|
version = "0.3.0"
|
||||||
source = { editable = "." }
|
source = { editable = "." }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "aiofiles" },
|
{ name = "aiofiles" },
|
||||||
@@ -1054,7 +1054,7 @@ requires-dist = [
|
|||||||
{ name = "httpx", specifier = ">=0.26.0" },
|
{ name = "httpx", specifier = ">=0.26.0" },
|
||||||
{ name = "isort", marker = "extra == 'dev'", specifier = ">=5.13.0" },
|
{ name = "isort", marker = "extra == 'dev'", specifier = ">=5.13.0" },
|
||||||
{ name = "jaeger-client", marker = "extra == 'monitoring'", specifier = ">=4.8.0" },
|
{ name = "jaeger-client", marker = "extra == 'monitoring'", specifier = ">=4.8.0" },
|
||||||
{ name = "mcp", specifier = ">=1.8.0,<2.0.0" },
|
{ name = "mcp", specifier = ">=1.0.0" },
|
||||||
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8.0" },
|
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8.0" },
|
||||||
{ name = "myst-parser", marker = "extra == 'dev'", specifier = ">=2.0.0" },
|
{ name = "myst-parser", marker = "extra == 'dev'", specifier = ">=2.0.0" },
|
||||||
{ name = "myst-parser", marker = "extra == 'docs'", specifier = ">=2.0.0" },
|
{ name = "myst-parser", marker = "extra == 'docs'", specifier = ">=2.0.0" },
|
||||||
|
|||||||