0.3.0 Release Version
This commit is contained in:
54
.env.example
Normal file
54
.env.example
Normal file
@@ -0,0 +1,54 @@
|
||||
# Doris MCP Server Environment Configuration
|
||||
# Copy this file to .env and modify the values as needed
|
||||
|
||||
# Database Configuration
|
||||
DORIS_HOST=localhost
|
||||
DORIS_PORT=9030
|
||||
DORIS_USER=root
|
||||
DORIS_PASSWORD=your_password_here
|
||||
DORIS_DATABASE=your_database_name
|
||||
|
||||
# Connection Pool Settings
|
||||
DORIS_MIN_CONNECTIONS=5
|
||||
DORIS_MAX_CONNECTIONS=20
|
||||
DORIS_CONNECTION_TIMEOUT=30
|
||||
DORIS_HEALTH_CHECK_INTERVAL=60
|
||||
DORIS_MAX_CONNECTION_AGE=3600
|
||||
|
||||
# Security Settings
|
||||
AUTH_TYPE=token
|
||||
TOKEN_SECRET=your_256_bit_secret_key_here
|
||||
TOKEN_EXPIRY=3600
|
||||
MAX_RESULT_ROWS=10000
|
||||
ENABLE_MASKING=true
|
||||
|
||||
# Performance Settings
|
||||
ENABLE_QUERY_CACHE=true
|
||||
CACHE_TTL=300
|
||||
MAX_CACHE_SIZE=1000
|
||||
MAX_CONCURRENT_QUERIES=50
|
||||
QUERY_TIMEOUT=300
|
||||
|
||||
# Logging Configuration
|
||||
LOG_LEVEL=INFO
|
||||
LOG_FILE_PATH=./log/doris-mcp-server.log
|
||||
ENABLE_AUDIT=true
|
||||
AUDIT_FILE_PATH=./log/doris-mcp-audit.log
|
||||
|
||||
# Monitoring Settings
|
||||
ENABLE_METRICS=true
|
||||
METRICS_PORT=3001
|
||||
METRICS_PATH=/metrics
|
||||
HEALTH_CHECK_PORT=3002
|
||||
HEALTH_CHECK_PATH=/health
|
||||
ENABLE_ALERTS=false
|
||||
ALERT_WEBHOOK_URL=
|
||||
|
||||
# Server Settings
|
||||
SERVER_NAME=doris-mcp-server
|
||||
SERVER_VERSION=0.3.0
|
||||
SERVER_PORT=3000
|
||||
|
||||
# Development Settings (for development environment only)
|
||||
DEBUG=false
|
||||
VERBOSE=false
|
||||
49
Dockerfile
Normal file
49
Dockerfile
Normal file
@@ -0,0 +1,49 @@
|
||||
# Use Python 3.12 as base image
|
||||
FROM python:3.12-slim
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONPATH=/app
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
curl \
|
||||
gcc \
|
||||
g++ \
|
||||
pkg-config \
|
||||
default-libmysqlclient-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements file
|
||||
COPY requirements.txt .
|
||||
|
||||
# Install Python dependencies
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy application code
|
||||
COPY . .
|
||||
|
||||
# Create necessary directories
|
||||
RUN mkdir -p /app/logs /app/config /app/data
|
||||
|
||||
# Set permissions
|
||||
RUN chmod +x /app/start.sh
|
||||
|
||||
# Create non-root user
|
||||
RUN groupadd -r doris && useradd -r -g doris doris
|
||||
RUN chown -R doris:doris /app
|
||||
USER doris
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
|
||||
CMD curl -f http://localhost:3000/health || exit 1
|
||||
|
||||
# Expose ports
|
||||
EXPOSE 3000 3001 3002
|
||||
|
||||
# Start command
|
||||
CMD ["/app/start.sh"]
|
||||
201
LICENSE.txt
201
LICENSE.txt
@@ -1,201 +0,0 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
119
Makefile
Normal file
119
Makefile
Normal file
@@ -0,0 +1,119 @@
|
||||
# Doris MCP Server Makefile
|
||||
# Provides convenient commands using UV
|
||||
|
||||
.PHONY: help install sync dev test lint format build clean check start-stdio start-sse
|
||||
|
||||
# Default target
|
||||
help:
|
||||
@echo "Available commands:"
|
||||
@echo " install - Install dependencies using UV"
|
||||
@echo " sync - Sync dependencies and create virtual environment"
|
||||
@echo " dev - Install development dependencies"
|
||||
@echo " test - Run tests"
|
||||
@echo " lint - Run linting tools"
|
||||
@echo " format - Format code with black and isort"
|
||||
@echo " build - Build the package"
|
||||
@echo " clean - Clean build artifacts"
|
||||
@echo " check - Run all checks (format, lint, test)"
|
||||
@echo " start-stdio - Start server in stdio mode"
|
||||
@echo " start-sse - Start server in SSE mode"
|
||||
|
||||
# Install dependencies
|
||||
install:
|
||||
uv sync
|
||||
|
||||
# Sync dependencies with development extras
|
||||
sync:
|
||||
uv sync
|
||||
|
||||
# Install development dependencies
|
||||
dev:
|
||||
uv sync --dev
|
||||
|
||||
# Run tests
|
||||
test:
|
||||
uv run pytest
|
||||
|
||||
# Run linting tools
|
||||
lint:
|
||||
uv run ruff check doris_mcp_server/
|
||||
uv run mypy doris_mcp_server/
|
||||
|
||||
# Format code
|
||||
format:
|
||||
uv run ruff format doris_mcp_server/
|
||||
uv run ruff check --fix doris_mcp_server/
|
||||
|
||||
# Build the package
|
||||
build:
|
||||
uv build
|
||||
|
||||
# Clean build artifacts
|
||||
clean:
|
||||
rm -rf build/
|
||||
rm -rf dist/
|
||||
rm -rf *.egg-info/
|
||||
find . -type d -name __pycache__ -exec rm -rf {} +
|
||||
find . -type d -name .pytest_cache -exec rm -rf {} +
|
||||
find . -type d -name .mypy_cache -exec rm -rf {} +
|
||||
|
||||
# Run all checks
|
||||
check: format lint test
|
||||
|
||||
# Start server in stdio mode
|
||||
start-stdio:
|
||||
uv run python -m doris_mcp_server.main --transport stdio
|
||||
|
||||
# Start server in SSE mode
|
||||
start-sse:
|
||||
uv run python -m doris_mcp_server.main --transport sse --host 0.0.0.0 --port 8080
|
||||
|
||||
# Start server with custom database settings
|
||||
start-dev:
|
||||
uv run python -m doris_mcp_server.main \
|
||||
--transport stdio \
|
||||
--db-host localhost \
|
||||
--db-port 9030 \
|
||||
--db-user root \
|
||||
--log-level DEBUG
|
||||
|
||||
# Run a single test file
|
||||
test-file:
|
||||
uv run pytest $(FILE) -v
|
||||
|
||||
# Install and run in one command
|
||||
run: install start-stdio
|
||||
|
||||
# Development setup
|
||||
setup: dev
|
||||
@echo "✅ Development environment is ready!"
|
||||
@echo "Run 'make start-stdio' to start the server"
|
||||
|
||||
# Add dependencies
|
||||
add:
|
||||
uv add $(PACKAGE)
|
||||
|
||||
# Add development dependencies
|
||||
add-dev:
|
||||
uv add --dev $(PACKAGE)
|
||||
|
||||
# Show dependency tree
|
||||
deps:
|
||||
uv tree
|
||||
|
||||
# Lock dependencies
|
||||
lock:
|
||||
uv lock
|
||||
|
||||
# Check for outdated dependencies
|
||||
outdated:
|
||||
uv tree --outdated
|
||||
|
||||
# Export requirements.txt
|
||||
export-requirements:
|
||||
uv export --no-hashes > requirements.txt
|
||||
|
||||
# Show UV version and info
|
||||
info:
|
||||
uv --version
|
||||
uv python list
|
||||
645
README.md
645
README.md
@@ -2,23 +2,37 @@
|
||||
|
||||
Doris MCP (Model Context Protocol) Server is a backend service built with Python and FastAPI. It implements the MCP, allowing clients to interact with it through defined "Tools". It's primarily designed to connect to Apache Doris databases, potentially leveraging Large Language Models (LLMs) for tasks like converting natural language queries to SQL (NL2SQL), executing queries, and performing metadata management and analysis.
|
||||
|
||||
## 🚀 What's New in v0.3.0
|
||||
|
||||
- **🔄 Streamlined Communication**: Completely migrated from SSE to Streamable HTTP for better performance and reliability
|
||||
- **🏗️ Unified Architecture**: Consolidated tools management with centralized registration and routing
|
||||
- **⚡ Enhanced Performance**: Improved query execution with advanced caching and optimization
|
||||
- **🔒 Enterprise Security**: Added comprehensive security management with SQL validation and data masking
|
||||
- **📊 Advanced Analytics**: New column analysis and performance monitoring tools
|
||||
- **🛠️ Simplified Development**: Streamlined tool development process with unified interfaces
|
||||
|
||||
> **⚠️ Breaking Changes**: SSE endpoints have been removed. Please update your client configurations to use Streamable HTTP (`/mcp` endpoint).
|
||||
|
||||
## Core Features
|
||||
|
||||
* **MCP Protocol Implementation**: Provides standard MCP interfaces, supporting tool calls, resource management, and prompt interactions.
|
||||
* **Multiple Communication Modes**:
|
||||
* **SSE (Server-Sent Events)**: Served via `/sse` (initialization) and `/mcp/messages` (communication) endpoints (`src/sse_server.py`).
|
||||
* **Streamable HTTP**: Served via the unified `/mcp` endpoint, supporting request/response and streaming (`src/streamable_server.py`).
|
||||
* **(Optional) Stdio**: Interaction possible via standard input/output (`src/stdio_server.py`), requires specific startup configuration.
|
||||
* **Tool-Based Interface**: Core functionalities are encapsulated as MCP tools that clients can call as needed. Currently available key tools focus on direct database interaction with full catalog federation support:
|
||||
* SQL Execution with Catalog Federation (`mcp_doris_exec_query`)
|
||||
* Catalog Management (`mcp_doris_get_catalog_list`)
|
||||
* Database and Table Listing (`mcp_doris_get_db_list`, `mcp_doris_get_db_table_list`)
|
||||
* Metadata Retrieval (`mcp_doris_get_table_schema`, `mcp_doris_get_table_comment`, `mcp_doris_get_table_column_comments`, `mcp_doris_get_table_indexes`)
|
||||
* Audit Log Retrieval (`mcp_doris_get_recent_audit_logs`)
|
||||
*Note: All metadata tools support catalog federation for multi-catalog environments.*
|
||||
* **Database Interaction**: Provides functionality to connect to Apache Doris (or other compatible databases) and execute queries (`src/utils/db.py`).
|
||||
* **Flexible Configuration**: Configured via a `.env` file, supporting settings for database connections, LLM providers/models, API keys, logging levels, etc.
|
||||
* **Metadata Extraction**: Capable of extracting database metadata information with full catalog federation support (`src/utils/schema_extractor.py`).
|
||||
* **Multiple Communication Modes** (Updated in v0.3.0):
|
||||
* **Stdio**: Standard input/output mode for direct integration with MCP clients like Cursor.
|
||||
* **Streamable HTTP**: Unified HTTP endpoint supporting request/response and streaming (Primary mode since v0.3.0).
|
||||
|
||||
> **⚠️ Breaking Change in v0.3.0**: SSE (Server-Sent Events) mode has been completely removed in favor of the more robust Streamable HTTP implementation.
|
||||
* **Enterprise-Grade Architecture**: Modular design with comprehensive functionality:
|
||||
* **Tools Manager**: Centralized tool registration and routing (`doris_mcp_server/tools/tools_manager.py`)
|
||||
* **Resources Manager**: Resource management and metadata exposure (`doris_mcp_server/tools/resources_manager.py`)
|
||||
* **Prompts Manager**: Intelligent prompt templates for data analysis (`doris_mcp_server/tools/prompts_manager.py`)
|
||||
* **Advanced Database Features**:
|
||||
* **Query Execution**: High-performance SQL execution with caching and optimization (`doris_mcp_server/utils/query_executor.py`)
|
||||
* **Security Management**: SQL security validation, data masking, and access control (`doris_mcp_server/utils/security.py`)
|
||||
* **Metadata Extraction**: Comprehensive database metadata with catalog federation support (`doris_mcp_server/utils/schema_extractor.py`)
|
||||
* **Performance Analysis**: Column statistics, performance monitoring, and data analysis tools (`doris_mcp_server/utils/analysis_tools.py`)
|
||||
* **Catalog Federation Support**: Full support for multi-catalog environments (internal Doris tables and external data sources like Hive, MySQL, etc.)
|
||||
* **Enterprise Security**: Comprehensive security framework with authentication, authorization, SQL injection protection, and data masking (`doris_mcp_server/utils/security.py`)
|
||||
* **Flexible Configuration**: Comprehensive configuration management with environment variables, file-based config, and validation (`doris_mcp_server/utils/config.py`)
|
||||
|
||||
## System Requirements
|
||||
|
||||
@@ -43,7 +57,7 @@ pip install -r requirements.txt
|
||||
|
||||
### 3. Configure Environment Variables
|
||||
|
||||
Copy the `.env.example` file to `.env` and modify the settings according to your environment:
|
||||
Copy the `env.example` file to `.env` and modify the settings according to your environment:
|
||||
|
||||
```bash
|
||||
cp env.example .env
|
||||
@@ -52,74 +66,82 @@ cp env.example .env
|
||||
**Key Environment Variables:**
|
||||
|
||||
* **Database Connection**:
|
||||
* `DB_HOST`: Database hostname
|
||||
* `DB_PORT`: Database port (default 9030)
|
||||
* `DB_USER`: Database username
|
||||
* `DB_PASSWORD`: Database password
|
||||
* `DB_DATABASE`: Default database name
|
||||
* **Server Configuration**:
|
||||
* `SERVER_HOST`: Host address the server listens on (default `0.0.0.0`)
|
||||
* `SERVER_PORT`: Port the server listens on (default `3000`)
|
||||
* `ALLOWED_ORIGINS`: CORS allowed origins (comma-separated, `*` allows all)
|
||||
* `MCP_ALLOW_CREDENTIALS`: Whether to allow CORS credentials (default `false`)
|
||||
* `DORIS_HOST`: Database hostname (default: localhost)
|
||||
* `DORIS_PORT`: Database port (default: 9030)
|
||||
* `DORIS_USER`: Database username (default: root)
|
||||
* `DORIS_PASSWORD`: Database password
|
||||
* `DORIS_DATABASE`: Default database name (default: test)
|
||||
* `DORIS_MIN_CONNECTIONS`: Minimum connection pool size (default: 5)
|
||||
* `DORIS_MAX_CONNECTIONS`: Maximum connection pool size (default: 20)
|
||||
* **Security Configuration**:
|
||||
* `AUTH_TYPE`: Authentication type (token/basic/oauth, default: token)
|
||||
* `TOKEN_SECRET`: Token secret key
|
||||
* `ENABLE_MASKING`: Enable data masking (default: true)
|
||||
* `MAX_RESULT_ROWS`: Maximum result rows (default: 10000)
|
||||
* **Performance Configuration**:
|
||||
* `ENABLE_QUERY_CACHE`: Enable query caching (default: true)
|
||||
* `CACHE_TTL`: Cache time-to-live in seconds (default: 300)
|
||||
* `MAX_CONCURRENT_QUERIES`: Maximum concurrent queries (default: 50)
|
||||
* **Logging Configuration**:
|
||||
* `LOG_DIR`: Directory for log files (default `./logs`)
|
||||
* `LOG_LEVEL`: Log level (e.g., `INFO`, `DEBUG`, `WARNING`, `ERROR`, default `INFO`)
|
||||
* `CONSOLE_LOGGING`: Whether to output logs to the console (default `false`)
|
||||
* `LOG_LEVEL`: Log level (DEBUG/INFO/WARNING/ERROR, default: INFO)
|
||||
* `LOG_FILE_PATH`: Log file path
|
||||
* `ENABLE_AUDIT`: Enable audit logging (default: true)
|
||||
|
||||
### Available MCP Tools
|
||||
|
||||
The following table lists the main tools currently available for invocation via an MCP client:
|
||||
|
||||
| Tool Name | Description | Parameters | Status |
|
||||
| :-------------------------------- | :---------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------- | :------- |
|
||||
| `mcp_doris_get_catalog_list` | Get a list of all catalogs with detailed information. | `random_string` (string, Required) | ✅ Active |
|
||||
| `mcp_doris_get_db_list` | Get a list of all database names in the specified catalog. | `random_string` (string, Required), `catalog_name` (string, Optional, defaults to internal catalog) | ✅ Active |
|
||||
| `mcp_doris_get_db_table_list` | Get a list of all table names in the specified database. | `random_string` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional) | ✅ Active |
|
||||
| `mcp_doris_get_table_schema` | Get detailed structure of the specified table. | `random_string` (string, Required), `table_name` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional) | ✅ Active |
|
||||
| `mcp_doris_get_table_comment` | Get the comment for the specified table. | `random_string` (string, Required), `table_name` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional) | ✅ Active |
|
||||
| `mcp_doris_get_table_column_comments` | Get comments for all columns in the specified table. | `random_string` (string, Required), `table_name` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional) | ✅ Active |
|
||||
| `mcp_doris_get_table_indexes` | Get index information for the specified table. | `random_string` (string, Required), `table_name` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional) | ✅ Active |
|
||||
| `mcp_doris_exec_query` | Execute SQL query with catalog federation support. | `random_string` (string, Required), `sql` (string, Required - MUST use three-part naming), `db_name` (string, Optional), `catalog_name` (string, Optional), `max_rows` (integer, Optional, default 100), `timeout` (integer, Optional, default 30) | ✅ Active |
|
||||
| `mcp_doris_get_recent_audit_logs` | Get audit log records for a recent period. | `random_string` (string, Required), `days` (integer, Optional, default 7), `limit` (integer, Optional, default 100) | ✅ Active |
|
||||
| Tool Name | Description | Parameters | Status |
|
||||
|:----------------------------| :---------------------------------------------------------- | :--------------------------------------------------------------------------------------------------------- | :------- |
|
||||
| `exec_query` | Execute SQL query with catalog federation support. | `sql` (string, Required - MUST use three-part naming), `db_name` (string, Optional), `catalog_name` (string, Optional), `max_rows` (integer, Optional, default 100), `timeout` (integer, Optional, default 30) | ✅ Active |
|
||||
| `get_catalog_list` | Get a list of all catalogs with detailed information. | `random_string` (string, Required) | ✅ Active |
|
||||
| `get_db_list` | Get a list of all database names in the specified catalog. | `catalog_name` (string, Optional, defaults to internal catalog) | ✅ Active |
|
||||
| `get_db_table_list` | Get a list of all table names in the specified database. | `db_name` (string, Optional), `catalog_name` (string, Optional) | ✅ Active |
|
||||
| `get_table_schema` | Get detailed structure of the specified table. | `table_name` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional) | ✅ Active |
|
||||
| `get_table_comment` | Get the comment for the specified table. | `table_name` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional) | ✅ Active |
|
||||
| `get_table_column_comments` | Get comments for all columns in the specified table. | `table_name` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional) | ✅ Active |
|
||||
| `get_table_indexes` | Get index information for the specified table. | `table_name` (string, Required), `db_name` (string, Optional), `catalog_name` (string, Optional) | ✅ Active |
|
||||
| `get_recent_audit_logs` | Get audit log records for a recent period. | `days` (integer, Optional, default 7), `limit` (integer, Optional, default 100) | ✅ Active |
|
||||
| `column_analysis` | Analyze statistical information and data distribution. | `table_name` (string, Required), `column_name` (string, Required), `analysis_type` (string, Optional: basic/distribution/detailed) | ⚠️ Experimental |
|
||||
| `performance_stats` | Get database performance statistics information. | `metric_type` (string, Optional: queries/connections/tables/system), `time_range` (string, Optional: 1h/6h/24h/7d) | ⚠️ Experimental |
|
||||
|
||||
**Note:** All tools require a `random_string` parameter as a call identifier, typically handled automatically by the MCP client. "Optional" and "Required" refer to the tool's internal logic; the client might need to provide values for all parameters depending on its implementation. The tool names listed here are the base names; clients might see them prefixed (e.g., `mcp_doris_stdio3_get_db_list`) depending on the connection mode.
|
||||
**Note:** All metadata tools support catalog federation for multi-catalog environments. The `get_catalog_list` tool requires a `random_string` parameter for compatibility reasons.
|
||||
|
||||
### 4. Run the Service
|
||||
|
||||
If you use SSE mode, execute the following command:
|
||||
Execute the following command to start the server:
|
||||
|
||||
```bash
|
||||
./start_server.sh
|
||||
```
|
||||
|
||||
This command starts the FastAPI application, providing both SSE and Streamable HTTP MCP services by default.
|
||||
This command starts the FastAPI application with Streamable HTTP MCP service.
|
||||
|
||||
**Service Endpoints:**
|
||||
**Service Endpoints (v0.3.0+):**
|
||||
|
||||
* **SSE Initialization**: `http://<host>:<port>/sse`
|
||||
* **SSE Communication**: `http://<host>:<port>/mcp/messages` (POST)
|
||||
* **Streamable HTTP**: `http://<host>:<port>/mcp` (Supports GET, POST, DELETE, OPTIONS)
|
||||
* **Streamable HTTP**: `http://<host>:<port>/mcp` (Primary MCP endpoint - supports GET, POST, DELETE, OPTIONS)
|
||||
* **Health Check**: `http://<host>:<port>/health`
|
||||
* **(Potential) Status Check**: `http://<host>:<port>/status` (Confirm if implemented in `main.py`)
|
||||
* **Status Check**: `http://<host>:<port>/status`
|
||||
|
||||
> **Note**: Starting from v0.3.0, only Streamable HTTP mode is supported for web-based communication. SSE endpoints have been removed.
|
||||
|
||||
## Usage
|
||||
|
||||
Interaction with the Doris MCP Server requires an **MCP Client**. The client connects to the server's SSE or Streamable HTTP endpoints and sends requests (like `tool_call`) according to the MCP specification to invoke the server's tools.
|
||||
Interaction with the Doris MCP Server requires an **MCP Client**. The client connects to the server's Streamable HTTP endpoint and sends requests according to the MCP specification to invoke the server's tools.
|
||||
|
||||
**Main Interaction Flow:**
|
||||
**Main Interaction Flow (v0.3.0+):**
|
||||
|
||||
1. **Client Initialization**: Connect to `/sse` (SSE) or send an `initialize` method call to `/mcp` (Streamable).
|
||||
2. **(Optional) Discover Tools**: The client can call `mcp/listTools` or `mcp/listOfferings` to get the list of supported tools, their descriptions, and parameter schemas.
|
||||
3. **Call Tool**: The client sends a `tool_call` message/request, specifying the `tool_name` and `arguments`.
|
||||
1. **Client Initialization**: Send an `initialize` method call to `/mcp` (Streamable HTTP).
|
||||
2. **(Optional) Discover Tools**: The client can call `tools/list` to get the list of supported tools, their descriptions, and parameter schemas.
|
||||
3. **Call Tool**: The client sends a `tools/call` request, specifying the `name` and `arguments`.
|
||||
* **Example: Get Table Schema**
|
||||
* `tool_name`: `mcp_doris_get_table_schema` (or the mode-specific name)
|
||||
* `arguments`: Include `random_string`, `table_name`, `db_name`, `catalog_name`.
|
||||
* `name`: `get_table_schema`
|
||||
* `arguments`: Include `table_name`, `db_name`, `catalog_name`.
|
||||
4. **Handle Response**:
|
||||
* **Non-streaming**: The client receives a response containing `result` or `error`.
|
||||
* **Streaming**: The client receives a series of `tools/progress` notifications, followed by a final response containing the `result` or `error`.
|
||||
* **Non-streaming**: The client receives a response containing `content` or `isError`.
|
||||
* **Streaming**: The client receives a series of progress notifications, followed by a final response.
|
||||
|
||||
Specific tool names and parameters should be referenced from the `src/tools/` code or obtained via MCP discovery mechanisms.
|
||||
> **Migration Note**: If you're upgrading from v0.2.x, note that tool names have been simplified (removed `mcp_doris_` prefix) and the communication protocol has been updated to use Streamable HTTP exclusively.
|
||||
|
||||
### Catalog Federation Support
|
||||
|
||||
@@ -189,20 +211,267 @@ The Doris MCP Server supports **catalog federation**, enabling interaction with
|
||||
}
|
||||
```
|
||||
|
||||
## Security Configuration (v0.3.0+)
|
||||
|
||||
The Doris MCP Server includes a comprehensive security framework that provides enterprise-level protection through authentication, authorization, SQL security validation, and data masking capabilities.
|
||||
|
||||
### Security Features
|
||||
|
||||
* **🔐 Authentication**: Support for token-based and basic authentication
|
||||
* **🛡️ Authorization**: Role-based access control (RBAC) with security levels
|
||||
* **🚫 SQL Security**: SQL injection protection and blocked operations
|
||||
* **🎭 Data Masking**: Automatic sensitive data masking based on user permissions
|
||||
* **📊 Security Levels**: Four-tier security classification (Public, Internal, Confidential, Secret)
|
||||
|
||||
### Authentication Configuration
|
||||
|
||||
Configure authentication in your environment variables:
|
||||
|
||||
```bash
|
||||
# Authentication Type (token/basic/oauth)
|
||||
AUTH_TYPE=token
|
||||
|
||||
# Token Secret for JWT validation
|
||||
TOKEN_SECRET=your_secret_key_here
|
||||
|
||||
# Session timeout (in seconds)
|
||||
SESSION_TIMEOUT=3600
|
||||
```
|
||||
|
||||
#### Token Authentication Example
|
||||
|
||||
```python
|
||||
# Client authentication with token
|
||||
auth_info = {
|
||||
"type": "token",
|
||||
"token": "your_jwt_token",
|
||||
"session_id": "unique_session_id"
|
||||
}
|
||||
```
|
||||
|
||||
#### Basic Authentication Example
|
||||
|
||||
```python
|
||||
# Client authentication with username/password
|
||||
auth_info = {
|
||||
"type": "basic",
|
||||
"username": "analyst",
|
||||
"password": "secure_password",
|
||||
"session_id": "unique_session_id"
|
||||
}
|
||||
```
|
||||
|
||||
### Authorization & Security Levels
|
||||
|
||||
The system supports four security levels with hierarchical access control:
|
||||
|
||||
| Security Level | Access Scope | Typical Use Cases |
|
||||
|:---------------|:-------------|:------------------|
|
||||
| **Public** | Unrestricted access | Public reports, general statistics |
|
||||
| **Internal** | Company employees | Internal dashboards, business metrics |
|
||||
| **Confidential** | Authorized personnel | Customer data, financial reports |
|
||||
| **Secret** | Senior management | Strategic data, sensitive analytics |
|
||||
|
||||
#### Role Configuration
|
||||
|
||||
Configure user roles and permissions:
|
||||
|
||||
```python
|
||||
# Example role configuration
|
||||
role_permissions = {
|
||||
"data_analyst": {
|
||||
"security_level": "internal",
|
||||
"permissions": ["read_data", "execute_query"],
|
||||
"allowed_tables": ["sales", "products", "orders"]
|
||||
},
|
||||
"data_admin": {
|
||||
"security_level": "confidential",
|
||||
"permissions": ["read_data", "execute_query", "admin"],
|
||||
"allowed_tables": ["*"]
|
||||
},
|
||||
"executive": {
|
||||
"security_level": "secret",
|
||||
"permissions": ["read_data", "execute_query", "admin"],
|
||||
"allowed_tables": ["*"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### SQL Security Validation
|
||||
|
||||
The system automatically validates SQL queries for security risks:
|
||||
|
||||
#### Blocked Operations
|
||||
|
||||
Configure blocked SQL operations:
|
||||
|
||||
```bash
|
||||
# Environment variable
|
||||
BLOCKED_SQL_OPERATIONS=DROP,DELETE,TRUNCATE,ALTER,CREATE,INSERT,UPDATE,GRANT,REVOKE
|
||||
|
||||
# Maximum query complexity score
|
||||
MAX_QUERY_COMPLEXITY=100
|
||||
```
|
||||
|
||||
#### SQL Injection Protection
|
||||
|
||||
The system automatically detects and blocks:
|
||||
|
||||
* **Union-based injections**: `UNION SELECT` attacks
|
||||
* **Boolean-based injections**: `OR 1=1` patterns
|
||||
* **Time-based injections**: `SLEEP()`, `WAITFOR` functions
|
||||
* **Comment injections**: `--`, `/**/` patterns
|
||||
* **Stacked queries**: Multiple statements separated by `;`
|
||||
|
||||
#### Example Security Validation
|
||||
|
||||
```python
|
||||
# This query would be blocked
|
||||
dangerous_sql = "SELECT * FROM users WHERE id = 1; DROP TABLE users;"
|
||||
|
||||
# This query would be allowed
|
||||
safe_sql = "SELECT name, email FROM users WHERE department = 'sales'"
|
||||
```
|
||||
|
||||
### Data Masking Configuration
|
||||
|
||||
Configure automatic data masking for sensitive information:
|
||||
|
||||
#### Built-in Masking Rules
|
||||
|
||||
```python
|
||||
# Default masking rules
|
||||
masking_rules = [
|
||||
{
|
||||
"column_pattern": r".*phone.*|.*mobile.*",
|
||||
"algorithm": "phone_mask",
|
||||
"parameters": {
|
||||
"mask_char": "*",
|
||||
"keep_prefix": 3,
|
||||
"keep_suffix": 4
|
||||
},
|
||||
"security_level": "internal"
|
||||
},
|
||||
{
|
||||
"column_pattern": r".*email.*",
|
||||
"algorithm": "email_mask",
|
||||
"parameters": {"mask_char": "*"},
|
||||
"security_level": "internal"
|
||||
},
|
||||
{
|
||||
"column_pattern": r".*id_card.*|.*identity.*",
|
||||
"algorithm": "id_mask",
|
||||
"parameters": {
|
||||
"mask_char": "*",
|
||||
"keep_prefix": 6,
|
||||
"keep_suffix": 4
|
||||
},
|
||||
"security_level": "confidential"
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
#### Masking Algorithms
|
||||
|
||||
| Algorithm | Description | Example |
|
||||
|:----------|:------------|:--------|
|
||||
| `phone_mask` | Masks phone numbers | `138****5678` |
|
||||
| `email_mask` | Masks email addresses | `j***n@example.com` |
|
||||
| `id_mask` | Masks ID card numbers | `110101****1234` |
|
||||
| `name_mask` | Masks personal names | `张*明` |
|
||||
| `partial_mask` | Partial masking with ratio | `abc***xyz` |
|
||||
|
||||
#### Custom Masking Rules
|
||||
|
||||
Add custom masking rules in your configuration:
|
||||
|
||||
```python
|
||||
# Custom masking rule
|
||||
custom_rule = {
|
||||
"column_pattern": r".*salary.*|.*income.*",
|
||||
"algorithm": "partial_mask",
|
||||
"parameters": {
|
||||
"mask_char": "*",
|
||||
"mask_ratio": 0.6
|
||||
},
|
||||
"security_level": "confidential"
|
||||
}
|
||||
```
|
||||
|
||||
### Security Configuration Examples
|
||||
|
||||
#### Environment Variables
|
||||
|
||||
```bash
|
||||
# .env file
|
||||
AUTH_TYPE=token
|
||||
TOKEN_SECRET=your_jwt_secret_key
|
||||
ENABLE_MASKING=true
|
||||
MAX_RESULT_ROWS=10000
|
||||
BLOCKED_SQL_OPERATIONS=DROP,DELETE,TRUNCATE,ALTER
|
||||
MAX_QUERY_COMPLEXITY=100
|
||||
ENABLE_AUDIT=true
|
||||
```
|
||||
|
||||
#### Sensitive Tables Configuration
|
||||
|
||||
```python
|
||||
# Configure sensitive tables with security levels
|
||||
sensitive_tables = {
|
||||
"user_profiles": "confidential",
|
||||
"payment_records": "secret",
|
||||
"employee_salaries": "secret",
|
||||
"customer_data": "confidential",
|
||||
"public_reports": "public"
|
||||
}
|
||||
```
|
||||
|
||||
### Security Best Practices
|
||||
|
||||
1. **🔑 Strong Authentication**: Use JWT tokens with proper expiration
|
||||
2. **🎯 Principle of Least Privilege**: Grant minimum required permissions
|
||||
3. **🔍 Regular Auditing**: Enable audit logging for security monitoring
|
||||
4. **🛡️ Input Validation**: All SQL queries are automatically validated
|
||||
5. **🎭 Data Classification**: Properly classify data with security levels
|
||||
6. **🔄 Regular Updates**: Keep security rules and configurations updated
|
||||
|
||||
### Security Monitoring
|
||||
|
||||
The system provides comprehensive security monitoring:
|
||||
|
||||
```python
|
||||
# Security audit log example
|
||||
{
|
||||
"timestamp": "2024-01-15T10:30:00Z",
|
||||
"user_id": "analyst_user",
|
||||
"action": "query_execution",
|
||||
"resource": "customer_data",
|
||||
"result": "blocked",
|
||||
"reason": "insufficient_permissions",
|
||||
"risk_level": "medium"
|
||||
}
|
||||
```
|
||||
|
||||
> **⚠️ Important**: Always test security configurations in a development environment before deploying to production. Regularly review and update security policies based on your organization's requirements.
|
||||
|
||||
## Connecting with Cursor
|
||||
|
||||
You can connect Cursor to this MCP server using either Stdio or SSE mode.
|
||||
You can connect Cursor to this MCP server using Stdio mode (recommended) or Streamable HTTP mode.
|
||||
|
||||
### Stdio Mode
|
||||
|
||||
Stdio mode allows Cursor to manage the server process directly. Configuration is done within Cursor's MCP Server settings file (typically `~/.cursor/mcp.json` or similar).
|
||||
|
||||
If you use stdio mode, please execute the following command to download and build the environment dependency package, **but please note that you need to change the project path to the correct path address**:
|
||||
### Using uv (Recommended)
|
||||
|
||||
If you have `uv` installed, you can run the server directly:
|
||||
|
||||
```bash
|
||||
uv --project /your/path/doris-mcp-server run doris-mcp
|
||||
uv run --project /path/to/doris-mcp-server doris-mcp-server
|
||||
```
|
||||
|
||||
**Note:** Replace `/path/to/doris-mcp-server` with the actual absolute path to your project directory.
|
||||
|
||||
1. **Configure Cursor:** Add an entry like the following to your Cursor MCP configuration:
|
||||
|
||||
```json
|
||||
@@ -210,189 +479,205 @@ uv --project /your/path/doris-mcp-server run doris-mcp
|
||||
"mcpServers": {
|
||||
"doris-stdio": {
|
||||
"command": "uv",
|
||||
"args": ["--project", "/path/to/your/doris-mcp-server", "run", "doris-mcp"],
|
||||
"args": ["run", "--project", "/path/to/your/doris-mcp-server", "doris-mcp-server"],
|
||||
"env": {
|
||||
"DB_HOST": "127.0.0.1",
|
||||
"DB_PORT": "9030",
|
||||
"DB_USER": "root",
|
||||
"DB_PASSWORD": "your_db_password",
|
||||
"DB_DATABASE": "your_default_db"
|
||||
"DORIS_HOST": "127.0.0.1",
|
||||
"DORIS_PORT": "9030",
|
||||
"DORIS_USER": "root",
|
||||
"DORIS_PASSWORD": "your_db_password",
|
||||
"DORIS_DATABASE": "your_default_db",
|
||||
"LOG_LEVEL": "INFO"
|
||||
}
|
||||
},
|
||||
// ... other server configurations ...
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
2. **Key Points:**
|
||||
* Replace `/path/to/your/doris-mcp` with the actual absolute path to the project's root directory on your system. The `--project` argument is crucial for `uv` to find the `pyproject.toml` and run the correct command.
|
||||
* The `command` is set to `uv` (assuming you use `uv` for package management as indicated by `uv.lock`). The `args` include `--project`, the path, `run`, and `mcp-doris` (which should correspond to a script defined in your `pyproject.toml`).
|
||||
* Database connection details (`DB_HOST`, `DB_PORT`, `DB_USER`, `DB_PASSWORD`, `DB_DATABASE`) are set directly in the `env` block within the configuration file. Cursor will pass these to the server process. No `.env` file is needed for this mode when configured via Cursor.
|
||||
* Replace `/path/to/your/doris-mcp-server` with the actual absolute path to the project's root directory on your system.
|
||||
* The `--project` argument is crucial for `uv` to find the `pyproject.toml` and run the correct command.
|
||||
* Database connection details are set directly in the `env` block. Cursor will pass these to the server process.
|
||||
* No `.env` file is needed for this mode when configured via Cursor.
|
||||
|
||||
### SSE Mode
|
||||
### Streamable HTTP Mode (v0.3.0+)
|
||||
|
||||
SSE mode requires you to run the MCP server independently first, and then tell Cursor how to connect to it.
|
||||
Streamable HTTP mode requires you to run the MCP server independently first, and then configure Cursor to connect to it.
|
||||
|
||||
1. **Configure `.env`:** Ensure your database credentials and any other necessary settings (like `SERVER_PORT` if not using the default 3000) are correctly configured in the `.env` file within the project directory.
|
||||
1. **Configure `.env`:** Ensure your database credentials and any other necessary settings are correctly configured in the `.env` file within the project directory.
|
||||
2. **Start the Server:** Run the server from your terminal in the project's root directory:
|
||||
```bash
|
||||
./start_server.sh
|
||||
```
|
||||
This script typically reads the `.env` file and starts the FastAPI server in SSE mode (check the script and `sse_server.py` / `main.py` for specifics). Note the host and port the server is listening on (default is `0.0.0.0:3000`).
|
||||
3. **Configure Cursor:** Add an entry like the following to your Cursor MCP configuration, pointing to the running server's SSE endpoint:
|
||||
This script reads the `.env` file and starts the FastAPI server with Streamable HTTP support. Note the host and port the server is listening on (default is `0.0.0.0:3000`).
|
||||
3. **Configure Cursor:** Add an entry like the following to your Cursor MCP configuration, pointing to the running server's Streamable HTTP endpoint:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"doris-sse": {
|
||||
"url": "http://127.0.0.1:3000/sse" // Adjust host/port if your server runs elsewhere
|
||||
},
|
||||
// ... other server configurations ...
|
||||
"doris-http": {
|
||||
"url": "http://127.0.0.1:3000/mcp"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
*Note: The example uses the default port `3000`. If your server is configured to run on a different port (like `3010` in the user example), adjust the URL accordingly.*
|
||||
|
||||
After configuring either mode in Cursor, you should be able to select the server (e.g., `doris-stdio` or `doris-sse`) and use its tools.
|
||||
> **Note**: Adjust the host/port if your server runs on a different address. The `/mcp` endpoint is the unified Streamable HTTP interface introduced in v0.3.0.
|
||||
|
||||
After configuring either mode in Cursor, you should be able to select the server (e.g., `doris-stdio` or `doris-http`) and use its tools.
|
||||
|
||||
> **⚠️ Migration from v0.2.x**: If you were using SSE mode (`/sse` endpoint), update your configuration to use the new Streamable HTTP endpoint (`/mcp`).
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
doris-mcp-server/
|
||||
├── doris_mcp_server/ # Source code for the MCP server
|
||||
│ ├── main.py # Main entry point, FastAPI app definition
|
||||
│ ├── mcp_core.py # Core MCP tool registration and Stdio handling
|
||||
│ ├── sse_server.py # SSE server implementation
|
||||
│ ├── streamable_server.py # Streamable HTTP server implementation
|
||||
│ ├── config.py # Configuration loading
|
||||
│ ├── tools/ # MCP tool definitions
|
||||
│ │ ├── mcp_doris_tools.py # Main Doris-related MCP tools
|
||||
│ │ ├── tool_initializer.py # Tool registration helper (used by mcp_core.py)
|
||||
├── doris_mcp_server/ # Main server package
|
||||
│ ├── main.py # Main entry point and FastAPI app
|
||||
│ ├── tools/ # MCP tools implementation
|
||||
│ │ ├── tools_manager.py # Centralized tools management and registration
|
||||
│ │ ├── resources_manager.py # Resource management and metadata exposure
|
||||
│ │ ├── prompts_manager.py # Intelligent prompt templates for data analysis
|
||||
│ │ └── __init__.py
|
||||
│ ├── utils/ # Utility classes and helper functions
|
||||
│ │ ├── db.py # Database connection and operations
|
||||
│ │ ├── logger.py # Logging configuration
|
||||
│ │ ├── schema_extractor.py # Doris metadata/schema extraction logic
|
||||
│ │ ├── sql_executor_tools.py # SQL execution helper (might be legacy)
|
||||
│ ├── utils/ # Core utility modules
|
||||
│ │ ├── config.py # Configuration management with validation
|
||||
│ │ ├── db.py # Database connection management with pooling
|
||||
│ │ ├── query_executor.py # High-performance SQL execution with caching
|
||||
│ │ ├── security.py # Security management and data masking
|
||||
│ │ ├── schema_extractor.py # Metadata extraction with catalog federation
|
||||
│ │ ├── analysis_tools.py # Data analysis and performance monitoring
|
||||
│ │ ├── logger.py # Logging configuration
|
||||
│ │ └── __init__.py
|
||||
│ └── __init__.py
|
||||
├── logs/ # Log file directory (if file logging enabled)
|
||||
├── README.md # This file
|
||||
├── .env.example # Example environment variable file
|
||||
├── requirements.txt # Python dependencies for pip
|
||||
├── pyproject.toml # Project metadata and build system configuration (PEP 518)
|
||||
├── uv.lock # Lock file for 'uv' package manager (alternative to pip)
|
||||
├── start_server.sh # Script to start the server
|
||||
└── restart_server.sh # Script to restart the server
|
||||
├── doris_mcp_client/ # MCP client implementation
|
||||
│ ├── client.py # Unified MCP client for testing and integration
|
||||
│ ├── README.md # Client documentation
|
||||
│ └── __init__.py
|
||||
├── logs/ # Log files directory
|
||||
├── README.md # This documentation
|
||||
├── .env.example # Environment variables template
|
||||
├── requirements.txt # Python dependencies
|
||||
├── pyproject.toml # Project configuration and entry points
|
||||
├── uv.lock # UV package manager lock file
|
||||
├── generate_requirements.py # Requirements generation script
|
||||
├── start_server.sh # Server startup script
|
||||
└── restart_server.sh # Server restart script
|
||||
```
|
||||
|
||||
## Developing New Tools
|
||||
|
||||
This section outlines the process for adding new MCP tools to the Doris MCP Server, considering the current project structure.
|
||||
This section outlines the process for adding new MCP tools to the Doris MCP Server, based on the current modular architecture.
|
||||
|
||||
### 1. Leverage Utility Modules
|
||||
### 1. Leverage Existing Utility Modules
|
||||
|
||||
Before writing new database interaction logic from scratch, check the existing utility modules:
|
||||
The server provides comprehensive utility modules for common database operations:
|
||||
|
||||
* **`doris_mcp_server/utils/db.py`**: Provides basic functions for getting database connections (`get_db_connection`) and executing raw queries (`execute_query`, `execute_query_df`).
|
||||
* **`doris_mcp_server/utils/schema_extractor.py` (`MetadataExtractor` class)**: Offers high-level methods to retrieve database metadata with catalog federation support, such as listing databases/tables (`get_all_databases`, `get_database_tables`), getting table schemas/comments/indexes (`get_table_schema`, `get_table_comment`, `get_column_comments`, `get_table_indexes`), and accessing audit logs (`get_recent_audit_logs`). All methods support optional `catalog_name` parameters for multi-catalog environments. It includes caching mechanisms.
|
||||
* **`doris_mcp_server/utils/sql_executor_tools.py` (`execute_sql_query` function)**: Provides a wrapper around `db.execute_query` that includes security checks (optional, controlled by `ENABLE_SQL_SECURITY_CHECK` env var), adds automatic `LIMIT` to SELECT queries, handles result serialization (dates, decimals), and formats the output into the standard MCP success/error structure. **It's recommended to use this for executing user-provided or generated SQL.**
|
||||
|
||||
You can import and combine functionalities from these modules to build your new tool.
|
||||
* **`doris_mcp_server/utils/db.py`**: Database connection management with connection pooling and health monitoring.
|
||||
* **`doris_mcp_server/utils/query_executor.py`**: High-performance SQL execution with caching, optimization, and performance monitoring.
|
||||
* **`doris_mcp_server/utils/schema_extractor.py`**: Metadata extraction with full catalog federation support.
|
||||
* **`doris_mcp_server/utils/security.py`**: Security management, SQL validation, and data masking.
|
||||
* **`doris_mcp_server/utils/analysis_tools.py`**: Data analysis and statistical tools.
|
||||
* **`doris_mcp_server/utils/config.py`**: Configuration management with validation.
|
||||
|
||||
### 2. Implement Tool Logic
|
||||
|
||||
Implement the core logic for your new tool as an `async` function within `doris_mcp_server/tools/mcp_doris_tools.py`. This keeps the primary tool implementations centralized. Ensure your function returns data in a format that can be easily wrapped into the standard MCP response structure (see `_format_response` in the same file for reference).
|
||||
Add your new tool to the `DorisToolsManager` class in `doris_mcp_server/tools/tools_manager.py`. The tools manager provides a centralized approach to tool registration and execution.
|
||||
|
||||
**Example:** Let's create a simple tool `get_server_time`.
|
||||
**Example:** Adding a new analysis tool:
|
||||
|
||||
```python
|
||||
# In doris_mcp_server/tools/mcp_doris_tools.py
|
||||
import datetime
|
||||
# ... other imports ...
|
||||
from doris_mcp_server.tools.mcp_doris_tools import _format_response # Reuse formatter
|
||||
# In doris_mcp_server/tools/tools_manager.py
|
||||
|
||||
# ... existing tools ...
|
||||
async def your_new_analysis_tool(self, arguments: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Your new analysis tool implementation
|
||||
|
||||
async def mcp_doris_get_server_time() -> Dict[str, Any]:
|
||||
"""Gets the current server time."""
|
||||
logger.info(f"MCP Tool Call: mcp_doris_get_server_time")
|
||||
Args:
|
||||
arguments: Tool arguments from MCP client
|
||||
|
||||
Returns:
|
||||
List of MCP response messages
|
||||
"""
|
||||
try:
|
||||
current_time = datetime.datetime.now().isoformat()
|
||||
# Use the existing formatter for consistency
|
||||
return _format_response(success=True, result={"server_time": current_time})
|
||||
# Use existing utilities
|
||||
result = await self.query_executor.execute_sql_for_mcp(
|
||||
sql="SELECT COUNT(*) FROM your_table",
|
||||
max_rows=arguments.get("max_rows", 100)
|
||||
)
|
||||
|
||||
return [{
|
||||
"type": "text",
|
||||
"text": json.dumps(result, ensure_ascii=False, indent=2)
|
||||
}]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool execution failed mcp_doris_get_server_time: {str(e)}", exc_info=True)
|
||||
return _format_response(success=False, error=str(e), message="Error getting server time")
|
||||
|
||||
logger.error(f"Tool execution failed: {str(e)}", exc_info=True)
|
||||
return [{
|
||||
"type": "text",
|
||||
"text": f"Error: {str(e)}"
|
||||
}]
|
||||
```
|
||||
|
||||
### 3. Register the Tool (Dual Registration)
|
||||
### 3. Register the Tool
|
||||
|
||||
Due to the separate handling of SSE/Streamable and Stdio modes, you need to register the tool in two places:
|
||||
|
||||
**A. SSE/Streamable Registration (`tool_initializer.py`)**
|
||||
|
||||
* Import your new tool function from `mcp_doris_tools.py`.
|
||||
* Inside the `register_mcp_tools` function, add a new wrapper function decorated with `@mcp.tool()`.
|
||||
* The wrapper function should call your core tool function.
|
||||
* Define the tool name and provide a detailed description (including parameters if any) in the decorator. Remember to include the mandatory `random_string` parameter description for client compatibility, even if your wrapper doesn't explicitly use it.
|
||||
|
||||
**Example (`tool_initializer.py`):**
|
||||
Add your tool to the `_register_tools` method in the same class:
|
||||
|
||||
```python
|
||||
# In doris_mcp_server/tools/tool_initializer.py
|
||||
# ... other imports ...
|
||||
from doris_mcp_server.tools.mcp_doris_tools import (
|
||||
# ... existing tool imports ...
|
||||
mcp_doris_get_server_time # <-- Import the new tool
|
||||
# In the _register_tools method of DorisToolsManager
|
||||
|
||||
@self.mcp.tool(
|
||||
name="your_new_analysis_tool",
|
||||
description="Description of your new analysis tool",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"parameter1": {
|
||||
"type": "string",
|
||||
"description": "Description of parameter1"
|
||||
},
|
||||
"parameter2": {
|
||||
"type": "integer",
|
||||
"description": "Description of parameter2",
|
||||
"default": 100
|
||||
}
|
||||
},
|
||||
"required": ["parameter1"]
|
||||
}
|
||||
)
|
||||
|
||||
async def register_mcp_tools(mcp):
|
||||
# ... existing tool registrations ...
|
||||
|
||||
# Register Tool: Get Server Time
|
||||
@mcp.tool("get_server_time", description="""[Function Description]: Get the current time of the MCP server.\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n""")
|
||||
async def get_server_time_tool() -> Dict[str, Any]:
|
||||
"""Wrapper: Get server time"""
|
||||
# Note: No parameters needed for the core function call here
|
||||
return await mcp_doris_get_server_time()
|
||||
|
||||
# ... logging registration count ...
|
||||
async def your_new_analysis_tool_wrapper(arguments: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
return await self.your_new_analysis_tool(arguments)
|
||||
```
|
||||
|
||||
**B. Stdio Registration (`mcp_core.py`)**
|
||||
### 4. Advanced Features
|
||||
|
||||
* Similar to SSE, add a new wrapper function decorated with `@stdio_mcp.tool()`.
|
||||
* **Important:** Import your core tool function (`mcp_doris_get_server_time`) *inside* the wrapper function (delayed import pattern used in this file).
|
||||
* The wrapper calls the core tool function. The wrapper itself *might* need to be `async def` depending on how `FastMCP` handles tools in Stdio mode, even if the underlying function is simple (as seen in the current file structure). Ensure the call matches (e.g., use `await` if calling an async function).
|
||||
For more complex tools, you can leverage:
|
||||
|
||||
**Example (`mcp_core.py`):**
|
||||
* **Caching**: Use the query executor's built-in caching for performance
|
||||
* **Security**: Apply SQL validation and data masking through the security manager
|
||||
* **Prompts**: Use the prompts manager for intelligent query generation
|
||||
* **Resources**: Expose metadata through the resources manager
|
||||
|
||||
### 5. Testing
|
||||
|
||||
Test your new tool using the included MCP client:
|
||||
|
||||
```python
|
||||
# In doris_mcp_server/mcp_core.py
|
||||
# ... other imports and setup ...
|
||||
# Using doris_mcp_client/client.py
|
||||
from doris_mcp_client.client import DorisUnifiedMCPClient
|
||||
|
||||
# ... existing Stdio tool registrations ...
|
||||
|
||||
# Register Tool: Get Server Time (for Stdio)
|
||||
@stdio_mcp.tool("get_server_time", description="""[Function Description]: Get the current time of the MCP server.\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n""")
|
||||
async def get_server_time_tool_stdio() -> Dict[str, Any]: # Using a slightly different wrapper name for clarity if needed
|
||||
"""Wrapper: Get server time (Stdio)"""
|
||||
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_server_time # <-- Delayed import
|
||||
# Assuming the Stdio runner handles async wrappers correctly
|
||||
return await mcp_doris_get_server_time()
|
||||
|
||||
# --- Register Tools --- (Or wherever the registrations are finalized)
|
||||
async def test_new_tool():
|
||||
client = DorisUnifiedMCPClient()
|
||||
result = await client.call_tool("your_new_analysis_tool", {
|
||||
"parameter1": "test_value",
|
||||
"parameter2": 50
|
||||
})
|
||||
print(result)
|
||||
```
|
||||
|
||||
### 4. Restart and Test
|
||||
## MCP Client
|
||||
|
||||
After implementing and registering the tool in both files, restart the MCP server (both SSE mode via `./start_server.sh` and ensure the Stdio command used by Cursor is updated if necessary) and test the new tool using your MCP client (like Cursor) in both connection modes.
|
||||
The project includes a unified MCP client (`doris_mcp_client/`) for testing and integration purposes. The client supports multiple connection modes and provides a convenient interface for interacting with the MCP server.
|
||||
|
||||
For detailed client documentation, see [`doris_mcp_client/README.md`](doris_mcp_client/README.md).
|
||||
|
||||
## Contributing
|
||||
|
||||
|
||||
202
docker-compose.yml
Normal file
202
docker-compose.yml
Normal file
@@ -0,0 +1,202 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
# Doris MCP Server
|
||||
doris-mcp-server:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
container_name: doris-mcp-server
|
||||
ports:
|
||||
- "3000:3000" # MCP service port
|
||||
- "3001:3001" # Monitoring metrics port
|
||||
- "3002:3002" # Health check port
|
||||
environment:
|
||||
# Database configuration
|
||||
- DORIS_HOST=doris-fe
|
||||
- DORIS_PORT=9030
|
||||
- DORIS_USER=root
|
||||
- DORIS_PASSWORD=doris123
|
||||
- DORIS_DATABASE=test_db
|
||||
|
||||
# Connection pool configuration
|
||||
- DORIS_MIN_CONNECTIONS=5
|
||||
- DORIS_MAX_CONNECTIONS=20
|
||||
|
||||
# Security configuration
|
||||
- AUTH_TYPE=token
|
||||
- TOKEN_SECRET=your_secret_key_here
|
||||
- MAX_RESULT_ROWS=10000
|
||||
|
||||
# Performance configuration
|
||||
- ENABLE_QUERY_CACHE=true
|
||||
- MAX_CONCURRENT_QUERIES=50
|
||||
|
||||
# Logging configuration
|
||||
- LOG_LEVEL=INFO
|
||||
- LOG_FILE_PATH=/app/logs/doris-mcp-server.log
|
||||
|
||||
# Monitoring configuration
|
||||
- ENABLE_METRICS=true
|
||||
- METRICS_PORT=8081
|
||||
volumes:
|
||||
- ./logs:/app/logs
|
||||
- ./config:/app/config
|
||||
depends_on:
|
||||
- doris-fe
|
||||
- doris-be
|
||||
networks:
|
||||
- doris-network
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8082/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 40s
|
||||
|
||||
# Apache Doris Frontend
|
||||
doris-fe:
|
||||
image: apache/doris:2.0.3-fe-x86_64
|
||||
container_name: doris-fe
|
||||
ports:
|
||||
- "8030:8030" # FE HTTP port
|
||||
- "9030:9030" # FE MySQL port
|
||||
environment:
|
||||
- FE_SERVERS=fe1:doris-fe:9010
|
||||
- FE_ID=1
|
||||
volumes:
|
||||
- doris-fe-data:/opt/apache-doris/fe/doris-meta
|
||||
- doris-fe-log:/opt/apache-doris/fe/log
|
||||
- ./doris-config/fe.conf:/opt/apache-doris/fe/conf/fe.conf
|
||||
networks:
|
||||
- doris-network
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8030/api/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 5
|
||||
start_period: 60s
|
||||
|
||||
# Apache Doris Backend
|
||||
doris-be:
|
||||
image: apache/doris:2.0.3-be-x86_64
|
||||
container_name: doris-be
|
||||
ports:
|
||||
- "8040:8040" # BE HTTP port
|
||||
- "9060:9060" # BE heartbeat port
|
||||
environment:
|
||||
- FE_SERVERS=doris-fe:9010
|
||||
- BE_ADDR=doris-be:9050
|
||||
volumes:
|
||||
- doris-be-data:/opt/apache-doris/be/storage
|
||||
- doris-be-log:/opt/apache-doris/be/log
|
||||
- ./doris-config/be.conf:/opt/apache-doris/be/conf/be.conf
|
||||
depends_on:
|
||||
- doris-fe
|
||||
networks:
|
||||
- doris-network
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8040/api/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 5
|
||||
start_period: 60s
|
||||
|
||||
# Redis cache (optional)
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
container_name: doris-redis
|
||||
ports:
|
||||
- "6379:6379"
|
||||
command: redis-server --appendonly yes --requirepass redis123
|
||||
volumes:
|
||||
- redis-data:/data
|
||||
networks:
|
||||
- doris-network
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "--raw", "incr", "ping"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
|
||||
# Prometheus monitoring
|
||||
prometheus:
|
||||
image: prom/prometheus:latest
|
||||
container_name: doris-prometheus
|
||||
ports:
|
||||
- "9090:9090"
|
||||
volumes:
|
||||
- ./monitoring/prometheus.yml:/etc/prometheus/prometheus.yml
|
||||
- prometheus-data:/prometheus
|
||||
command:
|
||||
- '--config.file=/etc/prometheus/prometheus.yml'
|
||||
- '--storage.tsdb.path=/prometheus'
|
||||
- '--web.console.libraries=/etc/prometheus/console_libraries'
|
||||
- '--web.console.templates=/etc/prometheus/consoles'
|
||||
- '--storage.tsdb.retention.time=200h'
|
||||
- '--web.enable-lifecycle'
|
||||
networks:
|
||||
- doris-network
|
||||
restart: unless-stopped
|
||||
|
||||
# Grafana visualization
|
||||
grafana:
|
||||
image: grafana/grafana:latest
|
||||
container_name: doris-grafana
|
||||
ports:
|
||||
- "3000:3000"
|
||||
environment:
|
||||
- GF_SECURITY_ADMIN_PASSWORD=admin123
|
||||
volumes:
|
||||
- grafana-data:/var/lib/grafana
|
||||
- ./monitoring/grafana/dashboards:/etc/grafana/provisioning/dashboards
|
||||
- ./monitoring/grafana/datasources:/etc/grafana/provisioning/datasources
|
||||
depends_on:
|
||||
- prometheus
|
||||
networks:
|
||||
- doris-network
|
||||
restart: unless-stopped
|
||||
|
||||
# Nginx load balancer
|
||||
nginx:
|
||||
image: nginx:alpine
|
||||
container_name: doris-nginx
|
||||
ports:
|
||||
- "80:80"
|
||||
- "443:443"
|
||||
volumes:
|
||||
- ./nginx/nginx.conf:/etc/nginx/nginx.conf
|
||||
- ./nginx/ssl:/etc/nginx/ssl
|
||||
- ./nginx/logs:/var/log/nginx
|
||||
depends_on:
|
||||
- doris-mcp-server
|
||||
networks:
|
||||
- doris-network
|
||||
restart: unless-stopped
|
||||
|
||||
volumes:
|
||||
doris-fe-data:
|
||||
driver: local
|
||||
doris-fe-log:
|
||||
driver: local
|
||||
doris-be-data:
|
||||
driver: local
|
||||
doris-be-log:
|
||||
driver: local
|
||||
redis-data:
|
||||
driver: local
|
||||
prometheus-data:
|
||||
driver: local
|
||||
grafana-data:
|
||||
driver: local
|
||||
|
||||
networks:
|
||||
doris-network:
|
||||
driver: bridge
|
||||
ipam:
|
||||
config:
|
||||
- subnet: 172.20.0.0/16
|
||||
322
doris_mcp_client/README.md
Normal file
322
doris_mcp_client/README.md
Normal file
@@ -0,0 +1,322 @@
|
||||
# Doris Unified MCP Client
|
||||
|
||||
This is a unified Doris MCP client that supports both **stdio** and **Streamable HTTP** transport modes, providing complete MCP protocol support.
|
||||
|
||||
## 🚀 Features
|
||||
|
||||
- ✅ **Dual Mode Support**: Both stdio and HTTP transport methods
|
||||
- ✅ **Complete MCP Support**: Resources, Tools, and Prompts primitives
|
||||
- ✅ **Unified API**: Same interface for different transport modes
|
||||
- ✅ **Asynchronous Design**: High-performance async client based on asyncio
|
||||
- ✅ **Enterprise Features**: Connection pooling, error handling, logging
|
||||
- ✅ **Convenience Methods**: High-level wrappers for common database operations
|
||||
|
||||
## 📦 Install Dependencies
|
||||
|
||||
```bash
|
||||
pip install mcp
|
||||
```
|
||||
|
||||
## 🎯 Quick Start
|
||||
|
||||
### 1. stdio Mode
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from client import create_stdio_client
|
||||
|
||||
async def main():
|
||||
# Create stdio client
|
||||
client = await create_stdio_client(
|
||||
"python",
|
||||
["-m", "doris_mcp_server.main", "--transport", "stdio"]
|
||||
)
|
||||
|
||||
async def test_client(client):
|
||||
# Get database list
|
||||
db_result = await client.get_database_list()
|
||||
print(f"Databases: {db_result}")
|
||||
|
||||
# Execute SQL query
|
||||
query_result = await client.execute_sql("SELECT 1 as test")
|
||||
print(f"Query result: {query_result}")
|
||||
|
||||
await client.connect_and_run(test_client)
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
### 2. HTTP Mode
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from unified_client import create_http_client
|
||||
|
||||
async def main():
|
||||
# Create HTTP client
|
||||
client = await create_http_client("http://localhost:3000/mcp")
|
||||
|
||||
async def test_client(client):
|
||||
# Get all tools
|
||||
tools = await client.list_all_tools()
|
||||
print(f"Available tools: {len(tools)}")
|
||||
|
||||
# Execute query
|
||||
result = await client.execute_sql(
|
||||
"SELECT COUNT(*) FROM internal.ssb.lineorder LIMIT 1"
|
||||
)
|
||||
print(f"Query result: {result}")
|
||||
|
||||
await client.connect_and_run(test_client)
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
## 🔧 API Reference
|
||||
|
||||
### Client Creation
|
||||
|
||||
```python
|
||||
# stdio mode
|
||||
client = await create_stdio_client(command, args)
|
||||
|
||||
# HTTP mode
|
||||
client = await create_http_client(server_url, timeout=60)
|
||||
```
|
||||
|
||||
### Basic Operations
|
||||
|
||||
```python
|
||||
async def test_client(client):
|
||||
# Get server capabilities
|
||||
tools = await client.list_all_tools()
|
||||
resources = await client.list_all_resources()
|
||||
prompts = await client.list_all_prompts()
|
||||
|
||||
# Call tool
|
||||
result = await client.call_tool("tool_name", {"param": "value"})
|
||||
|
||||
# Read resource
|
||||
content = await client.read_resource("resource://uri")
|
||||
|
||||
# Get prompt
|
||||
prompt = await client.get_prompt("prompt_name", {"param": "value"})
|
||||
```
|
||||
|
||||
### Advanced Database Operations
|
||||
|
||||
```python
|
||||
async def database_operations(client):
|
||||
# Execute SQL query
|
||||
result = await client.execute_sql("SELECT * FROM table LIMIT 10")
|
||||
|
||||
# Get database list
|
||||
databases = await client.get_database_list()
|
||||
|
||||
# Get table schema
|
||||
schema = await client.get_table_schema("table_name", "db_name")
|
||||
|
||||
# Column data analysis
|
||||
analysis = await client.analyze_column("table", "column", "basic")
|
||||
```
|
||||
|
||||
## 🧪 Testing
|
||||
|
||||
### Run Test Suite
|
||||
|
||||
```bash
|
||||
# Interactive testing
|
||||
python test_unified_client.py
|
||||
|
||||
# Test stdio mode
|
||||
python test_unified_client.py stdio
|
||||
|
||||
# Test HTTP mode
|
||||
python test_unified_client.py http
|
||||
|
||||
# Test both modes
|
||||
python test_unified_client.py both
|
||||
|
||||
# Performance benchmark
|
||||
python test_unified_client.py benchmark
|
||||
```
|
||||
|
||||
### Test Output Example
|
||||
|
||||
```
|
||||
🎯 Doris Unified Client Test Suite
|
||||
============================================================
|
||||
|
||||
🚀 Testing HTTP Mode
|
||||
==================================================
|
||||
📋 Getting server capabilities...
|
||||
✅ Found 11 tools
|
||||
✅ Found 0 resources
|
||||
✅ Found 0 prompts
|
||||
|
||||
🔧 Available tools:
|
||||
1. get_db_list: Get database list
|
||||
2. get_table_list: Get table list for specified database
|
||||
3. get_table_schema: Get table structure information
|
||||
4. exec_query: Execute SQL query
|
||||
5. column_analysis: Analyze column data distribution and statistics
|
||||
...
|
||||
|
||||
🧪 Testing basic functionality...
|
||||
1️⃣ Getting database list...
|
||||
✅ Success: 3 databases
|
||||
2️⃣ Executing simple query...
|
||||
✅ Query successful
|
||||
3️⃣ Executing SSB data query...
|
||||
✅ SSB query successful
|
||||
4️⃣ Getting table structure...
|
||||
✅ Table structure retrieved successfully
|
||||
5️⃣ Column data analysis...
|
||||
✅ Column analysis successful
|
||||
|
||||
✅ HTTP mode testing completed!
|
||||
```
|
||||
|
||||
## 🏗️ Architecture Design
|
||||
|
||||
### Unified Client Architecture
|
||||
|
||||
```
|
||||
DorisUnifiedClient
|
||||
├── DorisResourceClient # Resource management
|
||||
├── DorisToolsClient # Tool invocation
|
||||
├── DorisPromptClient # Prompt management
|
||||
└── Transport Layer
|
||||
├── stdio mode # Standard input/output
|
||||
└── HTTP mode # Streamable HTTP
|
||||
```
|
||||
|
||||
### Key Features
|
||||
|
||||
1. **Unified Interface**: Same API for different transport modes
|
||||
2. **Async Context**: Proper resource management and connection cleanup
|
||||
3. **Error Handling**: Comprehensive exception handling and error recovery
|
||||
4. **Performance Optimization**: Connection reuse and request caching
|
||||
|
||||
## 📚 Usage Examples
|
||||
|
||||
### Complete Example
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from client import DorisUnifiedClient, DorisClientConfig
|
||||
|
||||
async def comprehensive_example():
|
||||
# Create configuration
|
||||
config = DorisClientConfig.stdio(
|
||||
"python",
|
||||
["-m", "doris_mcp_server.main"]
|
||||
)
|
||||
|
||||
client = DorisUnifiedClient(config)
|
||||
|
||||
async def demo_operations(client):
|
||||
print("🔍 Discovering server capabilities...")
|
||||
|
||||
# List all available tools
|
||||
tools = await client.list_all_tools()
|
||||
print(f"Available tools: {[tool.name for tool in tools]}")
|
||||
|
||||
# Get database list
|
||||
print("\n📊 Getting database information...")
|
||||
db_result = await client.get_database_list()
|
||||
print(f"Databases: {db_result}")
|
||||
|
||||
# Execute queries
|
||||
print("\n🔍 Executing queries...")
|
||||
|
||||
# Simple query
|
||||
result1 = await client.execute_sql("SELECT 1 as test_column")
|
||||
print(f"Simple query result: {result1}")
|
||||
|
||||
# Get table schema
|
||||
schema_result = await client.get_table_schema("lineorder", "ssb")
|
||||
print(f"Table schema: {schema_result}")
|
||||
|
||||
# Column analysis
|
||||
analysis_result = await client.analyze_column(
|
||||
"lineorder", "lo_orderkey", "basic"
|
||||
)
|
||||
print(f"Column analysis: {analysis_result}")
|
||||
|
||||
await client.connect_and_run(demo_operations)
|
||||
|
||||
# Run the example
|
||||
asyncio.run(comprehensive_example())
|
||||
```
|
||||
|
||||
### Error Handling
|
||||
|
||||
```python
|
||||
async def error_handling_example(client):
|
||||
try:
|
||||
# This might fail
|
||||
result = await client.execute_sql("INVALID SQL")
|
||||
except Exception as e:
|
||||
print(f"SQL execution failed: {e}")
|
||||
|
||||
# Check result status
|
||||
result = await client.get_database_list()
|
||||
if result.get("success", True):
|
||||
print("Operation successful")
|
||||
else:
|
||||
print(f"Operation failed: {result.get('error')}")
|
||||
```
|
||||
|
||||
## 🔧 Configuration
|
||||
|
||||
### Client Configuration Options
|
||||
|
||||
```python
|
||||
# stdio mode with custom arguments
|
||||
config = DorisClientConfig(
|
||||
transport="stdio",
|
||||
server_command="python",
|
||||
server_args=["-m", "doris_mcp_server.main", "--debug"],
|
||||
timeout=30
|
||||
)
|
||||
|
||||
# HTTP mode with custom timeout
|
||||
config = DorisClientConfig(
|
||||
transport="http",
|
||||
server_url="http://localhost:8080/mcp",
|
||||
timeout=60
|
||||
)
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
|
||||
```bash
|
||||
# Set default server URL
|
||||
export DORIS_MCP_SERVER_URL="http://localhost:8080"
|
||||
|
||||
# Set default timeout
|
||||
export DORIS_MCP_TIMEOUT=60
|
||||
|
||||
# Enable debug logging
|
||||
export DORIS_MCP_DEBUG=true
|
||||
```
|
||||
|
||||
## 🚀 Performance Tips
|
||||
|
||||
1. **Connection Reuse**: Use the same client instance for multiple operations
|
||||
2. **Batch Operations**: Group related queries together
|
||||
3. **Async Context**: Always use proper async context management
|
||||
4. **Error Recovery**: Implement retry logic for transient failures
|
||||
|
||||
## 🤝 Contributing
|
||||
|
||||
1. Fork the repository
|
||||
2. Create a feature branch
|
||||
3. Make your changes
|
||||
4. Add tests
|
||||
5. Submit a pull request
|
||||
|
||||
## 📄 License
|
||||
|
||||
This project is licensed under the Apache 2.0 License.
|
||||
9
doris_mcp_client/__init__.py
Normal file
9
doris_mcp_client/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
Doris MCP Client Package
|
||||
|
||||
Unified MCP client supporting both stdio and HTTP transport modes
|
||||
"""
|
||||
|
||||
from .client import DorisUnifiedClient, DorisClientConfig
|
||||
|
||||
__all__ = ["DorisUnifiedClient", "DorisClientConfig"]
|
||||
497
doris_mcp_client/client.py
Normal file
497
doris_mcp_client/client.py
Normal file
@@ -0,0 +1,497 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Unified Doris MCP Client - Supports both stdio and Streamable HTTP modes
|
||||
|
||||
Combines the correct HTTP implementation from http_client.py and the complete architecture from client.py
|
||||
Provides complete support for the three major primitives: Resources, Tools, and Prompts
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Callable
|
||||
from datetime import timedelta
|
||||
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from mcp import StdioServerParameters
|
||||
from mcp.types import (
|
||||
Prompt,
|
||||
Resource,
|
||||
Tool,
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DorisClientConfig:
|
||||
"""Doris client configuration class"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: str = "stdio",
|
||||
server_command: str | None = None,
|
||||
server_args: list[str] | None = None,
|
||||
server_url: str | None = None,
|
||||
timeout: int = 60,
|
||||
):
|
||||
self.transport = transport
|
||||
self.server_command = server_command
|
||||
self.server_args = server_args or []
|
||||
self.server_url = server_url
|
||||
self.timeout = timeout
|
||||
|
||||
@classmethod
|
||||
def stdio(cls, command: str, args: list[str] = None) -> "DorisClientConfig":
|
||||
"""Create stdio connection configuration"""
|
||||
return cls(
|
||||
transport="stdio",
|
||||
server_command=command,
|
||||
server_args=args or []
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def http(cls, url: str, timeout: int = 60) -> "DorisClientConfig":
|
||||
"""Create HTTP connection configuration"""
|
||||
return cls(
|
||||
transport="http",
|
||||
server_url=url,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
|
||||
class DorisResourceClient:
|
||||
"""Doris resource client - Handles Resources related operations"""
|
||||
|
||||
def __init__(self, session: ClientSession):
|
||||
self.session = session
|
||||
self.logger = logging.getLogger(f"{__name__}.DorisResourceClient")
|
||||
self._resources_cache = None
|
||||
|
||||
async def list_resources(self) -> list[Resource]:
|
||||
"""Get list of all available resources"""
|
||||
try:
|
||||
self.logger.info("Getting resource list")
|
||||
response = await self.session.list_resources()
|
||||
resources = response.resources if hasattr(response, "resources") else []
|
||||
self._resources_cache = resources
|
||||
self.logger.info(f"Retrieved {len(resources)} resources")
|
||||
return resources
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to get resource list: {e}")
|
||||
return []
|
||||
|
||||
async def read_resource(self, uri: str) -> str | None:
|
||||
"""Read specified resource content"""
|
||||
try:
|
||||
self.logger.info(f"Reading resource: {uri}")
|
||||
response = await self.session.read_resource(uri)
|
||||
|
||||
if hasattr(response, "contents") and response.contents:
|
||||
# Merge all content
|
||||
content_parts = []
|
||||
for content in response.contents:
|
||||
if hasattr(content, "text"):
|
||||
content_parts.append(content.text)
|
||||
content = "\n".join(content_parts)
|
||||
self.logger.info(f"Successfully read resource content: {len(content)} characters")
|
||||
return content
|
||||
elif hasattr(response, "content"):
|
||||
return str(response.content)
|
||||
else:
|
||||
self.logger.warning(f"Resource {uri} returned no content")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to read resource {uri}: {e}")
|
||||
return None
|
||||
|
||||
async def filter_resources_by_type(self, resource_type: str) -> list[Resource]:
|
||||
"""Filter resources by type"""
|
||||
if not self._resources_cache:
|
||||
await self.list_resources()
|
||||
|
||||
if resource_type == "table":
|
||||
return [r for r in self._resources_cache if "table" in r.uri]
|
||||
elif resource_type == "view":
|
||||
return [r for r in self._resources_cache if "view" in r.uri]
|
||||
elif resource_type == "database":
|
||||
return [
|
||||
r for r in self._resources_cache
|
||||
if "database" in r.uri and "table" not in r.uri
|
||||
]
|
||||
else:
|
||||
return self._resources_cache
|
||||
|
||||
|
||||
class DorisToolsClient:
|
||||
"""Doris tools client - Handles Tools related operations"""
|
||||
|
||||
def __init__(self, session: ClientSession):
|
||||
self.session = session
|
||||
self.logger = logging.getLogger(f"{__name__}.DorisToolsClient")
|
||||
self._tools_cache = None
|
||||
|
||||
async def list_tools(self) -> list[Tool]:
|
||||
"""Get list of all available tools"""
|
||||
try:
|
||||
self.logger.info("Getting tool list")
|
||||
response = await self.session.list_tools()
|
||||
tools = response.tools if hasattr(response, "tools") else []
|
||||
self._tools_cache = tools
|
||||
self.logger.info(f"Retrieved {len(tools)} tools")
|
||||
return tools
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to get tool list: {e}")
|
||||
return []
|
||||
|
||||
async def call_tool(self, name: str, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Call specified tool"""
|
||||
try:
|
||||
self.logger.info(f"Calling tool: {name}")
|
||||
self.logger.debug(f"Tool arguments: {arguments}")
|
||||
|
||||
response = await self.session.call_tool(name, arguments)
|
||||
|
||||
if hasattr(response, "content") and response.content:
|
||||
# Parse response content
|
||||
result_text = ""
|
||||
for content in response.content:
|
||||
if hasattr(content, "text"):
|
||||
result_text += content.text
|
||||
|
||||
# Try to parse as JSON
|
||||
try:
|
||||
result = json.loads(result_text)
|
||||
self.logger.info(f"Tool call successful: {name}")
|
||||
return result
|
||||
except json.JSONDecodeError:
|
||||
# If not JSON format, return text directly
|
||||
return {"success": True, "data": result_text}
|
||||
|
||||
self.logger.warning(f"Tool {name} returned no content")
|
||||
return {"success": False, "error": "No response content"}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Tool call failed {name}: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def get_tool_by_name(self, name: str) -> Tool | None:
|
||||
"""Get tool definition by name"""
|
||||
if not self._tools_cache:
|
||||
await self.list_tools()
|
||||
|
||||
for tool in self._tools_cache:
|
||||
if tool.name == name:
|
||||
return tool
|
||||
return None
|
||||
|
||||
async def get_tools_by_category(self, category: str) -> list[Tool]:
|
||||
"""Filter tools by category"""
|
||||
if not self._tools_cache:
|
||||
await self.list_tools()
|
||||
|
||||
category_lower = category.lower()
|
||||
return [
|
||||
tool for tool in self._tools_cache
|
||||
if category_lower in tool.description.lower()
|
||||
or category_lower in tool.name.lower()
|
||||
]
|
||||
|
||||
|
||||
class DorisPromptClient:
|
||||
"""Doris prompt client - Handles Prompts related operations"""
|
||||
|
||||
def __init__(self, session: ClientSession):
|
||||
self.session = session
|
||||
self.logger = logging.getLogger(f"{__name__}.DorisPromptClient")
|
||||
self._prompts_cache = None
|
||||
|
||||
async def list_prompts(self) -> list[Prompt]:
|
||||
"""Get list of all available prompts"""
|
||||
try:
|
||||
self.logger.info("Getting prompt list")
|
||||
response = await self.session.list_prompts()
|
||||
prompts = response.prompts if hasattr(response, "prompts") else []
|
||||
self._prompts_cache = prompts
|
||||
self.logger.info(f"Retrieved {len(prompts)} prompts")
|
||||
return prompts
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to get prompt list: {e}")
|
||||
return []
|
||||
|
||||
async def get_prompt(self, name: str, arguments: dict[str, Any]) -> str | None:
|
||||
"""Get specified prompt content"""
|
||||
try:
|
||||
self.logger.info(f"Getting prompt: {name}")
|
||||
self.logger.debug(f"Prompt arguments: {arguments}")
|
||||
|
||||
response = await self.session.get_prompt(name, arguments)
|
||||
|
||||
if hasattr(response, "messages") and response.messages:
|
||||
# Merge all message content
|
||||
content_parts = []
|
||||
for message in response.messages:
|
||||
if hasattr(message, "content"):
|
||||
if hasattr(message.content, "text"):
|
||||
content_parts.append(message.content.text)
|
||||
else:
|
||||
content_parts.append(str(message.content))
|
||||
|
||||
content = "\n".join(content_parts)
|
||||
self.logger.info(f"Successfully retrieved prompt content: {len(content)} characters")
|
||||
return content
|
||||
|
||||
self.logger.warning(f"Prompt {name} returned no content")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to get prompt {name}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class DorisUnifiedClient:
|
||||
"""Unified Doris MCP client - Provides complete MCP functionality"""
|
||||
|
||||
def __init__(self, config: DorisClientConfig):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(f"{__name__}.DorisUnifiedClient")
|
||||
self.session = None
|
||||
self.resources = None
|
||||
self.tools = None
|
||||
self.prompts = None
|
||||
|
||||
async def connect_and_run(self, callback_func: Callable):
|
||||
"""Connect to server and execute callback function"""
|
||||
if self.config.transport == "stdio":
|
||||
await self._run_stdio_mode(callback_func)
|
||||
elif self.config.transport == "http":
|
||||
await self._run_http_mode(callback_func)
|
||||
else:
|
||||
raise ValueError(f"Unsupported transport type: {self.config.transport}")
|
||||
|
||||
async def _run_stdio_mode(self, callback_func: Callable):
|
||||
"""Run in stdio mode"""
|
||||
try:
|
||||
self.logger.info(f"Starting stdio client: {self.config.server_command}")
|
||||
|
||||
server_params = StdioServerParameters(
|
||||
command=self.config.server_command,
|
||||
args=self.config.server_args,
|
||||
)
|
||||
|
||||
async with stdio_client(server_params) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
self.session = session
|
||||
self._init_sub_clients()
|
||||
|
||||
# Initialize server
|
||||
await session.initialize()
|
||||
self.logger.info("Server initialized successfully")
|
||||
|
||||
# Execute callback function
|
||||
await callback_func(self)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"stdio mode execution failed: {e}")
|
||||
raise
|
||||
|
||||
async def _run_http_mode(self, callback_func: Callable):
|
||||
"""Run in HTTP mode"""
|
||||
try:
|
||||
self.logger.info(f"Starting HTTP client: {self.config.server_url}")
|
||||
|
||||
async with streamablehttp_client(
|
||||
self.config.server_url,
|
||||
timeout=timedelta(seconds=self.config.timeout)
|
||||
) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
self.session = session
|
||||
self._init_sub_clients()
|
||||
|
||||
# Initialize server
|
||||
await session.initialize()
|
||||
self.logger.info("Server initialized successfully")
|
||||
|
||||
# Execute callback function
|
||||
await callback_func(self)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"HTTP mode execution failed: {e}")
|
||||
raise
|
||||
|
||||
def _init_sub_clients(self):
|
||||
"""Initialize sub-clients"""
|
||||
self.resources = DorisResourceClient(self.session)
|
||||
self.tools = DorisToolsClient(self.session)
|
||||
self.prompts = DorisPromptClient(self.session)
|
||||
|
||||
# Convenience methods
|
||||
async def list_all_resources(self) -> list[Resource]:
|
||||
"""Get all resources"""
|
||||
return await self.resources.list_resources()
|
||||
|
||||
async def list_all_tools(self) -> list[Tool]:
|
||||
"""Get all tools"""
|
||||
return await self.tools.list_tools()
|
||||
|
||||
async def list_all_prompts(self) -> list[Prompt]:
|
||||
"""Get all prompts"""
|
||||
return await self.prompts.list_prompts()
|
||||
|
||||
async def call_tool(self, name: str, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Call tool"""
|
||||
return await self.tools.call_tool(name, arguments)
|
||||
|
||||
async def read_resource(self, uri: str) -> str | None:
|
||||
"""Read resource"""
|
||||
return await self.resources.read_resource(uri)
|
||||
|
||||
async def get_prompt(self, name: str, arguments: dict[str, Any]) -> str | None:
|
||||
"""Get prompt"""
|
||||
return await self.prompts.get_prompt(name, arguments)
|
||||
|
||||
# Smart tool finding methods
|
||||
async def _find_tool_by_pattern(self, patterns: list[str]) -> str | None:
|
||||
"""Find tool by name pattern"""
|
||||
tools = await self.list_all_tools()
|
||||
for pattern in patterns:
|
||||
for tool in tools:
|
||||
if pattern in tool.name:
|
||||
return tool.name
|
||||
return None
|
||||
|
||||
async def _find_tool_by_function(self, function_keywords: list[str]) -> str | None:
|
||||
"""Find tool by function keywords"""
|
||||
tools = await self.list_all_tools()
|
||||
for tool in tools:
|
||||
tool_desc = tool.description.lower()
|
||||
tool_name = tool.name.lower()
|
||||
for keyword in function_keywords:
|
||||
if keyword.lower() in tool_desc or keyword.lower() in tool_name:
|
||||
return tool.name
|
||||
return None
|
||||
|
||||
# High-level business methods
|
||||
async def execute_sql(self, sql: str, **kwargs) -> dict[str, Any]:
|
||||
"""Execute SQL query"""
|
||||
tool_name = await self._find_tool_by_pattern(["exec_query", "execute", "query"])
|
||||
if not tool_name:
|
||||
return {"success": False, "error": "SQL execution tool not found"}
|
||||
|
||||
arguments = {"sql": sql, **kwargs}
|
||||
return await self.call_tool(tool_name, arguments)
|
||||
|
||||
async def get_table_schema(self, table_name: str, db_name: str = None, **kwargs) -> dict[str, Any]:
|
||||
"""Get table schema"""
|
||||
tool_name = await self._find_tool_by_pattern(["get_table_schema", "table_schema", "schema"])
|
||||
if not tool_name:
|
||||
return {"success": False, "error": "Table schema tool not found"}
|
||||
|
||||
arguments = {"table_name": table_name}
|
||||
if db_name:
|
||||
arguments["db_name"] = db_name
|
||||
arguments.update(kwargs)
|
||||
|
||||
return await self.call_tool(tool_name, arguments)
|
||||
|
||||
async def get_database_list(self, **kwargs) -> dict[str, Any]:
|
||||
"""Get database list"""
|
||||
tool_name = await self._find_tool_by_pattern(["get_db_list", "database_list", "db_list"])
|
||||
if not tool_name:
|
||||
return {"success": False, "error": "Database list tool not found"}
|
||||
|
||||
return await self.call_tool(tool_name, kwargs)
|
||||
|
||||
async def analyze_column(self, table_name: str, column_name: str, analysis_type: str = "basic", **kwargs) -> dict[str, Any]:
|
||||
"""Analyze column"""
|
||||
tool_name = await self._find_tool_by_pattern(["column_analysis", "analyze_column", "column"])
|
||||
if not tool_name:
|
||||
return {"success": False, "error": "Column analysis tool not found"}
|
||||
|
||||
arguments = {
|
||||
"table_name": table_name,
|
||||
"column_name": column_name,
|
||||
"analysis_type": analysis_type,
|
||||
**kwargs
|
||||
}
|
||||
return await self.call_tool(tool_name, arguments)
|
||||
|
||||
async def call_tool_by_function(self, function_description: str, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Call tool by function description"""
|
||||
# Try to find appropriate tool based on function description
|
||||
function_keywords = function_description.lower().split()
|
||||
tool_name = await self._find_tool_by_function(function_keywords)
|
||||
|
||||
if not tool_name:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"No tool found for function: {function_description}"
|
||||
}
|
||||
|
||||
return await self.call_tool(tool_name, arguments)
|
||||
|
||||
|
||||
# Convenience factory functions
|
||||
async def create_stdio_client(command: str, args: list[str] = None) -> DorisUnifiedClient:
|
||||
"""Create stdio client"""
|
||||
config = DorisClientConfig.stdio(command, args)
|
||||
return DorisUnifiedClient(config)
|
||||
|
||||
|
||||
async def create_http_client(server_url: str, timeout: int = 60) -> DorisUnifiedClient:
|
||||
"""Create HTTP client"""
|
||||
config = DorisClientConfig.http(server_url, timeout)
|
||||
return DorisUnifiedClient(config)
|
||||
|
||||
|
||||
# Example usage
|
||||
async def example_stdio():
|
||||
"""stdio mode example"""
|
||||
client = await create_stdio_client("python", ["doris_mcp_server/main.py"])
|
||||
|
||||
async def test_client(client: DorisUnifiedClient):
|
||||
# Get server capabilities
|
||||
resources = await client.list_all_resources()
|
||||
tools = await client.list_all_tools()
|
||||
prompts = await client.list_all_prompts()
|
||||
|
||||
print(f"Resources: {len(resources)}")
|
||||
print(f"Tools: {len(tools)}")
|
||||
print(f"Prompts: {len(prompts)}")
|
||||
|
||||
# Test SQL execution
|
||||
result = await client.execute_sql("SELECT 1 as test")
|
||||
print(f"SQL execution result: {result}")
|
||||
|
||||
await client.connect_and_run(test_client)
|
||||
|
||||
|
||||
async def example_http():
|
||||
"""HTTP mode example"""
|
||||
client = await create_http_client("http://localhost:8080")
|
||||
|
||||
async def test_client(client: DorisUnifiedClient):
|
||||
# Get server capabilities
|
||||
resources = await client.list_all_resources()
|
||||
tools = await client.list_all_tools()
|
||||
|
||||
print(f"Resources: {len(resources)}")
|
||||
print(f"Tools: {len(tools)}")
|
||||
|
||||
# Test database list
|
||||
result = await client.get_database_list()
|
||||
print(f"Database list: {result}")
|
||||
|
||||
await client.connect_and_run(test_client)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run stdio example
|
||||
asyncio.run(example_stdio())
|
||||
|
||||
# Run HTTP example
|
||||
# asyncio.run(example_http())
|
||||
@@ -1 +1,13 @@
|
||||
# Mark directory as a package
|
||||
"""
|
||||
Doris MCP Server - A Model Context Protocol server for Apache Doris database integration.
|
||||
|
||||
This package provides:
|
||||
- MCP protocol implementation for Apache Doris
|
||||
- Multi-transport support (stdio, SSE, streamable HTTP)
|
||||
- Comprehensive database tools and resources
|
||||
- Enterprise-grade security and monitoring
|
||||
"""
|
||||
|
||||
__version__ = "1.0.0"
|
||||
__author__ = "Doris MCP Team"
|
||||
__description__ = "Apache Doris MCP Server Implementation"
|
||||
|
||||
8
doris_mcp_server/__main__.py
Normal file
8
doris_mcp_server/__main__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Entry point for running doris_mcp_server as a module
|
||||
"""
|
||||
|
||||
from .main import main_sync
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_sync()
|
||||
@@ -1,33 +0,0 @@
|
||||
# doris_mcp_server/config.py
|
||||
import os
|
||||
import logging
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Get Log Level from environment variable, default to 'info'
|
||||
LOG_LEVEL_STR = os.getenv('LOG_LEVEL', 'info').upper()
|
||||
|
||||
# Map string level to logging level constant
|
||||
LOG_LEVEL_MAP = {
|
||||
'DEBUG': logging.DEBUG,
|
||||
'INFO': logging.INFO,
|
||||
'WARNING': logging.WARNING,
|
||||
'ERROR': logging.ERROR,
|
||||
'CRITICAL': logging.CRITICAL
|
||||
}
|
||||
LOG_LEVEL = LOG_LEVEL_MAP.get(LOG_LEVEL_STR, logging.INFO)
|
||||
|
||||
# Function to load config (can be expanded later if needed)
|
||||
def load_config():
|
||||
"""Loads configuration settings."""
|
||||
# Currently, configuration is mainly handled by environment variables
|
||||
# and constants defined in this module.
|
||||
# This function can be used to perform additional setup if required.
|
||||
logging.getLogger(__name__).info("Configuration loaded (mainly from environment variables).")
|
||||
|
||||
# You can add other configuration constants here if needed
|
||||
# Example: DB_HOST = os.getenv("DB_HOST", "localhost")
|
||||
# But often it's better to access os.getenv directly where needed
|
||||
# or pass config dictionaries around.
|
||||
@@ -1,196 +1,515 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Apache Doris MCP Server Main Entry - Primarily handles SSE mode
|
||||
Apache Doris MCP Server - Enterprise Database Service Implementation
|
||||
|
||||
Stdio mode is handled by doris_mcp_server.mcp_core:run_stdio.
|
||||
Based on Apache Doris official MCP Server architecture design, providing complete MCP protocol support
|
||||
Supports independent encapsulation implementation of Resources, Tools, and Prompts
|
||||
Supports both stdio and streamable HTTP startup modes
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from collections.abc import AsyncIterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Any
|
||||
import uvicorn
|
||||
from uvicorn import Config, Server
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from dotenv import load_dotenv
|
||||
from typing import Any
|
||||
|
||||
# Add project root to path
|
||||
PROJECT_ROOT = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
|
||||
sys.path.insert(0, PROJECT_ROOT)
|
||||
from mcp.server import Server
|
||||
from mcp.server.models import InitializationOptions
|
||||
|
||||
# SSE related imports
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from doris_mcp_server.sse_server import DorisMCPSseServer
|
||||
from doris_mcp_server.streamable_server import DorisMCPStreamableServer
|
||||
|
||||
# Stdio related imports (only needed for tools now, maybe move tool init?)
|
||||
# from mcp.server.stdio import stdio_server -> No longer used here
|
||||
|
||||
# Config and Tool Initializer
|
||||
from doris_mcp_server.config import load_config # LOG_LEVEL might not be needed here directly
|
||||
from doris_mcp_server.tools.tool_initializer import register_mcp_tools
|
||||
|
||||
# Load environment variables (load early for all modes)
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Get logger
|
||||
logger = logging.getLogger("doris-mcp-main") # Changed logger name slightly
|
||||
|
||||
# --- Configuration Loading and Logging Setup ---
|
||||
load_config() # Loads .env
|
||||
|
||||
# --- Create FastAPI App (Global Scope for SSE Mode) ---
|
||||
# This 'app' object is targeted by 'mcp run doris_mcp_server/main.py:app --transport sse'
|
||||
# And used when running directly with --sse
|
||||
app = FastAPI(
|
||||
title="Doris MCP Server (SSE Mode)",
|
||||
# Lifespan will be added in start_sse_server
|
||||
from mcp.types import (
|
||||
Prompt,
|
||||
Resource,
|
||||
TextContent,
|
||||
Tool,
|
||||
)
|
||||
|
||||
# --- Removed StdioServerWrapper ---
|
||||
from .tools.tools_manager import DorisToolsManager
|
||||
from .tools.prompts_manager import DorisPromptsManager
|
||||
from .tools.resources_manager import DorisResourcesManager
|
||||
from .utils.config import DorisConfig
|
||||
from .utils.db import DorisConnectionManager
|
||||
from .utils.security import DorisSecurityManager
|
||||
|
||||
# --- Command Line Argument Parsing ---
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Apache Doris MCP Server (SSE Mode Entry)")
|
||||
# Only keep SSE related args here
|
||||
parser.add_argument('--sse', action='store_true', help='Start SSE Web server mode (required)')
|
||||
parser.add_argument('--host', type=str, default=os.getenv('SERVER_HOST', '0.0.0.0'), help='Host address')
|
||||
parser.add_argument('--port', type=int, default=int(os.getenv('SERVER_PORT', os.getenv('MCP_PORT', '3000'))), help='Port number')
|
||||
parser.add_argument('--debug', action='store_true', help='Enable debug mode')
|
||||
parser.add_argument('--reload', action='store_true', help='Enable auto-reload')
|
||||
return parser.parse_args()
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- SSE Mode Specific Code ---
|
||||
@dataclass
|
||||
class AppContext:
|
||||
config: Dict[str, Any]
|
||||
|
||||
@asynccontextmanager
|
||||
async def app_lifespan(app_instance: FastAPI) -> AsyncIterator[None]:
|
||||
logger.info("SSE application lifecycle start...")
|
||||
config = {
|
||||
# Simplified config - maybe get from elsewhere?
|
||||
"db_host": os.getenv("DB_HOST", "localhost"),
|
||||
"db_port": int(os.getenv("DB_PORT", "9030")),
|
||||
"db_user": os.getenv("DB_USER", "root"),
|
||||
"db_password": os.getenv("DB_PASSWORD", ""),
|
||||
"db_database": os.getenv("DB_DATABASE", "test"),
|
||||
}
|
||||
app_instance.state.config = config
|
||||
try:
|
||||
# Yield None implicitly or explicitly None
|
||||
yield
|
||||
finally:
|
||||
logger.info("Cleaning up SSE application resources...")
|
||||
class DorisServer:
|
||||
"""Apache Doris MCP Server main class"""
|
||||
|
||||
async def start_sse_server(args):
|
||||
"""Start SSE Web server mode (Configures the global 'app')"""
|
||||
logger.info("Starting SSE Web server mode...")
|
||||
global app
|
||||
def __init__(self, config: DorisConfig):
|
||||
self.config = config
|
||||
self.server = Server("doris-mcp-server")
|
||||
|
||||
# --- Initialize MCP and Tools for SSE ---
|
||||
# Create a *separate* MCP instance for SSE mode
|
||||
sse_mcp = FastMCP(
|
||||
name="doris-mcp-sse",
|
||||
description="Apache Doris MCP Server (SSE)",
|
||||
lifespan=None, # Managed by FastAPI
|
||||
dependencies=["fastapi", "uvicorn", "openai", "sse_starlette"]
|
||||
)
|
||||
logger.info("Registering MCP tools for SSE mode...")
|
||||
await register_mcp_tools(sse_mcp) # Register tools for the SSE instance
|
||||
logger.info("MCP tools registered for SSE.")
|
||||
# Initialize security manager
|
||||
self.security_manager = DorisSecurityManager(config)
|
||||
|
||||
# --- Configure Lifespan and CORS for the global app ---
|
||||
app.router.lifespan_context = app_lifespan
|
||||
origins = os.getenv("ALLOWED_ORIGINS", "*").split(",")
|
||||
allow_credentials = os.getenv("MCP_ALLOW_CREDENTIALS", "false").lower() == "true"
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=allow_credentials,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
expose_headers=["Mcp-Session-Id"],
|
||||
)
|
||||
# Initialize connection manager, pass in security manager
|
||||
self.connection_manager = DorisConnectionManager(config, self.security_manager)
|
||||
|
||||
# --- Initialize Handlers and Register Routes (Pass sse_mcp instance) ---
|
||||
logger.info("Initializing SSE server handlers and registering routes...")
|
||||
sse_server_handler = DorisMCPSseServer(sse_mcp, app)
|
||||
streamable_server_handler = DorisMCPStreamableServer(sse_mcp, app)
|
||||
logger.info("SSE Server handlers initialized and routes registered.")
|
||||
# Initialize independent managers
|
||||
self.resources_manager = DorisResourcesManager(self.connection_manager)
|
||||
self.tools_manager = DorisToolsManager(self.connection_manager)
|
||||
self.prompts_manager = DorisPromptsManager(self.connection_manager)
|
||||
|
||||
# --- Print Configuration and Endpoints ---
|
||||
print("--- SSE Mode Configuration ---")
|
||||
print(f"Server Host: {args.host}")
|
||||
print(f"Server Port: {args.port}")
|
||||
print(f"Allowed Origins: {origins}")
|
||||
print(f"Allow Credentials: {allow_credentials}")
|
||||
print(f"Log Level: {os.getenv('LOG_LEVEL', 'info')}")
|
||||
print(f"Debug Mode: {args.debug}")
|
||||
print(f"Reload Mode: {args.reload}")
|
||||
print(f"DB Host: {os.getenv('DB_HOST')}")
|
||||
print(f"DB Port: {os.getenv('DB_PORT')}")
|
||||
print(f"DB User: {os.getenv('DB_USER')}")
|
||||
print(f"DB Database: {os.getenv('DB_DATABASE')}")
|
||||
print(f"Force Refresh Metadata: {os.getenv('FORCE_REFRESH_METADATA', 'false')}")
|
||||
print("------------------------------")
|
||||
base_url = f"http://{args.host}:{args.port}"
|
||||
print(f"Service running at: {base_url}")
|
||||
print(f" Health Check: GET {base_url}/health")
|
||||
print(f" Status Check: GET {base_url}/status")
|
||||
print(f" SSE Init: GET {base_url}/sse")
|
||||
print(f" SSE/Legacy Messages: POST {base_url}/mcp/messages")
|
||||
print(f" Streamable HTTP: GET/POST/DELETE/OPTIONS {base_url}/mcp")
|
||||
print("------------------------------")
|
||||
print("Use Ctrl+C to stop the service")
|
||||
self.logger = logging.getLogger(f"{__name__}.DorisServer")
|
||||
self._setup_handlers()
|
||||
|
||||
# --- Start Uvicorn Server ---
|
||||
config = Config(
|
||||
app=app,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level="debug" if args.debug else "info",
|
||||
reload=args.reload
|
||||
)
|
||||
server = Server(config=config)
|
||||
await server.serve()
|
||||
def _setup_handlers(self):
|
||||
"""Setup MCP protocol handlers"""
|
||||
|
||||
# --- Main Execution Logic (Simplified) ---
|
||||
@self.server.list_resources()
|
||||
async def handle_list_resources() -> list[Resource]:
|
||||
"""Handle resource list request"""
|
||||
try:
|
||||
self.logger.info("Handling resource list request")
|
||||
resources = await self.resources_manager.list_resources()
|
||||
self.logger.info(f"Returning {len(resources)} resources")
|
||||
return resources
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to handle resource list request: {e}")
|
||||
return []
|
||||
|
||||
def run_main_sync():
|
||||
"""Synchronous wrapper, primarily for SSE mode now."""
|
||||
sync_logger = logging.getLogger("run_main_sync")
|
||||
sync_logger.info("Entering run_main_sync (SSE focus)...")
|
||||
print("DEBUG: Entering run_main_sync (SSE focus)...", file=sys.stderr, flush=True)
|
||||
args = parse_args()
|
||||
@self.server.read_resource()
|
||||
async def handle_read_resource(uri: str) -> str:
|
||||
"""Handle resource read request"""
|
||||
try:
|
||||
self.logger.info(f"Handling resource read request: {uri}")
|
||||
content = await self.resources_manager.read_resource(uri)
|
||||
return content
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to handle resource read request: {e}")
|
||||
return json.dumps(
|
||||
{"error": f"Failed to read resource: {str(e)}", "uri": uri},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
@self.server.list_tools()
|
||||
async def handle_list_tools() -> list[Tool]:
|
||||
"""Handle tool list request"""
|
||||
try:
|
||||
self.logger.info("Handling tool list request")
|
||||
tools = await self.tools_manager.list_tools()
|
||||
self.logger.info(f"Returning {len(tools)} tools")
|
||||
return tools
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to handle tool list request: {e}")
|
||||
return []
|
||||
|
||||
@self.server.call_tool()
|
||||
async def handle_call_tool(
|
||||
name: str, arguments: dict[str, Any]
|
||||
) -> list[TextContent]:
|
||||
"""Handle tool call request"""
|
||||
try:
|
||||
self.logger.info(f"Handling tool call request: {name}")
|
||||
result = await self.tools_manager.call_tool(name, arguments)
|
||||
|
||||
return [TextContent(type="text", text=result)]
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to handle tool call request: {e}")
|
||||
error_result = json.dumps(
|
||||
{
|
||||
"error": f"Tool call failed: {str(e)}",
|
||||
"tool_name": name,
|
||||
"arguments": arguments,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
return [TextContent(type="text", text=error_result)]
|
||||
|
||||
@self.server.list_prompts()
|
||||
async def handle_list_prompts() -> list[Prompt]:
|
||||
"""Handle prompt list request"""
|
||||
try:
|
||||
self.logger.info("Handling prompt list request")
|
||||
prompts = await self.prompts_manager.list_prompts()
|
||||
self.logger.info(f"Returning {len(prompts)} prompts")
|
||||
return prompts
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to handle prompt list request: {e}")
|
||||
return []
|
||||
|
||||
@self.server.get_prompt()
|
||||
async def handle_get_prompt(name: str, arguments: dict[str, Any]) -> str:
|
||||
"""Handle prompt get request"""
|
||||
try:
|
||||
self.logger.info(f"Handling prompt get request: {name}")
|
||||
result = await self.prompts_manager.get_prompt(name, arguments)
|
||||
return result
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to handle prompt get request: {e}")
|
||||
error_result = json.dumps(
|
||||
{
|
||||
"error": f"Failed to get prompt: {str(e)}",
|
||||
"prompt_name": name,
|
||||
"arguments": arguments,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
return error_result
|
||||
|
||||
async def start_stdio(self):
|
||||
"""Start stdio transport mode"""
|
||||
self.logger.info("Starting Doris MCP Server (stdio mode)")
|
||||
|
||||
if args.sse:
|
||||
try:
|
||||
# Run the async SSE server setup and Uvicorn loop
|
||||
asyncio.run(start_sse_server(args))
|
||||
sync_logger.info("asyncio.run(start_sse_server) completed.")
|
||||
print("DEBUG: asyncio.run(start_sse_server) completed.", file=sys.stderr, flush=True)
|
||||
except KeyboardInterrupt:
|
||||
sync_logger.info("SSE server stopped by KeyboardInterrupt.")
|
||||
# Ensure connection manager is initialized
|
||||
await self.connection_manager.initialize()
|
||||
self.logger.info("Connection manager initialization completed")
|
||||
|
||||
# Start stdio server - using simpler approach
|
||||
from mcp.server.stdio import stdio_server
|
||||
|
||||
self.logger.info("Creating stdio_server transport...")
|
||||
|
||||
# Try different startup approaches
|
||||
try:
|
||||
async with stdio_server() as streams:
|
||||
read_stream, write_stream = streams
|
||||
self.logger.info("stdio_server streams created successfully")
|
||||
|
||||
# Create initialization options
|
||||
# MCP 1.8.0 requires parameters for get_capabilities
|
||||
from mcp.server.lowlevel.server import NotificationOptions
|
||||
|
||||
capabilities = self.server.get_capabilities(
|
||||
notification_options=NotificationOptions(
|
||||
prompts_changed=True,
|
||||
resources_changed=True,
|
||||
tools_changed=True
|
||||
),
|
||||
experimental_capabilities={}
|
||||
)
|
||||
|
||||
init_options = InitializationOptions(
|
||||
server_name="doris-mcp-server",
|
||||
server_version="1.0.0",
|
||||
capabilities=capabilities,
|
||||
)
|
||||
self.logger.info("Initialization options created successfully")
|
||||
|
||||
# Run server
|
||||
self.logger.info("Starting to run MCP server...")
|
||||
await self.server.run(read_stream, write_stream, init_options)
|
||||
|
||||
except Exception as inner_e:
|
||||
self.logger.error(f"stdio_server internal error: {inner_e}")
|
||||
self.logger.error(f"Error type: {type(inner_e)}")
|
||||
|
||||
# Try to get more error information
|
||||
import traceback
|
||||
self.logger.error("Complete error stack:")
|
||||
self.logger.error(traceback.format_exc())
|
||||
|
||||
# If it's ExceptionGroup, try to parse
|
||||
if hasattr(inner_e, 'exceptions'):
|
||||
self.logger.error(f"ExceptionGroup contains {len(inner_e.exceptions)} exceptions:")
|
||||
for i, exc in enumerate(inner_e.exceptions):
|
||||
self.logger.error(f" Exception {i+1}: {type(exc).__name__}: {exc}")
|
||||
|
||||
raise inner_e
|
||||
|
||||
except Exception as e:
|
||||
sync_logger.critical(f"Error during asyncio.run(start_sse_server): {e}", exc_info=True)
|
||||
print(f"DEBUG: Error during asyncio.run(start_sse_server): {e}", file=sys.stderr, flush=True)
|
||||
self.logger.error(f"stdio server startup failed: {e}")
|
||||
self.logger.error(f"Error type: {type(e)}")
|
||||
raise
|
||||
else:
|
||||
# If run without --sse, print help/error
|
||||
message = "Error: This entry point requires --sse flag. For stdio mode, use 'uv run mcp-doris' or the appropriate command for your stdio setup."
|
||||
sync_logger.error(message)
|
||||
print(message, file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
|
||||
async def start_http(self, host: str = "localhost", port: int = 3000):
|
||||
"""Start Streamable HTTP transport mode"""
|
||||
self.logger.info(f"Starting Doris MCP Server (Streamable HTTP mode) - {host}:{port}")
|
||||
|
||||
try:
|
||||
# Ensure connection manager is initialized
|
||||
await self.connection_manager.initialize()
|
||||
|
||||
# Use Starlette and StreamableHTTPSessionManager according to official example
|
||||
import uvicorn
|
||||
import contextlib
|
||||
from collections.abc import AsyncIterator
|
||||
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
||||
from starlette.applications import Starlette
|
||||
from starlette.routing import Mount, Route
|
||||
from starlette.responses import JSONResponse, Response
|
||||
from starlette.types import Receive, Scope, Send
|
||||
|
||||
# Create session manager
|
||||
session_manager = StreamableHTTPSessionManager(
|
||||
app=self.server,
|
||||
json_response=True, # Enable JSON response
|
||||
stateless=False # Maintain session state
|
||||
)
|
||||
|
||||
self.logger.info(f"StreamableHTTP session manager created, will start at http://{host}:{port}")
|
||||
|
||||
# Health check endpoint
|
||||
async def health_check(request):
|
||||
return JSONResponse({"status": "healthy", "service": "doris-mcp-server"})
|
||||
|
||||
# Lifecycle manager - simplified since we manage session_manager externally
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(app: Starlette) -> AsyncIterator[None]:
|
||||
"""Context manager for managing application lifecycle"""
|
||||
self.logger.info("Application started!")
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self.logger.info("Application is shutting down...")
|
||||
|
||||
# Create ASGI application - use direct session manager as ASGI app
|
||||
starlette_app = Starlette(
|
||||
debug=True,
|
||||
routes=[
|
||||
Route("/health", health_check, methods=["GET"]),
|
||||
],
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Custom ASGI app that handles both /mcp and /mcp/ without redirects
|
||||
async def mcp_app(scope, receive, send):
|
||||
# Handle lifespan events
|
||||
if scope["type"] == "lifespan":
|
||||
await starlette_app(scope, receive, send)
|
||||
return
|
||||
|
||||
# Handle HTTP requests
|
||||
if scope["type"] == "http":
|
||||
path = scope.get("path", "")
|
||||
self.logger.info(f"Received request for path: {path}")
|
||||
|
||||
try:
|
||||
# Handle health check
|
||||
if path.startswith("/health"):
|
||||
await starlette_app(scope, receive, send)
|
||||
return
|
||||
|
||||
# Handle MCP requests - both /mcp and /mcp/ go to session manager
|
||||
if path == "/mcp" or path.startswith("/mcp/"):
|
||||
self.logger.info(f"Handling MCP request for path: {path}")
|
||||
# Log request details for debugging
|
||||
method = scope.get("method", "UNKNOWN")
|
||||
headers = dict(scope.get("headers", []))
|
||||
self.logger.info(f"MCP Request - Method: {method}")
|
||||
self.logger.info(f"MCP Request - Headers: {headers}")
|
||||
|
||||
# Handle Dify compatibility for GET requests
|
||||
if method == "GET":
|
||||
accept_header = headers.get(b'accept', b'').decode('utf-8')
|
||||
user_agent = headers.get(b'user-agent', b'').decode('utf-8')
|
||||
|
||||
|
||||
|
||||
# For other GET requests, try to add application/json to Accept header
|
||||
if 'text/event-stream' in accept_header and 'application/json' not in accept_header:
|
||||
self.logger.info("Adding application/json to Accept header for GET request")
|
||||
# Modify headers to include both content types
|
||||
new_headers = []
|
||||
for name, value in scope.get("headers", []):
|
||||
if name == b'accept':
|
||||
# Add application/json to the accept header
|
||||
new_value = value.decode('utf-8') + ', application/json'
|
||||
new_headers.append((name, new_value.encode('utf-8')))
|
||||
else:
|
||||
new_headers.append((name, value))
|
||||
# Update scope with modified headers
|
||||
scope = dict(scope)
|
||||
scope["headers"] = new_headers
|
||||
self.logger.info(f"Modified Accept header to: {new_value}")
|
||||
|
||||
await session_manager.handle_request(scope, receive, send)
|
||||
return
|
||||
|
||||
# 404 for other paths
|
||||
self.logger.info(f"Path not found: {path}")
|
||||
response = Response("Not Found", status_code=404)
|
||||
await response(scope, receive, send)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error handling request for {path}: {e}")
|
||||
import traceback
|
||||
self.logger.error(traceback.format_exc())
|
||||
response = Response("Internal Server Error", status_code=500)
|
||||
await response(scope, receive, send)
|
||||
else:
|
||||
# For other scope types, just return
|
||||
self.logger.warning(f"Unsupported scope type: {scope['type']}")
|
||||
return
|
||||
|
||||
# Start uvicorn server with session manager lifecycle
|
||||
config = uvicorn.Config(
|
||||
app=mcp_app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_level="info"
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
# Run session manager and server together
|
||||
async with session_manager.run():
|
||||
self.logger.info("Session manager started, now starting HTTP server")
|
||||
await server.serve()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Streamable HTTP server startup failed: {e}")
|
||||
import traceback
|
||||
self.logger.error("Complete error stack:")
|
||||
self.logger.error(traceback.format_exc())
|
||||
|
||||
# If it's ExceptionGroup, try to parse
|
||||
if hasattr(e, 'exceptions'):
|
||||
self.logger.error(f"ExceptionGroup contains {len(e.exceptions)} exceptions:")
|
||||
for i, exc in enumerate(e.exceptions):
|
||||
self.logger.error(f" Exception {i+1}: {type(exc).__name__}: {exc}")
|
||||
raise
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown server"""
|
||||
self.logger.info("Shutting down Doris MCP Server")
|
||||
try:
|
||||
await self.connection_manager.close()
|
||||
self.logger.info("Doris MCP Server has been shut down")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error occurred while shutting down server: {e}")
|
||||
|
||||
|
||||
def create_arg_parser():
|
||||
"""Create command line argument parser"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Apache Doris MCP Server - Enterprise Database Service",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Transport Modes:
|
||||
stdio - Standard input/output (for local process communication)
|
||||
http - Streamable HTTP mode (MCP 2025-03-26 protocol)
|
||||
|
||||
Examples:
|
||||
python -m doris_mcp_server --transport stdio
|
||||
python -m doris_mcp_server --transport http --host 0.0.0.0 --port 3000
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--transport",
|
||||
type=str,
|
||||
choices=["stdio", "http"],
|
||||
default="stdio",
|
||||
help="Transport protocol type: stdio (local), http (Streamable HTTP)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="Host address for HTTP mode (default: localhost)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--port", type=int, default=3000, help="Port number for HTTP mode (default: 3000)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--db-host",
|
||||
type=str,
|
||||
default="localhost",
|
||||
help="Doris database host address (default: localhost)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--db-port", type=int, default=9030, help="Doris database port number (default: 9030)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--db-user", type=str, default="root", help="Doris database username (default: root)"
|
||||
)
|
||||
|
||||
parser.add_argument("--db-password", type=str, default="", help="Doris database password")
|
||||
|
||||
parser.add_argument(
|
||||
"--db-database",
|
||||
type=str,
|
||||
default="information_schema",
|
||||
help="Doris database name (default: information_schema)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--log-level",
|
||||
type=str,
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||
default="INFO",
|
||||
help="Log level (default: INFO)",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main function"""
|
||||
parser = create_arg_parser()
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set log level
|
||||
logging.getLogger().setLevel(getattr(logging, args.log_level))
|
||||
|
||||
# Create configuration - priority: command line arguments > .env file > default values
|
||||
config = DorisConfig.from_env() # First load from .env file and environment variables
|
||||
|
||||
# Command line arguments override configuration (if provided)
|
||||
if args.db_host != "localhost": # If not default value, use command line argument
|
||||
config.database.host = args.db_host
|
||||
if args.db_port != 9030:
|
||||
config.database.port = args.db_port
|
||||
if args.db_user != "root":
|
||||
config.database.user = args.db_user
|
||||
if args.db_password: # Use password if provided
|
||||
config.database.password = args.db_password
|
||||
if args.db_database != "information_schema":
|
||||
config.database.database = args.db_database
|
||||
if args.log_level != "INFO":
|
||||
config.logging.level = args.log_level
|
||||
|
||||
# Create server instance
|
||||
server = DorisServer(config)
|
||||
|
||||
try:
|
||||
if args.transport == "stdio":
|
||||
await server.start_stdio()
|
||||
elif args.transport == "http":
|
||||
await server.start_http(args.host, args.port)
|
||||
else:
|
||||
logger.error(f"Unsupported transport protocol: {args.transport}")
|
||||
await server.shutdown()
|
||||
return 1
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received interrupt signal, shutting down server...")
|
||||
except Exception as e:
|
||||
logger.error(f"Server runtime error: {e}")
|
||||
# Clean up resources even in case of exception
|
||||
try:
|
||||
await server.shutdown()
|
||||
except Exception as shutdown_error:
|
||||
logger.error(f"Error occurred while shutting down server: {shutdown_error}")
|
||||
return 1
|
||||
finally:
|
||||
# Cleanup in case of normal shutdown
|
||||
try:
|
||||
await server.shutdown()
|
||||
except Exception as shutdown_error:
|
||||
logger.error(f"Error occurred while shutting down server: {shutdown_error}")
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def main_sync():
|
||||
"""Synchronous main function for entry point"""
|
||||
exit_code = asyncio.run(main())
|
||||
exit(exit_code)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_main_sync()
|
||||
main_sync()
|
||||
|
||||
@@ -1,159 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Core MCP instance and startup logic for stdio mode.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
|
||||
# Import necessary components from mcp and our project
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
logger = logging.getLogger("doris-mcp-core")
|
||||
|
||||
# --- Global MCP Instance for Stdio ---
|
||||
# Create the instance when the module is imported.
|
||||
# Tools will be registered synchronously(?) before running.
|
||||
stdio_mcp = FastMCP(
|
||||
name="doris-mcp-stdio-core",
|
||||
description="Apache Doris MCP Server (stdio via core)",
|
||||
)
|
||||
|
||||
# --- Removed async setup functions ---
|
||||
def run_stdio():
|
||||
"""
|
||||
Synchronous entry point for running the stdio server.
|
||||
Mimics the mcp-doris example by calling .run() on the instance.
|
||||
Handles tool registration beforehand.
|
||||
"""
|
||||
logger.info("Executing run_stdio (synchronous entry point)...")
|
||||
|
||||
# --- Run the stdio server using the instance's run() method ---
|
||||
logger.info("Calling stdio_mcp.run()...")
|
||||
try:
|
||||
# Assuming stdio_mcp has a synchronous run() method for stdio
|
||||
stdio_mcp.run()
|
||||
logger.info("stdio_mcp.run() completed.")
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Stdio server stopped by KeyboardInterrupt.")
|
||||
except AttributeError:
|
||||
logger.critical("Error: stdio_mcp object does not have a '.run()' method suitable for stdio.", exc_info=False)
|
||||
print("ERROR: stdio_mcp object does not have a '.run()' method.", file=sys.stderr, flush=True)
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logger.critical(f"run_stdio encountered an error during stdio_mcp.run(): {e}", exc_info=True)
|
||||
traceback.print_exc(file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Register Tool: Execute SQL Query
|
||||
@stdio_mcp.tool("exec_query", description="""[Function Description]: Execute SQL query and return result command with catalog federation support.\n
|
||||
[Parameter Content]:\n
|
||||
- sql (string) [Required] - SQL statement to execute. MUST use three-part naming for all table references: 'catalog_name.db_name.table_name'. For internal tables use 'internal.db_name.table_name', for external tables use 'catalog_name.db_name.table_name'\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n
|
||||
- catalog_name (string) [Optional] - Reference catalog name for context, defaults to current catalog\n
|
||||
- max_rows (integer) [Optional] - Maximum number of rows to return, default 100\n
|
||||
- timeout (integer) [Optional] - Query timeout in seconds, default 30\n""")
|
||||
async def exec_query_tool(sql: str, db_name: str = None, catalog_name: str = None, max_rows: int = 100, timeout: int = 30) -> Dict[str, Any]:
|
||||
"""Wrapper: Execute SQL query and return result command"""
|
||||
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_exec_query
|
||||
return await mcp_doris_exec_query(sql=sql, db_name=db_name, catalog_name=catalog_name, max_rows=max_rows, timeout=timeout)
|
||||
|
||||
# Register Tool: Get Table Schema
|
||||
@stdio_mcp.tool("get_table_schema", description="""[Function Description]: Get detailed structure information of the specified table (columns, types, comments, etc.).\n
|
||||
[Parameter Content]:\n
|
||||
- table_name (string) [Required] - Name of the table to query\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
|
||||
async def get_table_schema_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get table schema"""
|
||||
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_table_schema
|
||||
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
|
||||
return await mcp_doris_get_table_schema(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
|
||||
# Register Tool: Get Database Table List
|
||||
@stdio_mcp.tool("get_db_table_list", description="""[Function Description]: Get a list of all table names in the specified database.\n
|
||||
[Parameter Content]:\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
|
||||
async def get_db_table_list_tool(db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get database table list"""
|
||||
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_db_table_list
|
||||
return await mcp_doris_get_db_table_list(db_name=db_name, catalog_name=catalog_name)
|
||||
|
||||
# Register Tool: Get Database List
|
||||
@stdio_mcp.tool("get_db_list", description="""[Function Description]: Get a list of all database names on the server.\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
|
||||
async def get_db_list_tool(catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get database list"""
|
||||
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_db_list
|
||||
return await mcp_doris_get_db_list(catalog_name=catalog_name)
|
||||
|
||||
# Register Tool: Get Table Comment
|
||||
@stdio_mcp.tool("get_table_comment", description="""[Function Description]: Get the comment information for the specified table.\n
|
||||
[Parameter Content]:\n
|
||||
- table_name (string) [Required] - Name of the table to query\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
|
||||
async def get_table_comment_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get table comment"""
|
||||
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_table_comment
|
||||
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
|
||||
return await mcp_doris_get_table_comment(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
|
||||
# Register Tool: Get Table Column Comments
|
||||
@stdio_mcp.tool("get_table_column_comments", description="""[Function Description]: Get comment information for all columns in the specified table.\n
|
||||
[Parameter Content]:\n
|
||||
- table_name (string) [Required] - Name of the table to query\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
|
||||
async def get_table_column_comments_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get table column comments"""
|
||||
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_table_column_comments
|
||||
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
|
||||
return await mcp_doris_get_table_column_comments(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
|
||||
# Register Tool: Get Table Indexes
|
||||
@stdio_mcp.tool("get_table_indexes", description="""[Function Description]: Get index information for the specified table.
|
||||
[Parameter Content]:\n
|
||||
- table_name (string) [Required] - Name of the table to query\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
|
||||
async def get_table_indexes_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get table indexes"""
|
||||
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_table_indexes
|
||||
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
|
||||
return await mcp_doris_get_table_indexes(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
|
||||
# Register Tool: Get Recent Audit Logs
|
||||
@stdio_mcp.tool("get_recent_audit_logs", description="""[Function Description]: Get audit log records for a recent period.\n
|
||||
[Parameter Content]:\n
|
||||
- days (integer) [Optional] - Number of recent days of logs to retrieve, default is 7\n
|
||||
- limit (integer) [Optional] - Maximum number of records to return, default is 100\n""")
|
||||
async def get_recent_audit_logs_tool(days: int = 7, limit: int = 100) -> Dict[str, Any]:
|
||||
"""Wrapper: Get recent audit logs"""
|
||||
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_recent_audit_logs
|
||||
try:
|
||||
days = int(days)
|
||||
limit = int(limit)
|
||||
except (ValueError, TypeError):
|
||||
return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "days and limit parameters must be integers"})}]}
|
||||
return await mcp_doris_get_recent_audit_logs(days=days, limit=limit)
|
||||
|
||||
# Register Tool: Get Catalog List
|
||||
@stdio_mcp.tool("get_catalog_list", description="""[Function Description]: Get a list of all catalog names on the server.\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n""")
|
||||
async def get_catalog_list_tool() -> Dict[str, Any]:
|
||||
"""Wrapper: Get catalog list"""
|
||||
from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_catalog_list
|
||||
return await mcp_doris_get_catalog_list()
|
||||
|
||||
# --- Register Tools ---
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,912 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Doris MCP Streamable HTTP Server Implementation
|
||||
|
||||
Implements the MCP 2025-03-26 Streamable HTTP specification.
|
||||
Uses a unified /mcp endpoint for GET, POST, DELETE, OPTIONS.
|
||||
Manages sessions using Mcp-Session-Id header.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Optional, Dict, List
|
||||
from fastapi import FastAPI, Request, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
# Use a distinct logger name
|
||||
logger = logging.getLogger("doris-mcp-streamable")
|
||||
|
||||
# Special marker for closing streams
|
||||
STREAM_END_MARKER = "__MCP_STREAM_END__"
|
||||
|
||||
class DorisMCPStreamableServer:
|
||||
"""Doris MCP Streamable HTTP Server"""
|
||||
|
||||
def __init__(self, mcp_server, app: FastAPI):
|
||||
"""
|
||||
Initializes the Doris MCP Streamable HTTP server.
|
||||
|
||||
Args:
|
||||
mcp_server: The shared FastMCP server instance.
|
||||
app: The main FastAPI application instance.
|
||||
"""
|
||||
self.mcp_server = mcp_server
|
||||
self.app = app # We'll add routes to this app
|
||||
|
||||
# Note: CORS middleware should be added only once in main.py usually.
|
||||
# If added here, ensure it doesn't conflict or duplicate.
|
||||
# For separation, we might let main.py handle CORS entirely.
|
||||
|
||||
# Client session management for Streamable HTTP clients
|
||||
# key: session_id (from Mcp-Session-Id header)
|
||||
# value: {
|
||||
# "created_at": timestamp,
|
||||
# "last_active": timestamp,
|
||||
# "request_queues": { request_id: asyncio.Queue }, # For POST /mcp request streams
|
||||
# "general_sse_queues": List[asyncio.Queue] # For GET /mcp server push streams
|
||||
# }
|
||||
self.client_sessions: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# Setup the unified MCP endpoint
|
||||
self._setup_streamable_http_routes()
|
||||
|
||||
# Register session cleanup task if this instance manages lifespan independently
|
||||
# Usually, startup events are tied to the main app lifespan managed in main.py
|
||||
# We might not need @app.on_event("startup") here if main.py handles it.
|
||||
# Let's assume main.py handles the cleanup task initiation.
|
||||
|
||||
def _setup_streamable_http_routes(self):
|
||||
"""Sets up the unified /mcp endpoint for Streamable HTTP.
|
||||
Uses a distinct tag for API docs.
|
||||
"""
|
||||
|
||||
@self.app.api_route("/mcp", methods=["GET", "POST", "DELETE", "OPTIONS"], tags=["Streamable HTTP"])
|
||||
async def mcp_endpoint_handler(request: Request):
|
||||
"""Handles GET, POST, DELETE, OPTIONS for the /mcp endpoint."""
|
||||
|
||||
# 1. Handle OPTIONS (CORS preflight)
|
||||
if request.method == "OPTIONS":
|
||||
# Assuming CORS headers are handled by middleware in main.py
|
||||
# If not, provide necessary headers here.
|
||||
# This minimal response might suffice if middleware handles the rest
|
||||
logger.debug("Handling OPTIONS request for /mcp")
|
||||
# Return basic OK allowing exposed headers if middleware handles the rest
|
||||
return JSONResponse({}, headers={"Access-Control-Expose-Headers": "Mcp-Session-Id"})
|
||||
|
||||
# Session ID from header is required for most methods
|
||||
session_id = request.headers.get("Mcp-Session-Id")
|
||||
|
||||
# 2. Handle DELETE (Terminate Session)
|
||||
if request.method == "DELETE":
|
||||
if not session_id:
|
||||
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32600, "message": "Mcp-Session-Id header is required for DELETE"}}, status_code=400)
|
||||
|
||||
logger.info(f"Handling DELETE request for session [Session ID: {session_id}]")
|
||||
session_data = self.client_sessions.pop(session_id, None)
|
||||
if session_data:
|
||||
await self._cleanup_session_resources(session_id, session_data)
|
||||
return JSONResponse({}, status_code=204) # No Content
|
||||
else:
|
||||
logger.warning(f"Attempted DELETE on non-existent session: {session_id}")
|
||||
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32001, "message": "Session not found"}}, status_code=404)
|
||||
|
||||
# 3. Handle GET (Server Push SSE Stream)
|
||||
if request.method == "GET":
|
||||
if not session_id:
|
||||
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32000, "message": "Mcp-Session-Id header is required for GET streams"}}, status_code=400)
|
||||
if session_id not in self.client_sessions:
|
||||
# Note: Unlike legacy SSE, GET here assumes session exists.
|
||||
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32001, "message": "Session not found. Initialize first."}}, status_code=404)
|
||||
|
||||
accept_header = request.headers.get("Accept", "")
|
||||
if "text/event-stream" not in accept_header:
|
||||
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32600, "message": "Accept header must include text/event-stream for GET"}}, status_code=406)
|
||||
|
||||
# TODO: Handle Last-Event-ID for stream recovery?
|
||||
|
||||
logger.info(f"Handling GET request, establishing server push SSE stream [Session ID: {session_id}]")
|
||||
|
||||
push_queue = asyncio.Queue()
|
||||
if self.client_sessions[session_id].get("general_sse_queues") is None:
|
||||
self.client_sessions[session_id]["general_sse_queues"] = []
|
||||
self.client_sessions[session_id]["general_sse_queues"].append(push_queue)
|
||||
self.client_sessions[session_id]["last_active"] = time.time()
|
||||
|
||||
return EventSourceResponse(self._create_general_sse_generator(session_id, push_queue), media_type="text/event-stream")
|
||||
|
||||
# 4. Handle POST (Client Messages & Initialize)
|
||||
if request.method == "POST":
|
||||
accept_header = request.headers.get("Accept", "")
|
||||
content_type = request.headers.get("Content-Type", "")
|
||||
|
||||
body = {}
|
||||
try:
|
||||
if "application/json" not in content_type:
|
||||
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32700, "message": "Content-Type must be application/json"}}, status_code=415)
|
||||
body = await request.json()
|
||||
if isinstance(body, list): return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32600, "message": "Batch requests not supported"}}, status_code=400)
|
||||
if not isinstance(body, dict): return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32700, "message": "Invalid JSON received"}}, status_code=400)
|
||||
|
||||
method = body.get("method")
|
||||
message_id = body.get("id") # Can be None for notifications
|
||||
|
||||
# Handle Initialize request (does not require Mcp-Session-Id header)
|
||||
if method == "initialize":
|
||||
if "application/json" not in accept_header:
|
||||
return JSONResponse({"jsonrpc": "2.0", "id": message_id, "error": {"code": -32600, "message": "Accept header must include application/json for initialize"}}, status_code=406)
|
||||
return await self._handle_initialize(request, body, message_id)
|
||||
|
||||
# Handle other POST requests (require Mcp-Session-Id)
|
||||
else:
|
||||
if not session_id:
|
||||
return JSONResponse({"jsonrpc": "2.0", "id": message_id, "error": {"code": -32000, "message": "Mcp-Session-Id header is required for this request"}}, status_code=400)
|
||||
if session_id not in self.client_sessions:
|
||||
return JSONResponse({"jsonrpc": "2.0", "id": message_id, "error": {"code": -32001, "message": "Session not found"}}, status_code=404)
|
||||
# Check Accept header for non-initialize POST
|
||||
if not ("application/json" in accept_header and "text/event-stream" in accept_header):
|
||||
return JSONResponse({"jsonrpc": "2.0", "id": message_id, "error": {"code": -32600, "message": "Accept header must include application/json and text/event-stream for POST"}}, status_code=406)
|
||||
|
||||
self.client_sessions[session_id]["last_active"] = time.time()
|
||||
return await self._handle_client_post(request, body, session_id, message_id)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
return JSONResponse({"jsonrpc": "2.0", "error": {"code": -32700, "message": "Parse error - Invalid JSON received"}}, status_code=400)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error handling POST /mcp: {str(e)}", exc_info=True)
|
||||
error_id = body.get("id") if isinstance(body, dict) else None
|
||||
return JSONResponse({"jsonrpc": "2.0", "id": error_id, "error": {"code": -32000, "message": "Internal server error"}}, status_code=500)
|
||||
|
||||
# Fallback for other methods like PUT, PATCH etc.
|
||||
return JSONResponse({"error": "Method Not Allowed"}, status_code=405)
|
||||
|
||||
async def _handle_initialize(self, request: Request, body: Dict, message_id: Any):
|
||||
"""Handles the 'initialize' method call via POST /mcp."""
|
||||
logger.info("Handling Streamable HTTP initialize request")
|
||||
# Optional: Validate params in body if needed
|
||||
# params = body.get("params", {})
|
||||
|
||||
new_session_id = str(uuid.uuid4())
|
||||
logger.info(f"Created new Streamable HTTP session [Session ID: {new_session_id}]")
|
||||
|
||||
self.client_sessions[new_session_id] = {
|
||||
"created_at": time.time(),
|
||||
"last_active": time.time(),
|
||||
# No transport_type needed here as this class *is* the streamable server
|
||||
"request_queues": {}, # Initialize request queues dict
|
||||
"general_sse_queues": [] # Initialize general queues list
|
||||
}
|
||||
|
||||
# Build InitializeResult based on spec
|
||||
initialize_result = {
|
||||
"protocolVersion": "2025-03-26",
|
||||
"name": self.mcp_server.name,
|
||||
"instructions": "Apache Doris MCP Server (Streamable HTTP Mode)",
|
||||
"serverInfo": { "version": "0.2.0", "name": "Doris MCP Streamable Server" }, # Adjust as needed
|
||||
"capabilities": {
|
||||
"tools": { "supportsStreaming": True, "supportsProgress": True },
|
||||
"resources": { "supportsStreaming": False }, # Example capability
|
||||
"prompts": { "supported": True }, # Example capability
|
||||
"session": { "supported": True }
|
||||
}
|
||||
}
|
||||
response_body = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": message_id,
|
||||
"result": initialize_result
|
||||
}
|
||||
|
||||
# Return JSON response with Mcp-Session-Id header
|
||||
return JSONResponse(
|
||||
content=response_body,
|
||||
media_type="application/json",
|
||||
headers={"Mcp-Session-Id": new_session_id}
|
||||
)
|
||||
|
||||
async def _handle_client_post(self, request: Request, body: Dict, session_id: str, message_id: Any):
|
||||
"""Handles non-initialize POST requests (notifications, responses, method calls)."""
|
||||
method = body.get("method")
|
||||
|
||||
# Handle Notifications/Responses from client
|
||||
is_notification = "method" in body and "id" not in body
|
||||
is_response = "result" in body or "error" in body
|
||||
if is_notification or is_response:
|
||||
logger.info(f"Received Streamable HTTP notification/response [Session ID: {session_id}] - Processing needed? (Ignoring for now)")
|
||||
# TODO: If the server sends requests that expect responses, process is_response here.
|
||||
# For now, just acknowledge client notifications/responses.
|
||||
return JSONResponse({}, status_code=202) # Accepted
|
||||
|
||||
# Handle Requests from client (method call)
|
||||
if "method" in body and "id" in body:
|
||||
logger.info(f"Received Streamable HTTP request [Session ID: {session_id}, ID: {message_id}, Method: {method}]")
|
||||
params = body.get("params", {})
|
||||
stream_required = params.get("stream", False) if method in ["tools/call", "mcp/callTool"] else False
|
||||
|
||||
if stream_required:
|
||||
# --- Return SSE stream for response parts ---
|
||||
logger.info(f"Using SSE stream for request [Session ID: {session_id}, ID: {message_id}]")
|
||||
response_queue = asyncio.Queue()
|
||||
# Ensure request_queues exists (should have been created during initialize)
|
||||
if self.client_sessions[session_id].get("request_queues") is None:
|
||||
logger.error(f"Session {session_id} is missing 'request_queues' dictionary!")
|
||||
# Handle this inconsistency, maybe return an error
|
||||
return JSONResponse({"jsonrpc": "2.0", "id": message_id, "error": {"code": -32000, "message": "Internal server error: Session state inconsistent"}}, status_code=500)
|
||||
self.client_sessions[session_id]["request_queues"][message_id] = response_queue
|
||||
|
||||
# Start background task to process and put results in the queue
|
||||
asyncio.create_task(self._process_request_and_respond(
|
||||
request, body, session_id, message_id, response_queue, is_stream=True
|
||||
))
|
||||
|
||||
# Return EventSourceResponse using the request-specific queue
|
||||
return EventSourceResponse(self._create_request_sse_generator(session_id, message_id, response_queue), media_type="text/event-stream")
|
||||
else:
|
||||
# --- Return single JSON response ---
|
||||
logger.info(f"Using JSON response for request [Session ID: {session_id}, ID: {message_id}]")
|
||||
try:
|
||||
# Process the request directly and get the result/error payload
|
||||
result_or_error_payload = await self._process_request_and_respond(
|
||||
request, body, session_id, message_id, None, is_stream=False
|
||||
)
|
||||
# This function now returns the final JSON body or raises HTTPException
|
||||
return JSONResponse(content=result_or_error_payload, media_type="application/json")
|
||||
except HTTPException as http_exc:
|
||||
# Format HTTPException details into JSON-RPC error
|
||||
return JSONResponse(
|
||||
{"jsonrpc": "2.0", "id": message_id, "error": {"code": -32000, "message": http_exc.detail}},
|
||||
status_code=http_exc.status_code
|
||||
)
|
||||
except Exception as e:
|
||||
# Catch unexpected errors during synchronous processing
|
||||
logger.error(f"Error processing non-stream request [Session ID: {session_id}, ID: {message_id}]: {str(e)}", exc_info=True)
|
||||
error_response = {"jsonrpc": "2.0", "id": message_id, "error": {"code": -32000, "message": f"Internal server error: {str(e)}"}}
|
||||
return JSONResponse(content=error_response, status_code=500)
|
||||
else:
|
||||
# Invalid JSON-RPC format (e.g., missing method or id for a request)
|
||||
return JSONResponse({"jsonrpc": "2.0", "id": message_id, "error": {"code": -32600, "message": "Invalid JSON-RPC request format"}}, status_code=400)
|
||||
|
||||
# === Generator Functions for SSE Streams ===
|
||||
|
||||
async def _create_general_sse_generator(self, session_id: str, queue: asyncio.Queue):
|
||||
"""Generator for GET /mcp server push streams."""
|
||||
queue_removed = False
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
if session_id not in self.client_sessions:
|
||||
logger.warning(f"General SSE stream generator: Session {session_id} closed.")
|
||||
break
|
||||
|
||||
message = await asyncio.wait_for(queue.get(), timeout=60.0)
|
||||
|
||||
if message == STREAM_END_MARKER:
|
||||
logger.debug(f"General SSE stream received end marker [Session ID: {session_id}]")
|
||||
break
|
||||
|
||||
if isinstance(message, dict) and ("result" in message or "error" in message) and "id" in message:
|
||||
logger.warning(f"Attempted to send response on GET stream, blocked [Session ID: {session_id}]: {message}")
|
||||
queue.task_done()
|
||||
continue
|
||||
|
||||
# TODO: Event ID for recovery?
|
||||
yield {"event": "message", "data": json.dumps(message)}
|
||||
queue.task_done()
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
if session_id not in self.client_sessions:
|
||||
logger.warning(f"General SSE stream generator (timeout): Session {session_id} closed.")
|
||||
break
|
||||
yield {"event": "ping", "data": "keepalive"}
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"General SSE stream cancelled [Session ID: {session_id}]")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"General SSE stream error [Session ID: {session_id}]: {str(e)}", exc_info=True)
|
||||
break
|
||||
finally:
|
||||
logger.info(f"General SSE stream ended [Session ID: {session_id}]")
|
||||
if not queue_removed and session_id in self.client_sessions:
|
||||
session = self.client_sessions[session_id]
|
||||
if session.get("general_sse_queues") is not None:
|
||||
try:
|
||||
session["general_sse_queues"].remove(queue)
|
||||
queue_removed = True
|
||||
logger.debug(f"General SSE queue removed from session [Session ID: {session_id}]")
|
||||
except ValueError:
|
||||
logger.warning(f"Failed to remove general SSE queue (not found) [Session ID: {session_id}]")
|
||||
except Exception as ce:
|
||||
logger.error(f"Error removing general SSE queue [Session ID: {session_id}]: {ce}")
|
||||
while not queue.empty():
|
||||
try: queue.get_nowait(); queue.task_done()
|
||||
except asyncio.QueueEmpty: break
|
||||
|
||||
async def _create_request_sse_generator(self, session_id: str, request_id: Any, queue: asyncio.Queue):
|
||||
"""Generator for POST /mcp request-response streams."""
|
||||
queue_removed = False
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
if session_id not in self.client_sessions or \
|
||||
request_id not in self.client_sessions.get(session_id, {}).get("request_queues", {}):
|
||||
logger.warning(f"Request SSE stream generator: Session/Request queue closed [Session ID: {session_id}, Request ID: {request_id}]")
|
||||
break
|
||||
|
||||
message = await asyncio.wait_for(queue.get(), timeout=120.0) # Longer timeout for requests?
|
||||
|
||||
if message == STREAM_END_MARKER:
|
||||
logger.debug(f"Request SSE stream received end marker [Session ID: {session_id}, Request ID: {request_id}]")
|
||||
break
|
||||
|
||||
# TODO: Event ID for parts?
|
||||
yield {"event": "message", "data": json.dumps(message)}
|
||||
queue.task_done()
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
if session_id not in self.client_sessions or \
|
||||
request_id not in self.client_sessions.get(session_id, {}).get("request_queues", {}):
|
||||
logger.warning(f"Request SSE stream generator (timeout): Session/Request queue closed [Session ID: {session_id}, Request ID: {request_id}]")
|
||||
break
|
||||
logger.debug(f"Request SSE stream timed out waiting for message/end [Session ID: {session_id}, Request ID: {request_id}]")
|
||||
# Unlike general stream, timeout here might indicate an issue or just long processing.
|
||||
# Continue waiting for the STREAM_END_MARKER.
|
||||
continue
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"Request SSE stream cancelled [Session ID: {session_id}, Request ID: {request_id}]")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Request SSE stream error [Session ID: {session_id}, Request ID: {request_id}]: {str(e)}", exc_info=True)
|
||||
break
|
||||
finally:
|
||||
logger.info(f"Request SSE stream ended [Session ID: {session_id}, Request ID: {request_id}]")
|
||||
if not queue_removed and session_id in self.client_sessions:
|
||||
session = self.client_sessions[session_id]
|
||||
if session.get("request_queues") is not None:
|
||||
if session["request_queues"].pop(request_id, None):
|
||||
queue_removed = True
|
||||
logger.debug(f"Request SSE queue removed from session [Session ID: {session_id}, Request ID: {request_id}]")
|
||||
else:
|
||||
logger.warning(f"Failed to remove request SSE queue (not found) [Session ID: {session_id}, Request ID: {request_id}]")
|
||||
while not queue.empty():
|
||||
try: queue.get_nowait(); queue.task_done()
|
||||
except asyncio.QueueEmpty: break
|
||||
|
||||
# === Core Request Processing Logic ===
|
||||
|
||||
async def _process_request_and_respond(
|
||||
self, request: Request, body: Dict, session_id: str, message_id: Any,
|
||||
response_queue: Optional[asyncio.Queue], # Queue ONLY for streaming responses
|
||||
is_stream: bool # True if response should go via SSE queue
|
||||
):
|
||||
"""Processes client method calls and prepares response/error payload or sends to queue.
|
||||
Returns payload for non-streaming, returns None for streaming (uses queue).
|
||||
Raises HTTPException for non-streaming errors that need specific status codes.
|
||||
"""
|
||||
logger.info(f"Entering _process_request_and_respond for method '{body.get('method')}'...")
|
||||
method = body.get("method")
|
||||
params = body.get("params", {})
|
||||
response_payload = None # Holds the 'result' or 'error' part of JSON-RPC
|
||||
|
||||
try:
|
||||
# --- Handle Method Calls ---
|
||||
if method == "mcp/listOfferings":
|
||||
tools = await self.mcp_server.list_tools()
|
||||
tools_json = self._format_tools(tools)
|
||||
resources = await self.mcp_server.list_resources()
|
||||
resources_json = self._format_resources(resources)
|
||||
prompts = await self.mcp_server.list_prompts()
|
||||
prompts_json = self._format_prompts(prompts)
|
||||
response_payload = {"tools": tools_json, "resources": resources_json, "prompts": prompts_json}
|
||||
|
||||
elif method == "mcp/listTools" or method == "tools/list":
|
||||
tools = await self.mcp_server.list_tools()
|
||||
response_payload = {"tools": self._format_tools(tools)}
|
||||
|
||||
elif method == "mcp/listResources":
|
||||
resources = await self.mcp_server.list_resources()
|
||||
response_payload = {"resources": self._format_resources(resources)}
|
||||
|
||||
elif method == "mcp/listPrompts":
|
||||
prompts = await self.mcp_server.list_prompts()
|
||||
response_payload = {"prompts": self._format_prompts(prompts)}
|
||||
|
||||
elif method == "mcp/callTool" or method == "tools/call":
|
||||
tool_name = params.get("name")
|
||||
arguments = params.get("arguments", {})
|
||||
if not tool_name:
|
||||
# For non-streaming, raise HTTPException; for streaming, send error via queue
|
||||
error_detail = "Invalid params: tool name ('name') is required"
|
||||
if is_stream and response_queue:
|
||||
error_resp = {"jsonrpc": "2.0", "id": message_id, "error": {"code": -32602, "message": error_detail}}
|
||||
await response_queue.put(error_resp)
|
||||
# No return here for stream, let finally handle end marker
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=error_detail)
|
||||
return # Exit after handling error
|
||||
|
||||
# --- Tool Calling ---
|
||||
if is_stream and response_queue:
|
||||
# Background task handles putting results/errors in queue
|
||||
logger.info(f"Launching stream tool task [Session: {session_id}, Req: {message_id}, Tool: {tool_name}]")
|
||||
asyncio.create_task(self._execute_stream_tool_wrapper(
|
||||
tool_name, arguments, message_id, session_id, request, response_queue
|
||||
))
|
||||
# Returns None, caller (_handle_client_post) returns EventSourceResponse
|
||||
return
|
||||
else:
|
||||
# Execute tool directly for non-streaming response
|
||||
logger.info(f"Executing non-stream tool [Session: {session_id}, Req: {message_id}, Tool: {tool_name}]")
|
||||
# Note: call_tool now raises ValueError on internal errors
|
||||
result = await self.call_tool(tool_name, arguments, request, None) # No callback needed
|
||||
logger.debug(f"Raw result from non-stream call_tool: {result}")
|
||||
response_payload = self._format_tool_call_result(result)
|
||||
else:
|
||||
# Method not found
|
||||
error_detail = f"Method not found: {method}"
|
||||
if is_stream and response_queue:
|
||||
error_resp = {"jsonrpc": "2.0", "id": message_id, "error": {"code": -32601, "message": error_detail}}
|
||||
await response_queue.put(error_resp)
|
||||
else:
|
||||
raise HTTPException(status_code=405, detail=error_detail)
|
||||
return # Exit after handling error
|
||||
|
||||
# --- Prepare final response payload (only if not streaming and successful) ---
|
||||
if response_payload is not None:
|
||||
final_response = {"jsonrpc": "2.0", "id": message_id, "result": response_payload}
|
||||
if is_stream and response_queue: # Should not happen if response_payload is set
|
||||
logger.error("Logic error: response_payload set for streaming call?")
|
||||
await response_queue.put(final_response) # Send anyway?
|
||||
elif not is_stream:
|
||||
logger.debug(f"Returning successful non-stream payload for {method}")
|
||||
return final_response # Return dict for JSONResponse
|
||||
|
||||
except Exception as e:
|
||||
# Handles errors raised by call_tool (ValueError) or other unexpected issues
|
||||
logger.error(f"Error processing request [Session: {session_id}, Req: {message_id}, Method: {method}]: {str(e)}", exc_info=True)
|
||||
error_code = -32000
|
||||
error_message = f"Internal server error: {str(e)}"
|
||||
status_code = 500 # Default for unexpected errors
|
||||
|
||||
if isinstance(e, HTTPException):
|
||||
# If it was an HTTPException raised earlier (e.g., 400, 405)
|
||||
error_message = e.detail
|
||||
status_code = e.status_code
|
||||
error_code = -32000 # Keep generic JSON-RPC code for now
|
||||
elif isinstance(e, ValueError):
|
||||
# Errors from call_tool (tool not found, execution error)
|
||||
error_message = str(e)
|
||||
status_code = 500 # Treat tool execution errors as internal server errors
|
||||
error_code = -32000 # Or a custom tool error code?
|
||||
|
||||
error_response_payload = {"code": error_code, "message": error_message}
|
||||
|
||||
if is_stream and response_queue:
|
||||
# Send error via queue for streaming calls
|
||||
final_error_response = {"jsonrpc": "2.0", "id": message_id, "error": error_response_payload}
|
||||
logger.debug(f"Putting error response into stream queue [Session: {session_id}, Req: {message_id}]")
|
||||
await response_queue.put(final_error_response)
|
||||
# Returns None, let finally send end marker
|
||||
return
|
||||
else:
|
||||
# For non-streaming, raise HTTPException to set status code
|
||||
logger.debug(f"Raising HTTPException for non-stream error (Status: {status_code})")
|
||||
raise HTTPException(status_code=status_code, detail=error_message)
|
||||
|
||||
finally:
|
||||
# If this was a streaming call, ensure the end marker is sent.
|
||||
# This runs even if the processing returns early (e.g., after launching task or handling error).
|
||||
if is_stream and response_queue:
|
||||
logger.debug(f"Putting stream end marker [Session: {session_id}, Req: {message_id}]")
|
||||
await response_queue.put(STREAM_END_MARKER)
|
||||
|
||||
|
||||
async def _execute_stream_tool_wrapper(
|
||||
self, tool_name: str, arguments: Dict, message_id: Any, session_id: str,
|
||||
request: Request, response_queue: asyncio.Queue
|
||||
):
|
||||
"""Wraps stream-capable tool calls, handles callback, puts results/errors into queue."""
|
||||
logger.info(f"Entering _execute_stream_tool_wrapper for tool '{tool_name}'...")
|
||||
try:
|
||||
logger.debug(f"Executing stream tool wrapper [Session: {session_id}, Req: {message_id}, Tool: {tool_name}]")
|
||||
|
||||
async def stream_callback(content, metadata=None):
|
||||
logger.debug(f"Stream callback received content [Session: {session_id}, Req: {message_id}]")
|
||||
partial_result_formatted = self._format_tool_call_result(content)
|
||||
|
||||
# Check session/queue validity before putting
|
||||
if session_id not in self.client_sessions or \
|
||||
message_id not in self.client_sessions.get(session_id, {}).get("request_queues", {}):
|
||||
logger.warning(f"Stream callback: Session/Queue closed, cannot send partial result [Session: {session_id}, Req: {message_id}]")
|
||||
return
|
||||
|
||||
# Send progress notification
|
||||
progress_notification = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/progress",
|
||||
"params": {
|
||||
"requestId": message_id,
|
||||
"toolName": tool_name,
|
||||
"progress": partial_result_formatted,
|
||||
}
|
||||
}
|
||||
try:
|
||||
await response_queue.put(progress_notification)
|
||||
except Exception as e:
|
||||
logger.error(f"Stream callback failed to send progress: {str(e)}")
|
||||
|
||||
# Handle visualization data
|
||||
if metadata and "visualization" in metadata:
|
||||
await self.send_visualization_data(session_id, message_id, metadata["visualization"])
|
||||
|
||||
# --- Call Tool ---
|
||||
kwargs = dict(arguments)
|
||||
# Simplification: Assume tool supports callback if streaming requested
|
||||
kwargs['callback'] = stream_callback
|
||||
|
||||
# call_tool handles its own internal errors and raises ValueError
|
||||
result = await self.call_tool(tool_name, kwargs, request, stream_callback)
|
||||
logger.debug(f"Stream wrapper received final result from call_tool: {result}")
|
||||
|
||||
# --- Send Final Result ---
|
||||
if session_id not in self.client_sessions or \
|
||||
message_id not in self.client_sessions.get(session_id, {}).get("request_queues", {}):
|
||||
logger.warning(f"Stream tool finished but Session/Queue closed [Session: {session_id}, Req: {message_id}]")
|
||||
return # Cannot send final result
|
||||
|
||||
final_result_formatted = self._format_tool_call_result(result)
|
||||
final_message = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": message_id,
|
||||
"result": final_result_formatted
|
||||
}
|
||||
logger.debug(f"Putting final stream result into queue [Session: {session_id}, Req: {message_id}]")
|
||||
await response_queue.put(final_message)
|
||||
logger.info(f"Stream tool execution successful [Session: {session_id}, Req: {message_id}]")
|
||||
|
||||
except Exception as e:
|
||||
# Catches errors from call_tool (ValueError) or other wrapper issues
|
||||
logger.error(f"Error during stream tool execution wrapper [Session: {session_id}, Req: {message_id}]: {str(e)}", exc_info=True)
|
||||
# Check session/queue validity before sending error
|
||||
if session_id not in self.client_sessions or \
|
||||
message_id not in self.client_sessions.get(session_id, {}).get("request_queues", {}):
|
||||
logger.warning(f"Stream tool failed but Session/Queue closed [Session: {session_id}, Req: {message_id}]")
|
||||
return # Cannot send error
|
||||
|
||||
error_code = -32000
|
||||
error_message = f"Tool execution error: {str(e)}"
|
||||
if isinstance(e, ValueError):
|
||||
error_code = -32602 # Or -32000?
|
||||
error_message = str(e)
|
||||
|
||||
error_response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": message_id,
|
||||
"error": { "code": error_code, "message": error_message }
|
||||
}
|
||||
try:
|
||||
await response_queue.put(error_response)
|
||||
except Exception as qe:
|
||||
logger.error(f"Failed to put error response into stream queue: {qe}")
|
||||
# No finally block needed here, handled by _process_request_and_respond
|
||||
|
||||
|
||||
async def call_tool(self, tool_name, arguments, request, callback: Optional[callable] = None):
|
||||
"""Finds and executes the target tool function/method.
|
||||
Raises ValueError on tool not found or execution error.
|
||||
"""
|
||||
logger.info(f"Entering call_tool for tool '{tool_name}'...")
|
||||
# Log args excluding callback
|
||||
log_args = {k: v for k, v in arguments.items() if k != 'callback'}
|
||||
logger.info(f"Executing tool: {tool_name}, Args: {json.dumps(log_args, ensure_ascii=False, default=str)}")
|
||||
|
||||
recent_query = self._extract_recent_query(request)
|
||||
# Tool mapping might be needed if client uses different names
|
||||
tool_mapping = {
|
||||
# Example: "clientFacingName": "internalFunctionName"
|
||||
"status": "mcp_doris_status",
|
||||
"health": "mcp_doris_health",
|
||||
# Add other mappings if needed, ensure consistency with tool_initializer
|
||||
"nl2sql_query": "mcp_doris_nl2sql_query",
|
||||
"nl2sql_query_stream": "mcp_doris_nl2sql_query_stream",
|
||||
"list_database_tables": "mcp_doris_list_database_tables",
|
||||
"explain_table": "mcp_doris_explain_table",
|
||||
"get_nl2sql_status": "mcp_doris_get_nl2sql_status",
|
||||
"refresh_metadata": "mcp_doris_refresh_metadata",
|
||||
"sql_optimize": "mcp_doris_sql_optimize",
|
||||
"fix_sql": "mcp_doris_fix_sql",
|
||||
"count_chars": "mcp_doris_count_chars",
|
||||
"exec_query": "mcp_doris_exec_query",
|
||||
"get_schema_list": "mcp_doris_get_schema_list", # Deprecated?
|
||||
"save_metadata": "mcp_doris_save_metadata", # Likely internal
|
||||
"get_metadata": "mcp_doris_get_metadata", # Likely internal
|
||||
"analyze_query_result": "mcp_doris_analyze_query_result", # Internal?
|
||||
"generate_sql": "mcp_doris_generate_sql", # Likely internal
|
||||
"explain_sql": "mcp_doris_explain_sql", # Internal?
|
||||
"modify_sql": "mcp_doris_modify_sql", # Internal?
|
||||
"parse_query": "mcp_doris_parse_query", # Internal?
|
||||
"identify_query_type": "mcp_doris_identify_query_type", # Internal?
|
||||
"validate_sql_syntax": "mcp_doris_validate_sql_syntax", # Internal?
|
||||
"check_sql_security": "mcp_doris_check_sql_security", # Internal?
|
||||
"find_similar_examples": "mcp_doris_find_similar_examples", # Internal?
|
||||
"find_similar_history": "mcp_doris_find_similar_history", # Internal?
|
||||
"calculate_query_similarity": "mcp_doris_calculate_query_similarity", # Internal?
|
||||
"adapt_similar_query": "mcp_doris_adapt_similar_query", # Internal?
|
||||
"get_nl2sql_prompt": "mcp_doris_get_nl2sql_prompt" # Internal?
|
||||
}
|
||||
mapped_tool_name = tool_mapping.get(tool_name, tool_name)
|
||||
|
||||
try:
|
||||
# 1. Find the registered tool instance/function from FastMCP
|
||||
tool_instance = None
|
||||
mcp = self.app.state.mcp if hasattr(self.app.state, 'mcp') else self.mcp_server
|
||||
registered_tools = await mcp.list_tools()
|
||||
for tool in registered_tools:
|
||||
# The tool object returned by list_tools might be the wrapper function
|
||||
# defined in tool_initializer. We need its name.
|
||||
tool_registered_name = getattr(tool, 'name', getattr(tool, '__name__', None))
|
||||
if tool_registered_name == tool_name: # Match against the name used in @mcp.tool
|
||||
tool_instance = tool # This is likely the wrapper function itself
|
||||
logger.debug(f"Found registered tool wrapper: {tool_registered_name}")
|
||||
break
|
||||
|
||||
if not tool_instance:
|
||||
# Fallback: Try importing directly (less ideal as it bypasses registration)
|
||||
logger.warning(f"Tool '{tool_name}' not found in registered tools, trying direct import of {mapped_tool_name}")
|
||||
try:
|
||||
import doris_mcp_server.tools.mcp_doris_tools as mcp_tools
|
||||
tool_instance = getattr(mcp_tools, mapped_tool_name, None)
|
||||
if not tool_instance or not callable(tool_instance):
|
||||
raise ValueError(f"Tool function {mapped_tool_name} not found or not callable in mcp_doris_tools.")
|
||||
logger.debug(f"Using directly imported tool function: {mapped_tool_name}")
|
||||
# If using direct import, FastMCP context (ctx) is not available
|
||||
# We need to pass args directly
|
||||
processed_args = self._process_tool_arguments(mapped_tool_name, arguments, recent_query)
|
||||
# Inject callback if provided and applicable
|
||||
if callback and mapped_tool_name.endswith("_stream"):
|
||||
processed_args['callback'] = callback
|
||||
elif callback:
|
||||
processed_args.pop('callback', None)
|
||||
result = await tool_instance(**processed_args)
|
||||
logger.debug(f"Raw result from directly imported tool '{mapped_tool_name}': {result}")
|
||||
return result
|
||||
|
||||
except (ImportError, AttributeError, ValueError) as import_err:
|
||||
logger.error(f"Failed to find or import tool: {tool_name} / {mapped_tool_name}. Error: {import_err}")
|
||||
raise ValueError(f"Tool '{tool_name}' not found or failed to import.") from import_err
|
||||
|
||||
# 2. If found via registration, execute using FastMCP's mechanism (if possible)
|
||||
# or simulate the context passing if tool_instance is the wrapper.
|
||||
# The wrapper expects a Context object.
|
||||
logger.debug(f"Executing registered tool wrapper '{tool_name}'")
|
||||
# We need to manually create a mock or simplified Context if FastMCP doesn't handle this automatically
|
||||
# For simplicity, let's try passing parameters directly if the wrapper handles it.
|
||||
# Ideally, FastMCP would handle the execution via mcp.call_tool(tool_name, params=...) if available.
|
||||
# Let's assume the wrapper function handles **kwargs or a Context object.
|
||||
|
||||
# Create a pseudo-context or just pass params
|
||||
# Method 1: Pass params directly (assuming wrapper handles it)
|
||||
# processed_args = self._process_tool_arguments(mapped_tool_name, arguments, recent_query)
|
||||
# if callback:
|
||||
# processed_args['callback'] = callback
|
||||
# result = await tool_instance(**processed_args) # This likely won't work if it expects Context
|
||||
|
||||
# Method 2: Create a Context-like object (Requires Context class import)
|
||||
# from mcp.server.fastmcp import Context # Make sure imported
|
||||
# pseudo_ctx = Context(mcp=mcp, request=request, params=arguments, tool=tool_instance)
|
||||
# result = await tool_instance(pseudo_ctx)
|
||||
|
||||
# Method 3: Use mcp.call_tool internal method if accessible and appropriate
|
||||
# This is speculative based on potential FastMCP internals
|
||||
if hasattr(mcp, 'call_tool_by_name'): # Hypothetical method
|
||||
logger.debug("Attempting execution via mcp.call_tool_by_name")
|
||||
pseudo_ctx_params = arguments # Pass client args
|
||||
# pseudo_ctx_params['_request'] = request # Maybe pass request?
|
||||
if callback: pseudo_ctx_params['callback'] = callback # Pass callback?
|
||||
result = await mcp.call_tool_by_name(tool_name, params=pseudo_ctx_params)
|
||||
logger.debug(f"Result from mcp.call_tool_by_name: {result}")
|
||||
else:
|
||||
# Fallback to manual context simulation if no direct call method exists
|
||||
logger.debug("Falling back to manual context simulation for tool wrapper execution")
|
||||
from mcp.server.fastmcp import Context # Ensure imported
|
||||
# Prepare params for context, including potentially callback
|
||||
context_params = dict(arguments)
|
||||
if callback: context_params['callback'] = callback
|
||||
pseudo_ctx = Context(mcp=mcp, request=request, params=context_params, tool=tool_instance)
|
||||
result = await tool_instance(pseudo_ctx) # Call the wrapper with simulated context
|
||||
logger.debug(f"Result from manual context simulation: {result}")
|
||||
|
||||
logger.debug(f"Raw result received in call_tool from registered tool '{tool_name}': {result}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Exception during call_tool for '{tool_name}': {str(e)}", exc_info=True)
|
||||
raise ValueError(f"Error executing tool '{tool_name}': {str(e)}") from e
|
||||
|
||||
|
||||
# === Helper Methods (Formatting, Session Cleanup, etc.) ===
|
||||
|
||||
def _format_tools(self, tools):
|
||||
# Helper to format tool list for responses
|
||||
# Based on mcp/listTools structure
|
||||
tools_json = []
|
||||
for tool in tools:
|
||||
# Assuming tools from list_tools are the wrapper functions
|
||||
tool_registered_name = getattr(tool, 'name', getattr(tool, '__name__', None))
|
||||
if not tool_registered_name:
|
||||
logger.warning(f"Could not determine name for tool object: {tool}")
|
||||
continue
|
||||
|
||||
# Need a way to get description and schema associated with the wrapper
|
||||
# This might require inspecting the mcp instance's internal storage
|
||||
mcp = self.app.state.mcp if hasattr(self.app.state, 'mcp') else self.mcp_server
|
||||
# Hypothetical internal access - THIS IS FRAGILE
|
||||
tool_spec = mcp.tools.get(tool_registered_name) if hasattr(mcp, 'tools') else None
|
||||
|
||||
description = ""
|
||||
input_schema = {"type": "object", "properties": {}, "required": []}
|
||||
if tool_spec and hasattr(tool_spec, 'description'):
|
||||
description = tool_spec.description
|
||||
if tool_spec and hasattr(tool_spec, 'parameters'): # Assuming parameters holds the JSON schema
|
||||
input_schema = tool_spec.parameters
|
||||
|
||||
tools_json.append({
|
||||
"name": tool_registered_name,
|
||||
"description": description,
|
||||
"inputSchema": input_schema
|
||||
})
|
||||
return tools_json
|
||||
|
||||
def _format_resources(self, resources):
|
||||
# Helper to format resource list
|
||||
return [res.model_dump() if hasattr(res, "model_dump") else res for res in resources]
|
||||
|
||||
def _format_prompts(self, prompts):
|
||||
# Helper to format prompt list
|
||||
return [prompt.model_dump() if hasattr(prompt, "model_dump") else prompt for prompt in prompts]
|
||||
|
||||
def _format_tool_call_result(self, result: Any) -> Dict[str, Any]:
|
||||
# Helper to format tool results into MCP Content format
|
||||
content_list = []
|
||||
if isinstance(result, str):
|
||||
try:
|
||||
# If it looks like the tool already returned the full JSON RPC like structure
|
||||
parsed_json = json.loads(result)
|
||||
if isinstance(parsed_json, dict) and 'content' in parsed_json and isinstance(parsed_json['content'], list):
|
||||
logger.debug("Tool result already seems formatted with 'content', using as is.")
|
||||
return parsed_json # Use the structure directly
|
||||
else:
|
||||
# Assume it's JSON content, wrap it
|
||||
content_list.append({"type": "json", "json": parsed_json})
|
||||
except json.JSONDecodeError:
|
||||
# Not JSON, treat as text
|
||||
content_list.append({"type": "text", "text": result})
|
||||
elif isinstance(result, (dict, list)):
|
||||
# If result is already a dict with a 'content' list, use it directly
|
||||
if isinstance(result, dict) and 'content' in result and isinstance(result['content'], list):
|
||||
logger.debug("Tool result dictionary has 'content', using as is.")
|
||||
return result # Use the structure directly
|
||||
else:
|
||||
# Otherwise, assume it's JSON content to be wrapped
|
||||
content_list.append({"type": "json", "json": result})
|
||||
elif result is None:
|
||||
# Handle None result, maybe return empty content or specific type?
|
||||
logger.warning("_format_tool_call_result received None result")
|
||||
content_list.append({"type": "text", "text": ""}) # Example: empty text
|
||||
else:
|
||||
# Other types, convert to string and wrap as text
|
||||
content_list.append({"type": "text", "text": str(result)})
|
||||
# Always return a dict with a 'content' key containing a list
|
||||
return {"content": content_list}
|
||||
|
||||
def _process_tool_arguments(self, tool_name, arguments, recent_query):
|
||||
# Helper to process tool arguments, including random_string fallback
|
||||
# Note: Ensure callback is NOT passed here
|
||||
processed_args = dict(arguments)
|
||||
processed_args.pop('callback', None) # Explicitly remove callback
|
||||
|
||||
if "random_string" in arguments and tool_name.startswith("mcp_doris_"):
|
||||
random_string = processed_args.pop("random_string", "") # Remove from processed too
|
||||
logger.debug(f"Processing random_string '{random_string}' for tool {tool_name}")
|
||||
|
||||
# ... (rest of random_string logic as before) ...
|
||||
# Example for exec_query:
|
||||
if tool_name == "mcp_doris_exec_query" and not processed_args.get("sql"):
|
||||
sql_fallback = random_string or recent_query
|
||||
# ... (logic to extract SQL from fallback) ...
|
||||
if sql_extracted:
|
||||
processed_args["sql"] = sql_extracted
|
||||
else:
|
||||
logger.warning(f"Missing sql for {tool_name}, and fallback failed.")
|
||||
# ... (logic for table_name fallback) ...
|
||||
|
||||
return processed_args
|
||||
|
||||
def _extract_recent_query(self, request: Request) -> Optional[str]:
|
||||
# Helper to extract recent user query from request
|
||||
# (Implementation as provided previously)
|
||||
try:
|
||||
# Try to extract message history from request body
|
||||
body = None
|
||||
body_bytes = getattr(request, "_body", None)
|
||||
if body_bytes:
|
||||
try:
|
||||
body = json.loads(body_bytes)
|
||||
except: pass
|
||||
if not body: body = getattr(request, "_json", {})
|
||||
|
||||
messages = body.get("params", {}).get("messages", [])
|
||||
if messages:
|
||||
for msg in reversed(messages):
|
||||
if msg.get("role") == "user": return msg.get("content", "")
|
||||
|
||||
message = body.get("params", {}).get("message", {})
|
||||
if message and message.get("role") == "user": return message.get("content", "")
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting recent query: {str(e)}")
|
||||
return None
|
||||
|
||||
async def _cleanup_session_resources(self, session_id: str, session_data: Dict):
|
||||
# Helper to clean up queues when session is deleted
|
||||
logger.info(f"Cleaning up resources for session [Session ID: {session_id}]")
|
||||
# Close general SSE queues
|
||||
general_queues = session_data.get("general_sse_queues", [])
|
||||
for queue in general_queues:
|
||||
try:
|
||||
await queue.put(STREAM_END_MARKER)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error putting end marker in general queue for session {session_id}: {e}")
|
||||
# Close request-specific SSE queues
|
||||
request_queues = session_data.get("request_queues", {})
|
||||
for req_id, queue in request_queues.items():
|
||||
try:
|
||||
await queue.put(STREAM_END_MARKER)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error putting end marker in request queue {req_id} for session {session_id}: {e}")
|
||||
logger.info(f"Finished cleaning resources for session {session_id}")
|
||||
|
||||
# This method might belong in the main app or a shared utility if needed by both servers
|
||||
# async def cleanup_idle_sessions(self):
|
||||
# # ... (implementation - needs access to self.client_sessions) ...
|
||||
# pass
|
||||
|
||||
# This method might belong in the main app or a shared utility
|
||||
# async def broadcast_message(self, message):
|
||||
# # ... (implementation - needs access to self.client_sessions of BOTH servers?) ...
|
||||
# pass
|
||||
|
||||
# This method is specific to streamable http tool calls
|
||||
async def send_visualization_data(self, session_id: str, request_id: Any, visualization_data: Any):
|
||||
"""Sends visualization data as a notification on the request stream."""
|
||||
if session_id not in self.client_sessions:
|
||||
logger.warning(f"Cannot send visualization: Session {session_id} not found.")
|
||||
return
|
||||
queue = self.client_sessions.get(session_id, {}).get("request_queues", {}).get(request_id)
|
||||
if not queue:
|
||||
logger.warning(f"Cannot send visualization: Request queue {request_id} not found for session {session_id}.")
|
||||
return
|
||||
|
||||
notification = {
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/visualization",
|
||||
"params": visualization_data
|
||||
}
|
||||
try:
|
||||
await queue.put(notification)
|
||||
logger.info(f"Sent visualization notification [Session: {session_id}, Req: {request_id}]")
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending visualization notification [Session: {session_id}, Req: {request_id}]: {e}")
|
||||
|
||||
# This might belong in main app or shared utility
|
||||
# async def send_periodic_updates(self):
|
||||
# # ... (implementation) ...
|
||||
# pass
|
||||
|
||||
# End of class DorisMCPStreamableServer
|
||||
@@ -1,25 +1,9 @@
|
||||
from .mcp_doris_tools import (
|
||||
mcp_doris_exec_query,
|
||||
mcp_doris_get_table_schema,
|
||||
mcp_doris_get_db_table_list,
|
||||
mcp_doris_get_db_list,
|
||||
mcp_doris_get_table_comment,
|
||||
mcp_doris_get_table_column_comments,
|
||||
mcp_doris_get_table_indexes,
|
||||
mcp_doris_get_recent_audit_logs,
|
||||
mcp_doris_get_catalog_list
|
||||
)
|
||||
"""
|
||||
MCP Tools Package - Contains all MCP tool implementations.
|
||||
|
||||
# The __all__ list should reflect the registered tool names,
|
||||
# even though the implementation functions have the prefix.
|
||||
__all__ = [
|
||||
"exec_query",
|
||||
"get_table_schema",
|
||||
"get_db_table_list",
|
||||
"get_db_list",
|
||||
"get_table_comment",
|
||||
"get_table_column_comments",
|
||||
"get_table_indexes",
|
||||
"get_recent_audit_logs",
|
||||
"get_catalog_list"
|
||||
]
|
||||
This package includes:
|
||||
- Doris database tools
|
||||
- Resource managers
|
||||
- Prompt managers
|
||||
- Tool registration and initialization
|
||||
"""
|
||||
|
||||
@@ -1,230 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Doris MCP Tool Implementations
|
||||
|
||||
Includes exec_query and new tools based on schema_extractor.
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import json
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
import pandas as pd
|
||||
|
||||
# --- Use absolute imports ---
|
||||
from doris_mcp_server.utils.schema_extractor import MetadataExtractor
|
||||
from doris_mcp_server.utils.sql_executor_tools import execute_sql_query
|
||||
|
||||
# Get logger
|
||||
logger = logging.getLogger("doris-mcp-tools")
|
||||
|
||||
# --- Helper Function to format response ---
|
||||
def _format_response(success: bool, result: Any = None, error: str = None, message: str = "") -> Dict[str, Any]:
|
||||
response_data = {
|
||||
"success": success,
|
||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
if success and result is not None:
|
||||
# Handle DataFrame serialization
|
||||
if isinstance(result, pd.DataFrame):
|
||||
try:
|
||||
# Convert DataFrame to JSON records format
|
||||
response_data["result"] = json.loads(result.to_json(orient='records', date_format='iso'))
|
||||
except Exception as df_err:
|
||||
logger.error(f"DataFrame to JSON conversion failed: {df_err}")
|
||||
# Fallback or specific error handling for DataFrame
|
||||
response_data["result"] = {"error": "Failed to serialize DataFrame result"}
|
||||
response_data["success"] = False # Mark as failed if serialization fails
|
||||
response_data["error"] = f"DataFrame serialization error: {str(df_err)}"
|
||||
else:
|
||||
response_data["result"] = result
|
||||
response_data["message"] = message or "Operation successful" # Translated: Operation successful
|
||||
elif not success:
|
||||
response_data["error"] = error or "Unknown error" # Translated: Unknown error
|
||||
response_data["message"] = message or "Operation failed" # Translated: Operation failed
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(response_data, ensure_ascii=False, default=str) # Use default=str for non-serializable types
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
async def mcp_doris_exec_query(sql: str = None, db_name: str = None, catalog_name: str = None, max_rows: int = 100, timeout: int = 30) -> Dict[str, Any]:
|
||||
"""
|
||||
Executes an SQL query and returns the result with catalog federation support.
|
||||
|
||||
Args:
|
||||
sql (str): The SQL query to execute. MUST use three-part naming for table references:
|
||||
- Internal tables: internal.db_name.table_name (e.g., "SELECT * FROM internal.ssb.customer")
|
||||
- External tables: catalog_name.db_name.table_name (e.g., "SELECT * FROM mysql.ssb.customer")
|
||||
- Cross-catalog queries: "SELECT * FROM mysql.ssb.customer m JOIN internal.ssb.orders o ON m.id = o.customer_id"
|
||||
|
||||
Examples:
|
||||
- Query internal catalog: "SELECT COUNT(*) FROM internal.ssb.customer"
|
||||
- Query MySQL catalog: "SELECT COUNT(*) FROM mysql.ssb.customer"
|
||||
- Cross-catalog join: "SELECT * FROM internal.ssb.customer c JOIN mysql.test.user_info u ON c.id = u.customer_id"
|
||||
|
||||
db_name (str, optional): Target database name. Only used for connection context, table names in SQL must be fully qualified.
|
||||
catalog_name (str, optional): Reference catalog name for context. Does not affect SQL execution - table names in SQL must be fully qualified.
|
||||
Available catalogs can be found using get_catalog_list tool.
|
||||
max_rows (int, optional): Maximum number of rows to return. Defaults to 100.
|
||||
timeout (int, optional): Query timeout in seconds. Defaults to 30.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary containing the query result or an error.
|
||||
"""
|
||||
logger.info(f"MCP Tool Call: mcp_doris_exec_query, SQL: {sql}, DB: {db_name}, Catalog: {catalog_name}, MaxRows: {max_rows}, Timeout: {timeout}")
|
||||
try:
|
||||
if not sql:
|
||||
return _format_response(success=False, error="SQL statement not provided", message="Please provide the SQL statement to execute")
|
||||
|
||||
# Build parameters to pass to execute_sql_query
|
||||
exec_ctx = {
|
||||
"params": {
|
||||
"sql": sql,
|
||||
"db_name": db_name,
|
||||
"catalog_name": catalog_name,
|
||||
"max_rows": max_rows,
|
||||
"timeout": timeout
|
||||
}
|
||||
}
|
||||
|
||||
# Directly call execute_sql_query to execute the query
|
||||
exec_result = await execute_sql_query(exec_ctx)
|
||||
|
||||
# The format returned by execute_sql_query is {'content': [{'type': 'text', 'text': json_string}]}
|
||||
# Need to parse the internal JSON string
|
||||
if exec_result and 'content' in exec_result and len(exec_result['content']) > 0 and 'text' in exec_result['content'][0]:
|
||||
try:
|
||||
# Parse JSON string
|
||||
result_data = json.loads(exec_result['content'][0]['text'])
|
||||
|
||||
# Directly return the parsed result obtained from execute_sql_query
|
||||
# This result is already in the format {"success": ..., "data": ..., "columns": ...} or {"success": false, "error": ...}
|
||||
# _format_response would wrap it again, but here we directly use the parsed data
|
||||
# Note: This changes the original return structure of this function; it now directly returns the output of sql_executor
|
||||
# If the _format_response wrapper needs to be maintained, the code below needs adjustment
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(result_data, ensure_ascii=False, default=str)
|
||||
}
|
||||
]
|
||||
}
|
||||
except json.JSONDecodeError as json_err:
|
||||
logger.error(f"Failed to parse execute_sql_query result: {json_err}")
|
||||
return _format_response(success=False, error=str(json_err), message="Error parsing SQL execution result")
|
||||
except Exception as parse_err:
|
||||
logger.error(f"Unexpected error occurred while processing execute_sql_query result: {parse_err}", exc_info=True)
|
||||
return _format_response(success=False, error=str(parse_err), message="Unknown error occurred while processing SQL execution result")
|
||||
else:
|
||||
logger.error(f"execute_sql_query returned an unexpected format: {exec_result}")
|
||||
return _format_response(success=False, error="SQL executor returned invalid format", message="Internal error executing SQL query")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool execution failed mcp_doris_exec_query: {str(e)}", exc_info=True)
|
||||
return _format_response(success=False, error=str(e), message="Error executing SQL query")
|
||||
|
||||
|
||||
async def mcp_doris_get_table_schema(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
logger.info(f"MCP Tool Call: mcp_doris_get_table_schema, Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
|
||||
if not table_name:
|
||||
return _format_response(success=False, error="Missing table_name parameter")
|
||||
try:
|
||||
extractor = MetadataExtractor(db_name=db_name, catalog_name=catalog_name)
|
||||
schema = extractor.get_table_schema(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
if not schema:
|
||||
return _format_response(success=False, error="Table not found or has no columns", message=f"Could not get schema for table {catalog_name or 'default'}.{db_name or extractor.db_name}.{table_name}")
|
||||
return _format_response(success=True, result=schema)
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool execution failed mcp_doris_get_table_schema: {str(e)}", exc_info=True)
|
||||
return _format_response(success=False, error=str(e), message="Error getting table schema")
|
||||
|
||||
async def mcp_doris_get_db_table_list(db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
logger.info(f"MCP Tool Call: mcp_doris_get_db_table_list, DB: {db_name}, Catalog: {catalog_name}")
|
||||
try:
|
||||
extractor = MetadataExtractor(db_name=db_name, catalog_name=catalog_name)
|
||||
tables = extractor.get_database_tables(db_name=db_name, catalog_name=catalog_name)
|
||||
return _format_response(success=True, result=tables)
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool execution failed mcp_doris_get_db_table_list: {str(e)}", exc_info=True)
|
||||
return _format_response(success=False, error=str(e), message="Error getting database table list")
|
||||
|
||||
async def mcp_doris_get_db_list(catalog_name: str = None) -> Dict[str, Any]:
|
||||
logger.info(f"MCP Tool Call: mcp_doris_get_db_list, Catalog: {catalog_name}")
|
||||
try:
|
||||
extractor = MetadataExtractor(catalog_name=catalog_name)
|
||||
databases = extractor.get_all_databases(catalog_name=catalog_name)
|
||||
return _format_response(success=True, result=databases)
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool execution failed mcp_doris_get_db_list: {str(e)}", exc_info=True)
|
||||
return _format_response(success=False, error=str(e), message="Error getting database list")
|
||||
|
||||
async def mcp_doris_get_table_comment(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
logger.info(f"MCP Tool Call: mcp_doris_get_table_comment, Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
|
||||
if not table_name:
|
||||
return _format_response(success=False, error="Missing table_name parameter")
|
||||
try:
|
||||
extractor = MetadataExtractor(db_name=db_name, catalog_name=catalog_name)
|
||||
comment = extractor.get_table_comment(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
return _format_response(success=True, result=comment)
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool execution failed mcp_doris_get_table_comment: {str(e)}", exc_info=True)
|
||||
return _format_response(success=False, error=str(e), message="Error getting table comment")
|
||||
|
||||
async def mcp_doris_get_table_column_comments(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
logger.info(f"MCP Tool Call: mcp_doris_get_table_column_comments, Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
|
||||
if not table_name:
|
||||
return _format_response(success=False, error="Missing table_name parameter")
|
||||
try:
|
||||
extractor = MetadataExtractor(db_name=db_name, catalog_name=catalog_name)
|
||||
comments = extractor.get_column_comments(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
return _format_response(success=True, result=comments)
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool execution failed mcp_doris_get_table_column_comments: {str(e)}", exc_info=True)
|
||||
return _format_response(success=False, error=str(e), message="Error getting column comments")
|
||||
|
||||
async def mcp_doris_get_table_indexes(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
logger.info(f"MCP Tool Call: mcp_doris_get_table_indexes, Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
|
||||
if not table_name:
|
||||
return _format_response(success=False, error="Missing table_name parameter")
|
||||
try:
|
||||
extractor = MetadataExtractor(db_name=db_name, catalog_name=catalog_name)
|
||||
indexes = extractor.get_table_indexes(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
return _format_response(success=True, result=indexes)
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool execution failed mcp_doris_get_table_indexes: {str(e)}", exc_info=True)
|
||||
return _format_response(success=False, error=str(e), message="Error getting table indexes")
|
||||
|
||||
async def mcp_doris_get_recent_audit_logs(days: int = 7, limit: int = 100) -> Dict[str, Any]:
|
||||
logger.info(f"MCP Tool Call: mcp_doris_get_recent_audit_logs, Days: {days}, Limit: {limit}")
|
||||
try:
|
||||
extractor = MetadataExtractor()
|
||||
logs_df = extractor.get_recent_audit_logs(days=days, limit=limit)
|
||||
return _format_response(success=True, result=logs_df)
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool execution failed mcp_doris_get_recent_audit_logs: {str(e)}", exc_info=True)
|
||||
return _format_response(success=False, error=str(e), message="Error getting audit logs")
|
||||
|
||||
async def mcp_doris_get_catalog_list() -> Dict[str, Any]:
|
||||
"""
|
||||
Get Doris catalog list
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Dictionary containing catalog list or error information
|
||||
"""
|
||||
logger.info(f"MCP Tool Call: mcp_doris_get_catalog_list")
|
||||
try:
|
||||
extractor = MetadataExtractor()
|
||||
catalogs = extractor.get_catalog_list()
|
||||
return _format_response(success=True, result=catalogs, message="Successfully retrieved catalog list")
|
||||
except Exception as e:
|
||||
logger.error(f"MCP tool execution failed mcp_doris_get_catalog_list: {str(e)}", exc_info=True)
|
||||
return _format_response(success=False, error=str(e), message="Error getting catalog list")
|
||||
455
doris_mcp_server/tools/prompts_manager.py
Normal file
455
doris_mcp_server/tools/prompts_manager.py
Normal file
@@ -0,0 +1,455 @@
|
||||
"""
|
||||
Apache Doris MCP Prompts Manager
|
||||
Provides standardized management of query templates and intelligent prompts
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from mcp.types import (
|
||||
GetPromptResult,
|
||||
Prompt,
|
||||
PromptArgument,
|
||||
PromptMessage,
|
||||
TextContent,
|
||||
)
|
||||
|
||||
from ..utils.db import DorisConnectionManager
|
||||
|
||||
|
||||
class PromptTemplate:
|
||||
"""Prompt template"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
template: str,
|
||||
arguments: list[PromptArgument] = None,
|
||||
category: str = "general",
|
||||
):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.template = template
|
||||
self.arguments = arguments or []
|
||||
self.category = category
|
||||
self.created_at = datetime.now()
|
||||
|
||||
def render(self, arguments: dict[str, Any]) -> str:
|
||||
"""Render template content"""
|
||||
content = self.template
|
||||
for key, value in arguments.items():
|
||||
placeholder = f"{{{key}}}"
|
||||
content = content.replace(placeholder, str(value))
|
||||
return content
|
||||
|
||||
|
||||
class DorisPromptsManager:
|
||||
"""Apache Doris Prompts Manager"""
|
||||
|
||||
def __init__(self, connection_manager: DorisConnectionManager):
|
||||
self.connection_manager = connection_manager
|
||||
self.templates = self._init_prompt_templates()
|
||||
|
||||
def _init_prompt_templates(self) -> dict[str, PromptTemplate]:
|
||||
"""Initialize prompt templates"""
|
||||
templates = {}
|
||||
|
||||
# Sales data analysis template
|
||||
templates["sales_analysis"] = PromptTemplate(
|
||||
name="sales_analysis",
|
||||
description="Sales data analysis query template for generating sales statistics and trend analysis queries",
|
||||
template="""Please help me analyze sales data with the following requirements:
|
||||
|
||||
Analysis time range: {date_range}
|
||||
{product_filter}
|
||||
{region_filter}
|
||||
|
||||
Please generate SQL queries to analyze the following dimensions:
|
||||
1. Total sales amount and order quantity
|
||||
2. Sales trends by time dimension
|
||||
3. Top-selling product rankings
|
||||
4. Sales personnel performance statistics
|
||||
|
||||
Data table structure reference:
|
||||
- Order table: Contains order ID, customer ID, salesperson ID, order amount, order time and other fields
|
||||
- Product table: Contains product ID, product name, product category, price and other fields
|
||||
- Customer table: Contains customer ID, customer name, region and other fields
|
||||
|
||||
Please ensure query results are easy to understand and analyze.""",
|
||||
arguments=[
|
||||
PromptArgument(
|
||||
name="date_range",
|
||||
description="Date range for analysis, such as 'Q1 2024' or 'last 30 days'",
|
||||
required=True,
|
||||
),
|
||||
PromptArgument(
|
||||
name="product_category",
|
||||
description="Product category filter condition, such as 'electronics'",
|
||||
required=False,
|
||||
),
|
||||
PromptArgument(
|
||||
name="region",
|
||||
description="Sales region filter condition, such as 'East China'",
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
category="business_analysis",
|
||||
)
|
||||
|
||||
# User behavior analysis template
|
||||
templates["user_behavior_analysis"] = PromptTemplate(
|
||||
name="user_behavior_analysis",
|
||||
description="User behavior analysis query template for analyzing user activity patterns and preferences",
|
||||
template="""Please help me analyze user behavior data, analysis objectives:
|
||||
|
||||
User segment: {user_segment}
|
||||
{behavior_filter}
|
||||
Analysis period: {time_period}
|
||||
|
||||
Please generate SQL queries to analyze the following aspects:
|
||||
1. User activity statistics (DAU, MAU)
|
||||
2. User behavior path analysis
|
||||
3. Feature usage preference statistics
|
||||
4. User retention rate analysis
|
||||
|
||||
Data table structure reference:
|
||||
- User table: Contains user ID, registration time, user type, region and other fields
|
||||
- Behavior log table: Contains user ID, behavior type, behavior time, page path and other fields
|
||||
- Session table: Contains session ID, user ID, session start time, session duration and other fields
|
||||
|
||||
Please provide easy-to-understand statistical results and visualization suggestions.""",
|
||||
arguments=[
|
||||
PromptArgument(
|
||||
name="user_segment",
|
||||
description="User segment conditions, such as 'new users', 'active users'",
|
||||
required=True,
|
||||
),
|
||||
PromptArgument(
|
||||
name="behavior_type",
|
||||
description="Behavior type filter, such as 'login', 'purchase', 'browse'",
|
||||
required=False,
|
||||
),
|
||||
PromptArgument(
|
||||
name="time_period",
|
||||
description="Analysis time period, such as 'last 7 days', 'this month'",
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
category="user_analysis",
|
||||
)
|
||||
|
||||
# Performance optimization analysis template
|
||||
templates["performance_optimization"] = PromptTemplate(
|
||||
name="performance_optimization",
|
||||
description="Database performance optimization analysis template for identifying performance bottlenecks and optimization opportunities",
|
||||
template="""Please help me with database performance analysis and optimization recommendations:
|
||||
|
||||
Focus area: {focus_area}
|
||||
{table_scope}
|
||||
Performance metrics: {metrics}
|
||||
|
||||
Please generate SQL queries to analyze the following content:
|
||||
1. Table and query performance statistics
|
||||
2. Index usage efficiency analysis
|
||||
3. Slow query identification and analysis
|
||||
4. Storage space usage
|
||||
|
||||
Analysis objectives:
|
||||
- Identify performance bottlenecks
|
||||
- Provide optimization recommendations
|
||||
- Evaluate optimization effects
|
||||
|
||||
Please provide specific optimization recommendations and implementation steps.""",
|
||||
arguments=[
|
||||
PromptArgument(
|
||||
name="focus_area",
|
||||
description="Performance area of focus, such as 'query performance', 'storage optimization'",
|
||||
required=True,
|
||||
),
|
||||
PromptArgument(
|
||||
name="table_name",
|
||||
description="Specific table name (optional), if analyzing specific table performance",
|
||||
required=False,
|
||||
),
|
||||
PromptArgument(
|
||||
name="metrics",
|
||||
description="Performance metrics of interest, such as 'response time', 'throughput'",
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
category="performance",
|
||||
)
|
||||
|
||||
# Data quality check template
|
||||
templates["data_quality_check"] = PromptTemplate(
|
||||
name="data_quality_check",
|
||||
description="Data quality check template for detecting data integrity and consistency issues",
|
||||
template="""Please help me perform data quality checks:
|
||||
|
||||
Check target: {target_table}
|
||||
{quality_dimensions}
|
||||
Check level: {check_level}
|
||||
|
||||
Please generate SQL queries to check the following data quality issues:
|
||||
1. Data integrity (null values, duplicate values)
|
||||
2. Data consistency (format, range)
|
||||
3. Data accuracy (business rule validation)
|
||||
4. Data timeliness (update frequency)
|
||||
|
||||
Check items:
|
||||
- Required field null value checks
|
||||
- Primary key and unique constraint validation
|
||||
- Data format and type checks
|
||||
- Business logic consistency validation
|
||||
- Data distribution anomaly detection
|
||||
|
||||
Please provide detailed problem reports and fix recommendations.""",
|
||||
arguments=[
|
||||
PromptArgument(
|
||||
name="target_table", description="Target table name to check", required=True
|
||||
),
|
||||
PromptArgument(
|
||||
name="quality_dimensions",
|
||||
description="Quality check dimensions, such as 'integrity', 'consistency', 'accuracy'",
|
||||
required=False,
|
||||
),
|
||||
PromptArgument(
|
||||
name="check_level",
|
||||
description="Check level, such as 'basic check', 'deep check'",
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
category="data_quality",
|
||||
)
|
||||
|
||||
# Report generation template
|
||||
templates["report_generation"] = PromptTemplate(
|
||||
name="report_generation",
|
||||
description="Business report generation template for creating standardized business reports",
|
||||
template="""Please help me generate business reports:
|
||||
|
||||
Report type: {report_type}
|
||||
Report period: {report_period}
|
||||
{business_scope}
|
||||
|
||||
Please generate SQL queries to build the following report content:
|
||||
1. Key business indicator summary
|
||||
2. Trend analysis and year-over-year/month-over-month comparison
|
||||
3. Anomaly data identification and explanation
|
||||
4. Business insights and recommendations
|
||||
|
||||
Report requirements:
|
||||
- Data accuracy and timeliness
|
||||
- Clear hierarchical structure
|
||||
- Easy-to-understand data presentation
|
||||
- Decision-supporting analytical perspective
|
||||
|
||||
Please provide complete report structure and data acquisition logic.""",
|
||||
arguments=[
|
||||
PromptArgument(
|
||||
name="report_type",
|
||||
description="Report type, such as 'sales report', 'operations report', 'financial report'",
|
||||
required=True,
|
||||
),
|
||||
PromptArgument(
|
||||
name="report_period",
|
||||
description="Report period, such as 'daily report', 'weekly report', 'monthly report'",
|
||||
required=True,
|
||||
),
|
||||
PromptArgument(
|
||||
name="business_unit",
|
||||
description="Business unit scope, such as 'East China region', 'Product line A'",
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
category="reporting",
|
||||
)
|
||||
|
||||
# Real-time monitoring template
|
||||
templates["real_time_monitoring"] = PromptTemplate(
|
||||
name="real_time_monitoring",
|
||||
description="Real-time monitoring query template for building real-time data monitoring and alerting",
|
||||
template="""Please help me design real-time monitoring queries:
|
||||
|
||||
Monitoring target: {monitoring_target}
|
||||
Alert threshold: {alert_threshold}
|
||||
Monitoring frequency: {monitoring_frequency}
|
||||
|
||||
Please generate SQL queries to implement the following monitoring functions:
|
||||
1. Real-time statistics of key indicators
|
||||
2. Anomaly detection and alerting
|
||||
3. Trend change monitoring
|
||||
4. System health status checks
|
||||
|
||||
Monitoring dimensions:
|
||||
- Business indicator monitoring (transaction volume, user activity, etc.)
|
||||
- Technical indicator monitoring (performance, error rate, etc.)
|
||||
- Data quality monitoring (integrity, consistency, etc.)
|
||||
|
||||
Please provide complete monitoring solution and implementation recommendations.""",
|
||||
arguments=[
|
||||
PromptArgument(
|
||||
name="monitoring_target",
|
||||
description="Monitoring target, such as 'transaction system', 'user activity'",
|
||||
required=True,
|
||||
),
|
||||
PromptArgument(
|
||||
name="alert_threshold",
|
||||
description="Alert threshold setting, such as 'error rate > 5%'",
|
||||
required=False,
|
||||
),
|
||||
PromptArgument(
|
||||
name="monitoring_frequency",
|
||||
description="Monitoring frequency, such as 'real-time', 'every minute', 'every 5 minutes'",
|
||||
required=False,
|
||||
),
|
||||
],
|
||||
category="monitoring",
|
||||
)
|
||||
|
||||
return templates
|
||||
|
||||
async def list_prompts(self) -> list[Prompt]:
|
||||
"""List all available prompt templates"""
|
||||
prompts = []
|
||||
|
||||
for template in self.templates.values():
|
||||
prompt = Prompt(
|
||||
name=template.name,
|
||||
description=template.description,
|
||||
arguments=template.arguments,
|
||||
)
|
||||
prompts.append(prompt)
|
||||
|
||||
return prompts
|
||||
|
||||
async def get_prompt(self, name: str, arguments: dict[str, Any]) -> GetPromptResult:
|
||||
"""Get content of specific prompt template"""
|
||||
if name not in self.templates:
|
||||
raise ValueError(f"Prompt template named '{name}' not found")
|
||||
|
||||
template = self.templates[name]
|
||||
|
||||
# Process optional arguments
|
||||
processed_args = await self._process_arguments(template, arguments)
|
||||
|
||||
# Render template content
|
||||
rendered_content = template.render(processed_args)
|
||||
|
||||
# Add database context information
|
||||
context_info = await self._get_database_context()
|
||||
|
||||
full_content = f"""{rendered_content}
|
||||
|
||||
Database context information:
|
||||
{context_info}
|
||||
|
||||
Please generate accurate and efficient SQL queries based on the above requirements and database structure."""
|
||||
|
||||
return GetPromptResult(
|
||||
description=template.description,
|
||||
messages=[
|
||||
PromptMessage(
|
||||
role="user", content=TextContent(type="text", text=full_content)
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
async def _process_arguments(
|
||||
self, template: PromptTemplate, arguments: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Process template arguments"""
|
||||
processed = {}
|
||||
|
||||
for arg in template.arguments:
|
||||
if arg.name in arguments:
|
||||
processed[arg.name] = arguments[arg.name]
|
||||
elif arg.required:
|
||||
raise ValueError(f"Missing required parameter: {arg.name}")
|
||||
else:
|
||||
# Provide default handling for optional parameters
|
||||
processed[arg.name] = self._get_default_argument_text(arg.name)
|
||||
|
||||
return processed
|
||||
|
||||
def _get_default_argument_text(self, arg_name: str) -> str:
|
||||
"""Get default text for optional parameters"""
|
||||
defaults = {
|
||||
"product_category": "",
|
||||
"region": "",
|
||||
"behavior_type": "",
|
||||
"time_period": "No time range restriction",
|
||||
"table_name": "",
|
||||
"metrics": "All performance metrics",
|
||||
"quality_dimensions": "All quality dimensions",
|
||||
"check_level": "Standard check",
|
||||
"business_unit": "Full business scope",
|
||||
"alert_threshold": "Use default threshold",
|
||||
"monitoring_frequency": "Real-time monitoring",
|
||||
}
|
||||
|
||||
return defaults.get(arg_name, "")
|
||||
|
||||
async def _get_database_context(self) -> str:
|
||||
"""Get database context information"""
|
||||
try:
|
||||
connection = await self.connection_manager.get_connection("system")
|
||||
|
||||
# Get basic database information
|
||||
db_info_sql = """
|
||||
SELECT
|
||||
COUNT(*) as table_count,
|
||||
SUM(table_rows) as total_rows
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_type = 'BASE TABLE'
|
||||
"""
|
||||
|
||||
db_result = await connection.execute(db_info_sql)
|
||||
db_info = db_result.data[0] if db_result.data else {}
|
||||
|
||||
# Get main table list
|
||||
tables_sql = """
|
||||
SELECT
|
||||
table_name,
|
||||
table_comment,
|
||||
table_rows
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_type = 'BASE TABLE'
|
||||
ORDER BY table_rows DESC
|
||||
LIMIT 10
|
||||
"""
|
||||
|
||||
tables_result = await connection.execute(tables_sql)
|
||||
|
||||
context = f"""Current database statistics:
|
||||
- Total number of tables: {db_info.get("table_count", 0)}
|
||||
- Total data rows: {db_info.get("total_rows", 0):,}
|
||||
|
||||
Main data tables:"""
|
||||
|
||||
for table in tables_result.data:
|
||||
context += f"\n- {table['table_name']}"
|
||||
if table.get("table_comment"):
|
||||
context += f": {table['table_comment']}"
|
||||
context += f" ({table.get('table_rows', 0):,} rows)"
|
||||
|
||||
return context
|
||||
|
||||
except Exception as e:
|
||||
return f"Unable to get database context information: {str(e)}"
|
||||
|
||||
def get_templates_by_category(self, category: str) -> list[PromptTemplate]:
|
||||
"""Get templates by category"""
|
||||
return [
|
||||
template
|
||||
for template in self.templates.values()
|
||||
if template.category == category
|
||||
]
|
||||
|
||||
def get_all_categories(self) -> list[str]:
|
||||
"""Get all template categories"""
|
||||
categories = {template.category for template in self.templates.values()}
|
||||
return sorted(categories)
|
||||
361
doris_mcp_server/tools/resources_manager.py
Normal file
361
doris_mcp_server/tools/resources_manager.py
Normal file
@@ -0,0 +1,361 @@
|
||||
"""
|
||||
Apache Doris MCP Resources Manager
|
||||
Provides standardized abstraction and access interface for database metadata
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from mcp.types import Resource
|
||||
|
||||
from ..utils.db import DorisConnectionManager
|
||||
|
||||
|
||||
class TableMetadata:
|
||||
"""Data table metadata"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
comment: str = None,
|
||||
row_count: int = 0,
|
||||
columns: list[dict] = None,
|
||||
create_time: datetime = None,
|
||||
):
|
||||
self.name = name
|
||||
self.comment = comment
|
||||
self.row_count = row_count
|
||||
self.columns = columns or []
|
||||
self.create_time = create_time
|
||||
|
||||
|
||||
class ViewMetadata:
|
||||
"""Data view metadata"""
|
||||
|
||||
def __init__(self, name: str, comment: str = None, definition: str = None):
|
||||
self.name = name
|
||||
self.comment = comment
|
||||
self.definition = definition
|
||||
|
||||
|
||||
class MetadataCache:
|
||||
"""Metadata cache manager"""
|
||||
|
||||
def __init__(self, ttl_seconds: int = 300):
|
||||
self.cache = {}
|
||||
self.ttl = ttl_seconds
|
||||
|
||||
async def get(self, key: str) -> Any | None:
|
||||
if key in self.cache:
|
||||
data, timestamp = self.cache[key]
|
||||
if datetime.now().timestamp() - timestamp < self.ttl:
|
||||
return data
|
||||
else:
|
||||
del self.cache[key]
|
||||
return None
|
||||
|
||||
async def set(self, key: str, value: Any):
|
||||
self.cache[key] = (value, datetime.now().timestamp())
|
||||
|
||||
|
||||
class DorisResourcesManager:
|
||||
"""Apache Doris Resources Manager"""
|
||||
|
||||
def __init__(self, connection_manager: DorisConnectionManager):
|
||||
self.connection_manager = connection_manager
|
||||
self.metadata_cache = MetadataCache()
|
||||
|
||||
async def list_resources(self) -> list[Resource]:
|
||||
"""List all available database resources"""
|
||||
resources = []
|
||||
|
||||
try:
|
||||
# Get metadata for all tables
|
||||
tables = await self._get_table_metadata()
|
||||
for table in tables:
|
||||
resources.append(
|
||||
Resource(
|
||||
uri=f"doris://table/{table.name}",
|
||||
name=f"Data Table: {table.name}",
|
||||
description=f"{table.comment or 'Data table'} (rows: {table.row_count:,})",
|
||||
mimeType="application/json",
|
||||
)
|
||||
)
|
||||
|
||||
# Get metadata for all views
|
||||
views = await self._get_view_metadata()
|
||||
for view in views:
|
||||
resources.append(
|
||||
Resource(
|
||||
uri=f"doris://view/{view.name}",
|
||||
name=f"Data View: {view.name}",
|
||||
description=f"{view.comment or 'Data view'}",
|
||||
mimeType="application/json",
|
||||
)
|
||||
)
|
||||
|
||||
# Add database statistics resource
|
||||
resources.append(
|
||||
Resource(
|
||||
uri="doris://stats/database",
|
||||
name="Database Statistics",
|
||||
description="Overall database statistics and performance metrics",
|
||||
mimeType="application/json",
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to get resource list: {e}")
|
||||
|
||||
return resources
|
||||
|
||||
async def read_resource(self, uri: str) -> str:
|
||||
"""Read detailed information of specific resource"""
|
||||
try:
|
||||
resource_type, resource_name = self._parse_resource_uri(uri)
|
||||
|
||||
if resource_type == "table":
|
||||
return await self._get_table_schema(resource_name)
|
||||
elif resource_type == "view":
|
||||
return await self._get_view_definition(resource_name)
|
||||
elif resource_type == "stats" and resource_name == "database":
|
||||
return await self._get_database_stats()
|
||||
else:
|
||||
raise ValueError(f"Unsupported resource type: {resource_type}")
|
||||
|
||||
except Exception as e:
|
||||
return json.dumps(
|
||||
{"error": f"Failed to read resource: {str(e)}", "uri": uri},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
async def _get_table_metadata(self) -> list[TableMetadata]:
|
||||
"""Get metadata for all tables"""
|
||||
cache_key = "table_metadata"
|
||||
cached = await self.metadata_cache.get(cache_key)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
connection = await self.connection_manager.get_connection("system")
|
||||
|
||||
# Query basic table information
|
||||
tables_query = """
|
||||
SELECT
|
||||
table_name,
|
||||
table_comment,
|
||||
table_rows as row_count,
|
||||
create_time
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_type = 'BASE TABLE'
|
||||
ORDER BY table_name
|
||||
"""
|
||||
|
||||
result = await connection.execute(tables_query)
|
||||
tables = []
|
||||
|
||||
for row in result.data:
|
||||
# Get column information for the table
|
||||
columns = await self._get_table_columns(connection, row["table_name"])
|
||||
|
||||
table = TableMetadata(
|
||||
name=row["table_name"],
|
||||
comment=row.get("table_comment"),
|
||||
row_count=row.get("row_count", 0),
|
||||
columns=columns,
|
||||
create_time=row.get("create_time"),
|
||||
)
|
||||
tables.append(table)
|
||||
|
||||
await self.metadata_cache.set(cache_key, tables)
|
||||
return tables
|
||||
|
||||
async def _get_table_columns(self, connection, table_name: str) -> list[dict]:
|
||||
"""Get column information for table"""
|
||||
columns_query = """
|
||||
SELECT
|
||||
column_name,
|
||||
data_type,
|
||||
is_nullable,
|
||||
column_default,
|
||||
column_comment,
|
||||
column_key
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = %s
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
|
||||
result = await connection.execute(columns_query, (table_name,))
|
||||
return [dict(row) for row in result.data]
|
||||
|
||||
async def _get_view_metadata(self) -> list[ViewMetadata]:
|
||||
"""Get metadata for all views"""
|
||||
cache_key = "view_metadata"
|
||||
cached = await self.metadata_cache.get(cache_key)
|
||||
if cached:
|
||||
return cached
|
||||
|
||||
connection = await self.connection_manager.get_connection("system")
|
||||
|
||||
views_query = """
|
||||
SELECT
|
||||
table_name,
|
||||
table_comment,
|
||||
view_definition
|
||||
FROM information_schema.views
|
||||
WHERE table_schema = DATABASE()
|
||||
ORDER BY table_name
|
||||
"""
|
||||
|
||||
result = await connection.execute(views_query)
|
||||
views = []
|
||||
|
||||
for row in result.data:
|
||||
view = ViewMetadata(
|
||||
name=row["table_name"],
|
||||
comment=row.get("table_comment"),
|
||||
definition=row.get("view_definition"),
|
||||
)
|
||||
views.append(view)
|
||||
|
||||
await self.metadata_cache.set(cache_key, views)
|
||||
return views
|
||||
|
||||
async def _get_table_schema(self, table_name: str) -> str:
|
||||
"""Get detailed structure information of table"""
|
||||
connection = await self.connection_manager.get_connection("system")
|
||||
|
||||
# Get basic table information
|
||||
table_info_query = """
|
||||
SELECT
|
||||
table_name,
|
||||
table_comment,
|
||||
table_rows,
|
||||
create_time,
|
||||
engine
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = %s
|
||||
"""
|
||||
|
||||
table_result = await connection.execute(table_info_query, (table_name,))
|
||||
if not table_result.data:
|
||||
raise ValueError(f"Table {table_name} does not exist")
|
||||
|
||||
table_info = table_result.data[0]
|
||||
|
||||
# Get column information
|
||||
columns = await self._get_table_columns(connection, table_name)
|
||||
|
||||
# Get index information
|
||||
indexes = await self._get_table_indexes(connection, table_name)
|
||||
|
||||
schema_info = {
|
||||
"table_name": table_info["table_name"],
|
||||
"comment": table_info.get("table_comment"),
|
||||
"row_count": table_info.get("table_rows", 0),
|
||||
"create_time": str(table_info.get("create_time")),
|
||||
"engine": table_info.get("engine"),
|
||||
"columns": columns,
|
||||
"indexes": indexes,
|
||||
}
|
||||
|
||||
return json.dumps(schema_info, ensure_ascii=False, indent=2)
|
||||
|
||||
async def _get_table_indexes(self, connection, table_name: str) -> list[dict]:
|
||||
"""Get index information for table"""
|
||||
indexes_query = """
|
||||
SELECT
|
||||
index_name,
|
||||
column_name,
|
||||
index_type,
|
||||
non_unique
|
||||
FROM information_schema.statistics
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = %s
|
||||
ORDER BY index_name, seq_in_index
|
||||
"""
|
||||
|
||||
result = await connection.execute(indexes_query, (table_name,))
|
||||
return [dict(row) for row in result.data]
|
||||
|
||||
async def _get_view_definition(self, view_name: str) -> str:
|
||||
"""Get definition information of view"""
|
||||
connection = await self.connection_manager.get_connection("system")
|
||||
|
||||
view_query = """
|
||||
SELECT
|
||||
table_name,
|
||||
table_comment,
|
||||
view_definition
|
||||
FROM information_schema.views
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = %s
|
||||
"""
|
||||
|
||||
result = await connection.execute(view_query, (view_name,))
|
||||
if not result.data:
|
||||
raise ValueError(f"View {view_name} does not exist")
|
||||
|
||||
view_info = result.data[0]
|
||||
|
||||
schema_info = {
|
||||
"view_name": view_info["table_name"],
|
||||
"comment": view_info.get("table_comment"),
|
||||
"definition": view_info.get("view_definition"),
|
||||
}
|
||||
|
||||
return json.dumps(schema_info, ensure_ascii=False, indent=2)
|
||||
|
||||
async def _get_database_stats(self) -> str:
|
||||
"""Get database statistics"""
|
||||
connection = await self.connection_manager.get_connection("system")
|
||||
|
||||
# Get table statistics
|
||||
table_stats_query = """
|
||||
SELECT
|
||||
COUNT(*) as table_count,
|
||||
SUM(table_rows) as total_rows
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_type = 'BASE TABLE'
|
||||
"""
|
||||
|
||||
table_result = await connection.execute(table_stats_query)
|
||||
table_stats = table_result.data[0] if table_result.data else {}
|
||||
|
||||
# Get view statistics
|
||||
view_stats_query = """
|
||||
SELECT COUNT(*) as view_count
|
||||
FROM information_schema.views
|
||||
WHERE table_schema = DATABASE()
|
||||
"""
|
||||
|
||||
view_result = await connection.execute(view_stats_query)
|
||||
view_stats = view_result.data[0] if view_result.data else {}
|
||||
|
||||
stats_info = {
|
||||
"database_name": "current_database",
|
||||
"table_count": table_stats.get("table_count", 0),
|
||||
"view_count": view_stats.get("view_count", 0),
|
||||
"total_rows": table_stats.get("total_rows", 0),
|
||||
"last_updated": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
return json.dumps(stats_info, ensure_ascii=False, indent=2)
|
||||
|
||||
def _parse_resource_uri(self, uri: str) -> tuple:
|
||||
"""Parse resource URI"""
|
||||
if not uri.startswith("doris://"):
|
||||
raise ValueError("Invalid resource URI format")
|
||||
|
||||
path = uri[8:] # Remove "doris://" prefix
|
||||
parts = path.split("/")
|
||||
|
||||
if len(parts) < 2:
|
||||
raise ValueError("Incomplete resource URI format")
|
||||
|
||||
return parts[0], parts[1]
|
||||
@@ -1,157 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
Tool Initialization Module
|
||||
|
||||
Centralized initialization of all tools, ensuring they are correctly registered with MCP
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Dict, Any, Optional
|
||||
import json
|
||||
from datetime import datetime
|
||||
import traceback
|
||||
|
||||
# Import Context
|
||||
from mcp.server.fastmcp import Context
|
||||
|
||||
# Import doris mcp tools
|
||||
from doris_mcp_server.tools.mcp_doris_tools import (
|
||||
mcp_doris_exec_query,
|
||||
mcp_doris_get_table_schema,
|
||||
mcp_doris_get_db_table_list,
|
||||
mcp_doris_get_db_list,
|
||||
mcp_doris_get_table_comment,
|
||||
mcp_doris_get_table_column_comments,
|
||||
mcp_doris_get_table_indexes,
|
||||
mcp_doris_get_recent_audit_logs,
|
||||
mcp_doris_get_catalog_list
|
||||
)
|
||||
|
||||
# Get logger
|
||||
logger = logging.getLogger("doris-mcp-tools-initializer")
|
||||
|
||||
async def register_mcp_tools(mcp):
|
||||
"""Register MCP tool functions
|
||||
|
||||
Args:
|
||||
mcp: FastMCP instance
|
||||
"""
|
||||
logger.info("Starting to register MCP tools...")
|
||||
|
||||
try:
|
||||
# Register Tool: Execute SQL Query (Using long description string including parameters)
|
||||
@mcp.tool("exec_query", description="""[Function Description]: Execute SQL query and return result command with catalog federation support.\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n
|
||||
- sql (string) [Required] - SQL statement to execute. MUST use three-part naming for all table references: 'catalog_name.db_name.table_name'. For internal tables use 'internal.db_name.table_name', for external tables use 'catalog_name.db_name.table_name'\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n
|
||||
- catalog_name (string) [Optional] - Reference catalog name for context, defaults to current catalog\n
|
||||
- max_rows (integer) [Optional] - Maximum number of rows to return, default 100
|
||||
- timeout (integer) [Optional] - Query timeout in seconds, default 30""")
|
||||
async def exec_query_tool(sql: str, db_name: str = None, catalog_name: str = None, max_rows: int = 100, timeout: int = 30) -> Dict[str, Any]:
|
||||
"""Wrapper: Execute SQL query and return result command"""
|
||||
# Note: ctx parameter is no longer needed here as we receive named parameters directly
|
||||
return await mcp_doris_exec_query(sql=sql, db_name=db_name, catalog_name=catalog_name, max_rows=max_rows, timeout=timeout)
|
||||
|
||||
# Register Tool: Get Table Schema (Keep long description string including parameters)
|
||||
@mcp.tool("get_table_schema", description="""[Function Description]: Get detailed structure information of the specified table (columns, types, comments, etc.).\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n
|
||||
- table_name (string) [Required] - Name of the table to query\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
|
||||
async def get_table_schema_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get table schema"""
|
||||
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
|
||||
return await mcp_doris_get_table_schema(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
|
||||
# Register Tool: Get Database Table List (Keep long description string including parameters)
|
||||
@mcp.tool("get_db_table_list", description="""[Function Description]: Get a list of all table names in the specified database.\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
|
||||
async def get_db_table_list_tool(db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get database table list"""
|
||||
return await mcp_doris_get_db_table_list(db_name=db_name, catalog_name=catalog_name)
|
||||
|
||||
# Register Tool: Get Database List (Keep long description string including parameters)
|
||||
# Note: Although the description mentions random_string, the wrapper function signature does not. See how mcp handles this.
|
||||
@mcp.tool("get_db_list", description="""[Function Description]: Get a list of all database names on the server.\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
|
||||
async def get_db_list_tool(catalog_name: str = None) -> Dict[str, Any]: # Function signature has no parameters
|
||||
"""Wrapper: Get database list"""
|
||||
return await mcp_doris_get_db_list(catalog_name=catalog_name)
|
||||
|
||||
# Register Tool: Get Table Comment (Keep long description string including parameters)
|
||||
@mcp.tool("get_table_comment", description="""[Function Description]: Get the comment information for the specified table.\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n
|
||||
- table_name (string) [Required] - Name of the table to query\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
|
||||
async def get_table_comment_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get table comment"""
|
||||
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
|
||||
return await mcp_doris_get_table_comment(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
|
||||
# Register Tool: Get Table Column Comments (Keep long description string including parameters)
|
||||
@mcp.tool("get_table_column_comments", description="""[Function Description]: Get comment information for all columns in the specified table.\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n
|
||||
- table_name (string) [Required] - Name of the table to query\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
|
||||
async def get_table_column_comments_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get table column comments"""
|
||||
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
|
||||
return await mcp_doris_get_table_column_comments(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
|
||||
# Register Tool: Get Table Indexes (Keep long description string including parameters)
|
||||
@mcp.tool("get_table_indexes", description="""[Function Description]: Get index information for the specified table.\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n
|
||||
- table_name (string) [Required] - Name of the table to query\n
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database\n
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""")
|
||||
async def get_table_indexes_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Wrapper: Get table indexes"""
|
||||
if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]}
|
||||
return await mcp_doris_get_table_indexes(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
|
||||
# Register Tool: Get Recent Audit Logs (Keep long description string including parameters)
|
||||
@mcp.tool("get_recent_audit_logs", description="""[Function Description]: Get audit log records for a recent period.\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n
|
||||
- days (integer) [Optional] - Number of recent days of logs to retrieve, default is 7\n
|
||||
- limit (integer) [Optional] - Maximum number of records to return, default is 100\n""")
|
||||
async def get_recent_audit_logs_tool(days: int = 7, limit: int = 100) -> Dict[str, Any]:
|
||||
"""Wrapper: Get recent audit logs"""
|
||||
try:
|
||||
days = int(days)
|
||||
limit = int(limit)
|
||||
except (ValueError, TypeError):
|
||||
return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "days and limit parameters must be integers"})}]}
|
||||
return await mcp_doris_get_recent_audit_logs(days=days, limit=limit)
|
||||
|
||||
# Register Tool: Get Catalog List (Keep long description string including parameters)
|
||||
@mcp.tool("get_catalog_list", description="""[Function Description]: Get a list of all catalog names on the server.\n
|
||||
[Parameter Content]:\n
|
||||
- random_string (string) [Required] - Unique identifier for the tool call\n""")
|
||||
async def get_catalog_list_tool() -> Dict[str, Any]:
|
||||
"""Wrapper: Get catalog list"""
|
||||
return await mcp_doris_get_catalog_list()
|
||||
|
||||
# Get tool count
|
||||
tools_count = len(await mcp.list_tools())
|
||||
logger.info(f"Registered all MCP tools, total {tools_count} tools")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error registering MCP tools: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
766
doris_mcp_server/tools/tools_manager.py
Normal file
766
doris_mcp_server/tools/tools_manager.py
Normal file
@@ -0,0 +1,766 @@
|
||||
"""
|
||||
Apache Doris MCP Tools Manager
|
||||
Responsible for tool registration, management, scheduling and routing, does not contain specific business logic implementation
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from mcp.types import Tool
|
||||
|
||||
from ..utils.db import DorisConnectionManager
|
||||
from ..utils.query_executor import DorisQueryExecutor
|
||||
from ..utils.analysis_tools import TableAnalyzer, PerformanceMonitor
|
||||
from ..utils.schema_extractor import MetadataExtractor
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
|
||||
class DorisToolsManager:
|
||||
"""Apache Doris Tools Manager"""
|
||||
|
||||
def __init__(self, connection_manager: DorisConnectionManager):
|
||||
self.connection_manager = connection_manager
|
||||
|
||||
# Initialize business logic processors
|
||||
self.query_executor = DorisQueryExecutor(connection_manager)
|
||||
self.table_analyzer = TableAnalyzer(connection_manager)
|
||||
self.performance_monitor = PerformanceMonitor(connection_manager)
|
||||
self.metadata_extractor = MetadataExtractor(connection_manager=connection_manager)
|
||||
|
||||
logger.info("DorisToolsManager initialized with business logic processors")
|
||||
|
||||
async def register_tools_with_mcp(self, mcp):
|
||||
"""Register all tools to MCP server"""
|
||||
logger.info("Starting to register MCP tools")
|
||||
|
||||
# Column statistical analysis tool
|
||||
@mcp.tool(
|
||||
"column_analysis",
|
||||
description="""[Function Description]: Analyze statistical information and data distribution of the specified column.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- table_name (string) [Required] - Name of the table to analyze
|
||||
|
||||
- column_name (string) [Required] - Name of the column to analyze
|
||||
|
||||
- analysis_type (string) [Optional] - Type of analysis to perform, default is "basic"
|
||||
* "basic": Basic statistics (count, null values, distinct values)
|
||||
* "distribution": Data distribution analysis (frequency, percentiles)
|
||||
* "detailed": Comprehensive analysis including all above plus patterns and outliers
|
||||
""",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"table_name": {"type": "string", "description": "Table name"},
|
||||
"column_name": {
|
||||
"type": "string",
|
||||
"description": "Column name to analyze",
|
||||
},
|
||||
"analysis_type": {
|
||||
"type": "string",
|
||||
"enum": ["basic", "distribution", "detailed"],
|
||||
"description": "Analysis type",
|
||||
"default": "basic",
|
||||
},
|
||||
},
|
||||
"required": ["table_name", "column_name"],
|
||||
}
|
||||
)
|
||||
async def column_analysis_tool(
|
||||
table_name: str,
|
||||
column_name: str,
|
||||
analysis_type: str = "basic"
|
||||
) -> str:
|
||||
"""Column statistical analysis tool"""
|
||||
return await self.call_tool("column_analysis", {
|
||||
"table_name": table_name,
|
||||
"column_name": column_name,
|
||||
"analysis_type": analysis_type
|
||||
})
|
||||
|
||||
# Database performance monitoring tool
|
||||
@mcp.tool(
|
||||
"performance_stats[Experimental]",
|
||||
description="""[Important]: This tool is experimental and may not be fully functional!
|
||||
[Function Description]: Get database performance statistics information.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- metric_type (string) [Optional] - Type of performance metrics to retrieve, default is "queries"
|
||||
* "queries": Query performance metrics (execution time, frequency, etc.)
|
||||
* "connections": Connection statistics (active connections, connection pool status)
|
||||
* "tables": Table-level statistics (size, row count, access patterns)
|
||||
* "system": System-level metrics (CPU, memory, disk usage)
|
||||
|
||||
- time_range (string) [Optional] - Time range for statistics, default is "1h"
|
||||
* "1h": Last 1 hour
|
||||
* "6h": Last 6 hours
|
||||
* "24h": Last 24 hours
|
||||
* "7d": Last 7 days
|
||||
""",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"metric_type": {
|
||||
"type": "string",
|
||||
"enum": ["queries", "connections", "tables", "system"],
|
||||
"description": "Performance metric type",
|
||||
"default": "queries",
|
||||
},
|
||||
"time_range": {
|
||||
"type": "string",
|
||||
"enum": ["1h", "6h", "24h", "7d"],
|
||||
"description": "Time range",
|
||||
"default": "1h",
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
async def performance_stats_tool(
|
||||
metric_type: str = "queries",
|
||||
time_range: str = "1h"
|
||||
) -> str:
|
||||
"""Database performance monitoring tool"""
|
||||
return await self.call_tool("performance_stats", {
|
||||
"metric_type": metric_type,
|
||||
"time_range": time_range
|
||||
})
|
||||
|
||||
# SQL query execution tool (supports catalog federation queries)
|
||||
@mcp.tool(
|
||||
"exec_query",
|
||||
description="""[Function Description]: Execute SQL query and return result command with catalog federation support.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- sql (string) [Required] - SQL statement to execute. MUST use three-part naming for all table references: 'catalog_name.db_name.table_name'. For internal tables use 'internal.db_name.table_name', for external tables use 'catalog_name.db_name.table_name'
|
||||
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database
|
||||
|
||||
- catalog_name (string) [Optional] - Reference catalog name for context, defaults to current catalog
|
||||
|
||||
- max_rows (integer) [Optional] - Maximum number of rows to return, default 100
|
||||
|
||||
- timeout (integer) [Optional] - Query timeout in seconds, default 30
|
||||
""",
|
||||
)
|
||||
async def exec_query_tool(
|
||||
sql: str,
|
||||
db_name: str = None,
|
||||
catalog_name: str = None,
|
||||
max_rows: int = 100,
|
||||
timeout: int = 30,
|
||||
) -> str:
|
||||
"""Execute SQL query (supports federation queries)"""
|
||||
return await self.call_tool("exec_query", {
|
||||
"sql": sql,
|
||||
"db_name": db_name,
|
||||
"catalog_name": catalog_name,
|
||||
"max_rows": max_rows,
|
||||
"timeout": timeout
|
||||
})
|
||||
|
||||
# Get table schema tool
|
||||
@mcp.tool(
|
||||
"get_table_schema",
|
||||
description="""[Function Description]: Get detailed structure information of the specified table (columns, types, comments, etc.).
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- table_name (string) [Required] - Name of the table to query
|
||||
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database
|
||||
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
|
||||
""",
|
||||
)
|
||||
async def get_table_schema_tool(
|
||||
table_name: str, db_name: str = None, catalog_name: str = None
|
||||
) -> str:
|
||||
"""Get table schema information"""
|
||||
return await self.call_tool("get_table_schema", {
|
||||
"table_name": table_name,
|
||||
"db_name": db_name,
|
||||
"catalog_name": catalog_name
|
||||
})
|
||||
|
||||
# Get database table list tool
|
||||
@mcp.tool(
|
||||
"get_db_table_list",
|
||||
description="""[Function Description]: Get a list of all table names in the specified database.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database
|
||||
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
|
||||
""",
|
||||
)
|
||||
async def get_db_table_list_tool(
|
||||
db_name: str = None, catalog_name: str = None
|
||||
) -> str:
|
||||
"""Get database table list"""
|
||||
return await self.call_tool("get_db_table_list", {
|
||||
"db_name": db_name,
|
||||
"catalog_name": catalog_name
|
||||
})
|
||||
|
||||
# Get database list tool
|
||||
@mcp.tool(
|
||||
"get_db_list",
|
||||
description="""[Function Description]: Get a list of all database names on the server.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
|
||||
""",
|
||||
)
|
||||
async def get_db_list_tool(catalog_name: str = None) -> str:
|
||||
"""Get database list"""
|
||||
return await self.call_tool("get_db_list", {
|
||||
"catalog_name": catalog_name
|
||||
})
|
||||
|
||||
# Get table comment tool
|
||||
@mcp.tool(
|
||||
"get_table_comment",
|
||||
description="""[Function Description]: Get the comment information for the specified table.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- table_name (string) [Required] - Name of the table to query
|
||||
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database
|
||||
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
|
||||
""",
|
||||
)
|
||||
async def get_table_comment_tool(
|
||||
table_name: str, db_name: str = None, catalog_name: str = None
|
||||
) -> str:
|
||||
"""Get table comment"""
|
||||
return await self.call_tool("get_table_comment", {
|
||||
"table_name": table_name,
|
||||
"db_name": db_name,
|
||||
"catalog_name": catalog_name
|
||||
})
|
||||
|
||||
# Get table column comments tool
|
||||
@mcp.tool(
|
||||
"get_table_column_comments",
|
||||
description="""[Function Description]: Get comment information for all columns in the specified table.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- table_name (string) [Required] - Name of the table to query
|
||||
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database
|
||||
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
|
||||
""",
|
||||
)
|
||||
async def get_table_column_comments_tool(
|
||||
table_name: str, db_name: str = None, catalog_name: str = None
|
||||
) -> str:
|
||||
"""Get table column comments"""
|
||||
return await self.call_tool("get_table_column_comments", {
|
||||
"table_name": table_name,
|
||||
"db_name": db_name,
|
||||
"catalog_name": catalog_name
|
||||
})
|
||||
|
||||
# Get table indexes tool
|
||||
@mcp.tool(
|
||||
"get_table_indexes",
|
||||
description="""[Function Description]: Get index information for the specified table.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- table_name (string) [Required] - Name of the table to query
|
||||
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database
|
||||
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
|
||||
""",
|
||||
)
|
||||
async def get_table_indexes_tool(
|
||||
table_name: str, db_name: str = None, catalog_name: str = None
|
||||
) -> str:
|
||||
"""Get table indexes"""
|
||||
return await self.call_tool("get_table_indexes", {
|
||||
"table_name": table_name,
|
||||
"db_name": db_name,
|
||||
"catalog_name": catalog_name
|
||||
})
|
||||
|
||||
# Get audit logs tool
|
||||
@mcp.tool(
|
||||
"get_recent_audit_logs",
|
||||
description="""[Function Description]: Get audit log records for a recent period.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- days (integer) [Optional] - Number of recent days of logs to retrieve, default is 7
|
||||
|
||||
- limit (integer) [Optional] - Maximum number of records to return, default is 100
|
||||
""",
|
||||
)
|
||||
async def get_recent_audit_logs_tool(
|
||||
days: int = 7, limit: int = 100
|
||||
) -> str:
|
||||
"""Get audit logs"""
|
||||
return await self.call_tool("get_recent_audit_logs", {
|
||||
"days": days,
|
||||
"limit": limit
|
||||
})
|
||||
|
||||
# Get catalog list tool
|
||||
@mcp.tool(
|
||||
"get_catalog_list",
|
||||
description="""[Function Description]: Get a list of all catalog names on the server.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- random_string (string) [Required] - Unique identifier for the tool call
|
||||
""",
|
||||
)
|
||||
async def get_catalog_list_tool(random_string: str) -> str:
|
||||
"""Get catalog list"""
|
||||
return await self.call_tool("get_catalog_list", {
|
||||
"random_string": random_string
|
||||
})
|
||||
|
||||
logger.info("Successfully registered 11 tools to MCP server (2 core tools + 9 migrated tools)")
|
||||
|
||||
async def list_tools(self) -> List[Tool]:
|
||||
"""List all available query tools (for stdio mode)"""
|
||||
tools = [
|
||||
Tool(
|
||||
name="column_analysis[Experimental]",
|
||||
description="""[Important]: This tool is experimental and may not be fully functional!
|
||||
[Function Description]: Analyze statistical information and data distribution of the specified column.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- table_name (string) [Required] - Name of the table to analyze
|
||||
|
||||
- column_name (string) [Required] - Name of the column to analyze
|
||||
|
||||
- analysis_type (string) [Optional] - Type of analysis to perform, default is "basic"
|
||||
* "basic": Basic statistics (count, null values, distinct values)
|
||||
* "distribution": Data distribution analysis (frequency, percentiles)
|
||||
* "detailed": Comprehensive analysis including all above plus patterns and outliers
|
||||
""",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"table_name": {"type": "string", "description": "Table name"},
|
||||
"column_name": {
|
||||
"type": "string",
|
||||
"description": "Column name to analyze",
|
||||
},
|
||||
"analysis_type": {
|
||||
"type": "string",
|
||||
"enum": ["basic", "distribution", "detailed"],
|
||||
"description": "Analysis type",
|
||||
"default": "basic",
|
||||
},
|
||||
},
|
||||
"required": ["table_name", "column_name"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="performance_stats",
|
||||
description="""[Function Description]: Get database performance statistics information.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- metric_type (string) [Optional] - Type of performance metrics to retrieve, default is "queries"
|
||||
* "queries": Query performance metrics (execution time, frequency, etc.)
|
||||
* "connections": Connection statistics (active connections, connection pool status)
|
||||
* "tables": Table-level statistics (size, row count, access patterns)
|
||||
* "system": System-level metrics (CPU, memory, disk usage)
|
||||
|
||||
- time_range (string) [Optional] - Time range for statistics, default is "1h"
|
||||
* "1h": Last 1 hour
|
||||
* "6h": Last 6 hours
|
||||
* "24h": Last 24 hours
|
||||
* "7d": Last 7 days
|
||||
""",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"metric_type": {
|
||||
"type": "string",
|
||||
"enum": ["queries", "connections", "tables", "system"],
|
||||
"description": "Performance metric type",
|
||||
"default": "queries",
|
||||
},
|
||||
"time_range": {
|
||||
"type": "string",
|
||||
"enum": ["1h", "6h", "24h", "7d"],
|
||||
"description": "Time range",
|
||||
"default": "1h",
|
||||
},
|
||||
},
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="exec_query",
|
||||
description="""[Function Description]: Execute SQL query and return result command with catalog federation support.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- sql (string) [Required] - SQL statement to execute. MUST use three-part naming for all table references: 'catalog_name.db_name.table_name'. For internal tables use 'internal.db_name.table_name', for external tables use 'catalog_name.db_name.table_name'
|
||||
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database
|
||||
|
||||
- catalog_name (string) [Optional] - Reference catalog name for context, defaults to current catalog
|
||||
|
||||
- max_rows (integer) [Optional] - Maximum number of rows to return, default 100
|
||||
|
||||
- timeout (integer) [Optional] - Query timeout in seconds, default 30
|
||||
""",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"sql": {"type": "string", "description": "SQL statement to execute, must use three-part naming"},
|
||||
"db_name": {"type": "string", "description": "Target database name"},
|
||||
"catalog_name": {"type": "string", "description": "Catalog name"},
|
||||
"max_rows": {"type": "integer", "description": "Maximum number of rows to return", "default": 100},
|
||||
"timeout": {"type": "integer", "description": "Timeout in seconds", "default": 30},
|
||||
},
|
||||
"required": ["sql"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="get_table_schema",
|
||||
description="""[Function Description]: Get detailed structure information of the specified table (columns, types, comments, etc.).
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- table_name (string) [Required] - Name of the table to query
|
||||
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database
|
||||
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
|
||||
""",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"table_name": {"type": "string", "description": "Table name"},
|
||||
"db_name": {"type": "string", "description": "Database name"},
|
||||
"catalog_name": {"type": "string", "description": "Catalog name"},
|
||||
},
|
||||
"required": ["table_name"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="get_db_table_list",
|
||||
description="""[Function Description]: Get a list of all table names in the specified database.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database
|
||||
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
|
||||
""",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"db_name": {"type": "string", "description": "Database name"},
|
||||
"catalog_name": {"type": "string", "description": "Catalog name"},
|
||||
},
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="get_db_list",
|
||||
description="""[Function Description]: Get a list of all database names on the server.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
|
||||
""",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"catalog_name": {"type": "string", "description": "Catalog name"},
|
||||
},
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="get_table_comment",
|
||||
description="""[Function Description]: Get the comment information for the specified table.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- table_name (string) [Required] - Name of the table to query
|
||||
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database
|
||||
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
|
||||
""",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"table_name": {"type": "string", "description": "Table name"},
|
||||
"db_name": {"type": "string", "description": "Database name"},
|
||||
"catalog_name": {"type": "string", "description": "Catalog name"},
|
||||
},
|
||||
"required": ["table_name"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="get_table_column_comments",
|
||||
description="""[Function Description]: Get comment information for all columns in the specified table.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- table_name (string) [Required] - Name of the table to query
|
||||
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database
|
||||
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
|
||||
""",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"table_name": {"type": "string", "description": "Table name"},
|
||||
"db_name": {"type": "string", "description": "Database name"},
|
||||
"catalog_name": {"type": "string", "description": "Catalog name"},
|
||||
},
|
||||
"required": ["table_name"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="get_table_indexes",
|
||||
description="""[Function Description]: Get index information for the specified table.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- table_name (string) [Required] - Name of the table to query
|
||||
|
||||
- db_name (string) [Optional] - Target database name, defaults to the current database
|
||||
|
||||
- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog
|
||||
""",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"table_name": {"type": "string", "description": "Table name"},
|
||||
"db_name": {"type": "string", "description": "Database name"},
|
||||
"catalog_name": {"type": "string", "description": "Catalog name"},
|
||||
},
|
||||
"required": ["table_name"],
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="get_recent_audit_logs",
|
||||
description="""[Function Description]: Get audit log records for a recent period.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- days (integer) [Optional] - Number of recent days of logs to retrieve, default is 7
|
||||
|
||||
- limit (integer) [Optional] - Maximum number of records to return, default is 100
|
||||
""",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"days": {"type": "integer", "description": "Number of recent days", "default": 7},
|
||||
"limit": {"type": "integer", "description": "Maximum number of records", "default": 100},
|
||||
},
|
||||
},
|
||||
),
|
||||
Tool(
|
||||
name="get_catalog_list",
|
||||
description="""[Function Description]: Get a list of all catalog names on the server.
|
||||
|
||||
[Parameter Content]:
|
||||
|
||||
- random_string (string) [Required] - Unique identifier for the tool call
|
||||
""",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"random_string": {"type": "string", "description": "Unique identifier"},
|
||||
},
|
||||
"required": ["random_string"],
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
return tools
|
||||
|
||||
async def call_tool(self, name: str, arguments: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Call the specified query tool (tool routing and scheduling center)
|
||||
"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# Tool routing - dispatch requests to corresponding business logic processors
|
||||
if name == "column_analysis":
|
||||
result = await self._column_analysis_tool(arguments)
|
||||
elif name == "performance_stats":
|
||||
result = await self._performance_stats_tool(arguments)
|
||||
# ===== 9 tool routes migrated from source project =====
|
||||
elif name == "exec_query":
|
||||
result = await self._exec_query_tool(arguments)
|
||||
elif name == "get_table_schema":
|
||||
result = await self._get_table_schema_tool(arguments)
|
||||
elif name == "get_db_table_list":
|
||||
result = await self._get_db_table_list_tool(arguments)
|
||||
elif name == "get_db_list":
|
||||
result = await self._get_db_list_tool(arguments)
|
||||
elif name == "get_table_comment":
|
||||
result = await self._get_table_comment_tool(arguments)
|
||||
elif name == "get_table_column_comments":
|
||||
result = await self._get_table_column_comments_tool(arguments)
|
||||
elif name == "get_table_indexes":
|
||||
result = await self._get_table_indexes_tool(arguments)
|
||||
elif name == "get_recent_audit_logs":
|
||||
result = await self._get_recent_audit_logs_tool(arguments)
|
||||
elif name == "get_catalog_list":
|
||||
result = await self._get_catalog_list_tool(arguments)
|
||||
else:
|
||||
raise ValueError(f"Unknown tool: {name}")
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
# Add execution information
|
||||
if isinstance(result, dict):
|
||||
result["_execution_info"] = {
|
||||
"tool_name": name,
|
||||
"execution_time": round(execution_time, 3),
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tool call failed {name}: {str(e)}")
|
||||
error_result = {
|
||||
"error": str(e),
|
||||
"tool_name": name,
|
||||
"arguments": arguments,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
return json.dumps(error_result, ensure_ascii=False, indent=2)
|
||||
|
||||
# The following are tool routing methods, responsible for calling corresponding business logic processors
|
||||
|
||||
async def _column_analysis_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Column statistical analysis tool routing"""
|
||||
table_name = arguments.get("table_name")
|
||||
column_name = arguments.get("column_name")
|
||||
analysis_type = arguments.get("analysis_type", "basic")
|
||||
|
||||
# Delegate to table analyzer for processing
|
||||
return await self.table_analyzer.analyze_column(
|
||||
table_name, column_name, analysis_type
|
||||
)
|
||||
|
||||
async def _performance_stats_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Database performance statistics tool routing"""
|
||||
metric_type = arguments.get("metric_type", "queries")
|
||||
time_range = arguments.get("time_range", "1h")
|
||||
|
||||
# Delegate to performance monitor for processing
|
||||
return await self.performance_monitor.get_performance_stats(
|
||||
metric_type, time_range
|
||||
)
|
||||
|
||||
async def _exec_query_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""SQL query execution tool routing (supports federation queries)"""
|
||||
sql = arguments.get("sql")
|
||||
db_name = arguments.get("db_name")
|
||||
catalog_name = arguments.get("catalog_name")
|
||||
max_rows = arguments.get("max_rows", 100)
|
||||
timeout = arguments.get("timeout", 30)
|
||||
|
||||
# Delegate to metadata extractor for processing
|
||||
return await self.metadata_extractor.exec_query_for_mcp(
|
||||
sql, db_name, catalog_name, max_rows, timeout
|
||||
)
|
||||
|
||||
async def _get_table_schema_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get table schema tool routing"""
|
||||
table_name = arguments.get("table_name")
|
||||
db_name = arguments.get("db_name")
|
||||
catalog_name = arguments.get("catalog_name")
|
||||
|
||||
# Delegate to metadata extractor for processing
|
||||
return await self.metadata_extractor.get_table_schema_for_mcp(
|
||||
table_name, db_name, catalog_name
|
||||
)
|
||||
|
||||
async def _get_db_table_list_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get database table list tool routing"""
|
||||
db_name = arguments.get("db_name")
|
||||
catalog_name = arguments.get("catalog_name")
|
||||
|
||||
# Delegate to metadata extractor for processing
|
||||
return await self.metadata_extractor.get_db_table_list_for_mcp(db_name, catalog_name)
|
||||
|
||||
async def _get_db_list_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get database list tool routing"""
|
||||
catalog_name = arguments.get("catalog_name")
|
||||
|
||||
# Delegate to metadata extractor for processing
|
||||
return await self.metadata_extractor.get_db_list_for_mcp(catalog_name)
|
||||
|
||||
async def _get_table_comment_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get table comment tool routing"""
|
||||
table_name = arguments.get("table_name")
|
||||
db_name = arguments.get("db_name")
|
||||
catalog_name = arguments.get("catalog_name")
|
||||
|
||||
# Delegate to metadata extractor for processing
|
||||
return await self.metadata_extractor.get_table_comment_for_mcp(
|
||||
table_name, db_name, catalog_name
|
||||
)
|
||||
|
||||
async def _get_table_column_comments_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get table column comments tool routing"""
|
||||
table_name = arguments.get("table_name")
|
||||
db_name = arguments.get("db_name")
|
||||
catalog_name = arguments.get("catalog_name")
|
||||
|
||||
# Delegate to metadata extractor for processing
|
||||
return await self.metadata_extractor.get_table_column_comments_for_mcp(
|
||||
table_name, db_name, catalog_name
|
||||
)
|
||||
|
||||
async def _get_table_indexes_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get table indexes tool routing"""
|
||||
table_name = arguments.get("table_name")
|
||||
db_name = arguments.get("db_name")
|
||||
catalog_name = arguments.get("catalog_name")
|
||||
|
||||
# Delegate to metadata extractor for processing
|
||||
return await self.metadata_extractor.get_table_indexes_for_mcp(
|
||||
table_name, db_name, catalog_name
|
||||
)
|
||||
|
||||
async def _get_recent_audit_logs_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get audit logs tool routing"""
|
||||
days = arguments.get("days", 7)
|
||||
limit = arguments.get("limit", 100)
|
||||
|
||||
# Delegate to metadata extractor for processing
|
||||
return await self.metadata_extractor.get_recent_audit_logs_for_mcp(days, limit)
|
||||
|
||||
async def _get_catalog_list_tool(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get catalog list tool routing"""
|
||||
# random_string parameter is required in the source project, but not actually used in business logic
|
||||
# Here we ignore it and directly call business logic
|
||||
|
||||
# Delegate to metadata extractor for processing
|
||||
return await self.metadata_extractor.get_catalog_list_for_mcp()
|
||||
@@ -1 +1,10 @@
|
||||
# Mark directory as a package
|
||||
"""
|
||||
Utilities Package - Contains utility classes and helper functions.
|
||||
|
||||
This package includes:
|
||||
- Database connection and operations
|
||||
- Configuration management
|
||||
- Security utilities
|
||||
- Query execution helpers
|
||||
- Logging configuration
|
||||
"""
|
||||
|
||||
318
doris_mcp_server/utils/analysis_tools.py
Normal file
318
doris_mcp_server/utils/analysis_tools.py
Normal file
@@ -0,0 +1,318 @@
|
||||
"""
|
||||
Data Analysis Tools Module
|
||||
Provides data analysis functions including table analysis, column statistics, performance monitoring, etc.
|
||||
"""
|
||||
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TableAnalyzer:
|
||||
"""Table analyzer"""
|
||||
|
||||
def __init__(self, connection_manager: DorisConnectionManager):
|
||||
self.connection_manager = connection_manager
|
||||
|
||||
async def get_table_summary(
|
||||
self,
|
||||
table_name: str,
|
||||
include_sample: bool = True,
|
||||
sample_size: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""Get table summary information"""
|
||||
connection = await self.connection_manager.get_connection("query")
|
||||
|
||||
# Get table basic information
|
||||
table_info_sql = f"""
|
||||
SELECT
|
||||
table_name,
|
||||
table_comment,
|
||||
table_rows,
|
||||
create_time,
|
||||
engine
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = '{table_name}'
|
||||
"""
|
||||
|
||||
table_info_result = await connection.execute(table_info_sql)
|
||||
if not table_info_result.data:
|
||||
raise ValueError(f"Table {table_name} does not exist")
|
||||
|
||||
table_info = table_info_result.data[0]
|
||||
|
||||
# Get column information
|
||||
columns_sql = f"""
|
||||
SELECT
|
||||
column_name,
|
||||
data_type,
|
||||
is_nullable,
|
||||
column_comment
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = '{table_name}'
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
|
||||
columns_result = await connection.execute(columns_sql)
|
||||
|
||||
summary = {
|
||||
"table_name": table_info["table_name"],
|
||||
"comment": table_info.get("table_comment"),
|
||||
"row_count": table_info.get("table_rows", 0),
|
||||
"create_time": str(table_info.get("create_time")),
|
||||
"engine": table_info.get("engine"),
|
||||
"column_count": len(columns_result.data),
|
||||
"columns": columns_result.data,
|
||||
}
|
||||
|
||||
# Get sample data
|
||||
if include_sample and sample_size > 0:
|
||||
sample_sql = f"SELECT * FROM {table_name} LIMIT {sample_size}"
|
||||
sample_result = await connection.execute(sample_sql)
|
||||
summary["sample_data"] = sample_result.data
|
||||
|
||||
return summary
|
||||
|
||||
async def analyze_column(
|
||||
self,
|
||||
table_name: str,
|
||||
column_name: str,
|
||||
analysis_type: str = "basic"
|
||||
) -> Dict[str, Any]:
|
||||
"""Analyze column statistics"""
|
||||
try:
|
||||
connection = await self.connection_manager.get_connection("query")
|
||||
|
||||
# Basic statistics
|
||||
basic_stats_sql = f"""
|
||||
SELECT
|
||||
'{column_name}' as column_name,
|
||||
COUNT(*) as total_count,
|
||||
COUNT({column_name}) as non_null_count,
|
||||
COUNT(DISTINCT {column_name}) as distinct_count
|
||||
FROM {table_name}
|
||||
"""
|
||||
|
||||
basic_result = await connection.execute(basic_stats_sql)
|
||||
if not basic_result.data:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Unable to get statistics for table {table_name} column {column_name}"
|
||||
}
|
||||
|
||||
analysis = basic_result.data[0].copy()
|
||||
analysis["success"] = True
|
||||
analysis["analysis_type"] = analysis_type
|
||||
|
||||
if analysis_type in ["distribution", "detailed"]:
|
||||
# Data distribution analysis
|
||||
distribution_sql = f"""
|
||||
SELECT
|
||||
{column_name} as value,
|
||||
COUNT(*) as frequency
|
||||
FROM {table_name}
|
||||
WHERE {column_name} IS NOT NULL
|
||||
GROUP BY {column_name}
|
||||
ORDER BY frequency DESC
|
||||
LIMIT 20
|
||||
"""
|
||||
|
||||
distribution_result = await connection.execute(distribution_sql)
|
||||
analysis["value_distribution"] = distribution_result.data
|
||||
|
||||
if analysis_type == "detailed":
|
||||
# Detailed statistics (for numeric types)
|
||||
try:
|
||||
numeric_stats_sql = f"""
|
||||
SELECT
|
||||
MIN({column_name}) as min_value,
|
||||
MAX({column_name}) as max_value,
|
||||
AVG({column_name}) as avg_value
|
||||
FROM {table_name}
|
||||
WHERE {column_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
numeric_result = await connection.execute(numeric_stats_sql)
|
||||
if numeric_result.data:
|
||||
analysis.update(numeric_result.data[0])
|
||||
except Exception:
|
||||
# Non-numeric columns don't support numeric statistics
|
||||
pass
|
||||
|
||||
return analysis
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Column analysis failed: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"column_name": column_name,
|
||||
"table_name": table_name
|
||||
}
|
||||
|
||||
async def analyze_table_relationships(
|
||||
self,
|
||||
table_name: str,
|
||||
depth: int = 2
|
||||
) -> Dict[str, Any]:
|
||||
"""Analyze table relationships"""
|
||||
connection = await self.connection_manager.get_connection("system")
|
||||
|
||||
# Get table basic information
|
||||
table_info_sql = f"""
|
||||
SELECT
|
||||
table_name,
|
||||
table_comment,
|
||||
table_rows
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = '{table_name}'
|
||||
"""
|
||||
|
||||
table_result = await connection.execute(table_info_sql)
|
||||
if not table_result.data:
|
||||
raise ValueError(f"Table {table_name} does not exist")
|
||||
|
||||
# Get all tables list (for analyzing potential relationships)
|
||||
all_tables_sql = """
|
||||
SELECT
|
||||
table_name,
|
||||
table_comment
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_type = 'BASE TABLE'
|
||||
AND table_name != %s
|
||||
"""
|
||||
|
||||
all_tables_result = await connection.execute(all_tables_sql, (table_name,))
|
||||
|
||||
return {
|
||||
"center_table": table_result.data[0],
|
||||
"related_tables": all_tables_result.data,
|
||||
"depth": depth,
|
||||
"note": "Table relationship analysis based on column name similarity and business logic inference",
|
||||
}
|
||||
|
||||
|
||||
class PerformanceMonitor:
|
||||
"""Performance monitor"""
|
||||
|
||||
def __init__(self, connection_manager: DorisConnectionManager):
|
||||
self.connection_manager = connection_manager
|
||||
|
||||
async def get_performance_stats(
|
||||
self,
|
||||
metric_type: str = "queries",
|
||||
time_range: str = "1h"
|
||||
) -> Dict[str, Any]:
|
||||
"""Get performance statistics"""
|
||||
connection = await self.connection_manager.get_connection("system")
|
||||
|
||||
# Convert time range to seconds
|
||||
time_mapping = {
|
||||
"1h": 3600,
|
||||
"6h": 21600,
|
||||
"24h": 86400,
|
||||
"7d": 604800
|
||||
}
|
||||
|
||||
seconds = time_mapping.get(time_range, 3600)
|
||||
|
||||
if metric_type == "queries":
|
||||
# Query performance metrics
|
||||
stats = {
|
||||
"metric_type": "queries",
|
||||
"time_range": time_range,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"total_queries": 0,
|
||||
"avg_execution_time": 0.0,
|
||||
"slow_queries": 0,
|
||||
"error_queries": 0,
|
||||
"note": "Query performance statistics (simulated data)"
|
||||
}
|
||||
|
||||
elif metric_type == "connections":
|
||||
# Connection statistics
|
||||
connection_metrics = await self.connection_manager.get_metrics()
|
||||
stats = {
|
||||
"metric_type": "connections",
|
||||
"time_range": time_range,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"total_connections": connection_metrics.total_connections,
|
||||
"active_connections": connection_metrics.active_connections,
|
||||
"idle_connections": connection_metrics.idle_connections,
|
||||
"failed_connections": connection_metrics.failed_connections,
|
||||
"connection_errors": connection_metrics.connection_errors,
|
||||
"avg_connection_time": connection_metrics.avg_connection_time,
|
||||
"last_health_check": connection_metrics.last_health_check.isoformat() if connection_metrics.last_health_check else None
|
||||
}
|
||||
|
||||
elif metric_type == "tables":
|
||||
# Table-level statistics
|
||||
tables_sql = """
|
||||
SELECT
|
||||
table_name,
|
||||
table_rows,
|
||||
data_length,
|
||||
index_length,
|
||||
create_time,
|
||||
update_time
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_type = 'BASE TABLE'
|
||||
ORDER BY table_rows DESC
|
||||
LIMIT 20
|
||||
"""
|
||||
|
||||
tables_result = await connection.execute(tables_sql)
|
||||
stats = {
|
||||
"metric_type": "tables",
|
||||
"time_range": time_range,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"table_count": len(tables_result.data),
|
||||
"tables": tables_result.data
|
||||
}
|
||||
|
||||
elif metric_type == "system":
|
||||
# System-level metrics (simulated)
|
||||
stats = {
|
||||
"metric_type": "system",
|
||||
"time_range": time_range,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"cpu_usage": 45.2,
|
||||
"memory_usage": 68.5,
|
||||
"disk_usage": 72.1,
|
||||
"network_io": {
|
||||
"bytes_sent": 1024000,
|
||||
"bytes_received": 2048000
|
||||
},
|
||||
"note": "System metrics (simulated data)"
|
||||
}
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported metric type: {metric_type}")
|
||||
|
||||
return stats
|
||||
|
||||
async def get_query_history(
|
||||
self,
|
||||
limit: int = 50,
|
||||
order_by: str = "time"
|
||||
) -> Dict[str, Any]:
|
||||
"""Get query history"""
|
||||
# Since Doris doesn't have a built-in query history table,
|
||||
# we return simulated data
|
||||
return {
|
||||
"total_queries": 0,
|
||||
"queries": [],
|
||||
"limit": limit,
|
||||
"order_by": order_by,
|
||||
"note": "Query history feature requires audit log configuration"
|
||||
}
|
||||
608
doris_mcp_server/utils/config.py
Normal file
608
doris_mcp_server/utils/config.py
Normal file
@@ -0,0 +1,608 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Doris Configuration Management Module
|
||||
Implements configuration loading, validation and management functionality
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
try:
|
||||
from dotenv import load_dotenv
|
||||
except ImportError:
|
||||
load_dotenv = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatabaseConfig:
|
||||
"""Database connection configuration"""
|
||||
|
||||
host: str = "localhost"
|
||||
port: int = 9030
|
||||
user: str = "root"
|
||||
password: str = ""
|
||||
database: str = "test"
|
||||
charset: str = "utf8mb4"
|
||||
|
||||
# Connection pool configuration
|
||||
min_connections: int = 5
|
||||
max_connections: int = 20
|
||||
connection_timeout: int = 30
|
||||
health_check_interval: int = 60
|
||||
max_connection_age: int = 3600
|
||||
|
||||
|
||||
@dataclass
|
||||
class SecurityConfig:
|
||||
"""Security configuration"""
|
||||
|
||||
# Authentication configuration
|
||||
auth_type: str = "token" # token, basic, oauth
|
||||
token_secret: str = "default_secret"
|
||||
token_expiry: int = 3600
|
||||
|
||||
# SQL security configuration
|
||||
blocked_keywords: list[str] = field(
|
||||
default_factory=lambda: [
|
||||
"DROP",
|
||||
"DELETE",
|
||||
"TRUNCATE",
|
||||
"ALTER",
|
||||
"CREATE",
|
||||
"INSERT",
|
||||
"UPDATE",
|
||||
"GRANT",
|
||||
"REVOKE",
|
||||
]
|
||||
)
|
||||
max_query_complexity: int = 100
|
||||
max_result_rows: int = 10000
|
||||
|
||||
# Sensitive table configuration
|
||||
sensitive_tables: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
# Data masking configuration
|
||||
enable_masking: bool = True
|
||||
masking_rules: list[dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PerformanceConfig:
|
||||
"""Performance configuration"""
|
||||
|
||||
# Query cache configuration
|
||||
enable_query_cache: bool = True
|
||||
cache_ttl: int = 300
|
||||
max_cache_size: int = 1000
|
||||
|
||||
# Concurrency control configuration
|
||||
max_concurrent_queries: int = 50
|
||||
query_timeout: int = 300
|
||||
|
||||
# Connection pool optimization configuration
|
||||
connection_pool_size: int = 20
|
||||
idle_timeout: int = 1800
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoggingConfig:
|
||||
"""Logging configuration"""
|
||||
|
||||
level: str = "INFO"
|
||||
format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
file_path: str | None = None
|
||||
max_file_size: int = 10 * 1024 * 1024 # 10MB
|
||||
backup_count: int = 5
|
||||
|
||||
# Audit log configuration
|
||||
enable_audit: bool = True
|
||||
audit_file_path: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MonitoringConfig:
|
||||
"""Monitoring configuration"""
|
||||
|
||||
# Metrics collection configuration
|
||||
enable_metrics: bool = True
|
||||
metrics_port: int = 8081
|
||||
metrics_path: str = "/metrics"
|
||||
|
||||
# Health check configuration
|
||||
health_check_port: int = 8082
|
||||
health_check_path: str = "/health"
|
||||
|
||||
# Alert configuration
|
||||
enable_alerts: bool = False
|
||||
alert_webhook_url: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DorisConfig:
|
||||
"""Doris MCP Server complete configuration"""
|
||||
|
||||
# Basic configuration
|
||||
server_name: str = "doris-mcp-server"
|
||||
server_version: str = "1.0.0"
|
||||
server_port: int = 8080
|
||||
|
||||
# Sub-configuration modules
|
||||
database: DatabaseConfig = field(default_factory=DatabaseConfig)
|
||||
security: SecurityConfig = field(default_factory=SecurityConfig)
|
||||
performance: PerformanceConfig = field(default_factory=PerformanceConfig)
|
||||
logging: LoggingConfig = field(default_factory=LoggingConfig)
|
||||
monitoring: MonitoringConfig = field(default_factory=MonitoringConfig)
|
||||
|
||||
# Custom configuration
|
||||
custom_config: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, config_path: str) -> "DorisConfig":
|
||||
"""Load configuration from file"""
|
||||
config_file = Path(config_path)
|
||||
|
||||
if not config_file.exists():
|
||||
raise FileNotFoundError(f"Configuration file does not exist: {config_path}")
|
||||
|
||||
try:
|
||||
with open(config_file, encoding="utf-8") as f:
|
||||
if config_file.suffix.lower() == ".json":
|
||||
config_data = json.load(f)
|
||||
else:
|
||||
# Support other formats (like YAML)
|
||||
raise ValueError(f"Unsupported configuration file format: {config_file.suffix}")
|
||||
|
||||
return cls._from_dict(config_data)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to load configuration file: {e}")
|
||||
|
||||
@classmethod
|
||||
def from_env(cls, env_file: str | None = None) -> "DorisConfig":
|
||||
"""Load configuration from environment variables
|
||||
|
||||
Args:
|
||||
env_file: .env file path, if None, search in the following order:
|
||||
.env, .env.local, .env.production, .env.development
|
||||
"""
|
||||
# Load .env file
|
||||
if load_dotenv is not None:
|
||||
if env_file:
|
||||
# Load specified .env file
|
||||
if Path(env_file).exists():
|
||||
load_dotenv(env_file)
|
||||
logging.getLogger(__name__).info(f"Loaded environment configuration file: {env_file}")
|
||||
else:
|
||||
logging.getLogger(__name__).warning(f"Environment configuration file does not exist: {env_file}")
|
||||
else:
|
||||
# Load .env files in priority order
|
||||
env_files = [".env", ".env.local", ".env.production", ".env.development"]
|
||||
for env_path in env_files:
|
||||
if Path(env_path).exists():
|
||||
load_dotenv(env_path)
|
||||
logging.getLogger(__name__).info(f"Loaded environment configuration file: {env_path}")
|
||||
break
|
||||
else:
|
||||
logging.getLogger(__name__).info("No .env configuration file found, using system environment variables")
|
||||
else:
|
||||
logging.getLogger(__name__).warning("python-dotenv not installed, cannot load .env files")
|
||||
|
||||
config = cls()
|
||||
|
||||
# Database configuration
|
||||
config.database.host = os.getenv("DORIS_HOST", config.database.host)
|
||||
config.database.port = int(os.getenv("DORIS_PORT", str(config.database.port)))
|
||||
config.database.user = os.getenv("DORIS_USER", config.database.user)
|
||||
config.database.password = os.getenv("DORIS_PASSWORD", config.database.password)
|
||||
config.database.database = os.getenv("DORIS_DATABASE", config.database.database)
|
||||
|
||||
# Connection pool configuration
|
||||
config.database.min_connections = int(
|
||||
os.getenv("DORIS_MIN_CONNECTIONS", str(config.database.min_connections))
|
||||
)
|
||||
config.database.max_connections = int(
|
||||
os.getenv("DORIS_MAX_CONNECTIONS", str(config.database.max_connections))
|
||||
)
|
||||
config.database.connection_timeout = int(
|
||||
os.getenv("DORIS_CONNECTION_TIMEOUT", str(config.database.connection_timeout))
|
||||
)
|
||||
config.database.health_check_interval = int(
|
||||
os.getenv("DORIS_HEALTH_CHECK_INTERVAL", str(config.database.health_check_interval))
|
||||
)
|
||||
config.database.max_connection_age = int(
|
||||
os.getenv("DORIS_MAX_CONNECTION_AGE", str(config.database.max_connection_age))
|
||||
)
|
||||
|
||||
# Security configuration
|
||||
config.security.auth_type = os.getenv("AUTH_TYPE", config.security.auth_type)
|
||||
config.security.token_secret = os.getenv("TOKEN_SECRET", config.security.token_secret)
|
||||
config.security.token_expiry = int(
|
||||
os.getenv("TOKEN_EXPIRY", str(config.security.token_expiry))
|
||||
)
|
||||
config.security.max_result_rows = int(
|
||||
os.getenv("MAX_RESULT_ROWS", str(config.security.max_result_rows))
|
||||
)
|
||||
config.security.max_query_complexity = int(
|
||||
os.getenv("MAX_QUERY_COMPLEXITY", str(config.security.max_query_complexity))
|
||||
)
|
||||
config.security.enable_masking = (
|
||||
os.getenv("ENABLE_MASKING", str(config.security.enable_masking).lower()).lower() == "true"
|
||||
)
|
||||
|
||||
# Performance configuration
|
||||
config.performance.enable_query_cache = (
|
||||
os.getenv("ENABLE_QUERY_CACHE", "true").lower() == "true"
|
||||
)
|
||||
config.performance.cache_ttl = int(
|
||||
os.getenv("CACHE_TTL", str(config.performance.cache_ttl))
|
||||
)
|
||||
config.performance.max_cache_size = int(
|
||||
os.getenv("MAX_CACHE_SIZE", str(config.performance.max_cache_size))
|
||||
)
|
||||
config.performance.max_concurrent_queries = int(
|
||||
os.getenv("MAX_CONCURRENT_QUERIES", str(config.performance.max_concurrent_queries))
|
||||
)
|
||||
config.performance.query_timeout = int(
|
||||
os.getenv("QUERY_TIMEOUT", str(config.performance.query_timeout))
|
||||
)
|
||||
|
||||
# Logging configuration
|
||||
config.logging.level = os.getenv("LOG_LEVEL", config.logging.level)
|
||||
config.logging.file_path = os.getenv("LOG_FILE_PATH", config.logging.file_path)
|
||||
config.logging.enable_audit = (
|
||||
os.getenv("ENABLE_AUDIT", str(config.logging.enable_audit).lower()).lower() == "true"
|
||||
)
|
||||
config.logging.audit_file_path = os.getenv("AUDIT_FILE_PATH", config.logging.audit_file_path)
|
||||
|
||||
# Monitoring configuration
|
||||
config.monitoring.enable_metrics = (
|
||||
os.getenv("ENABLE_METRICS", "true").lower() == "true"
|
||||
)
|
||||
config.monitoring.metrics_port = int(
|
||||
os.getenv("METRICS_PORT", str(config.monitoring.metrics_port))
|
||||
)
|
||||
config.monitoring.health_check_port = int(
|
||||
os.getenv("HEALTH_CHECK_PORT", str(config.monitoring.health_check_port))
|
||||
)
|
||||
config.monitoring.enable_alerts = (
|
||||
os.getenv("ENABLE_ALERTS", str(config.monitoring.enable_alerts).lower()).lower() == "true"
|
||||
)
|
||||
config.monitoring.alert_webhook_url = os.getenv("ALERT_WEBHOOK_URL", config.monitoring.alert_webhook_url)
|
||||
|
||||
# Server configuration
|
||||
config.server_name = os.getenv("SERVER_NAME", config.server_name)
|
||||
config.server_version = os.getenv("SERVER_VERSION", config.server_version)
|
||||
config.server_port = int(os.getenv("SERVER_PORT", str(config.server_port)))
|
||||
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def _from_dict(cls, config_data: dict[str, Any]) -> "DorisConfig":
|
||||
"""Create configuration object from dictionary"""
|
||||
config = cls()
|
||||
|
||||
# Update basic configuration
|
||||
for key in ["server_name", "server_version", "server_port"]:
|
||||
if key in config_data:
|
||||
setattr(config, key, config_data[key])
|
||||
|
||||
# Update database configuration
|
||||
if "database" in config_data:
|
||||
db_config = config_data["database"]
|
||||
for key, value in db_config.items():
|
||||
if hasattr(config.database, key):
|
||||
setattr(config.database, key, value)
|
||||
|
||||
# Update security configuration
|
||||
if "security" in config_data:
|
||||
sec_config = config_data["security"]
|
||||
for key, value in sec_config.items():
|
||||
if hasattr(config.security, key):
|
||||
setattr(config.security, key, value)
|
||||
|
||||
# Update performance configuration
|
||||
if "performance" in config_data:
|
||||
perf_config = config_data["performance"]
|
||||
for key, value in perf_config.items():
|
||||
if hasattr(config.performance, key):
|
||||
setattr(config.performance, key, value)
|
||||
|
||||
# Update logging configuration
|
||||
if "logging" in config_data:
|
||||
log_config = config_data["logging"]
|
||||
for key, value in log_config.items():
|
||||
if hasattr(config.logging, key):
|
||||
setattr(config.logging, key, value)
|
||||
|
||||
# Update monitoring configuration
|
||||
if "monitoring" in config_data:
|
||||
mon_config = config_data["monitoring"]
|
||||
for key, value in mon_config.items():
|
||||
if hasattr(config.monitoring, key):
|
||||
setattr(config.monitoring, key, value)
|
||||
|
||||
# Custom configuration
|
||||
config.custom_config = config_data.get("custom", {})
|
||||
|
||||
return config
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Convert to dictionary format"""
|
||||
return {
|
||||
"server_name": self.server_name,
|
||||
"server_version": self.server_version,
|
||||
"server_port": self.server_port,
|
||||
"database": {
|
||||
"host": self.database.host,
|
||||
"port": self.database.port,
|
||||
"user": self.database.user,
|
||||
"password": "***", # Hide password
|
||||
"database": self.database.database,
|
||||
"charset": self.database.charset,
|
||||
"min_connections": self.database.min_connections,
|
||||
"max_connections": self.database.max_connections,
|
||||
"connection_timeout": self.database.connection_timeout,
|
||||
"health_check_interval": self.database.health_check_interval,
|
||||
"max_connection_age": self.database.max_connection_age,
|
||||
},
|
||||
"security": {
|
||||
"auth_type": self.security.auth_type,
|
||||
"token_secret": "***", # Hide secret key
|
||||
"token_expiry": self.security.token_expiry,
|
||||
"blocked_keywords": self.security.blocked_keywords,
|
||||
"max_query_complexity": self.security.max_query_complexity,
|
||||
"max_result_rows": self.security.max_result_rows,
|
||||
"sensitive_tables": self.security.sensitive_tables,
|
||||
"enable_masking": self.security.enable_masking,
|
||||
"masking_rules": len(self.security.masking_rules),
|
||||
},
|
||||
"performance": {
|
||||
"enable_query_cache": self.performance.enable_query_cache,
|
||||
"cache_ttl": self.performance.cache_ttl,
|
||||
"max_cache_size": self.performance.max_cache_size,
|
||||
"max_concurrent_queries": self.performance.max_concurrent_queries,
|
||||
"query_timeout": self.performance.query_timeout,
|
||||
"connection_pool_size": self.performance.connection_pool_size,
|
||||
"idle_timeout": self.performance.idle_timeout,
|
||||
},
|
||||
"logging": {
|
||||
"level": self.logging.level,
|
||||
"format": self.logging.format,
|
||||
"file_path": self.logging.file_path,
|
||||
"max_file_size": self.logging.max_file_size,
|
||||
"backup_count": self.logging.backup_count,
|
||||
"enable_audit": self.logging.enable_audit,
|
||||
"audit_file_path": self.logging.audit_file_path,
|
||||
},
|
||||
"monitoring": {
|
||||
"enable_metrics": self.monitoring.enable_metrics,
|
||||
"metrics_port": self.monitoring.metrics_port,
|
||||
"metrics_path": self.monitoring.metrics_path,
|
||||
"health_check_port": self.monitoring.health_check_port,
|
||||
"health_check_path": self.monitoring.health_check_path,
|
||||
"enable_alerts": self.monitoring.enable_alerts,
|
||||
"alert_webhook_url": self.monitoring.alert_webhook_url,
|
||||
},
|
||||
"custom": self.custom_config,
|
||||
}
|
||||
|
||||
def save_to_file(self, config_path: str):
|
||||
"""Save configuration to file"""
|
||||
config_file = Path(config_path)
|
||||
config_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
try:
|
||||
with open(config_file, "w", encoding="utf-8") as f:
|
||||
if config_file.suffix.lower() == ".json":
|
||||
json.dump(self.to_dict(), f, indent=2, ensure_ascii=False)
|
||||
else:
|
||||
raise ValueError(f"Unsupported configuration file format: {config_file.suffix}")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to save configuration file: {e}")
|
||||
|
||||
def validate(self) -> list[str]:
|
||||
"""Validate configuration validity"""
|
||||
errors = []
|
||||
|
||||
# Validate database configuration
|
||||
if not self.database.host:
|
||||
errors.append("Database host address cannot be empty")
|
||||
|
||||
if not (1 <= self.database.port <= 65535):
|
||||
errors.append("Database port must be in the range 1-65535")
|
||||
|
||||
if not self.database.user:
|
||||
errors.append("Database username cannot be empty")
|
||||
|
||||
if self.database.min_connections <= 0:
|
||||
errors.append("Minimum connections must be greater than 0")
|
||||
|
||||
if self.database.max_connections <= self.database.min_connections:
|
||||
errors.append("Maximum connections must be greater than minimum connections")
|
||||
|
||||
# Validate security configuration
|
||||
if self.security.auth_type not in ["token", "basic", "oauth"]:
|
||||
errors.append("Authentication type must be one of token, basic, or oauth")
|
||||
|
||||
if self.security.token_expiry <= 0:
|
||||
errors.append("Token expiry time must be greater than 0")
|
||||
|
||||
if self.security.max_query_complexity <= 0:
|
||||
errors.append("Maximum query complexity must be greater than 0")
|
||||
|
||||
if self.security.max_result_rows <= 0:
|
||||
errors.append("Maximum result rows must be greater than 0")
|
||||
|
||||
# Validate performance configuration
|
||||
if self.performance.cache_ttl <= 0:
|
||||
errors.append("Cache TTL must be greater than 0")
|
||||
|
||||
if self.performance.max_concurrent_queries <= 0:
|
||||
errors.append("Maximum concurrent queries must be greater than 0")
|
||||
|
||||
if self.performance.query_timeout <= 0:
|
||||
errors.append("Query timeout must be greater than 0")
|
||||
|
||||
# Validate logging configuration
|
||||
if self.logging.level not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
|
||||
errors.append("Log level must be one of DEBUG, INFO, WARNING, ERROR, or CRITICAL")
|
||||
|
||||
if self.logging.max_file_size <= 0:
|
||||
errors.append("Maximum log file size must be greater than 0")
|
||||
|
||||
if self.logging.backup_count < 0:
|
||||
errors.append("Log backup count cannot be negative")
|
||||
|
||||
# Validate monitoring configuration
|
||||
if not (1 <= self.monitoring.metrics_port <= 65535):
|
||||
errors.append("Monitoring port must be in the range 1-65535")
|
||||
|
||||
if not (1 <= self.monitoring.health_check_port <= 65535):
|
||||
errors.append("Health check port must be in the range 1-65535")
|
||||
|
||||
return errors
|
||||
|
||||
def get_connection_string(self) -> str:
|
||||
"""Get database connection string (hide password)"""
|
||||
return f"mysql://{self.database.user}:***@{self.database.host}:{self.database.port}/{self.database.database}"
|
||||
|
||||
def get_config_summary(self) -> dict[str, Any]:
|
||||
"""Get configuration summary information"""
|
||||
return {
|
||||
"server": f"{self.server_name} v{self.server_version}",
|
||||
"database": f"{self.database.host}:{self.database.port}/{self.database.database}",
|
||||
"connection_pool": f"{self.database.min_connections}-{self.database.max_connections}",
|
||||
"security": {
|
||||
"auth_type": self.security.auth_type,
|
||||
"masking_enabled": self.security.enable_masking,
|
||||
"blocked_keywords_count": len(self.security.blocked_keywords),
|
||||
},
|
||||
"performance": {
|
||||
"cache_enabled": self.performance.enable_query_cache,
|
||||
"max_concurrent": self.performance.max_concurrent_queries,
|
||||
"query_timeout": self.performance.query_timeout,
|
||||
},
|
||||
"monitoring": {
|
||||
"metrics_enabled": self.monitoring.enable_metrics,
|
||||
"alerts_enabled": self.monitoring.enable_alerts,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
"""Configuration manager class"""
|
||||
|
||||
def __init__(self, config: DorisConfig):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def setup_logging(self):
|
||||
"""Setup logging configuration"""
|
||||
# Configure root logger
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(getattr(logging, self.config.logging.level.upper()))
|
||||
|
||||
# Clear existing handlers
|
||||
for handler in root_logger.handlers[:]:
|
||||
root_logger.removeHandler(handler)
|
||||
|
||||
# Create formatter
|
||||
formatter = logging.Formatter(self.config.logging.format)
|
||||
|
||||
# Console handler
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setFormatter(formatter)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
||||
# File handler (if configured)
|
||||
if self.config.logging.file_path:
|
||||
try:
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
file_handler = RotatingFileHandler(
|
||||
self.config.logging.file_path,
|
||||
maxBytes=self.config.logging.max_file_size,
|
||||
backupCount=self.config.logging.backup_count,
|
||||
encoding="utf-8",
|
||||
)
|
||||
file_handler.setFormatter(formatter)
|
||||
root_logger.addHandler(file_handler)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to setup file logging: {e}")
|
||||
|
||||
# Audit log handler (if configured)
|
||||
if self.config.logging.enable_audit and self.config.logging.audit_file_path:
|
||||
try:
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
audit_logger = logging.getLogger("audit")
|
||||
audit_handler = RotatingFileHandler(
|
||||
self.config.logging.audit_file_path,
|
||||
maxBytes=self.config.logging.max_file_size,
|
||||
backupCount=self.config.logging.backup_count,
|
||||
encoding="utf-8",
|
||||
)
|
||||
audit_handler.setFormatter(formatter)
|
||||
audit_logger.addHandler(audit_handler)
|
||||
audit_logger.setLevel(logging.INFO)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to setup audit logging: {e}")
|
||||
|
||||
def validate_config(self) -> bool:
|
||||
"""Validate configuration"""
|
||||
errors = self.config.validate()
|
||||
if errors:
|
||||
self.logger.error("Configuration validation failed:")
|
||||
for error in errors:
|
||||
self.logger.error(f" - {error}")
|
||||
return False
|
||||
|
||||
self.logger.info("Configuration validation passed")
|
||||
return True
|
||||
|
||||
def log_config_summary(self):
|
||||
"""Log configuration summary"""
|
||||
summary = self.config.get_config_summary()
|
||||
self.logger.info("Configuration Summary:")
|
||||
self.logger.info(f" Server: {summary['server']}")
|
||||
self.logger.info(f" Database: {summary['database']}")
|
||||
self.logger.info(f" Connection Pool: {summary['connection_pool']}")
|
||||
self.logger.info(f" Security: {summary['security']}")
|
||||
self.logger.info(f" Performance: {summary['performance']}")
|
||||
self.logger.info(f" Monitoring: {summary['monitoring']}")
|
||||
|
||||
|
||||
def create_default_config_file(config_path: str):
|
||||
"""Create default configuration file"""
|
||||
config = DorisConfig()
|
||||
config.save_to_file(config_path)
|
||||
print(f"Default configuration file created: {config_path}")
|
||||
|
||||
|
||||
# Example usage
|
||||
if __name__ == "__main__":
|
||||
# Create default configuration
|
||||
config = DorisConfig()
|
||||
|
||||
# Load from environment variables
|
||||
# config = DorisConfig.from_env()
|
||||
|
||||
# Load from file
|
||||
# config = DorisConfig.from_file("config.json")
|
||||
|
||||
# Validate configuration
|
||||
config_manager = ConfigManager(config)
|
||||
if config_manager.validate_config():
|
||||
config_manager.setup_logging()
|
||||
config_manager.log_config_summary()
|
||||
|
||||
# Save configuration
|
||||
config.save_to_file("example_config.json")
|
||||
print("Configuration saved to example_config.json")
|
||||
else:
|
||||
print("Configuration validation failed")
|
||||
@@ -1,100 +1,479 @@
|
||||
import os
|
||||
import json
|
||||
import pymysql
|
||||
import pandas as pd
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dotenv import load_dotenv
|
||||
import re
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Apache Doris Database Connection Management Module
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv(override=True)
|
||||
Provides high-performance database connection pool management, automatic reconnection mechanism and connection health check functionality
|
||||
Supports asynchronous operations and concurrent connection management, ensuring stability and performance for enterprise applications
|
||||
"""
|
||||
|
||||
# Database configuration
|
||||
DB_CONFIG = {
|
||||
"host": os.getenv("DB_HOST", "localhost"),
|
||||
"port": int(os.getenv("DB_PORT", "9030")),
|
||||
"user": os.getenv("DB_USER", "root"),
|
||||
"password": os.getenv("DB_PASSWORD", ""),
|
||||
"database": os.getenv("DB_DATABASE", ""),
|
||||
"charset": "utf8mb4",
|
||||
"cursorclass": pymysql.cursors.DictCursor
|
||||
}
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
|
||||
def get_db_connection(db_name: Optional[str] = None):
|
||||
import aiomysql
|
||||
from aiomysql import Connection, Pool
|
||||
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionMetrics:
|
||||
"""Connection pool performance metrics"""
|
||||
|
||||
total_connections: int = 0
|
||||
active_connections: int = 0
|
||||
idle_connections: int = 0
|
||||
failed_connections: int = 0
|
||||
connection_errors: int = 0
|
||||
avg_connection_time: float = 0.0
|
||||
last_health_check: datetime | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryResult:
|
||||
"""Query result wrapper"""
|
||||
|
||||
data: list[dict[str, Any]]
|
||||
metadata: dict[str, Any]
|
||||
execution_time: float
|
||||
row_count: int
|
||||
|
||||
|
||||
class DorisConnection:
|
||||
"""Doris database connection wrapper class"""
|
||||
|
||||
def __init__(self, connection: Connection, session_id: str, security_manager=None):
|
||||
self.connection = connection
|
||||
self.session_id = session_id
|
||||
self.created_at = datetime.utcnow()
|
||||
self.last_used = datetime.utcnow()
|
||||
self.query_count = 0
|
||||
self.is_healthy = True
|
||||
self.security_manager = security_manager
|
||||
|
||||
async def execute(self, sql: str, params: tuple | None = None, auth_context=None) -> QueryResult:
|
||||
"""Execute SQL query"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# If security manager exists, perform SQL security check
|
||||
security_result = None
|
||||
if self.security_manager and auth_context:
|
||||
validation_result = await self.security_manager.validate_sql_security(sql, auth_context)
|
||||
if not validation_result.is_valid:
|
||||
raise ValueError(f"SQL security validation failed: {validation_result.error_message}")
|
||||
security_result = {
|
||||
"is_valid": validation_result.is_valid,
|
||||
"risk_level": validation_result.risk_level,
|
||||
"blocked_operations": validation_result.blocked_operations
|
||||
}
|
||||
|
||||
async with self.connection.cursor(aiomysql.DictCursor) as cursor:
|
||||
await cursor.execute(sql, params)
|
||||
|
||||
# Check if it's a query statement (statement that returns result set)
|
||||
sql_upper = sql.strip().upper()
|
||||
if (sql_upper.startswith("SELECT") or
|
||||
sql_upper.startswith("SHOW") or
|
||||
sql_upper.startswith("DESCRIBE") or
|
||||
sql_upper.startswith("DESC") or
|
||||
sql_upper.startswith("EXPLAIN")):
|
||||
data = await cursor.fetchall()
|
||||
row_count = len(data)
|
||||
else:
|
||||
data = []
|
||||
row_count = cursor.rowcount
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
self.last_used = datetime.utcnow()
|
||||
self.query_count += 1
|
||||
|
||||
# Get column information
|
||||
columns = []
|
||||
if cursor.description:
|
||||
columns = [desc[0] for desc in cursor.description]
|
||||
|
||||
# If security manager exists and has auth context, apply data masking
|
||||
final_data = list(data) if data else []
|
||||
if self.security_manager and auth_context and final_data:
|
||||
final_data = await self.security_manager.apply_data_masking(final_data, auth_context)
|
||||
|
||||
metadata = {"columns": columns, "query": sql, "params": params}
|
||||
if security_result:
|
||||
metadata["security_check"] = security_result
|
||||
|
||||
return QueryResult(
|
||||
data=final_data,
|
||||
metadata=metadata,
|
||||
execution_time=execution_time,
|
||||
row_count=row_count,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.is_healthy = False
|
||||
logging.error(f"Query execution failed: {e}")
|
||||
raise
|
||||
|
||||
async def ping(self) -> bool:
|
||||
"""Check connection health status"""
|
||||
try:
|
||||
await self.connection.ping()
|
||||
self.is_healthy = True
|
||||
return True
|
||||
except Exception:
|
||||
self.is_healthy = False
|
||||
return False
|
||||
|
||||
async def close(self):
|
||||
"""Close connection"""
|
||||
try:
|
||||
if self.connection and not self.connection.closed:
|
||||
await self.connection.ensure_closed()
|
||||
except Exception as e:
|
||||
logging.error(f"Error occurred while closing connection: {e}")
|
||||
|
||||
|
||||
class DorisConnectionManager:
|
||||
"""Doris database connection manager
|
||||
|
||||
Provides connection pool management, connection health monitoring, fault recovery and other functions
|
||||
Supports session-level connection reuse and intelligent load balancing
|
||||
Integrates security manager to provide unified security validation and data masking
|
||||
"""
|
||||
Get database connection
|
||||
|
||||
Args:
|
||||
db_name: Specify the database name to connect to, use default config if None
|
||||
def __init__(self, config, security_manager=None):
|
||||
self.config = config
|
||||
self.pool: Pool | None = None
|
||||
self.session_connections: dict[str, DorisConnection] = {}
|
||||
self.metrics = ConnectionMetrics()
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.security_manager = security_manager
|
||||
|
||||
Returns:
|
||||
Database connection
|
||||
# Health check configuration
|
||||
self.health_check_interval = config.database.health_check_interval or 60
|
||||
self.max_connection_age = config.database.max_connection_age or 3600
|
||||
self.connection_timeout = config.database.connection_timeout or 30
|
||||
|
||||
# Start background tasks
|
||||
self._health_check_task = None
|
||||
self._cleanup_task = None
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize connection manager"""
|
||||
try:
|
||||
# Create connection pool
|
||||
self.pool = await aiomysql.create_pool(
|
||||
host=self.config.database.host,
|
||||
port=self.config.database.port,
|
||||
user=self.config.database.user,
|
||||
password=self.config.database.password,
|
||||
db=self.config.database.database,
|
||||
charset="utf8",
|
||||
minsize=self.config.database.min_connections or 5,
|
||||
maxsize=self.config.database.max_connections or 20,
|
||||
autocommit=True,
|
||||
connect_timeout=self.connection_timeout,
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
f"Connection pool initialized successfully, min connections: {self.config.database.min_connections}, "
|
||||
f"max connections: {self.config.database.max_connections}"
|
||||
)
|
||||
|
||||
# Start background monitoring tasks
|
||||
self._health_check_task = asyncio.create_task(self._health_check_loop())
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Connection pool initialization failed: {e}")
|
||||
raise
|
||||
|
||||
async def get_connection(self, session_id: str) -> DorisConnection:
|
||||
"""Get database connection
|
||||
|
||||
Supports session-level connection reuse to improve performance and consistency
|
||||
"""
|
||||
# Check if there's an existing session connection
|
||||
if session_id in self.session_connections:
|
||||
conn = self.session_connections[session_id]
|
||||
# Check connection health
|
||||
if await conn.ping():
|
||||
return conn
|
||||
else:
|
||||
# Connection is unhealthy, clean up and create new one
|
||||
await self._cleanup_session_connection(session_id)
|
||||
|
||||
# Create new connection
|
||||
return await self._create_new_connection(session_id)
|
||||
|
||||
async def _create_new_connection(self, session_id: str) -> DorisConnection:
|
||||
"""Create new database connection"""
|
||||
try:
|
||||
if not self.pool:
|
||||
raise RuntimeError("Connection pool not initialized")
|
||||
|
||||
# Get connection from pool
|
||||
raw_connection = await self.pool.acquire()
|
||||
|
||||
# Create wrapped connection
|
||||
doris_conn = DorisConnection(raw_connection, session_id, self.security_manager)
|
||||
|
||||
# Store in session connections
|
||||
self.session_connections[session_id] = doris_conn
|
||||
|
||||
self.metrics.total_connections += 1
|
||||
self.logger.debug(f"Created new connection for session: {session_id}")
|
||||
|
||||
return doris_conn
|
||||
|
||||
except Exception as e:
|
||||
self.metrics.connection_errors += 1
|
||||
self.logger.error(f"Failed to create connection for session {session_id}: {e}")
|
||||
raise
|
||||
|
||||
async def release_connection(self, session_id: str):
|
||||
"""Release session connection"""
|
||||
if session_id in self.session_connections:
|
||||
await self._cleanup_session_connection(session_id)
|
||||
|
||||
async def _cleanup_session_connection(self, session_id: str):
|
||||
"""Clean up session connection"""
|
||||
if session_id in self.session_connections:
|
||||
conn = self.session_connections[session_id]
|
||||
try:
|
||||
# Return connection to pool
|
||||
if self.pool and conn.connection and not conn.connection.closed:
|
||||
self.pool.release(conn.connection)
|
||||
|
||||
# Close connection wrapper
|
||||
await conn.close()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error cleaning up connection for session {session_id}: {e}")
|
||||
finally:
|
||||
# Remove from session connections
|
||||
del self.session_connections[session_id]
|
||||
self.logger.debug(f"Cleaned up connection for session: {session_id}")
|
||||
|
||||
async def _health_check_loop(self):
|
||||
"""Background health check loop"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self.health_check_interval)
|
||||
await self._perform_health_check()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error(f"Health check error: {e}")
|
||||
|
||||
async def _perform_health_check(self):
|
||||
"""Perform health check"""
|
||||
try:
|
||||
unhealthy_sessions = []
|
||||
|
||||
for session_id, conn in self.session_connections.items():
|
||||
if not await conn.ping():
|
||||
unhealthy_sessions.append(session_id)
|
||||
|
||||
# Clean up unhealthy connections
|
||||
for session_id in unhealthy_sessions:
|
||||
await self._cleanup_session_connection(session_id)
|
||||
self.metrics.failed_connections += 1
|
||||
|
||||
# Update metrics
|
||||
await self._update_connection_metrics()
|
||||
self.metrics.last_health_check = datetime.utcnow()
|
||||
|
||||
if unhealthy_sessions:
|
||||
self.logger.warning(f"Cleaned up {len(unhealthy_sessions)} unhealthy connections")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Health check failed: {e}")
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
"""Background cleanup loop"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(300) # Run every 5 minutes
|
||||
await self._cleanup_idle_connections()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error(f"Cleanup loop error: {e}")
|
||||
|
||||
async def _cleanup_idle_connections(self):
|
||||
"""Clean up idle connections"""
|
||||
current_time = datetime.utcnow()
|
||||
idle_sessions = []
|
||||
|
||||
for session_id, conn in self.session_connections.items():
|
||||
# Check if connection has exceeded maximum age
|
||||
age = (current_time - conn.created_at).total_seconds()
|
||||
if age > self.max_connection_age:
|
||||
idle_sessions.append(session_id)
|
||||
|
||||
# Clean up idle connections
|
||||
for session_id in idle_sessions:
|
||||
await self._cleanup_session_connection(session_id)
|
||||
|
||||
if idle_sessions:
|
||||
self.logger.info(f"Cleaned up {len(idle_sessions)} idle connections")
|
||||
|
||||
async def _update_connection_metrics(self):
|
||||
"""Update connection metrics"""
|
||||
self.metrics.active_connections = len(self.session_connections)
|
||||
if self.pool:
|
||||
self.metrics.idle_connections = self.pool.freesize
|
||||
|
||||
async def get_metrics(self) -> ConnectionMetrics:
|
||||
"""Get connection metrics"""
|
||||
await self._update_connection_metrics()
|
||||
return self.metrics
|
||||
|
||||
async def execute_query(
|
||||
self, session_id: str, sql: str, params: tuple | None = None, auth_context=None
|
||||
) -> QueryResult:
|
||||
"""Execute query"""
|
||||
conn = await self.get_connection(session_id)
|
||||
return await conn.execute(sql, params, auth_context)
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_connection_context(self, session_id: str):
|
||||
"""Get connection context manager"""
|
||||
conn = await self.get_connection(session_id)
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
# Connection will be reused, no need to close here
|
||||
pass
|
||||
|
||||
async def close(self):
|
||||
"""Close connection manager"""
|
||||
try:
|
||||
# Cancel background tasks
|
||||
if self._health_check_task:
|
||||
self._health_check_task.cancel()
|
||||
try:
|
||||
await self._health_check_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Clean up all session connections
|
||||
for session_id in list(self.session_connections.keys()):
|
||||
await self._cleanup_session_connection(session_id)
|
||||
|
||||
# Close connection pool
|
||||
if self.pool:
|
||||
self.pool.close()
|
||||
await self.pool.wait_closed()
|
||||
|
||||
self.logger.info("Connection manager closed successfully")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error closing connection manager: {e}")
|
||||
|
||||
async def test_connection(self) -> bool:
|
||||
"""Test database connection"""
|
||||
try:
|
||||
if not self.pool:
|
||||
return False
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("SELECT 1")
|
||||
result = await cursor.fetchone()
|
||||
return result is not None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Connection test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class ConnectionPoolMonitor:
|
||||
"""Connection pool monitor
|
||||
|
||||
Provides detailed monitoring and reporting capabilities for connection pool status
|
||||
"""
|
||||
if db_name:
|
||||
# Use default config but override database name
|
||||
config = DB_CONFIG.copy()
|
||||
config["database"] = db_name
|
||||
return pymysql.connect(**config)
|
||||
else:
|
||||
# Use default config
|
||||
return pymysql.connect(**DB_CONFIG)
|
||||
|
||||
def get_db_name() -> str:
|
||||
"""Get the currently configured default database name"""
|
||||
return DB_CONFIG["database"] or os.getenv("DB_DATABASE", "")
|
||||
def __init__(self, connection_manager: DorisConnectionManager):
|
||||
self.connection_manager = connection_manager
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def execute_query(sql, db_name: Optional[str] = None):
|
||||
"""
|
||||
Execute SQL query and return results
|
||||
async def get_pool_status(self) -> dict[str, Any]:
|
||||
"""Get connection pool status"""
|
||||
metrics = await self.connection_manager.get_metrics()
|
||||
|
||||
Args:
|
||||
sql: SQL query statement
|
||||
db_name: Specify the database name to connect to, use default config if None
|
||||
status = {
|
||||
"pool_size": self.connection_manager.pool.size if self.connection_manager.pool else 0,
|
||||
"free_connections": self.connection_manager.pool.freesize if self.connection_manager.pool else 0,
|
||||
"active_sessions": len(self.connection_manager.session_connections),
|
||||
"total_connections": metrics.total_connections,
|
||||
"failed_connections": metrics.failed_connections,
|
||||
"connection_errors": metrics.connection_errors,
|
||||
"avg_connection_time": metrics.avg_connection_time,
|
||||
"last_health_check": metrics.last_health_check.isoformat() if metrics.last_health_check else None,
|
||||
}
|
||||
|
||||
Returns:
|
||||
Query results
|
||||
"""
|
||||
conn = get_db_connection(db_name)
|
||||
try:
|
||||
with conn.cursor() as cursor:
|
||||
# Set connection character set to utf8 before executing query
|
||||
cursor.execute("SET NAMES utf8")
|
||||
return status
|
||||
|
||||
# Execute the actual query
|
||||
cursor.execute(sql)
|
||||
result = cursor.fetchall()
|
||||
return result
|
||||
finally:
|
||||
conn.close()
|
||||
async def get_session_details(self) -> list[dict[str, Any]]:
|
||||
"""Get session connection details"""
|
||||
sessions = []
|
||||
|
||||
def execute_query_df(sql, db_name: Optional[str] = None):
|
||||
"""
|
||||
Execute SQL query and return pandas DataFrame
|
||||
for session_id, conn in self.connection_manager.session_connections.items():
|
||||
session_info = {
|
||||
"session_id": session_id,
|
||||
"created_at": conn.created_at.isoformat(),
|
||||
"last_used": conn.last_used.isoformat(),
|
||||
"query_count": conn.query_count,
|
||||
"is_healthy": conn.is_healthy,
|
||||
"connection_age": (datetime.utcnow() - conn.created_at).total_seconds(),
|
||||
}
|
||||
sessions.append(session_info)
|
||||
|
||||
Args:
|
||||
sql: SQL query statement
|
||||
db_name: Specify the database name to connect to, use default config if None
|
||||
return sessions
|
||||
|
||||
Returns:
|
||||
pandas DataFrame
|
||||
"""
|
||||
conn = get_db_connection(db_name)
|
||||
try:
|
||||
# Use a temporary cursor to execute the query and get results
|
||||
with conn.cursor() as cursor:
|
||||
# Set connection character set to utf8 before executing query
|
||||
cursor.execute("SET NAMES utf8")
|
||||
async def generate_health_report(self) -> dict[str, Any]:
|
||||
"""Generate connection health report"""
|
||||
pool_status = await self.get_pool_status()
|
||||
session_details = await self.get_session_details()
|
||||
|
||||
# Execute the actual query
|
||||
cursor.execute(sql)
|
||||
result = cursor.fetchall()
|
||||
# Calculate health statistics
|
||||
healthy_sessions = sum(1 for s in session_details if s["is_healthy"])
|
||||
total_sessions = len(session_details)
|
||||
health_ratio = healthy_sessions / total_sessions if total_sessions > 0 else 1.0
|
||||
|
||||
# If no results, return empty DataFrame
|
||||
if not result:
|
||||
return pd.DataFrame()
|
||||
report = {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"pool_status": pool_status,
|
||||
"session_summary": {
|
||||
"total_sessions": total_sessions,
|
||||
"healthy_sessions": healthy_sessions,
|
||||
"health_ratio": health_ratio,
|
||||
},
|
||||
"session_details": session_details,
|
||||
"recommendations": [],
|
||||
}
|
||||
|
||||
# Manually convert dict results to DataFrame
|
||||
df = pd.DataFrame(result)
|
||||
return df
|
||||
finally:
|
||||
conn.close()
|
||||
# Add recommendations based on health status
|
||||
if health_ratio < 0.8:
|
||||
report["recommendations"].append("Consider checking database connectivity and network stability")
|
||||
|
||||
if pool_status["connection_errors"] > 10:
|
||||
report["recommendations"].append("High connection error rate detected, review connection configuration")
|
||||
|
||||
if pool_status["active_sessions"] > pool_status["pool_size"] * 0.9:
|
||||
report["recommendations"].append("Connection pool utilization is high, consider increasing pool size")
|
||||
|
||||
return report
|
||||
|
||||
@@ -1,226 +1,85 @@
|
||||
"""
|
||||
Unified Logging Configuration Module
|
||||
|
||||
Provides unified logging configuration, including:
|
||||
- General logs: Record all program execution information
|
||||
- Audit logs: Record JSON data for key operations and processing results
|
||||
- Error logs: Specifically record program exceptions and errors
|
||||
Logging configuration for Doris MCP Server.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import logging.handlers
|
||||
import logging.config
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from datetime import datetime
|
||||
from dotenv import load_dotenv
|
||||
from typing import Any
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Get project root directory
|
||||
PROJECT_ROOT = Path(__file__).parents[2].absolute()
|
||||
def setup_logging(
|
||||
level: str = "INFO",
|
||||
log_file: str | None = None,
|
||||
log_format: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Setup logging configuration.
|
||||
|
||||
# Get log configuration from environment variables
|
||||
LOG_DIR = os.getenv("LOG_DIR", str(PROJECT_ROOT / "logs"))
|
||||
LOG_PREFIX = os.getenv("LOG_PREFIX", "doris_mcp")
|
||||
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
|
||||
LOG_MAX_DAYS = int(os.getenv("LOG_MAX_DAYS", "30"))
|
||||
# Whether to output logs to the console (should be disabled when running as a service)
|
||||
CONSOLE_LOGGING = os.getenv("CONSOLE_LOGGING", "false").lower() == "true"
|
||||
# Whether stdio transport mode is being used
|
||||
STDIO_MODE = os.getenv("MCP_TRANSPORT_TYPE", "").lower() == "stdio"
|
||||
Args:
|
||||
level: Logging level (DEBUG, INFO, WARNING, ERROR)
|
||||
log_file: Optional log file path
|
||||
log_format: Optional custom log format
|
||||
"""
|
||||
if log_format is None:
|
||||
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
|
||||
def purge_old_logs():
|
||||
"""Clean up expired log files"""
|
||||
# --- Only perform cleanup in non-Stdio mode ---
|
||||
if STDIO_MODE:
|
||||
return
|
||||
try:
|
||||
now = datetime.now()
|
||||
log_dir = Path(LOG_DIR)
|
||||
# Check if directory exists and is readable/writable
|
||||
if not log_dir.is_dir() or not os.access(LOG_DIR, os.W_OK):
|
||||
if not STDIO_MODE: # Avoid printing to stdout in stdio mode
|
||||
print(f"Warning: Log directory {LOG_DIR} not accessible, skipping log purge.", file=sys.stderr)
|
||||
return
|
||||
# Base configuration
|
||||
config: dict[str, Any] = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"default": {"format": log_format, "datefmt": "%Y-%m-%d %H:%M:%S"}
|
||||
},
|
||||
"handlers": {
|
||||
"console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"level": level,
|
||||
"formatter": "default",
|
||||
"stream": sys.stdout,
|
||||
}
|
||||
},
|
||||
"root": {"level": level, "handlers": ["console"]},
|
||||
"loggers": {
|
||||
"doris_mcp_server": {
|
||||
"level": level,
|
||||
"handlers": ["console"],
|
||||
"propagate": False,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
for log_file in log_dir.glob(f"{LOG_PREFIX}*.20*"):
|
||||
# Parse date
|
||||
file_name = log_file.name
|
||||
date_str = None
|
||||
# Add file handler if log_file is specified
|
||||
if log_file:
|
||||
# Ensure log directory exists
|
||||
log_path = Path(log_file)
|
||||
log_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Try to find the date part
|
||||
parts = file_name.split('.')
|
||||
for part in parts:
|
||||
if part.startswith('20') and len(part) == 8: # 20YYMMDD format
|
||||
date_str = part
|
||||
break
|
||||
config["handlers"]["file"] = {
|
||||
"class": "logging.handlers.RotatingFileHandler",
|
||||
"level": level,
|
||||
"formatter": "default",
|
||||
"filename": log_file,
|
||||
"maxBytes": 10485760, # 10MB
|
||||
"backupCount": 5,
|
||||
}
|
||||
|
||||
if date_str:
|
||||
try:
|
||||
file_date = datetime.strptime(date_str, '%Y%m%d')
|
||||
days_old = (now - file_date).days
|
||||
# Add file handler to root and package loggers
|
||||
config["root"]["handlers"].append("file")
|
||||
config["loggers"]["doris_mcp_server"]["handlers"].append("file")
|
||||
|
||||
if days_old > LOG_MAX_DAYS:
|
||||
os.remove(log_file)
|
||||
if not STDIO_MODE:
|
||||
print(f"Deleted expired log file: {log_file}")
|
||||
except (ValueError, OSError) as e:
|
||||
if not STDIO_MODE:
|
||||
print(f"Error processing log file {file_name}: {e}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
if not STDIO_MODE:
|
||||
print(f"Error cleaning up logs: {e}", file=sys.stderr)
|
||||
|
||||
# Force disable console log output if in stdio mode
|
||||
if STDIO_MODE:
|
||||
CONSOLE_LOGGING = False
|
||||
|
||||
# --- Only create log directory and clean old logs in non-Stdio mode ---
|
||||
if not STDIO_MODE:
|
||||
try:
|
||||
os.makedirs(LOG_DIR, exist_ok=True)
|
||||
# Clean up expired logs on startup (also moved here, as it only handles file logs)
|
||||
purge_old_logs()
|
||||
except OSError as e:
|
||||
# If directory creation fails (e.g., permission issue), print warning but continue to avoid startup failure
|
||||
print(f"Warning: Failed to create log directory {LOG_DIR} or purge logs: {e}", file=sys.stderr)
|
||||
|
||||
# Log file paths (definition still needed, but files might not be created/used)
|
||||
LOG_FILE = os.path.join(LOG_DIR, f"{LOG_PREFIX}.log")
|
||||
AUDIT_LOG_FILE = os.path.join(LOG_DIR, f"{LOG_PREFIX}.audit")
|
||||
ERROR_LOG_FILE = os.path.join(LOG_DIR, f"{LOG_PREFIX}.error")
|
||||
|
||||
# Log level mapping
|
||||
LOG_LEVELS = {
|
||||
"DEBUG": logging.DEBUG,
|
||||
"INFO": logging.INFO,
|
||||
"WARNING": logging.WARNING,
|
||||
"ERROR": logging.ERROR,
|
||||
"CRITICAL": logging.CRITICAL
|
||||
}
|
||||
|
||||
# Log format
|
||||
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
AUDIT_FORMAT = '%(asctime)s - %(name)s - %(message)s'
|
||||
ERROR_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(pathname)s:%(lineno)d - %(message)s'
|
||||
|
||||
# Dedicated audit log level
|
||||
AUDIT = 25 # Level between INFO and WARNING
|
||||
logging.addLevelName(AUDIT, "AUDIT")
|
||||
|
||||
# Logger object cache
|
||||
_loggers: Dict[str, logging.Logger] = {}
|
||||
|
||||
# Handler type mapping, used to ensure no duplicates are added
|
||||
_handler_types = {
|
||||
'console': logging.StreamHandler,
|
||||
'file': logging.handlers.TimedRotatingFileHandler,
|
||||
'audit': logging.handlers.TimedRotatingFileHandler,
|
||||
'error': logging.handlers.TimedRotatingFileHandler
|
||||
}
|
||||
logging.config.dictConfig(config)
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""
|
||||
Get a logger with the specified name
|
||||
Get a logger instance.
|
||||
|
||||
Args:
|
||||
name: Logger name
|
||||
|
||||
Returns:
|
||||
logging.Logger: Configured logger
|
||||
Logger instance
|
||||
"""
|
||||
if name in _loggers:
|
||||
return _loggers[name]
|
||||
|
||||
# Create logger
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(LOG_LEVELS.get(LOG_LEVEL, logging.INFO))
|
||||
|
||||
# Avoid duplicate logs caused by propagation
|
||||
logger.propagate = False
|
||||
|
||||
# Check if handlers already exist to avoid duplicates
|
||||
handler_types = set(type(h) for h in logger.handlers)
|
||||
|
||||
# Add audit log method
|
||||
def audit(self, message, *args, **kwargs):
|
||||
self.log(AUDIT, message, *args, **kwargs)
|
||||
|
||||
logger.audit = audit.__get__(logger)
|
||||
|
||||
# General log handler - output to console (only if enabled)
|
||||
if CONSOLE_LOGGING and _handler_types['console'] not in handler_types:
|
||||
# Use stderr instead of stdout to avoid conflicts with MCP communication
|
||||
console_handler = logging.StreamHandler(sys.stderr)
|
||||
console_handler.setFormatter(logging.Formatter(LOG_FORMAT))
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# --- Only add file handlers in non-Stdio mode ---
|
||||
if not STDIO_MODE:
|
||||
# General log handler - daily rotating file
|
||||
if _handler_types['file'] not in handler_types:
|
||||
try: # Add try-except block
|
||||
file_handler = logging.handlers.TimedRotatingFileHandler(
|
||||
LOG_FILE,
|
||||
when='midnight',
|
||||
interval=1,
|
||||
backupCount=LOG_MAX_DAYS,
|
||||
encoding='utf-8'
|
||||
)
|
||||
file_handler.setFormatter(logging.Formatter(LOG_FORMAT))
|
||||
file_handler.suffix = "%Y%m%d"
|
||||
logger.addHandler(file_handler)
|
||||
except OSError as e:
|
||||
print(f"Warning: Failed to add file log handler for {LOG_FILE}: {e}", file=sys.stderr)
|
||||
|
||||
# Audit log handler - only logs AUDIT level
|
||||
if _handler_types['audit'] not in handler_types:
|
||||
try: # Add try-except block
|
||||
audit_handler = logging.handlers.TimedRotatingFileHandler(
|
||||
AUDIT_LOG_FILE,
|
||||
when='midnight',
|
||||
interval=1,
|
||||
backupCount=LOG_MAX_DAYS,
|
||||
encoding='utf-8'
|
||||
)
|
||||
audit_handler.setFormatter(logging.Formatter(AUDIT_FORMAT))
|
||||
audit_handler.suffix = "%Y%m%d"
|
||||
audit_handler.setLevel(AUDIT)
|
||||
audit_handler.addFilter(lambda record: record.levelno == AUDIT)
|
||||
logger.addHandler(audit_handler)
|
||||
except OSError as e:
|
||||
print(f"Warning: Failed to add audit log handler for {AUDIT_LOG_FILE}: {e}", file=sys.stderr)
|
||||
|
||||
# Error log handler - only logs ERROR level and above
|
||||
if _handler_types['error'] not in handler_types:
|
||||
try: # Add try-except block
|
||||
error_handler = logging.handlers.TimedRotatingFileHandler(
|
||||
ERROR_LOG_FILE,
|
||||
when='midnight',
|
||||
interval=1,
|
||||
backupCount=LOG_MAX_DAYS,
|
||||
encoding='utf-8'
|
||||
)
|
||||
error_handler.setFormatter(logging.Formatter(ERROR_FORMAT))
|
||||
error_handler.suffix = "%Y%m%d"
|
||||
error_handler.setLevel(logging.ERROR)
|
||||
logger.addHandler(error_handler)
|
||||
except OSError as e:
|
||||
print(f"Warning: Failed to add error log handler for {ERROR_LOG_FILE}: {e}", file=sys.stderr)
|
||||
|
||||
# Cache logger
|
||||
_loggers[name] = logger
|
||||
|
||||
return logger
|
||||
|
||||
# Default logger
|
||||
logger = get_logger('doris_mcp')
|
||||
|
||||
# Audit logger - for recording processing results, business operations, etc.
|
||||
audit_logger = get_logger('audit')
|
||||
|
||||
# Call to clean logs moved after directory creation, and added non-stdio check
|
||||
return logging.getLogger(name)
|
||||
|
||||
800
doris_mcp_server/utils/query_executor.py
Normal file
800
doris_mcp_server/utils/query_executor.py
Normal file
@@ -0,0 +1,800 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Doris Query Execution Module
|
||||
Implements query optimization, cache management and performance monitoring functionality
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import os
|
||||
import uuid
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, date
|
||||
from typing import Any, Dict
|
||||
from decimal import Decimal
|
||||
|
||||
from .db import DorisConnectionManager, QueryResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryRequest:
|
||||
"""Query request wrapper"""
|
||||
|
||||
sql: str
|
||||
session_id: str
|
||||
user_id: str
|
||||
parameters: dict[str, Any] | None = None
|
||||
timeout: int | None = None
|
||||
cache_enabled: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedQuery:
|
||||
"""Cached query result"""
|
||||
|
||||
result: QueryResult
|
||||
created_at: datetime
|
||||
ttl: int
|
||||
access_count: int = 0
|
||||
last_accessed: datetime | None = None
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""Check if cache is expired"""
|
||||
if self.ttl <= 0:
|
||||
return False
|
||||
return (datetime.utcnow() - self.created_at).total_seconds() > self.ttl
|
||||
|
||||
def access(self):
|
||||
"""Record access"""
|
||||
self.access_count += 1
|
||||
self.last_accessed = datetime.utcnow()
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryMetrics:
|
||||
"""Query performance metrics"""
|
||||
|
||||
total_queries: int = 0
|
||||
successful_queries: int = 0
|
||||
failed_queries: int = 0
|
||||
cache_hits: int = 0
|
||||
cache_misses: int = 0
|
||||
avg_execution_time: float = 0.0
|
||||
total_execution_time: float = 0.0
|
||||
slow_queries: int = 0
|
||||
concurrent_queries: int = 0
|
||||
|
||||
|
||||
class QueryCache:
|
||||
"""Query result cache manager"""
|
||||
|
||||
def __init__(self, max_size: int = 1000, default_ttl: int = 300):
|
||||
self.max_size = max_size
|
||||
self.default_ttl = default_ttl
|
||||
self.cache: dict[str, CachedQuery] = {}
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
def _generate_cache_key(
|
||||
self, sql: str, parameters: dict[str, Any] | None = None
|
||||
) -> str:
|
||||
"""Generate cache key"""
|
||||
cache_data = {"sql": sql.strip().lower(), "parameters": parameters or {}}
|
||||
cache_string = json.dumps(cache_data, sort_keys=True)
|
||||
return hashlib.md5(cache_string.encode()).hexdigest()
|
||||
|
||||
async def get(
|
||||
self, sql: str, parameters: dict[str, Any] | None = None
|
||||
) -> CachedQuery | None:
|
||||
"""Get cached query result"""
|
||||
cache_key = self._generate_cache_key(sql, parameters)
|
||||
|
||||
if cache_key in self.cache:
|
||||
cached_query = self.cache[cache_key]
|
||||
|
||||
if not cached_query.is_expired():
|
||||
cached_query.access()
|
||||
self.logger.debug(f"Cache hit: {cache_key}")
|
||||
return cached_query
|
||||
else:
|
||||
# Clean up expired cache
|
||||
del self.cache[cache_key]
|
||||
self.logger.debug(f"Cache expired, cleaned up: {cache_key}")
|
||||
|
||||
return None
|
||||
|
||||
async def set(
|
||||
self,
|
||||
sql: str,
|
||||
result: QueryResult,
|
||||
parameters: dict[str, Any] | None = None,
|
||||
ttl: int | None = None,
|
||||
) -> str:
|
||||
"""Set query result cache"""
|
||||
cache_key = self._generate_cache_key(sql, parameters)
|
||||
|
||||
# Check cache size limit
|
||||
if len(self.cache) >= self.max_size:
|
||||
await self._evict_oldest()
|
||||
|
||||
cached_query = CachedQuery(
|
||||
result=result, created_at=datetime.utcnow(), ttl=ttl or self.default_ttl
|
||||
)
|
||||
|
||||
self.cache[cache_key] = cached_query
|
||||
self.logger.debug(f"Cache set: {cache_key}")
|
||||
|
||||
return cache_key
|
||||
|
||||
async def _evict_oldest(self):
|
||||
"""Clean up oldest cache item"""
|
||||
if not self.cache:
|
||||
return
|
||||
|
||||
# Find oldest cache item
|
||||
oldest_key = min(self.cache.keys(), key=lambda k: self.cache[k].created_at)
|
||||
|
||||
del self.cache[oldest_key]
|
||||
self.logger.debug(f"Cleaned up oldest cache: {oldest_key}")
|
||||
|
||||
async def clear_expired(self):
|
||||
"""Clean up all expired cache"""
|
||||
expired_keys = [
|
||||
key for key, cached_query in self.cache.items() if cached_query.is_expired()
|
||||
]
|
||||
|
||||
for key in expired_keys:
|
||||
del self.cache[key]
|
||||
|
||||
if expired_keys:
|
||||
self.logger.info(f"Cleaned up {len(expired_keys)} expired cache items")
|
||||
|
||||
async def clear_all(self):
|
||||
"""Clean up all cache"""
|
||||
cache_count = len(self.cache)
|
||||
self.cache.clear()
|
||||
self.logger.info(f"Cleaned up all cache, total {cache_count} items")
|
||||
|
||||
def get_stats(self) -> dict[str, Any]:
|
||||
"""Get cache statistics"""
|
||||
total_access = sum(cached.access_count for cached in self.cache.values())
|
||||
|
||||
return {
|
||||
"cache_size": len(self.cache),
|
||||
"max_size": self.max_size,
|
||||
"total_access": total_access,
|
||||
"hit_rate": 0.0
|
||||
if total_access == 0
|
||||
else sum(cached.access_count for cached in self.cache.values())
|
||||
/ total_access,
|
||||
}
|
||||
|
||||
|
||||
class QueryOptimizer:
|
||||
"""Query optimizer"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.optimization_rules = self._load_optimization_rules()
|
||||
|
||||
def _load_optimization_rules(self) -> list[dict[str, Any]]:
|
||||
"""Load query optimization rules"""
|
||||
return [
|
||||
{
|
||||
"name": "add_limit_clause",
|
||||
"description": "Add default limit for SELECT queries without LIMIT",
|
||||
"pattern": r"^select\s+.*(?!.*limit\s+\d+)",
|
||||
"action": "add_limit",
|
||||
"params": {"default_limit": 1000},
|
||||
},
|
||||
{
|
||||
"name": "optimize_count_query",
|
||||
"description": "Optimize COUNT queries",
|
||||
"pattern": r"select\s+count\(\*\)\s+from\s+(\w+)",
|
||||
"action": "optimize_count",
|
||||
"params": {},
|
||||
},
|
||||
]
|
||||
|
||||
async def optimize_query(self, sql: str, context: dict[str, Any]) -> str:
|
||||
"""Apply query optimization"""
|
||||
optimized_sql = sql
|
||||
|
||||
for rule in self.optimization_rules:
|
||||
if self._should_apply_rule(rule, optimized_sql, context):
|
||||
optimized_sql = await self._apply_optimization_rule(
|
||||
optimized_sql, rule, context
|
||||
)
|
||||
self.logger.debug(f"Applied optimization rule: {rule['name']}")
|
||||
|
||||
return optimized_sql
|
||||
|
||||
def _should_apply_rule(
|
||||
self, rule: dict[str, Any], sql: str, context: dict[str, Any]
|
||||
) -> bool:
|
||||
"""Check if optimization rule should be applied"""
|
||||
import re
|
||||
|
||||
# Check pattern match
|
||||
if "pattern" in rule:
|
||||
if not re.search(rule["pattern"], sql, re.IGNORECASE):
|
||||
return False
|
||||
|
||||
# Check conditions
|
||||
if "conditions" in rule:
|
||||
for condition in rule["conditions"]:
|
||||
if not self._check_condition(condition, context):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _check_condition(
|
||||
self, condition: dict[str, Any], context: dict[str, Any]
|
||||
) -> bool:
|
||||
"""Check optimization condition"""
|
||||
condition_type = condition.get("type")
|
||||
|
||||
if condition_type == "user_role":
|
||||
required_roles = condition.get("roles", [])
|
||||
user_roles = context.get("user_roles", [])
|
||||
return any(role in user_roles for role in required_roles)
|
||||
|
||||
elif condition_type == "query_size":
|
||||
max_size = condition.get("max_size", 1000)
|
||||
return len(context.get("sql", "")) <= max_size
|
||||
|
||||
return True
|
||||
|
||||
async def _apply_optimization_rule(
|
||||
self, sql: str, rule: dict[str, Any], context: dict[str, Any]
|
||||
) -> str:
|
||||
"""Apply optimization rule"""
|
||||
action = rule.get("action")
|
||||
params = rule.get("params", {})
|
||||
|
||||
if action == "add_limit":
|
||||
return await self._add_limit_clause(sql, params)
|
||||
elif action == "optimize_count":
|
||||
return await self._optimize_count_query(sql, params)
|
||||
elif action == "add_hints":
|
||||
return await self._add_query_hints(sql, params)
|
||||
|
||||
return sql
|
||||
|
||||
async def _add_limit_clause(self, sql: str, params: dict[str, Any]) -> str:
|
||||
"""Add LIMIT clause to query"""
|
||||
import re
|
||||
|
||||
default_limit = params.get("default_limit", 1000)
|
||||
|
||||
# Check if LIMIT already exists
|
||||
if re.search(r"\blimit\s+\d+", sql, re.IGNORECASE):
|
||||
return sql
|
||||
|
||||
# Add LIMIT clause
|
||||
if sql.strip().endswith(";"):
|
||||
sql = sql.strip()[:-1]
|
||||
|
||||
return f"{sql} LIMIT {default_limit}"
|
||||
|
||||
async def _optimize_count_query(self, sql: str, params: dict[str, Any]) -> str:
|
||||
"""Optimize COUNT query"""
|
||||
# For COUNT queries, we can add optimization hints
|
||||
return sql.replace("COUNT(*)", "COUNT(1)")
|
||||
|
||||
async def _add_query_hints(self, sql: str, params: dict[str, Any]) -> str:
|
||||
"""Add query hints"""
|
||||
hints = params.get("hints", [])
|
||||
if not hints:
|
||||
return sql
|
||||
|
||||
hint_string = "/*+ " + " ".join(hints) + " */"
|
||||
return f"{hint_string} {sql}"
|
||||
|
||||
|
||||
class DorisQueryExecutor:
|
||||
"""Doris query executor with caching and optimization"""
|
||||
|
||||
def __init__(self, connection_manager: DorisConnectionManager, config=None):
|
||||
self.connection_manager = connection_manager
|
||||
self.config = config or self._create_default_config()
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize components
|
||||
cache_config = getattr(self.config, 'performance', None)
|
||||
if cache_config:
|
||||
cache_size = getattr(cache_config, 'max_cache_size', 1000)
|
||||
cache_ttl = getattr(cache_config, 'cache_ttl', 300)
|
||||
else:
|
||||
cache_size = 1000
|
||||
cache_ttl = 300
|
||||
|
||||
self.query_cache = QueryCache(max_size=cache_size, default_ttl=cache_ttl)
|
||||
self.query_optimizer = QueryOptimizer(self.config)
|
||||
self.metrics = QueryMetrics()
|
||||
|
||||
# Performance monitoring
|
||||
self.slow_query_threshold = 5.0 # seconds
|
||||
self.max_concurrent_queries = getattr(
|
||||
getattr(self.config, 'performance', None), 'max_concurrent_queries', 50
|
||||
) if hasattr(self.config, 'performance') else 50
|
||||
|
||||
# Background tasks
|
||||
self._background_tasks = []
|
||||
self._start_background_tasks()
|
||||
|
||||
def _create_default_config(self):
|
||||
"""Create default configuration"""
|
||||
class DefaultConfig:
|
||||
def __init__(self):
|
||||
self.performance = DefaultPerformanceConfig()
|
||||
|
||||
class DefaultPerformanceConfig:
|
||||
def __init__(self):
|
||||
self.max_cache_size = 1000
|
||||
self.cache_ttl = 300
|
||||
self.max_concurrent_queries = 50
|
||||
|
||||
return DefaultConfig()
|
||||
|
||||
def _start_background_tasks(self):
|
||||
"""Start background tasks"""
|
||||
try:
|
||||
# Cache cleanup task
|
||||
cleanup_task = asyncio.create_task(self._cache_cleanup_loop())
|
||||
self._background_tasks.append(cleanup_task)
|
||||
except RuntimeError:
|
||||
# No event loop running (e.g., in tests), skip background tasks
|
||||
self.logger.debug("No event loop running, skipping background tasks")
|
||||
|
||||
async def _cache_cleanup_loop(self):
|
||||
"""Background cache cleanup loop"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(300) # Run every 5 minutes
|
||||
await self.query_cache.clear_expired()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error(f"Cache cleanup error: {e}")
|
||||
|
||||
async def execute_query(
|
||||
self, query_request: QueryRequest, auth_context=None
|
||||
) -> QueryResult:
|
||||
"""Execute query with caching and optimization"""
|
||||
start_time = time.time()
|
||||
self.metrics.total_queries += 1
|
||||
self.metrics.concurrent_queries += 1
|
||||
|
||||
try:
|
||||
# Check cache first
|
||||
if query_request.cache_enabled:
|
||||
cached_result = await self.query_cache.get(
|
||||
query_request.sql, query_request.parameters
|
||||
)
|
||||
if cached_result:
|
||||
self.metrics.cache_hits += 1
|
||||
self.logger.debug(f"Cache hit for query: {query_request.sql[:50]}...")
|
||||
return cached_result.result
|
||||
|
||||
self.metrics.cache_misses += 1
|
||||
|
||||
# Execute query
|
||||
result = await self._execute_query_internal(query_request, auth_context)
|
||||
|
||||
# Cache result if enabled
|
||||
if query_request.cache_enabled and result.row_count > 0:
|
||||
await self.query_cache.set(
|
||||
query_request.sql, result, query_request.parameters
|
||||
)
|
||||
|
||||
self.metrics.successful_queries += 1
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
self.metrics.failed_queries += 1
|
||||
self.logger.error(f"Query execution failed: {e}")
|
||||
raise
|
||||
|
||||
finally:
|
||||
execution_time = time.time() - start_time
|
||||
self.metrics.concurrent_queries -= 1
|
||||
self._update_execution_metrics(execution_time)
|
||||
|
||||
async def _execute_query_internal(
|
||||
self, query_request: QueryRequest, auth_context
|
||||
) -> QueryResult:
|
||||
"""Internal query execution"""
|
||||
# Optimize query
|
||||
optimized_sql = await self.query_optimizer.optimize_query(
|
||||
query_request.sql, {"user_roles": getattr(auth_context, 'roles', [])}
|
||||
)
|
||||
|
||||
# Execute query
|
||||
connection = await self.connection_manager.get_connection(
|
||||
query_request.session_id
|
||||
)
|
||||
|
||||
# Set timeout if specified
|
||||
if query_request.timeout:
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
connection.execute(optimized_sql, query_request.parameters, auth_context),
|
||||
timeout=query_request.timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception(f"Query timeout after {query_request.timeout} seconds")
|
||||
else:
|
||||
result = await connection.execute(optimized_sql, query_request.parameters, auth_context)
|
||||
|
||||
return result
|
||||
|
||||
def _update_execution_metrics(self, execution_time: float):
|
||||
"""Update execution metrics"""
|
||||
self.metrics.total_execution_time += execution_time
|
||||
|
||||
# Update average execution time
|
||||
if self.metrics.successful_queries > 0:
|
||||
self.metrics.avg_execution_time = (
|
||||
self.metrics.total_execution_time / self.metrics.successful_queries
|
||||
)
|
||||
|
||||
# Check for slow queries
|
||||
if execution_time > self.slow_query_threshold:
|
||||
self.metrics.slow_queries += 1
|
||||
self.logger.warning(
|
||||
f"Slow query detected: {execution_time:.2f}s (threshold: {self.slow_query_threshold}s)"
|
||||
)
|
||||
|
||||
async def execute_batch_queries(
|
||||
self, query_requests: list[QueryRequest], auth_context=None
|
||||
) -> list[QueryResult]:
|
||||
"""Execute multiple queries in batch"""
|
||||
results = []
|
||||
|
||||
# Check concurrent query limit
|
||||
if len(query_requests) > self.max_concurrent_queries:
|
||||
raise Exception(
|
||||
f"Batch size {len(query_requests)} exceeds maximum concurrent queries {self.max_concurrent_queries}"
|
||||
)
|
||||
|
||||
# Execute queries concurrently
|
||||
tasks = [
|
||||
self.execute_query(request, auth_context) for request in query_requests
|
||||
]
|
||||
|
||||
try:
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Batch query execution failed: {e}")
|
||||
raise
|
||||
|
||||
return results
|
||||
|
||||
async def explain_query(self, sql: str, session_id: str) -> dict[str, Any]:
|
||||
"""Get query execution plan"""
|
||||
explain_sql = f"EXPLAIN {sql}"
|
||||
|
||||
connection = await self.connection_manager.get_connection(session_id)
|
||||
result = await connection.execute(explain_sql)
|
||||
|
||||
return {
|
||||
"query": sql,
|
||||
"execution_plan": result.data,
|
||||
"estimated_cost": "N/A", # Doris doesn't provide cost estimates
|
||||
}
|
||||
|
||||
async def get_query_stats(self) -> dict[str, Any]:
|
||||
"""Get query execution statistics"""
|
||||
cache_stats = self.query_cache.get_stats()
|
||||
|
||||
return {
|
||||
"query_metrics": {
|
||||
"total_queries": self.metrics.total_queries,
|
||||
"successful_queries": self.metrics.successful_queries,
|
||||
"failed_queries": self.metrics.failed_queries,
|
||||
"success_rate": (
|
||||
self.metrics.successful_queries / self.metrics.total_queries
|
||||
if self.metrics.total_queries > 0
|
||||
else 0.0
|
||||
),
|
||||
"avg_execution_time": self.metrics.avg_execution_time,
|
||||
"slow_queries": self.metrics.slow_queries,
|
||||
"concurrent_queries": self.metrics.concurrent_queries,
|
||||
},
|
||||
"cache_metrics": {
|
||||
"cache_hits": self.metrics.cache_hits,
|
||||
"cache_misses": self.metrics.cache_misses,
|
||||
"hit_rate": (
|
||||
self.metrics.cache_hits
|
||||
/ (self.metrics.cache_hits + self.metrics.cache_misses)
|
||||
if (self.metrics.cache_hits + self.metrics.cache_misses) > 0
|
||||
else 0.0
|
||||
),
|
||||
**cache_stats,
|
||||
},
|
||||
}
|
||||
|
||||
async def clear_cache(self):
|
||||
"""Clear query cache"""
|
||||
await self.query_cache.clear_all()
|
||||
|
||||
async def execute_sql_for_mcp(
|
||||
self,
|
||||
sql: str,
|
||||
limit: int = 1000,
|
||||
timeout: int = 30,
|
||||
session_id: str = "mcp_session",
|
||||
user_id: str = "mcp_user"
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute SQL query for MCP interface - unified method"""
|
||||
try:
|
||||
if not sql:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "SQL query is required",
|
||||
"data": None
|
||||
}
|
||||
|
||||
# Add LIMIT if not present and it's a SELECT query
|
||||
if sql.upper().startswith("SELECT") and "LIMIT" not in sql.upper():
|
||||
if sql.endswith(";"):
|
||||
sql = sql[:-1]
|
||||
sql = f"{sql} LIMIT {limit}"
|
||||
|
||||
# Create auth context for MCP calls
|
||||
class MockAuthContext:
|
||||
def __init__(self):
|
||||
self.user_id = user_id
|
||||
self.roles = ["data_analyst"]
|
||||
self.permissions = ["read_data", "execute_query"]
|
||||
self.session_id = session_id
|
||||
self.security_level = "internal"
|
||||
|
||||
auth_context = MockAuthContext()
|
||||
|
||||
# Create query request
|
||||
query_request = QueryRequest(
|
||||
sql=sql,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
timeout=timeout,
|
||||
cache_enabled=True
|
||||
)
|
||||
|
||||
# Execute query
|
||||
result = await self.execute_query(query_request, auth_context)
|
||||
|
||||
# Process results
|
||||
processed_data = []
|
||||
if result.data:
|
||||
for row in result.data:
|
||||
processed_row = self._serialize_row_data(row)
|
||||
processed_data.append(processed_row)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": processed_data,
|
||||
"metadata": {
|
||||
"row_count": result.row_count,
|
||||
"execution_time": result.execution_time,
|
||||
"columns": result.metadata.get("columns", []),
|
||||
"query": sql
|
||||
},
|
||||
"error": None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
self.logger.error(f"SQL execution error: {error_msg}")
|
||||
|
||||
# Analyze error for better user feedback
|
||||
error_analysis = self._analyze_error(error_msg)
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"error": error_analysis.get("user_message", error_msg),
|
||||
"error_type": error_analysis.get("error_type", "execution_error"),
|
||||
"data": None,
|
||||
"metadata": {
|
||||
"query": sql,
|
||||
"error_details": error_msg
|
||||
}
|
||||
}
|
||||
|
||||
def _serialize_row_data(self, row_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Serialize row data for JSON response"""
|
||||
serialized = {}
|
||||
|
||||
for key, value in row_data.items():
|
||||
if value is None:
|
||||
serialized[key] = None
|
||||
elif isinstance(value, (str, int, float, bool)):
|
||||
serialized[key] = value
|
||||
elif isinstance(value, Decimal):
|
||||
serialized[key] = float(value)
|
||||
elif isinstance(value, (datetime, date)):
|
||||
serialized[key] = value.isoformat()
|
||||
elif isinstance(value, bytes):
|
||||
try:
|
||||
serialized[key] = value.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
serialized[key] = str(value)
|
||||
else:
|
||||
serialized[key] = str(value)
|
||||
|
||||
return serialized
|
||||
|
||||
def _analyze_error(self, error_message: str) -> Dict[str, str]:
|
||||
"""Analyze error message and provide user-friendly feedback"""
|
||||
error_msg_lower = error_message.lower()
|
||||
|
||||
if "table" in error_msg_lower and "doesn't exist" in error_msg_lower:
|
||||
return {
|
||||
"error_type": "table_not_found",
|
||||
"user_message": "The specified table does not exist. Please check the table name and database."
|
||||
}
|
||||
elif "column" in error_msg_lower and ("unknown" in error_msg_lower or "doesn't exist" in error_msg_lower):
|
||||
return {
|
||||
"error_type": "column_not_found",
|
||||
"user_message": "One or more columns in the query do not exist. Please check column names."
|
||||
}
|
||||
elif "syntax error" in error_msg_lower or "sql syntax" in error_msg_lower:
|
||||
return {
|
||||
"error_type": "syntax_error",
|
||||
"user_message": "SQL syntax error. Please check your query syntax."
|
||||
}
|
||||
elif "access denied" in error_msg_lower or "permission" in error_msg_lower:
|
||||
return {
|
||||
"error_type": "permission_denied",
|
||||
"user_message": "Access denied. You don't have permission to execute this query."
|
||||
}
|
||||
elif "timeout" in error_msg_lower:
|
||||
return {
|
||||
"error_type": "timeout",
|
||||
"user_message": "Query execution timed out. Try simplifying your query or adding more specific filters."
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"error_type": "general_error",
|
||||
"user_message": f"Query execution failed: {error_message}"
|
||||
}
|
||||
|
||||
async def close(self):
|
||||
"""Close query executor and cleanup resources"""
|
||||
# Cancel background tasks
|
||||
for task in self._background_tasks:
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Clear cache
|
||||
await self.query_cache.clear_all()
|
||||
|
||||
self.logger.info("Query executor closed")
|
||||
|
||||
|
||||
class QueryPerformanceMonitor:
|
||||
"""Query performance monitor"""
|
||||
|
||||
def __init__(self, query_executor: DorisQueryExecutor):
|
||||
self.query_executor = query_executor
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.performance_records = []
|
||||
|
||||
async def record_query_performance(
|
||||
self, query_request: QueryRequest, result: QueryResult, execution_time: float
|
||||
):
|
||||
"""Record query performance"""
|
||||
record = {
|
||||
"timestamp": datetime.utcnow(),
|
||||
"sql": query_request.sql,
|
||||
"user_id": query_request.user_id,
|
||||
"session_id": query_request.session_id,
|
||||
"execution_time": execution_time,
|
||||
"row_count": result.row_count,
|
||||
"cache_hit": False, # This would need to be passed from executor
|
||||
}
|
||||
|
||||
self.performance_records.append(record)
|
||||
|
||||
# Keep only recent records (last 1000)
|
||||
if len(self.performance_records) > 1000:
|
||||
self.performance_records = self.performance_records[-1000:]
|
||||
|
||||
async def get_performance_report(
|
||||
self, time_range_minutes: int = 60
|
||||
) -> dict[str, Any]:
|
||||
"""Get performance report"""
|
||||
cutoff_time = datetime.utcnow() - timedelta(minutes=time_range_minutes)
|
||||
|
||||
recent_records = [
|
||||
record
|
||||
for record in self.performance_records
|
||||
if record["timestamp"] >= cutoff_time
|
||||
]
|
||||
|
||||
if not recent_records:
|
||||
return {"message": "No performance data available for the specified time range"}
|
||||
|
||||
# Calculate statistics
|
||||
execution_times = [record["execution_time"] for record in recent_records]
|
||||
row_counts = [record["row_count"] for record in recent_records]
|
||||
|
||||
return {
|
||||
"time_range_minutes": time_range_minutes,
|
||||
"total_queries": len(recent_records),
|
||||
"avg_execution_time": sum(execution_times) / len(execution_times),
|
||||
"max_execution_time": max(execution_times),
|
||||
"min_execution_time": min(execution_times),
|
||||
"avg_row_count": sum(row_counts) / len(row_counts),
|
||||
"query_distribution": self._analyze_query_distribution(recent_records),
|
||||
}
|
||||
|
||||
def _analyze_query_distribution(
|
||||
self, records: list[dict[str, Any]]
|
||||
) -> dict[str, Any]:
|
||||
"""Analyze query distribution"""
|
||||
query_types = {}
|
||||
user_distribution = {}
|
||||
|
||||
for record in records:
|
||||
# Analyze query type
|
||||
sql_upper = record["sql"].strip().upper()
|
||||
if sql_upper.startswith("SELECT"):
|
||||
query_type = "SELECT"
|
||||
elif sql_upper.startswith("INSERT"):
|
||||
query_type = "INSERT"
|
||||
elif sql_upper.startswith("UPDATE"):
|
||||
query_type = "UPDATE"
|
||||
elif sql_upper.startswith("DELETE"):
|
||||
query_type = "DELETE"
|
||||
else:
|
||||
query_type = "OTHER"
|
||||
|
||||
query_types[query_type] = query_types.get(query_type, 0) + 1
|
||||
|
||||
# Analyze user distribution
|
||||
user_id = record["user_id"]
|
||||
user_distribution[user_id] = user_distribution.get(user_id, 0) + 1
|
||||
|
||||
return {"query_types": query_types, "user_distribution": user_distribution}
|
||||
|
||||
|
||||
# Unified convenience function for MCP integration
|
||||
async def execute_sql_query(sql: str, connection_manager: DorisConnectionManager, **kwargs) -> Dict[str, Any]:
|
||||
"""Execute SQL query - unified convenience function for MCP tools"""
|
||||
try:
|
||||
# Create query executor
|
||||
executor = DorisQueryExecutor(connection_manager)
|
||||
|
||||
try:
|
||||
# Extract parameters from kwargs or use defaults
|
||||
limit = kwargs.get("limit", 1000)
|
||||
timeout = kwargs.get("timeout", 30)
|
||||
session_id = kwargs.get("session_id", "mcp_session")
|
||||
user_id = kwargs.get("user_id", "mcp_user")
|
||||
|
||||
result = await executor.execute_sql_for_mcp(
|
||||
sql=sql,
|
||||
limit=limit,
|
||||
timeout=timeout,
|
||||
session_id=session_id,
|
||||
user_id=user_id
|
||||
)
|
||||
return result
|
||||
finally:
|
||||
await executor.close()
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Query execution failed: {str(e)}",
|
||||
"data": None
|
||||
}
|
||||
@@ -8,6 +8,8 @@ import os
|
||||
import json
|
||||
import pandas as pd
|
||||
import re
|
||||
import uuid
|
||||
import time
|
||||
from typing import Dict, List, Any, Optional, Tuple
|
||||
from dotenv import load_dotenv
|
||||
from datetime import datetime, timedelta
|
||||
@@ -26,23 +28,25 @@ ENABLE_MULTI_DATABASE=os.getenv("ENABLE_MULTI_DATABASE",True)
|
||||
MULTI_DATABASE_NAMES=os.getenv("MULTI_DATABASE_NAMES","")
|
||||
|
||||
# Import local modules
|
||||
from doris_mcp_server.utils.db import execute_query_df, execute_query
|
||||
from .db import DorisConnectionManager
|
||||
|
||||
class MetadataExtractor:
|
||||
"""Apache Doris Metadata Extractor"""
|
||||
|
||||
def __init__(self, db_name: str = None, catalog_name: str = None):
|
||||
def __init__(self, db_name: str = None, catalog_name: str = None, connection_manager=None):
|
||||
"""
|
||||
Initialize the metadata extractor
|
||||
|
||||
Args:
|
||||
db_name: Default database name, uses the currently connected database if not specified
|
||||
catalog_name: Default catalog name for federation queries, uses the current catalog if not specified
|
||||
connection_manager: DorisConnectionManager instance for database operations
|
||||
"""
|
||||
# Get configuration from environment variables
|
||||
self.db_name = db_name or os.getenv("DB_DATABASE", "")
|
||||
self.catalog_name = catalog_name # Store catalog name for federation support
|
||||
self.metadata_db = METADATA_DB_NAME # Use constant
|
||||
self.connection_manager = connection_manager
|
||||
|
||||
# Caching system
|
||||
self.metadata_cache = {}
|
||||
@@ -65,6 +69,9 @@ class MetadataExtractor:
|
||||
# List of excluded system databases
|
||||
self.excluded_databases = self._load_excluded_databases()
|
||||
|
||||
# Session ID for database queries
|
||||
self._session_id = f"metadata_extractor_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
def _load_excluded_databases(self) -> List[str]:
|
||||
"""
|
||||
Load the list of excluded databases configuration
|
||||
@@ -482,7 +489,7 @@ class MetadataExtractor:
|
||||
TABLE_SCHEMA = '{db_name}'
|
||||
AND TABLE_NAME = '{table_name}'
|
||||
"""
|
||||
table_type_result = execute_query(table_type_query)
|
||||
table_type_result = self._execute_query(table_type_query)
|
||||
if table_type_result:
|
||||
schema["table_type"] = table_type_result[0].get("TABLE_TYPE", "")
|
||||
schema["engine"] = table_type_result[0].get("ENGINE", "")
|
||||
@@ -633,31 +640,52 @@ class MetadataExtractor:
|
||||
else:
|
||||
query = f"SHOW INDEX FROM `{db_name}`.`{table_name}`"
|
||||
|
||||
df = execute_query_df(query)
|
||||
try:
|
||||
df = self._execute_query(query, return_dataframe=True)
|
||||
|
||||
# Process results
|
||||
indexes = []
|
||||
current_index = None
|
||||
# Process results
|
||||
indexes = []
|
||||
current_index = None
|
||||
|
||||
for _, row in df.iterrows():
|
||||
index_name = row['Key_name']
|
||||
column_name = row['Column_name']
|
||||
if not df.empty:
|
||||
for _, row in df.iterrows():
|
||||
try:
|
||||
index_name = row['Key_name']
|
||||
column_name = row['Column_name']
|
||||
|
||||
if current_index is None or current_index['name'] != index_name:
|
||||
if current_index is not None:
|
||||
indexes.append(current_index)
|
||||
|
||||
current_index = {
|
||||
'name': index_name,
|
||||
'columns': [column_name],
|
||||
'unique': row['Non_unique'] == 0,
|
||||
'type': row['Index_type']
|
||||
}
|
||||
else:
|
||||
current_index['columns'].append(column_name)
|
||||
except Exception as row_error:
|
||||
logger.warning(f"Failed to process index row data: {row_error}")
|
||||
continue
|
||||
|
||||
if current_index is None or current_index['name'] != index_name:
|
||||
if current_index is not None:
|
||||
indexes.append(current_index)
|
||||
|
||||
current_index = {
|
||||
'name': index_name,
|
||||
'columns': [column_name],
|
||||
'unique': row['Non_unique'] == 0,
|
||||
'type': row['Index_type']
|
||||
}
|
||||
else:
|
||||
current_index['columns'].append(column_name)
|
||||
|
||||
if current_index is not None:
|
||||
indexes.append(current_index)
|
||||
except Exception as df_error:
|
||||
logger.warning(f"DataFrame processing failed, trying regular query: {df_error}")
|
||||
# Fall back to regular query
|
||||
result = self._execute_query(query, return_dataframe=False)
|
||||
indexes = []
|
||||
if result:
|
||||
# Simple processing, no complex index grouping
|
||||
for row in result:
|
||||
if isinstance(row, dict):
|
||||
indexes.append({
|
||||
'name': row.get('Key_name', ''),
|
||||
'columns': [row.get('Column_name', '')],
|
||||
'unique': row.get('Non_unique', 1) == 0,
|
||||
'type': row.get('Index_type', '')
|
||||
})
|
||||
|
||||
# Update cache
|
||||
self.metadata_cache[cache_key] = indexes
|
||||
@@ -748,7 +776,7 @@ class MetadataExtractor:
|
||||
ORDER BY time DESC
|
||||
LIMIT {limit}
|
||||
"""
|
||||
df = execute_query_df(query)
|
||||
df = self._execute_query(query, return_dataframe=True)
|
||||
return df
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting audit logs: {str(e)}")
|
||||
@@ -768,7 +796,7 @@ class MetadataExtractor:
|
||||
try:
|
||||
# Use SHOW CATALOGS command to get catalog list
|
||||
query = "SHOW CATALOGS"
|
||||
result = execute_query(query)
|
||||
result = self._execute_query(query)
|
||||
|
||||
if not result:
|
||||
catalogs = []
|
||||
@@ -1057,7 +1085,7 @@ class MetadataExtractor:
|
||||
AND TABLE_NAME = '{table_name}'
|
||||
"""
|
||||
|
||||
partitions = execute_query(query)
|
||||
partitions = self._execute_query(query)
|
||||
|
||||
if not partitions:
|
||||
return {}
|
||||
@@ -1099,10 +1127,511 @@ class MetadataExtractor:
|
||||
# Replace 'information_schema' with 'catalog_name.information_schema'
|
||||
modified_query = query.replace('information_schema', f'{catalog_name}.information_schema')
|
||||
logger.info(f"Modified query for catalog {catalog_name}: {modified_query}")
|
||||
return execute_query(modified_query, db_name)
|
||||
return self._execute_query(modified_query, db_name)
|
||||
else:
|
||||
# Execute the original query
|
||||
return execute_query(query, db_name)
|
||||
return self._execute_query(query, db_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing query with catalog: {str(e)}")
|
||||
raise
|
||||
|
||||
async def _execute_query_async(self, query: str, db_name: str = None, return_dataframe: bool = False):
|
||||
"""
|
||||
Execute database query asynchronously
|
||||
|
||||
Args:
|
||||
query: SQL query to execute
|
||||
db_name: Database name to use (optional)
|
||||
return_dataframe: Whether to return a pandas DataFrame instead of list
|
||||
|
||||
Returns:
|
||||
Query result data (list of dictionaries or pandas DataFrame)
|
||||
"""
|
||||
try:
|
||||
if self.connection_manager:
|
||||
# Use the injected connection manager directly (async)
|
||||
result = await self.connection_manager.execute_query(self._session_id, query, None)
|
||||
|
||||
# Extract data from QueryResult
|
||||
if hasattr(result, 'data'):
|
||||
data = result.data
|
||||
else:
|
||||
data = result
|
||||
|
||||
# Convert to DataFrame if requested
|
||||
if return_dataframe and data:
|
||||
import pandas as pd
|
||||
return pd.DataFrame(data)
|
||||
elif return_dataframe:
|
||||
import pandas as pd
|
||||
return pd.DataFrame()
|
||||
else:
|
||||
return data
|
||||
else:
|
||||
# Fallback: Return empty result
|
||||
logger.warning("No connection manager provided, returning empty result")
|
||||
if return_dataframe:
|
||||
import pandas as pd
|
||||
return pd.DataFrame()
|
||||
else:
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing query: {str(e)}")
|
||||
# Return empty result instead of raising exception to prevent cascade failures
|
||||
if return_dataframe:
|
||||
import pandas as pd
|
||||
return pd.DataFrame()
|
||||
else:
|
||||
return []
|
||||
|
||||
def _execute_query(self, query: str, db_name: str = None, return_dataframe: bool = False):
|
||||
"""
|
||||
Execute database query with proper session management (sync wrapper)
|
||||
|
||||
Args:
|
||||
query: SQL query to execute
|
||||
db_name: Database name to use (optional)
|
||||
return_dataframe: Whether to return a pandas DataFrame instead of list
|
||||
|
||||
Returns:
|
||||
Query result data (list of dictionaries or pandas DataFrame)
|
||||
"""
|
||||
try:
|
||||
if self.connection_manager:
|
||||
import asyncio
|
||||
|
||||
# Try to run the async query
|
||||
try:
|
||||
# Check if there's a running event loop
|
||||
loop = asyncio.get_running_loop()
|
||||
# If we're in an async context, we need to run in a separate thread
|
||||
import concurrent.futures
|
||||
|
||||
def run_in_new_loop():
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
try:
|
||||
return new_loop.run_until_complete(
|
||||
self._execute_query_async(query, db_name, return_dataframe)
|
||||
)
|
||||
finally:
|
||||
new_loop.close()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_new_loop)
|
||||
return future.result(timeout=30)
|
||||
|
||||
except RuntimeError:
|
||||
# No running loop, we can safely create one
|
||||
return asyncio.run(
|
||||
self._execute_query_async(query, db_name, return_dataframe)
|
||||
)
|
||||
else:
|
||||
# Fallback: Return empty result
|
||||
logger.warning("No connection manager provided, returning empty result")
|
||||
if return_dataframe:
|
||||
import pandas as pd
|
||||
return pd.DataFrame()
|
||||
else:
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"Error executing query: {str(e)}")
|
||||
# Return empty result instead of raising exception to prevent cascade failures
|
||||
if return_dataframe:
|
||||
import pandas as pd
|
||||
return pd.DataFrame()
|
||||
else:
|
||||
return []
|
||||
|
||||
async def get_table_schema_async(self, table_name: str, db_name: str = None, catalog_name: str = None) -> List[Dict[str, Any]]:
|
||||
"""Asynchronously get table schema information"""
|
||||
try:
|
||||
# Use async query method
|
||||
effective_catalog = catalog_name or self.catalog_name
|
||||
|
||||
# Build query statement
|
||||
if effective_catalog and effective_catalog != "internal":
|
||||
query = f"DESCRIBE `{effective_catalog}`.`{db_name or self.db_name}`.`{table_name}`"
|
||||
else:
|
||||
query = f"DESCRIBE `{db_name or self.db_name}`.`{table_name}`"
|
||||
|
||||
# Execute async query
|
||||
result = await self._execute_query_async(query, db_name)
|
||||
|
||||
if not result:
|
||||
return []
|
||||
|
||||
# Process results
|
||||
schema = []
|
||||
for row in result:
|
||||
if isinstance(row, dict):
|
||||
schema.append({
|
||||
'column_name': row.get('Field', ''),
|
||||
'data_type': row.get('Type', ''),
|
||||
'is_nullable': row.get('Null', 'NO') == 'YES',
|
||||
'default_value': row.get('Default', None),
|
||||
'comment': row.get('Comment', ''),
|
||||
'key': row.get('Key', ''),
|
||||
'extra': row.get('Extra', '')
|
||||
})
|
||||
|
||||
return schema
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get table schema: {e}")
|
||||
return []
|
||||
|
||||
async def get_all_databases_async(self, catalog_name: str = None) -> List[str]:
|
||||
"""Asynchronously get all database list"""
|
||||
try:
|
||||
effective_catalog = catalog_name or self.catalog_name
|
||||
|
||||
if effective_catalog and effective_catalog != "internal":
|
||||
query = f"SHOW DATABASES FROM `{effective_catalog}`"
|
||||
else:
|
||||
query = "SHOW DATABASES"
|
||||
|
||||
result = await self._execute_query_async(query)
|
||||
|
||||
if not result:
|
||||
return []
|
||||
|
||||
# Extract database names
|
||||
databases = []
|
||||
for row in result:
|
||||
if isinstance(row, dict):
|
||||
# Get the value of the first field (usually Database field)
|
||||
db_name = list(row.values())[0] if row else None
|
||||
if db_name:
|
||||
databases.append(db_name)
|
||||
|
||||
return databases
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get database list: {e}")
|
||||
return []
|
||||
|
||||
async def get_database_tables_async(self, db_name: str = None, catalog_name: str = None) -> List[str]:
|
||||
"""Asynchronously get table list in database"""
|
||||
try:
|
||||
effective_catalog = catalog_name or self.catalog_name
|
||||
effective_db = db_name or self.db_name
|
||||
|
||||
if effective_catalog and effective_catalog != "internal":
|
||||
query = f"SHOW TABLES FROM `{effective_catalog}`.`{effective_db}`"
|
||||
else:
|
||||
query = f"SHOW TABLES FROM `{effective_db}`"
|
||||
|
||||
result = await self._execute_query_async(query, effective_db)
|
||||
|
||||
if not result:
|
||||
return []
|
||||
|
||||
# Extract table names
|
||||
tables = []
|
||||
for row in result:
|
||||
if isinstance(row, dict):
|
||||
# Get the value of the first field (usually Tables_in_xxx field)
|
||||
table_name = list(row.values())[0] if row else None
|
||||
if table_name:
|
||||
tables.append(table_name)
|
||||
|
||||
return tables
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get table list: {e}")
|
||||
return []
|
||||
|
||||
async def get_catalog_list_async(self) -> List[str]:
|
||||
"""Asynchronously get catalog list"""
|
||||
try:
|
||||
query = "SHOW CATALOGS"
|
||||
result = await self._execute_query_async(query)
|
||||
|
||||
if not result:
|
||||
return []
|
||||
|
||||
# Extract catalog names
|
||||
catalogs = []
|
||||
for row in result:
|
||||
if isinstance(row, dict):
|
||||
# SHOW CATALOGS returns fields including: CatalogId, CatalogName, Type, IsCurrent, CreateTime, LastUpdateTime, Comment
|
||||
# We need to get the CatalogName field (second field)
|
||||
if 'CatalogName' in row:
|
||||
catalog_name = row['CatalogName']
|
||||
else:
|
||||
# If no CatalogName field, try to get the second field
|
||||
values = list(row.values())
|
||||
catalog_name = values[1] if len(values) > 1 else values[0] if values else None
|
||||
|
||||
if catalog_name:
|
||||
catalogs.append(str(catalog_name))
|
||||
|
||||
return catalogs
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get catalog list: {e}")
|
||||
return []
|
||||
|
||||
# ==================== Business layer methods (original metadata_tools.py functionality) ====================
|
||||
|
||||
def _format_response(self, success: bool, result: Any = None, error: str = None, message: str = "") -> Dict[str, Any]:
|
||||
"""Format response result"""
|
||||
response_data = {
|
||||
"success": success,
|
||||
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
}
|
||||
if success and result is not None:
|
||||
response_data["result"] = result
|
||||
response_data["message"] = message or "Operation successful"
|
||||
elif not success:
|
||||
response_data["error"] = error or "Unknown error"
|
||||
response_data["message"] = message or "Operation failed"
|
||||
|
||||
return response_data
|
||||
|
||||
async def exec_query_for_mcp(
|
||||
self,
|
||||
sql: str,
|
||||
db_name: str = None,
|
||||
catalog_name: str = None,
|
||||
max_rows: int = 100,
|
||||
timeout: int = 30
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute SQL query and return results, supports catalog federation queries
|
||||
Unified interface for MCP tools
|
||||
"""
|
||||
logger.info(f"Executing SQL query: {sql}, DB: {db_name}, Catalog: {catalog_name}, MaxRows: {max_rows}, Timeout: {timeout}")
|
||||
|
||||
try:
|
||||
if not sql:
|
||||
return self._format_response(success=False, error="No SQL statement provided", message="Please provide SQL statement to execute")
|
||||
|
||||
# Import query executor
|
||||
from .query_executor import execute_sql_query
|
||||
|
||||
# Call execute_sql_query to execute query
|
||||
exec_result = await execute_sql_query(
|
||||
sql=sql,
|
||||
connection_manager=self.connection_manager,
|
||||
limit=max_rows,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
return exec_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute SQL query: {str(e)}", exc_info=True)
|
||||
return self._format_response(success=False, error=str(e), message="Error occurred while executing SQL query")
|
||||
|
||||
async def get_table_schema_for_mcp(
|
||||
self,
|
||||
table_name: str,
|
||||
db_name: str = None,
|
||||
catalog_name: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get detailed schema information for specified table (columns, types, comments, etc.) - MCP interface"""
|
||||
logger.info(f"Getting table schema: Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
|
||||
|
||||
if not table_name:
|
||||
return self._format_response(success=False, error="Missing table_name parameter")
|
||||
|
||||
try:
|
||||
schema = await self.get_table_schema_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
|
||||
if not schema:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error="Table does not exist or has no columns",
|
||||
message=f"Unable to get schema for table {catalog_name or 'default'}.{db_name or self.db_name}.{table_name}"
|
||||
)
|
||||
|
||||
return self._format_response(success=True, result=schema)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get table schema: {str(e)}", exc_info=True)
|
||||
return self._format_response(success=False, error=str(e), message="Error occurred while getting table schema")
|
||||
|
||||
async def get_db_table_list_for_mcp(
|
||||
self,
|
||||
db_name: str = None,
|
||||
catalog_name: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get list of all table names in specified database - MCP interface"""
|
||||
logger.info(f"Getting database table list: DB: {db_name}, Catalog: {catalog_name}")
|
||||
|
||||
try:
|
||||
tables = await self.get_database_tables_async(db_name=db_name, catalog_name=catalog_name)
|
||||
return self._format_response(success=True, result=tables)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get database table list: {str(e)}", exc_info=True)
|
||||
return self._format_response(success=False, error=str(e), message="Error occurred while getting database table list")
|
||||
|
||||
async def get_db_list_for_mcp(self, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Get list of all database names on server - MCP interface"""
|
||||
logger.info(f"Getting database list: Catalog: {catalog_name}")
|
||||
|
||||
try:
|
||||
databases = await self.get_all_databases_async(catalog_name=catalog_name)
|
||||
return self._format_response(success=True, result=databases)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get database list: {str(e)}", exc_info=True)
|
||||
return self._format_response(success=False, error=str(e), message="Error occurred while getting database list")
|
||||
|
||||
async def get_table_comment_for_mcp(
|
||||
self,
|
||||
table_name: str,
|
||||
db_name: str = None,
|
||||
catalog_name: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get comment information for specified table - MCP interface"""
|
||||
logger.info(f"Getting table comment: Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
|
||||
|
||||
if not table_name:
|
||||
return self._format_response(success=False, error="Missing table_name parameter")
|
||||
|
||||
try:
|
||||
comment = self.get_table_comment(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
return self._format_response(success=True, result=comment)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get table comment: {str(e)}", exc_info=True)
|
||||
return self._format_response(success=False, error=str(e), message="Error occurred while getting table comment")
|
||||
|
||||
async def get_table_column_comments_for_mcp(
|
||||
self,
|
||||
table_name: str,
|
||||
db_name: str = None,
|
||||
catalog_name: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get comment information for all columns in specified table - MCP interface"""
|
||||
logger.info(f"Getting table column comments: Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
|
||||
|
||||
if not table_name:
|
||||
return self._format_response(success=False, error="Missing table_name parameter")
|
||||
|
||||
try:
|
||||
comments = self.get_column_comments(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
return self._format_response(success=True, result=comments)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get table column comments: {str(e)}", exc_info=True)
|
||||
return self._format_response(success=False, error=str(e), message="Error occurred while getting table column comments")
|
||||
|
||||
async def get_table_indexes_for_mcp(
|
||||
self,
|
||||
table_name: str,
|
||||
db_name: str = None,
|
||||
catalog_name: str = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Get index information for specified table - MCP interface"""
|
||||
logger.info(f"Getting table indexes: Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}")
|
||||
|
||||
if not table_name:
|
||||
return self._format_response(success=False, error="Missing table_name parameter")
|
||||
|
||||
try:
|
||||
indexes = self.get_table_indexes(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
return self._format_response(success=True, result=indexes)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get table indexes: {str(e)}", exc_info=True)
|
||||
return self._format_response(success=False, error=str(e), message="Error occurred while getting table indexes")
|
||||
|
||||
def _serialize_datetime_objects(self, data):
|
||||
"""Serialize datetime objects to JSON compatible format"""
|
||||
if isinstance(data, list):
|
||||
return [self._serialize_datetime_objects(item) for item in data]
|
||||
elif isinstance(data, dict):
|
||||
return {key: self._serialize_datetime_objects(value) for key, value in data.items()}
|
||||
elif hasattr(data, 'isoformat'): # datetime, date, time objects
|
||||
return data.isoformat()
|
||||
elif hasattr(data, 'strftime'): # pandas Timestamp objects
|
||||
return data.strftime('%Y-%m-%d %H:%M:%S')
|
||||
else:
|
||||
return data
|
||||
|
||||
async def get_recent_audit_logs_for_mcp(self, days: int = 7, limit: int = 100) -> Dict[str, Any]:
|
||||
"""Get recent audit log records - MCP interface"""
|
||||
logger.info(f"Getting audit logs: Days: {days}, Limit: {limit}")
|
||||
|
||||
try:
|
||||
logs_df = self.get_recent_audit_logs(days=days, limit=limit)
|
||||
|
||||
# Convert DataFrame to JSON format
|
||||
if hasattr(logs_df, 'to_dict'):
|
||||
try:
|
||||
logs_data = logs_df.to_dict('records')
|
||||
except Exception as e:
|
||||
logger.warning(f"DataFrame.to_dict failed, trying manual conversion: {e}")
|
||||
# Manually convert DataFrame to records format
|
||||
logs_data = []
|
||||
if not logs_df.empty:
|
||||
for _, row in logs_df.iterrows():
|
||||
logs_data.append(dict(row))
|
||||
# Serialize datetime objects
|
||||
logs_data = self._serialize_datetime_objects(logs_data)
|
||||
else:
|
||||
logs_data = self._serialize_datetime_objects(logs_df)
|
||||
|
||||
return self._format_response(success=True, result=logs_data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get audit logs: {str(e)}", exc_info=True)
|
||||
return self._format_response(success=False, error=str(e), message="Error occurred while getting audit logs")
|
||||
|
||||
async def get_catalog_list_for_mcp(self) -> Dict[str, Any]:
|
||||
"""Get Doris catalog list - MCP interface"""
|
||||
logger.info("Getting catalog list")
|
||||
|
||||
try:
|
||||
catalogs = await self.get_catalog_list_async()
|
||||
return self._format_response(success=True, result=catalogs, message="Successfully retrieved catalog list")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get catalog list: {str(e)}", exc_info=True)
|
||||
return self._format_response(success=False, error=str(e), message="Error occurred while getting catalog list")
|
||||
|
||||
|
||||
# ==================== Compatibility aliases ====================
|
||||
|
||||
# For backward compatibility, create MetadataManager alias
|
||||
class MetadataManager:
|
||||
"""
|
||||
Metadata manager - backward compatibility class
|
||||
Actually a wrapper for MetadataExtractor
|
||||
"""
|
||||
|
||||
def __init__(self, connection_manager=None):
|
||||
self.extractor = MetadataExtractor(connection_manager=connection_manager)
|
||||
|
||||
async def exec_query(self, sql: str, db_name: str = None, catalog_name: str = None, max_rows: int = 100, timeout: int = 30) -> Dict[str, Any]:
|
||||
"""Execute SQL query and return results, supports catalog federation queries"""
|
||||
return await self.extractor.exec_query_for_mcp(sql, db_name, catalog_name, max_rows, timeout)
|
||||
|
||||
async def get_table_schema(self, table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Get detailed schema information for specified table (columns, types, comments, etc.)"""
|
||||
return await self.extractor.get_table_schema_for_mcp(table_name, db_name, catalog_name)
|
||||
|
||||
async def get_db_table_list(self, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Get list of all table names in specified database"""
|
||||
return await self.extractor.get_db_table_list_for_mcp(db_name, catalog_name)
|
||||
|
||||
async def get_db_list(self, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Get list of all database names on server"""
|
||||
return await self.extractor.get_db_list_for_mcp(catalog_name)
|
||||
|
||||
async def get_table_comment(self, table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Get comment information for specified table"""
|
||||
return await self.extractor.get_table_comment_for_mcp(table_name, db_name, catalog_name)
|
||||
|
||||
async def get_table_column_comments(self, table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Get comment information for all columns in specified table"""
|
||||
return await self.extractor.get_table_column_comments_for_mcp(table_name, db_name, catalog_name)
|
||||
|
||||
async def get_table_indexes(self, table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]:
|
||||
"""Get index information for specified table"""
|
||||
return await self.extractor.get_table_indexes_for_mcp(table_name, db_name, catalog_name)
|
||||
|
||||
async def get_recent_audit_logs(self, days: int = 7, limit: int = 100) -> Dict[str, Any]:
|
||||
"""Get recent audit log records"""
|
||||
return await self.extractor.get_recent_audit_logs_for_mcp(days, limit)
|
||||
|
||||
async def get_catalog_list(self) -> Dict[str, Any]:
|
||||
"""Get Doris catalog list"""
|
||||
return await self.extractor.get_catalog_list_for_mcp()
|
||||
861
doris_mcp_server/utils/security.py
Normal file
861
doris_mcp_server/utils/security.py
Normal file
@@ -0,0 +1,861 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Doris Security Management Module
|
||||
Implements enterprise-level authentication, authorization, SQL security validation and data masking functionality
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import sqlparse
|
||||
from sqlparse.sql import Statement
|
||||
from sqlparse.tokens import Keyword, Name
|
||||
|
||||
|
||||
class SecurityLevel(Enum):
|
||||
"""Security level enumeration"""
|
||||
|
||||
PUBLIC = "public"
|
||||
INTERNAL = "internal"
|
||||
CONFIDENTIAL = "confidential"
|
||||
SECRET = "secret"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthContext:
|
||||
"""Authentication context"""
|
||||
|
||||
user_id: str
|
||||
roles: list[str]
|
||||
permissions: list[str]
|
||||
session_id: str
|
||||
login_time: datetime | None = None
|
||||
last_activity: datetime | None = None
|
||||
security_level: SecurityLevel = SecurityLevel.INTERNAL
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""Validation result"""
|
||||
|
||||
is_valid: bool
|
||||
error_message: str | None = None
|
||||
risk_level: str = "low"
|
||||
blocked_operations: list[str] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.blocked_operations is None:
|
||||
self.blocked_operations = []
|
||||
|
||||
|
||||
@dataclass
|
||||
class MaskingRule:
|
||||
"""Data masking rule"""
|
||||
|
||||
column_pattern: str
|
||||
algorithm: str
|
||||
parameters: dict[str, Any]
|
||||
security_level: SecurityLevel
|
||||
|
||||
|
||||
class DorisSecurityManager:
|
||||
"""Doris security manager
|
||||
|
||||
Provides complete security control functionality, including authentication, authorization, SQL security validation and data masking
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Initialize security components
|
||||
self.auth_provider = AuthenticationProvider(config)
|
||||
self.authz_provider = AuthorizationProvider(config)
|
||||
self.sql_validator = SQLSecurityValidator(config)
|
||||
self.masking_processor = DataMaskingProcessor(config)
|
||||
|
||||
# Security rule configuration
|
||||
self.blocked_keywords = self._load_blocked_keywords()
|
||||
self.sensitive_tables = self._load_sensitive_tables()
|
||||
self.masking_rules = self._load_masking_rules()
|
||||
|
||||
def _load_blocked_keywords(self) -> set[str]:
|
||||
"""Load blocked SQL keywords"""
|
||||
default_blocked = {
|
||||
"DROP",
|
||||
"DELETE",
|
||||
"TRUNCATE",
|
||||
"ALTER",
|
||||
"CREATE",
|
||||
"INSERT",
|
||||
"UPDATE",
|
||||
"GRANT",
|
||||
"REVOKE",
|
||||
"EXEC",
|
||||
"EXECUTE",
|
||||
"SHUTDOWN",
|
||||
"KILL",
|
||||
}
|
||||
|
||||
# Load custom rules from configuration file
|
||||
if hasattr(self.config, 'get'):
|
||||
custom_blocked = set(self.config.get("blocked_keywords", []))
|
||||
else:
|
||||
custom_blocked = set()
|
||||
|
||||
return default_blocked.union(custom_blocked)
|
||||
|
||||
def _load_sensitive_tables(self) -> dict[str, SecurityLevel]:
|
||||
"""Load sensitive table configuration"""
|
||||
default_tables = {
|
||||
"user_info": SecurityLevel.CONFIDENTIAL,
|
||||
"payment_records": SecurityLevel.SECRET,
|
||||
"employee_data": SecurityLevel.CONFIDENTIAL,
|
||||
"public_reports": SecurityLevel.PUBLIC,
|
||||
}
|
||||
|
||||
if hasattr(self.config, 'get'):
|
||||
config_tables = self.config.get("sensitive_tables", {})
|
||||
# Convert string values to SecurityLevel enum
|
||||
for table_name, level in config_tables.items():
|
||||
if isinstance(level, str):
|
||||
try:
|
||||
default_tables[table_name] = SecurityLevel(level.lower())
|
||||
except ValueError:
|
||||
default_tables[table_name] = SecurityLevel.INTERNAL
|
||||
else:
|
||||
default_tables[table_name] = level
|
||||
return default_tables
|
||||
else:
|
||||
return default_tables
|
||||
|
||||
def _load_masking_rules(self) -> list[MaskingRule]:
|
||||
"""Load data masking rules"""
|
||||
default_rules = [
|
||||
MaskingRule(
|
||||
column_pattern=r".*phone.*|.*mobile.*",
|
||||
algorithm="phone_mask",
|
||||
parameters={"mask_char": "*", "keep_prefix": 3, "keep_suffix": 4},
|
||||
security_level=SecurityLevel.INTERNAL,
|
||||
),
|
||||
MaskingRule(
|
||||
column_pattern=r".*email.*",
|
||||
algorithm="email_mask",
|
||||
parameters={"mask_char": "*"},
|
||||
security_level=SecurityLevel.INTERNAL,
|
||||
),
|
||||
MaskingRule(
|
||||
column_pattern=r".*id_card.*|.*identity.*",
|
||||
algorithm="id_mask",
|
||||
parameters={"mask_char": "*", "keep_prefix": 6, "keep_suffix": 4},
|
||||
security_level=SecurityLevel.CONFIDENTIAL,
|
||||
),
|
||||
]
|
||||
|
||||
# Load custom rules from configuration
|
||||
custom_rules = []
|
||||
if hasattr(self.config, 'get'):
|
||||
custom_rules = self.config.get("masking_rules", [])
|
||||
elif hasattr(self.config, 'security') and hasattr(self.config.security, 'masking_rules'):
|
||||
custom_rules = self.config.security.masking_rules
|
||||
|
||||
for rule_config in custom_rules:
|
||||
if isinstance(rule_config, dict):
|
||||
default_rules.append(MaskingRule(**rule_config))
|
||||
elif isinstance(rule_config, MaskingRule):
|
||||
default_rules.append(rule_config)
|
||||
|
||||
return default_rules
|
||||
|
||||
async def authenticate_request(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Validate request authentication information"""
|
||||
return await self.auth_provider.authenticate(auth_info)
|
||||
|
||||
async def authorize_resource_access(
|
||||
self, auth_context: AuthContext, resource_uri: str
|
||||
) -> bool:
|
||||
"""Validate resource access permissions"""
|
||||
return await self.authz_provider.check_permission(
|
||||
auth_context, resource_uri, "read"
|
||||
)
|
||||
|
||||
async def validate_sql_security(
|
||||
self, sql: str, auth_context: AuthContext
|
||||
) -> ValidationResult:
|
||||
"""Validate SQL query security"""
|
||||
return await self.sql_validator.validate(sql, auth_context)
|
||||
|
||||
async def apply_data_masking(
|
||||
self, data: list[dict[str, Any]], auth_context: AuthContext
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Apply data masking processing"""
|
||||
return await self.masking_processor.process(data, auth_context)
|
||||
|
||||
|
||||
class AuthenticationProvider:
|
||||
"""Authentication provider"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.session_cache = {}
|
||||
|
||||
async def authenticate(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Perform identity authentication"""
|
||||
auth_type = auth_info.get("type", "token")
|
||||
|
||||
if auth_type == "token":
|
||||
return await self._authenticate_token(auth_info)
|
||||
elif auth_type == "basic":
|
||||
return await self._authenticate_basic(auth_info)
|
||||
else:
|
||||
raise ValueError(f"Unsupported authentication type: {auth_type}")
|
||||
|
||||
async def _authenticate_token(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Token authentication"""
|
||||
token = auth_info.get("token")
|
||||
if not token:
|
||||
raise ValueError("Missing authentication token")
|
||||
|
||||
# Validate token (simplified implementation, should validate JWT or query authentication service in practice)
|
||||
user_info = await self._validate_token(token)
|
||||
|
||||
return AuthContext(
|
||||
user_id=user_info["user_id"],
|
||||
roles=user_info["roles"],
|
||||
permissions=user_info["permissions"],
|
||||
session_id=auth_info.get("session_id", "default"),
|
||||
login_time=datetime.utcnow(),
|
||||
security_level=SecurityLevel(user_info.get("security_level", "internal")),
|
||||
)
|
||||
|
||||
async def _authenticate_basic(self, auth_info: dict[str, Any]) -> AuthContext:
|
||||
"""Basic authentication (username password)"""
|
||||
username = auth_info.get("username")
|
||||
password = auth_info.get("password")
|
||||
|
||||
if not username or not password:
|
||||
raise ValueError("Missing username or password")
|
||||
|
||||
# Validate username password (simplified implementation)
|
||||
user_info = await self._validate_credentials(username, password)
|
||||
|
||||
return AuthContext(
|
||||
user_id=user_info["user_id"],
|
||||
roles=user_info["roles"],
|
||||
permissions=user_info["permissions"],
|
||||
session_id=auth_info.get("session_id", "default"),
|
||||
login_time=datetime.utcnow(),
|
||||
security_level=SecurityLevel(user_info.get("security_level", "internal")),
|
||||
)
|
||||
|
||||
async def _validate_token(self, token: str) -> dict[str, Any]:
|
||||
"""Validate token validity"""
|
||||
# Simplified implementation for testing, should parse JWT or query authentication service in practice
|
||||
valid_tokens = {
|
||||
"valid_token_123": {
|
||||
"user_id": "test_user",
|
||||
"roles": ["data_analyst"],
|
||||
"permissions": ["read_data"],
|
||||
"security_level": SecurityLevel.INTERNAL,
|
||||
},
|
||||
"admin_token_456": {
|
||||
"user_id": "admin_user",
|
||||
"roles": ["data_admin"],
|
||||
"permissions": ["admin"],
|
||||
"security_level": SecurityLevel.SECRET,
|
||||
}
|
||||
}
|
||||
|
||||
if token in valid_tokens:
|
||||
return valid_tokens[token]
|
||||
else:
|
||||
raise ValueError("Invalid token")
|
||||
|
||||
async def _validate_credentials(
|
||||
self, username: str, password: str
|
||||
) -> dict[str, Any]:
|
||||
"""Validate user credentials"""
|
||||
# Simplified implementation for testing, should query user database in practice
|
||||
valid_users = {
|
||||
"admin": {
|
||||
"password": "admin123",
|
||||
"user_id": "admin_user",
|
||||
"roles": ["data_admin"],
|
||||
"permissions": ["admin", "read_data", "write_data"],
|
||||
"security_level": SecurityLevel.SECRET,
|
||||
},
|
||||
"analyst": {
|
||||
"password": "analyst123",
|
||||
"user_id": "analyst_user",
|
||||
"roles": ["data_analyst"],
|
||||
"permissions": ["read_data"],
|
||||
"security_level": SecurityLevel.INTERNAL,
|
||||
}
|
||||
}
|
||||
|
||||
if username in valid_users and valid_users[username]["password"] == password:
|
||||
user_info = valid_users[username].copy()
|
||||
del user_info["password"] # Remove password from returned info
|
||||
return user_info
|
||||
else:
|
||||
raise ValueError("Incorrect username or password")
|
||||
|
||||
|
||||
class AuthorizationProvider:
|
||||
"""Authorization provider"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.permission_cache = {}
|
||||
|
||||
# Load sensitive tables configuration
|
||||
self.sensitive_tables = self._load_sensitive_tables()
|
||||
|
||||
def _load_sensitive_tables(self) -> dict[str, SecurityLevel]:
|
||||
"""Load sensitive table configuration"""
|
||||
default_tables = {
|
||||
"user_info": SecurityLevel.CONFIDENTIAL,
|
||||
"payment_records": SecurityLevel.SECRET,
|
||||
"employee_data": SecurityLevel.CONFIDENTIAL,
|
||||
"public_reports": SecurityLevel.PUBLIC,
|
||||
}
|
||||
|
||||
if hasattr(self.config, 'get'):
|
||||
config_tables = self.config.get("sensitive_tables", {})
|
||||
# Convert string values to SecurityLevel enum
|
||||
for table_name, level in config_tables.items():
|
||||
if isinstance(level, str):
|
||||
try:
|
||||
default_tables[table_name] = SecurityLevel(level.lower())
|
||||
except ValueError:
|
||||
default_tables[table_name] = SecurityLevel.INTERNAL
|
||||
else:
|
||||
default_tables[table_name] = level
|
||||
return default_tables
|
||||
else:
|
||||
return default_tables
|
||||
|
||||
async def check_permission(
|
||||
self, auth_context: AuthContext, resource_uri: str, action: str
|
||||
) -> bool:
|
||||
"""Check permissions"""
|
||||
# Parse resource information
|
||||
resource_info = self._parse_resource_uri(resource_uri)
|
||||
|
||||
# First check security level - this is mandatory
|
||||
if not await self._check_security_level_permission(auth_context, resource_info):
|
||||
return False
|
||||
|
||||
# Then check role-based permissions
|
||||
if await self._check_role_permission(auth_context, resource_info, action):
|
||||
return True
|
||||
|
||||
# Finally check user-based permissions
|
||||
if await self._check_user_permission(auth_context, resource_info, action):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _parse_resource_uri(self, uri: str) -> dict[str, str]:
|
||||
"""Parse resource URI"""
|
||||
parts = uri.split("/")
|
||||
if len(parts) >= 3:
|
||||
return {
|
||||
"type": parts[2], # table, view, etc.
|
||||
"name": parts[3] if len(parts) > 3 else "",
|
||||
"schema": parts[4] if len(parts) > 4 else "default",
|
||||
}
|
||||
return {"type": "unknown", "name": "", "schema": "default"}
|
||||
|
||||
async def _check_role_permission(
|
||||
self, auth_context: AuthContext, resource_info: dict[str, str], action: str
|
||||
) -> bool:
|
||||
"""Check role-based permissions"""
|
||||
# Role permission mapping
|
||||
role_permissions = {
|
||||
"data_analyst": {"table": ["read"], "view": ["read"]},
|
||||
"data_admin": {
|
||||
"table": ["read", "write", "admin"],
|
||||
"view": ["read", "write", "admin"],
|
||||
},
|
||||
}
|
||||
|
||||
for role in auth_context.roles:
|
||||
role_perms = role_permissions.get(role, {})
|
||||
resource_perms = role_perms.get(resource_info["type"], [])
|
||||
if action in resource_perms:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _check_user_permission(
|
||||
self, auth_context: AuthContext, resource_info: dict[str, str], action: str
|
||||
) -> bool:
|
||||
"""Check user-based permissions"""
|
||||
# User-specific permission check
|
||||
if "admin" in auth_context.permissions:
|
||||
return True
|
||||
|
||||
if action == "read" and "read_data" in auth_context.permissions:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _check_security_level_permission(
|
||||
self, auth_context: AuthContext, resource_info: dict[str, str]
|
||||
) -> bool:
|
||||
"""Check security level permissions"""
|
||||
# Get resource security level
|
||||
resource_security_level = self._get_resource_security_level(resource_info)
|
||||
|
||||
# Check if user security level is sufficient
|
||||
security_hierarchy = {
|
||||
SecurityLevel.PUBLIC: 0,
|
||||
SecurityLevel.INTERNAL: 1,
|
||||
SecurityLevel.CONFIDENTIAL: 2,
|
||||
SecurityLevel.SECRET: 3,
|
||||
}
|
||||
|
||||
user_level = security_hierarchy.get(auth_context.security_level, 0)
|
||||
resource_level = security_hierarchy.get(resource_security_level, 0)
|
||||
|
||||
# User must have higher or equal security level to access resource
|
||||
return user_level >= resource_level
|
||||
|
||||
def _get_resource_security_level(
|
||||
self, resource_info: dict[str, str]
|
||||
) -> SecurityLevel:
|
||||
"""Get resource security level"""
|
||||
# Get table security level from configuration
|
||||
table_name = resource_info.get("name", "")
|
||||
|
||||
# Use the loaded sensitive tables
|
||||
sensitive_tables = self.sensitive_tables
|
||||
|
||||
# Convert string values to SecurityLevel enum if needed
|
||||
security_level = sensitive_tables.get(table_name, SecurityLevel.INTERNAL)
|
||||
if isinstance(security_level, str):
|
||||
try:
|
||||
security_level = SecurityLevel(security_level.lower())
|
||||
except ValueError:
|
||||
security_level = SecurityLevel.INTERNAL
|
||||
|
||||
return security_level
|
||||
|
||||
|
||||
class SQLSecurityValidator:
|
||||
"""SQL security validator"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
# Handle DorisConfig object or dictionary configuration
|
||||
if hasattr(config, 'get'):
|
||||
# Dictionary configuration
|
||||
self.blocked_keywords = set(config.get("blocked_keywords", []))
|
||||
self.max_query_complexity = config.get("max_query_complexity", 100)
|
||||
else:
|
||||
# DorisConfig object, use default values
|
||||
self.blocked_keywords = set(["DROP", "DELETE", "TRUNCATE", "ALTER", "CREATE", "INSERT", "UPDATE"])
|
||||
self.max_query_complexity = 100
|
||||
|
||||
async def validate(self, sql: str, auth_context: AuthContext) -> ValidationResult:
|
||||
"""Validate SQL query security"""
|
||||
try:
|
||||
# Parse SQL statement
|
||||
parsed = sqlparse.parse(sql)[0]
|
||||
|
||||
# Check blocked operations first (more specific)
|
||||
keyword_result = await self._check_blocked_keywords(parsed)
|
||||
if not keyword_result.is_valid:
|
||||
return keyword_result
|
||||
|
||||
# Check SQL injection risks
|
||||
injection_result = await self._check_sql_injection(sql, parsed)
|
||||
if not injection_result.is_valid:
|
||||
return injection_result
|
||||
|
||||
# Check query complexity
|
||||
complexity_result = await self._check_query_complexity(parsed)
|
||||
if not complexity_result.is_valid:
|
||||
return complexity_result
|
||||
|
||||
# Check table access permissions
|
||||
table_result = await self._check_table_access(parsed, auth_context)
|
||||
if not table_result.is_valid:
|
||||
return table_result
|
||||
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"SQL security validation failed: {e}")
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message=f"SQL parsing error: {str(e)}",
|
||||
risk_level="high",
|
||||
)
|
||||
|
||||
async def _check_sql_injection(
|
||||
self, sql: str, parsed: Statement
|
||||
) -> ValidationResult:
|
||||
"""Check SQL injection risks"""
|
||||
# Check common SQL injection patterns
|
||||
injection_patterns = [
|
||||
r"(\s|^)(union|select|insert|update|delete|drop|create|alter)\s+.*\s+(union|select|insert|update|delete|drop|create|alter)",
|
||||
r"(\s|^)(or|and)\s+\d+\s*=\s*\d+",
|
||||
r"(\s|^)(or|and)\s+['\"].*['\"]",
|
||||
r";\s*(drop|delete|truncate|alter|create)",
|
||||
r"(exec|execute|sp_|xp_)",
|
||||
r"(script|javascript|vbscript)",
|
||||
r"(char|ascii|substring|concat)\s*\(",
|
||||
]
|
||||
|
||||
sql_lower = sql.lower()
|
||||
for pattern in injection_patterns:
|
||||
if re.search(pattern, sql_lower, re.IGNORECASE):
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Potential SQL injection risk detected",
|
||||
risk_level="high",
|
||||
)
|
||||
|
||||
# Check suspicious quotes and comments
|
||||
if self._has_suspicious_quotes_or_comments(sql):
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Suspicious quote or comment pattern detected",
|
||||
risk_level="medium",
|
||||
)
|
||||
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
def _has_suspicious_quotes_or_comments(self, sql: str) -> bool:
|
||||
"""Check suspicious quote and comment patterns"""
|
||||
# Check unmatched quotes
|
||||
single_quotes = sql.count("'")
|
||||
double_quotes = sql.count('"')
|
||||
|
||||
if single_quotes % 2 != 0 or double_quotes % 2 != 0:
|
||||
return True
|
||||
|
||||
# Check SQL comments
|
||||
if "--" in sql or "/*" in sql:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _check_blocked_keywords(self, parsed: Statement) -> ValidationResult:
|
||||
"""Check blocked keywords"""
|
||||
blocked_operations = []
|
||||
|
||||
# Check all tokens in the parsed statement
|
||||
for token in parsed.flatten():
|
||||
# Check if token is a keyword (including DML/DDL) or name that matches blocked operations
|
||||
if (token.ttype is Keyword or
|
||||
token.ttype is Name or
|
||||
(token.ttype and str(token.ttype).startswith('Token.Keyword'))):
|
||||
token_value = token.value.upper().strip()
|
||||
if token_value in self.blocked_keywords:
|
||||
blocked_operations.append(token_value)
|
||||
# Also check for DDL/DML keywords in token values
|
||||
elif hasattr(token, 'value') and token.value:
|
||||
token_value = token.value.upper().strip()
|
||||
for blocked_keyword in self.blocked_keywords:
|
||||
if blocked_keyword in token_value:
|
||||
blocked_operations.append(blocked_keyword)
|
||||
|
||||
if blocked_operations:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message=f"Contains blocked operations: {', '.join(set(blocked_operations))}",
|
||||
risk_level="high",
|
||||
blocked_operations=list(set(blocked_operations)),
|
||||
)
|
||||
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
async def _check_query_complexity(self, parsed: Statement) -> ValidationResult:
|
||||
"""Check query complexity"""
|
||||
complexity_score = 0
|
||||
|
||||
# Calculate complexity score
|
||||
for token in parsed.flatten():
|
||||
if token.ttype is Keyword:
|
||||
keyword = token.value.upper()
|
||||
if keyword in ["JOIN", "INNER", "LEFT", "RIGHT", "FULL"]:
|
||||
complexity_score += 10
|
||||
elif keyword in ["UNION", "INTERSECT", "EXCEPT"]:
|
||||
complexity_score += 15
|
||||
elif keyword in ["GROUP BY", "ORDER BY", "HAVING"]:
|
||||
complexity_score += 5
|
||||
elif keyword in ["SUBQUERY", "EXISTS", "IN"]:
|
||||
complexity_score += 8
|
||||
|
||||
if complexity_score > self.max_query_complexity:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message=f"Query complexity too high (score: {complexity_score}, limit: {self.max_query_complexity})",
|
||||
risk_level="medium",
|
||||
)
|
||||
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
async def _check_table_access(
|
||||
self, parsed: Statement, auth_context: AuthContext
|
||||
) -> ValidationResult:
|
||||
"""Check table access permissions"""
|
||||
# Extract table names from query
|
||||
tables = self._extract_table_names(parsed)
|
||||
|
||||
# Check access permissions for each table
|
||||
unauthorized_tables = []
|
||||
for table in tables:
|
||||
# Should call authorization provider to check permissions
|
||||
# Simplified implementation, assume some tables require special permissions
|
||||
if (
|
||||
table.lower() in ["sensitive_data", "admin_logs"]
|
||||
and "admin" not in auth_context.roles
|
||||
):
|
||||
unauthorized_tables.append(table)
|
||||
|
||||
if unauthorized_tables:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message=f"No access to tables: {', '.join(unauthorized_tables)}",
|
||||
risk_level="high",
|
||||
)
|
||||
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
def _extract_table_names(self, parsed: Statement) -> list[str]:
|
||||
"""Extract table names from SQL statement"""
|
||||
tables = []
|
||||
|
||||
# Simplified table name extraction logic
|
||||
tokens = list(parsed.flatten())
|
||||
for i, token in enumerate(tokens):
|
||||
if token.ttype is Keyword and token.value.upper() == "FROM":
|
||||
# Find table name after FROM
|
||||
for j in range(i + 1, len(tokens)):
|
||||
next_token = tokens[j]
|
||||
if next_token.ttype is Name:
|
||||
tables.append(next_token.value)
|
||||
break
|
||||
elif next_token.ttype is Keyword:
|
||||
break
|
||||
|
||||
return tables
|
||||
|
||||
|
||||
class DataMaskingProcessor:
|
||||
"""Data masking processor"""
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.masking_algorithms = self._init_masking_algorithms()
|
||||
self.masking_rules = self._load_masking_rules()
|
||||
|
||||
def _load_masking_rules(self) -> list[MaskingRule]:
|
||||
"""Load data masking rules"""
|
||||
default_rules = [
|
||||
MaskingRule(
|
||||
column_pattern=r".*phone.*|.*mobile.*",
|
||||
algorithm="phone_mask",
|
||||
parameters={"mask_char": "*", "keep_prefix": 3, "keep_suffix": 4},
|
||||
security_level=SecurityLevel.INTERNAL,
|
||||
),
|
||||
MaskingRule(
|
||||
column_pattern=r".*email.*",
|
||||
algorithm="email_mask",
|
||||
parameters={"mask_char": "*"},
|
||||
security_level=SecurityLevel.INTERNAL,
|
||||
),
|
||||
MaskingRule(
|
||||
column_pattern=r".*id_card.*|.*identity.*",
|
||||
algorithm="id_mask",
|
||||
parameters={"mask_char": "*", "keep_prefix": 6, "keep_suffix": 4},
|
||||
security_level=SecurityLevel.CONFIDENTIAL,
|
||||
),
|
||||
]
|
||||
|
||||
# Load custom rules from configuration
|
||||
if hasattr(self.config, 'get'):
|
||||
custom_rules = self.config.get("masking_rules", [])
|
||||
for rule_config in custom_rules:
|
||||
if isinstance(rule_config, dict):
|
||||
# Convert string security level to enum
|
||||
if 'security_level' in rule_config and isinstance(rule_config['security_level'], str):
|
||||
try:
|
||||
rule_config['security_level'] = SecurityLevel(rule_config['security_level'].lower())
|
||||
except ValueError:
|
||||
rule_config['security_level'] = SecurityLevel.INTERNAL
|
||||
default_rules.append(MaskingRule(**rule_config))
|
||||
elif isinstance(rule_config, MaskingRule):
|
||||
default_rules.append(rule_config)
|
||||
|
||||
return default_rules
|
||||
|
||||
def _init_masking_algorithms(self) -> dict[str, callable]:
|
||||
"""Initialize masking algorithms"""
|
||||
return {
|
||||
"phone_mask": self._mask_phone,
|
||||
"email_mask": self._mask_email,
|
||||
"id_mask": self._mask_id_card,
|
||||
"name_mask": self._mask_name,
|
||||
"partial_mask": self._mask_partial,
|
||||
}
|
||||
|
||||
async def process(
|
||||
self, data: list[dict[str, Any]], auth_context: AuthContext
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Process data masking"""
|
||||
if not data:
|
||||
return data
|
||||
|
||||
# Get applicable masking rules
|
||||
applicable_rules = self._get_applicable_rules(auth_context)
|
||||
|
||||
masked_data = []
|
||||
for row in data:
|
||||
masked_row = {}
|
||||
for column, value in row.items():
|
||||
masked_value = await self._apply_masking_rules(
|
||||
column, value, applicable_rules
|
||||
)
|
||||
masked_row[column] = masked_value
|
||||
masked_data.append(masked_row)
|
||||
|
||||
return masked_data
|
||||
|
||||
def _get_applicable_rules(self, auth_context: AuthContext) -> list[MaskingRule]:
|
||||
"""Get applicable masking rules"""
|
||||
applicable_rules = []
|
||||
|
||||
for rule in self.masking_rules:
|
||||
# Decide whether to apply masking rules based on user security level
|
||||
if self._should_apply_rule(rule, auth_context):
|
||||
applicable_rules.append(rule)
|
||||
|
||||
return applicable_rules
|
||||
|
||||
def _should_apply_rule(self, rule: MaskingRule, auth_context: AuthContext) -> bool:
|
||||
"""Determine whether masking rule should be applied"""
|
||||
# Admin users can see original data
|
||||
if "admin" in auth_context.roles:
|
||||
return False
|
||||
|
||||
# Decide based on security level
|
||||
security_hierarchy = {
|
||||
SecurityLevel.PUBLIC: 0,
|
||||
SecurityLevel.INTERNAL: 1,
|
||||
SecurityLevel.CONFIDENTIAL: 2,
|
||||
SecurityLevel.SECRET: 3,
|
||||
}
|
||||
|
||||
user_level = security_hierarchy.get(auth_context.security_level, 0)
|
||||
rule_level = security_hierarchy.get(rule.security_level, 0)
|
||||
|
||||
# Apply masking if user level is less than or equal to rule level
|
||||
return user_level <= rule_level
|
||||
|
||||
async def _apply_masking_rules(
|
||||
self, column: str, value: Any, rules: list[MaskingRule]
|
||||
) -> Any:
|
||||
"""Apply masking rules"""
|
||||
if value is None:
|
||||
return value
|
||||
|
||||
for rule in rules:
|
||||
if re.match(rule.column_pattern, column, re.IGNORECASE):
|
||||
algorithm = self.masking_algorithms.get(rule.algorithm)
|
||||
if algorithm:
|
||||
return algorithm(str(value), rule.parameters)
|
||||
|
||||
return value
|
||||
|
||||
def _mask_phone(self, value: str, params: dict[str, Any]) -> str:
|
||||
"""Phone number masking"""
|
||||
if len(value) < 7:
|
||||
return value
|
||||
|
||||
mask_char = params.get("mask_char", "*")
|
||||
keep_prefix = params.get("keep_prefix", 3)
|
||||
keep_suffix = params.get("keep_suffix", 4)
|
||||
|
||||
if len(value) <= keep_prefix + keep_suffix:
|
||||
return mask_char * len(value)
|
||||
|
||||
prefix = value[:keep_prefix]
|
||||
suffix = value[-keep_suffix:]
|
||||
middle_length = len(value) - keep_prefix - keep_suffix
|
||||
|
||||
return prefix + mask_char * middle_length + suffix
|
||||
|
||||
def _mask_email(self, value: str, params: dict[str, Any]) -> str:
|
||||
"""Email masking"""
|
||||
if "@" not in value:
|
||||
return value
|
||||
|
||||
mask_char = params.get("mask_char", "*")
|
||||
local, domain = value.split("@", 1)
|
||||
|
||||
if len(local) <= 2:
|
||||
masked_local = mask_char * len(local)
|
||||
else:
|
||||
masked_local = local[0] + mask_char * (len(local) - 2) + local[-1]
|
||||
|
||||
return f"{masked_local}@{domain}"
|
||||
|
||||
def _mask_id_card(self, value: str, params: dict[str, Any]) -> str:
|
||||
"""ID card number masking"""
|
||||
if len(value) < 10:
|
||||
return value
|
||||
|
||||
mask_char = params.get("mask_char", "*")
|
||||
keep_prefix = params.get("keep_prefix", 6)
|
||||
keep_suffix = params.get("keep_suffix", 4)
|
||||
|
||||
if len(value) <= keep_prefix + keep_suffix:
|
||||
return mask_char * len(value)
|
||||
|
||||
prefix = value[:keep_prefix]
|
||||
suffix = value[-keep_suffix:]
|
||||
middle_length = len(value) - keep_prefix - keep_suffix
|
||||
|
||||
return prefix + mask_char * middle_length + suffix
|
||||
|
||||
def _mask_name(self, value: str, params: dict[str, Any]) -> str:
|
||||
"""Name masking"""
|
||||
if len(value) <= 1:
|
||||
return value
|
||||
|
||||
mask_char = params.get("mask_char", "*")
|
||||
|
||||
if len(value) == 2:
|
||||
return value[0] + mask_char
|
||||
else:
|
||||
return value[0] + mask_char * (len(value) - 2) + value[-1]
|
||||
|
||||
def _mask_partial(self, value: str, params: dict[str, Any]) -> str:
|
||||
"""Partial masking"""
|
||||
mask_char = params.get("mask_char", "*")
|
||||
mask_ratio = params.get("mask_ratio", 0.5)
|
||||
|
||||
mask_length = int(len(value) * mask_ratio)
|
||||
start_pos = (len(value) - mask_length) // 2
|
||||
|
||||
result = list(value)
|
||||
for i in range(start_pos, start_pos + mask_length):
|
||||
if i < len(result):
|
||||
result[i] = mask_char
|
||||
|
||||
return "".join(result)
|
||||
@@ -1,352 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
SQL Execution Tool
|
||||
|
||||
Responsible for executing SQL queries and handling results
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
import time
|
||||
from typing import Dict, Any
|
||||
import re
|
||||
import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
# Get logger
|
||||
logger = logging.getLogger("doris-mcp.sql-executor")
|
||||
|
||||
# Add environment variable control for whether to perform SQL security checks
|
||||
ENABLE_SQL_SECURITY_CHECK = os.environ.get('ENABLE_SQL_SECURITY_CHECK', 'true').lower() == 'true'
|
||||
|
||||
async def execute_sql_query(ctx) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute SQL query and return results
|
||||
|
||||
Args:
|
||||
ctx: Context object or dictionary containing request parameters
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Execution result
|
||||
"""
|
||||
try:
|
||||
# Support the case where the passed argument is a dictionary
|
||||
if isinstance(ctx, dict) and 'params' in ctx:
|
||||
params = ctx['params']
|
||||
else:
|
||||
params = ctx.params
|
||||
|
||||
sql = params.get("sql")
|
||||
db_name = params.get("db_name", os.getenv("DB_DATABASE", ""))
|
||||
catalog_name = params.get("catalog_name", None) # Add catalog parameter support
|
||||
max_rows = params.get("max_rows", 1000) # Maximum number of rows to return
|
||||
timeout = params.get("timeout", 30) # Timeout in seconds
|
||||
|
||||
if not sql:
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps({
|
||||
"success": False,
|
||||
"error": "Missing SQL parameter",
|
||||
"message": "Please provide the SQL query to execute"
|
||||
}, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# First check SQL security
|
||||
security_result = await _check_sql_security(sql)
|
||||
if not security_result.get("is_safe", False):
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps({
|
||||
"success": False,
|
||||
"error": "SQL security check failed",
|
||||
"message": "Query contains unsafe operations and cannot be executed",
|
||||
"security_issues": security_result.get("security_issues", [])
|
||||
}, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Import database connection tool
|
||||
from doris_mcp_server.utils.db import execute_query
|
||||
|
||||
if not sql:
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps({
|
||||
"success": False,
|
||||
"error": "Missing SQL parameter",
|
||||
"message": "Please provide the SQL query to execute"
|
||||
}, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Ensure SELECT statements include a LIMIT clause
|
||||
sql_lower = sql.lower().strip()
|
||||
if sql_lower.startswith("select") and "limit" not in sql_lower:
|
||||
sql = sql.rstrip(";") + f" LIMIT {max_rows};"
|
||||
|
||||
# Start timer
|
||||
start_time = time.time()
|
||||
|
||||
# Execute query
|
||||
try:
|
||||
# For federation queries, SQL must use three-part naming: catalog_name.db_name.table_name
|
||||
# This is enforced at the tool description level
|
||||
|
||||
result = execute_query(sql, db_name)
|
||||
|
||||
# Calculate execution time
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
# Build return result
|
||||
if isinstance(result, list):
|
||||
# Handle list of query results
|
||||
row_count = len(result)
|
||||
|
||||
# Extract column names
|
||||
if hasattr(result[0], "_fields"):
|
||||
# If it's a named tuple
|
||||
columns = list(result[0]._fields)
|
||||
else:
|
||||
# Otherwise, assume it's a dictionary
|
||||
columns = list(result[0].keys()) if isinstance(result[0], dict) else []
|
||||
|
||||
# Convert results to serializable format
|
||||
data = []
|
||||
for row in result:
|
||||
row_dict = {}
|
||||
if hasattr(row, "_asdict"):
|
||||
# If it's a named tuple
|
||||
row_dict = row._asdict()
|
||||
elif isinstance(row, dict):
|
||||
# If it's a dictionary
|
||||
row_dict = row
|
||||
else:
|
||||
# If it's a list or tuple
|
||||
row_dict = dict(zip(columns, row)) if columns else row
|
||||
|
||||
# Handle special types to make them JSON serializable
|
||||
serialized_row = _serialize_row_data(row_dict)
|
||||
data.append(serialized_row)
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps({
|
||||
"success": True,
|
||||
"sql": sql,
|
||||
"row_count": row_count,
|
||||
"columns": columns,
|
||||
"data": data[:max_rows], # Limit returned rows
|
||||
"execution_time": execution_time,
|
||||
"truncated": row_count > max_rows
|
||||
}, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
# Handle other types of results
|
||||
other_response = {
|
||||
"success": True,
|
||||
"sql": sql,
|
||||
"result": str(result),
|
||||
"execution_time": execution_time
|
||||
}
|
||||
other_response = _serialize_row_data(other_response)
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(other_response, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
except Exception as db_error:
|
||||
error_message = str(db_error)
|
||||
|
||||
# Try to get more detailed error information
|
||||
error_details = {}
|
||||
if "timeout" in error_message.lower():
|
||||
error_details["type"] = "timeout"
|
||||
error_details["suggestion"] = "Query timed out, please optimize SQL or increase timeout"
|
||||
elif "syntax" in error_message.lower():
|
||||
error_details["type"] = "syntax"
|
||||
error_details["suggestion"] = "SQL syntax error, please check syntax"
|
||||
elif "not found" in error_message.lower() or "doesn't exist" in error_message.lower():
|
||||
error_details["type"] = "not_found"
|
||||
error_details["suggestion"] = "Table or column not found, please check table and column names"
|
||||
else:
|
||||
error_details["type"] = "unknown"
|
||||
error_details["suggestion"] = "Please check the SQL statement and try simplifying the query"
|
||||
|
||||
# Create error response
|
||||
error_response = {
|
||||
"success": False,
|
||||
"error": error_message,
|
||||
"error_details": error_details,
|
||||
"sql": sql,
|
||||
"db_name": db_name
|
||||
}
|
||||
|
||||
# Ensure error response is also serializable
|
||||
error_response = _serialize_row_data(error_response)
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(error_response, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute SQL query: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
error_response = {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": "Error occurred while executing SQL query"
|
||||
}
|
||||
|
||||
# Ensure error response is also serializable
|
||||
error_response = _serialize_row_data(error_response)
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(error_response, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Helper function
|
||||
async def _check_sql_security(sql: str) -> Dict[str, Any]:
|
||||
"""Check SQL security"""
|
||||
# If environment variable is set to disable security check, return safe immediately
|
||||
if not ENABLE_SQL_SECURITY_CHECK:
|
||||
return {
|
||||
"is_safe": True,
|
||||
"security_issues": []
|
||||
}
|
||||
|
||||
# Check if SQL contains dangerous operations
|
||||
sql_lower = sql.lower()
|
||||
|
||||
# Check if it's a read-only query type
|
||||
is_read_only = sql_lower.strip().startswith(("select ", "show ", "desc ", "describe ", "explain "))
|
||||
|
||||
# Define list of dangerous operations (checked for both read-only and non-read-only queries)
|
||||
dangerous_operations = [
|
||||
(r'\bdelete\b', "DELETE operation"),
|
||||
(r'\bdrop\b', "DROP TABLE/DATABASE operation"),
|
||||
(r'\btruncate\b', "TRUNCATE TABLE operation"),
|
||||
(r'\bupdate\b', "UPDATE operation"),
|
||||
(r'\binsert\b', "INSERT operation"),
|
||||
(r'\balter\b', "ALTER TABLE structure operation"),
|
||||
(r'\bcreate\b', "CREATE TABLE/DATABASE operation"),
|
||||
(r'\bgrant\b', "GRANT operation"),
|
||||
(r'\brevoke\b', "REVOKE permission operation"),
|
||||
(r'\bexec\b', "EXECUTE stored procedure"),
|
||||
(r'\bxp_', "Extended stored procedure, potential security risk"),
|
||||
(r'\bshutdown\b', "SHUTDOWN database operation"),
|
||||
(r'\binto\s+outfile\b', "Write to file operation"),
|
||||
(r'\bload_file\b', "Load file operation")
|
||||
]
|
||||
|
||||
# Dangerous operations checked only for non-read-only queries
|
||||
non_readonly_operations = []
|
||||
if not is_read_only:
|
||||
non_readonly_operations = [
|
||||
(r'--', "SQL comment, potential SQL injection"),
|
||||
(r'/\*', "SQL block comment, potential SQL injection")
|
||||
]
|
||||
|
||||
# Check if dangerous operations are included
|
||||
security_issues = []
|
||||
|
||||
# Check dangerous operations applicable to all queries
|
||||
for operation, description in dangerous_operations:
|
||||
if re.search(operation, sql_lower):
|
||||
# For specific keywords in read-only queries, differentiate if used as independent operations
|
||||
if is_read_only and operation in [r'\bcreate\b', r'\bdrop\b', r'\bdelete\b', r'\binsert\b', r'\bupdate\b', r'\balter\b']:
|
||||
# Check if used as DDL/DML keyword, e.g., CREATE TABLE, DROP DATABASE
|
||||
pattern = operation + r'\s+(?:table|database|view|index|procedure|function|trigger|event)'
|
||||
if re.search(pattern, sql_lower):
|
||||
security_issues.append({
|
||||
"operation": operation.replace(r'\b', '').replace(r'\s+', ' '),
|
||||
"description": description,
|
||||
"severity": "High"
|
||||
})
|
||||
else:
|
||||
security_issues.append({
|
||||
"operation": operation.replace(r'\b', '').replace(r'\s+', ' '),
|
||||
"description": description,
|
||||
"severity": "High"
|
||||
})
|
||||
|
||||
# Check dangerous operations specific to non-read-only queries
|
||||
for operation, description in non_readonly_operations:
|
||||
if re.search(operation, sql_lower):
|
||||
security_issues.append({
|
||||
"operation": operation.replace(r'\b', '').replace(r'\s+', ' '),
|
||||
"description": description,
|
||||
"severity": "Medium"
|
||||
})
|
||||
|
||||
return {
|
||||
"is_safe": len(security_issues) == 0,
|
||||
"security_issues": security_issues
|
||||
}
|
||||
|
||||
|
||||
def _serialize_row_data(row_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert special types in row data (like date, time, Decimal) to JSON serializable format
|
||||
|
||||
Args:
|
||||
row_data: Row data dictionary
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Processed serializable dictionary
|
||||
"""
|
||||
serialized_data = {}
|
||||
for key, value in row_data.items():
|
||||
if value is None:
|
||||
serialized_data[key] = None
|
||||
elif isinstance(value, (datetime.date, datetime.datetime)):
|
||||
# Convert date and time types to ISO format string
|
||||
serialized_data[key] = value.isoformat()
|
||||
elif isinstance(value, Decimal):
|
||||
# Convert Decimal type to float
|
||||
serialized_data[key] = float(value)
|
||||
elif isinstance(value, (list, tuple)):
|
||||
# Recursively process elements in list or tuple
|
||||
serialized_data[key] = [
|
||||
_serialize_row_data(item) if isinstance(item, dict) else item
|
||||
for item in value
|
||||
]
|
||||
elif isinstance(value, dict):
|
||||
# Recursively process nested dictionaries
|
||||
serialized_data[key] = _serialize_row_data(value)
|
||||
else:
|
||||
serialized_data[key] = value
|
||||
return serialized_data
|
||||
61
env.example
61
env.example
@@ -1,61 +0,0 @@
|
||||
# Doris MCP Server Example Configuration File
|
||||
# Copy this file to .env and modify it for your configuration
|
||||
# Comment out unused configuration items with #
|
||||
|
||||
#===============================
|
||||
# Database Configuration
|
||||
#===============================
|
||||
# Database connection information
|
||||
DB_HOST=localhost
|
||||
DB_PORT=9030
|
||||
DB_WEB_PORT=8030
|
||||
DB_USER=root
|
||||
DB_PASSWORD=
|
||||
# Default database
|
||||
DB_DATABASE=test
|
||||
|
||||
# Multi-database support
|
||||
# ENABLE_MULTI_DATABASE=false
|
||||
# List of multi-database names (different databases using the same connection), JSON array format
|
||||
# MULTI_DATABASE_NAMES=["test", "sales", "user", "product"]
|
||||
|
||||
#===============================
|
||||
# Table Hierarchy Matching Configuration
|
||||
#===============================
|
||||
# Whether to enable table hierarchy priority matching
|
||||
# ENABLE_TABLE_HIERARCHY_MATCHING=false
|
||||
# Table hierarchy matching regular expressions, sorted by priority from high to low, JSON format
|
||||
# TABLE_HIERARCHY_PATTERNS=["^ads_.*$","^dim_.*$","^dws_.*$","^dwd_.*$","^ods_.*$","^tmp_.*$","^stg_.*$","^.*$"]
|
||||
# Table hierarchy matching timeout (seconds)
|
||||
# TABLE_HIERARCHY_TIMEOUT=10
|
||||
|
||||
# List of excluded databases, these databases will not be scanned and metadata processed, JSON format
|
||||
# EXCLUDED_DATABASES=["information_schema", "mysql", "performance_schema", "sys", "doris_metadata"]
|
||||
|
||||
|
||||
#===============================
|
||||
# Server Configuration
|
||||
#===============================
|
||||
SERVER_HOST=0.0.0.0
|
||||
SERVER_PORT=3000
|
||||
# LOG_LEVEL=INFO # Defined below
|
||||
|
||||
# Cache Configuration
|
||||
CACHE_TTL=86400
|
||||
|
||||
#===============================
|
||||
# Logging Configuration
|
||||
#===============================
|
||||
# Log directory path
|
||||
LOG_DIR=logs
|
||||
# Log file prefix
|
||||
LOG_PREFIX=doris_mcp
|
||||
# Log level: DEBUG, INFO, WARNING, ERROR, CRITICAL
|
||||
LOG_LEVEL=INFO
|
||||
# Log retention days
|
||||
LOG_MAX_DAYS=30
|
||||
# Whether to enable console log output (should be set to false when running as a service)
|
||||
CONSOLE_LOGGING=false
|
||||
|
||||
# CORS Configuration
|
||||
ALLOWED_ORIGINS=*
|
||||
153
generate_requirements.py
Normal file
153
generate_requirements.py
Normal file
@@ -0,0 +1,153 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Generate requirements.txt from pyproject.toml
|
||||
Ensures consistency in dependency management
|
||||
"""
|
||||
|
||||
import re
|
||||
import toml
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
def generate_requirements():
|
||||
"""Generate requirements.txt from pyproject.toml"""
|
||||
|
||||
# Read pyproject.toml
|
||||
try:
|
||||
with open('pyproject.toml', 'r', encoding='utf-8') as f:
|
||||
pyproject = toml.load(f)
|
||||
except FileNotFoundError:
|
||||
print("❌ pyproject.toml not found")
|
||||
sys.exit(1)
|
||||
|
||||
# Get main dependencies
|
||||
dependencies = pyproject.get('project', {}).get('dependencies', [])
|
||||
|
||||
# Get development dependencies
|
||||
dev_dependencies = pyproject.get('project', {}).get('optional-dependencies', {}).get('dev', [])
|
||||
|
||||
# Generate requirements.txt
|
||||
requirements_content = []
|
||||
|
||||
# Add header comment
|
||||
requirements_content.append("# Main dependencies - auto-generated from pyproject.toml")
|
||||
requirements_content.append("# Do not edit this file manually, use 'python generate_requirements.py' to regenerate")
|
||||
requirements_content.append("")
|
||||
|
||||
# Add main dependencies
|
||||
requirements_content.append("# === Core Dependencies ===")
|
||||
for dep in dependencies:
|
||||
requirements_content.append(dep)
|
||||
|
||||
requirements_content.append("")
|
||||
requirements_content.append("# === Development Dependencies ===")
|
||||
for dep in dev_dependencies:
|
||||
requirements_content.append(dep)
|
||||
|
||||
# Write requirements.txt
|
||||
with open('requirements.txt', 'w', encoding='utf-8') as f:
|
||||
f.write('\n'.join(requirements_content))
|
||||
|
||||
print(f"✅ Generated requirements.txt")
|
||||
print(f" Main dependencies: {len(dependencies)} items")
|
||||
print(f" Dev dependencies: {len(dev_dependencies)} items")
|
||||
|
||||
def generate_requirements_dev():
|
||||
"""Generate requirements-dev.txt (only development dependencies)"""
|
||||
|
||||
pyproject_path = Path("pyproject.toml")
|
||||
if not pyproject_path.exists():
|
||||
print("Error: pyproject.toml not found")
|
||||
return
|
||||
|
||||
with open(pyproject_path, 'r', encoding='utf-8') as f:
|
||||
data = toml.load(f)
|
||||
|
||||
# Get development dependencies
|
||||
dev_dependencies = data.get('project', {}).get('optional-dependencies', {}).get('dev', [])
|
||||
|
||||
# Generate requirements-dev.txt
|
||||
content = []
|
||||
content.append("# Development dependencies - auto-generated from pyproject.toml")
|
||||
content.append("# Installation command: pip install -r requirements-dev.txt")
|
||||
content.append("")
|
||||
|
||||
for dep in dev_dependencies:
|
||||
content.append(dep)
|
||||
|
||||
# Write file
|
||||
dev_requirements_path = Path("requirements-dev.txt")
|
||||
with open(dev_requirements_path, 'w', encoding='utf-8') as f:
|
||||
f.write('\n'.join(content))
|
||||
|
||||
print(f"✅ Generated requirements-dev.txt ({len(dev_dependencies)} development dependencies)")
|
||||
|
||||
def verify_consistency():
|
||||
"""Verify dependency consistency"""
|
||||
|
||||
def extract_packages_from_requirements():
|
||||
"""Extract package names from requirements.txt"""
|
||||
packages = set()
|
||||
try:
|
||||
with open('requirements.txt', 'r') as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith('#'):
|
||||
# Extract package name (remove version)
|
||||
pkg = re.split(r'[>=<\[]', line)[0].strip()
|
||||
if pkg:
|
||||
packages.add(pkg.lower())
|
||||
except FileNotFoundError:
|
||||
print("requirements.txt not found")
|
||||
return packages
|
||||
|
||||
def extract_packages_from_pyproject():
|
||||
"""Extract package names from pyproject.toml"""
|
||||
packages = set()
|
||||
try:
|
||||
with open('pyproject.toml', 'r') as f:
|
||||
data = toml.load(f)
|
||||
|
||||
# Get main dependencies
|
||||
dependencies = data.get('project', {}).get('dependencies', [])
|
||||
for dep in dependencies:
|
||||
pkg = re.split(r'[>=<\[]', dep)[0].strip()
|
||||
if pkg:
|
||||
packages.add(pkg.lower())
|
||||
|
||||
# Get development dependencies
|
||||
dev_deps = data.get('project', {}).get('optional-dependencies', {}).get('dev', [])
|
||||
for dep in dev_deps:
|
||||
pkg = re.split(r'[>=<\[]', dep)[0].strip()
|
||||
if pkg:
|
||||
packages.add(pkg.lower())
|
||||
|
||||
except FileNotFoundError:
|
||||
print("pyproject.toml not found")
|
||||
return packages
|
||||
|
||||
req_packages = extract_packages_from_requirements()
|
||||
toml_packages = extract_packages_from_pyproject()
|
||||
|
||||
only_in_req = req_packages - toml_packages
|
||||
only_in_toml = toml_packages - req_packages
|
||||
|
||||
if len(only_in_req) == 0 and len(only_in_toml) == 0:
|
||||
print("✅ Dependency consistency verification passed!")
|
||||
return True
|
||||
else:
|
||||
print("⚠️ Found dependency inconsistencies:")
|
||||
if only_in_req:
|
||||
print(f" Only in requirements.txt: {sorted(only_in_req)}")
|
||||
if only_in_toml:
|
||||
print(f" Only in pyproject.toml: {sorted(only_in_toml)}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("🔄 Generating requirements.txt from pyproject.toml...")
|
||||
generate_requirements()
|
||||
generate_requirements_dev()
|
||||
print()
|
||||
print("🔍 Verifying dependency consistency...")
|
||||
verify_consistency()
|
||||
print("✨ Completed!")
|
||||
294
pyproject.toml
294
pyproject.toml
@@ -1,44 +1,280 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=42", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "doris-mcp"
|
||||
version = "0.2.0"
|
||||
description = "Doris MCP Server for Cursor integration"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
license = "Apache-2.0"
|
||||
name = "doris-mcp-server"
|
||||
version = "0.3.0"
|
||||
description = "Enterprise-grade Model Context Protocol (MCP) server implementation for Apache Doris"
|
||||
authors = [
|
||||
{name = "Doris MCP Team - Yijia Su"}
|
||||
{name = "Yijia Su", email = "freeoneplus@apache.org"}
|
||||
]
|
||||
readme = "README.md"
|
||||
license = {text = "Apache-2.0"}
|
||||
requires-python = ">=3.12"
|
||||
keywords = ["doris", "mcp", "model-context-protocol", "database", "analytics"]
|
||||
classifiers = [
|
||||
"Development Status :: 4 - Beta",
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Operating System :: OS Independent",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Topic :: Database",
|
||||
"Topic :: Software Development :: Libraries :: Python Modules",
|
||||
"Topic :: Scientific/Engineering :: Information Analysis",
|
||||
]
|
||||
|
||||
dependencies = [
|
||||
"mcp[cli]>=1.0.0",
|
||||
"pymysql>=1.0.2",
|
||||
"pandas>=1.5.0",
|
||||
"numpy>=1.20.0",
|
||||
"scikit-learn>=1.0.0",
|
||||
"python-dotenv>=0.19.0",
|
||||
"pydantic>=1.10.0",
|
||||
"requests>=2.28.0",
|
||||
"openai>=1.66.3",
|
||||
"fastapi>=0.95.0",
|
||||
"uvicorn>=0.21.0",
|
||||
"simplejson>=3.17.0"
|
||||
# Core MCP dependencies
|
||||
"mcp>=1.0.0",
|
||||
# Database drivers
|
||||
"aiomysql>=0.2.0",
|
||||
"PyMySQL>=1.1.0",
|
||||
# Async and utility libraries
|
||||
"asyncio-mqtt>=0.16.0",
|
||||
"aiofiles>=23.0.0",
|
||||
"aiohttp>=3.9.0",
|
||||
"aioredis>=2.0.0",
|
||||
# Data processing
|
||||
"pandas>=2.0.0",
|
||||
"numpy>=1.24.0",
|
||||
"python-dateutil>=2.8.0",
|
||||
"orjson>=3.9.0",
|
||||
# Configuration and serialization
|
||||
"pydantic>=2.5.0",
|
||||
"pydantic-settings>=2.1.0",
|
||||
"toml>=0.10.0",
|
||||
"PyYAML>=6.0.0",
|
||||
"python-dotenv>=1.0.0",
|
||||
# Security and authentication
|
||||
"cryptography>=41.0.0",
|
||||
"PyJWT>=2.8.0",
|
||||
"passlib[bcrypt]>=1.7.0",
|
||||
"bcrypt>=4.1.0",
|
||||
"sqlparse>=0.4.4",
|
||||
"python-jose[cryptography]>=3.3.0",
|
||||
"python-multipart>=0.0.6",
|
||||
# Monitoring and logging
|
||||
"prometheus-client>=0.19.0",
|
||||
"structlog>=23.2.0",
|
||||
"rich>=13.7.0",
|
||||
# HTTP and networking
|
||||
"httpx>=0.26.0",
|
||||
"websockets>=12.0",
|
||||
"uvicorn[standard]>=0.25.0",
|
||||
"fastapi>=0.108.0",
|
||||
"starlette>=0.27.0",
|
||||
# Development utilities
|
||||
"click>=8.1.0",
|
||||
"typer>=0.9.0",
|
||||
"requests>=2.31.0",
|
||||
"tqdm>=4.66.0",
|
||||
"pytest>=8.4.0",
|
||||
"pytest-asyncio>=1.0.0",
|
||||
"pytest-cov>=6.1.1",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=7.0.0",
|
||||
"black>=23.0.0",
|
||||
"isort>=5.12.0"
|
||||
# Testing
|
||||
"pytest>=7.4.0",
|
||||
"pytest-asyncio>=0.23.0",
|
||||
"pytest-cov>=4.1.0",
|
||||
"pytest-mock>=3.12.0",
|
||||
"pytest-xdist>=3.5.0",
|
||||
|
||||
# Code quality
|
||||
"ruff>=0.1.0",
|
||||
"black>=23.12.0",
|
||||
"isort>=5.13.0",
|
||||
"flake8>=7.0.0",
|
||||
"mypy>=1.8.0",
|
||||
"bandit>=1.7.0",
|
||||
"safety>=2.3.0",
|
||||
|
||||
# Documentation
|
||||
"sphinx>=7.2.0",
|
||||
"sphinx-rtd-theme>=2.0.0",
|
||||
"myst-parser>=2.0.0",
|
||||
|
||||
# Development tools
|
||||
"pre-commit>=3.6.0",
|
||||
"tox>=4.11.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
# Updated entry point for stdio mode back to mcp_core
|
||||
doris-mcp = "doris_mcp_server.mcp_core:run_stdio"
|
||||
docs = [
|
||||
"sphinx>=7.2.0",
|
||||
"sphinx-rtd-theme>=2.0.0",
|
||||
"myst-parser>=2.0.0",
|
||||
"sphinx-autoapi>=3.0.0",
|
||||
]
|
||||
|
||||
[tool.setuptools]
|
||||
# Explicitly list the package found in the root directory
|
||||
performance = [
|
||||
"uvloop>=0.19.0", # High-performance event loop
|
||||
"orjson>=3.9.0", # Fast JSON serialization
|
||||
"cchardet>=2.1.0", # Fast character encoding detection
|
||||
]
|
||||
|
||||
monitoring = [
|
||||
"prometheus-client>=0.19.0",
|
||||
"grafana-client>=3.5.0",
|
||||
"jaeger-client>=4.8.0",
|
||||
"opentelemetry-api>=1.21.0",
|
||||
"opentelemetry-sdk>=1.21.0",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/apache/doris-mcp-server"
|
||||
Documentation = "https://doris.apache.org/docs/"
|
||||
Repository = "https://github.com/apache/doris-mcp-server.git"
|
||||
Issues = "https://github.com/apache/doris-mcp-server/issues"
|
||||
Changelog = "https://github.com/apache/doris-mcp-server/blob/main/CHANGELOG.md"
|
||||
|
||||
[project.scripts]
|
||||
doris-mcp-server = "doris_mcp_server.main:main_sync"
|
||||
doris-mcp-client = "doris_mcp_server.client:main"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["doris_mcp_server"]
|
||||
|
||||
[tool.hatch.build.targets.sdist]
|
||||
include = [
|
||||
"/doris_mcp_server",
|
||||
"/README.md",
|
||||
"/LICENSE",
|
||||
]
|
||||
|
||||
# Black configuration
|
||||
[tool.black]
|
||||
line-length = 88
|
||||
target-version = ['py310', 'py311', 'py312']
|
||||
include = '\.pyi?$'
|
||||
extend-exclude = '''
|
||||
/(
|
||||
# directories
|
||||
\.eggs
|
||||
| \.git
|
||||
| \.hg
|
||||
| \.mypy_cache
|
||||
| \.tox
|
||||
| \.venv
|
||||
| build
|
||||
| dist
|
||||
)/
|
||||
'''
|
||||
|
||||
# isort configuration
|
||||
[tool.isort]
|
||||
profile = "black"
|
||||
multi_line_output = 3
|
||||
line_length = 88
|
||||
known_first_party = ["doris_mcp_server"]
|
||||
known_third_party = ["mcp", "aiomysql", "pydantic", "click"]
|
||||
|
||||
# MyPy configuration
|
||||
[tool.mypy]
|
||||
python_version = "3.12"
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
disallow_untyped_defs = true
|
||||
disallow_incomplete_defs = true
|
||||
check_untyped_defs = true
|
||||
disallow_untyped_decorators = true
|
||||
no_implicit_optional = true
|
||||
warn_redundant_casts = true
|
||||
warn_unused_ignores = true
|
||||
warn_no_return = true
|
||||
warn_unreachable = true
|
||||
strict_equality = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = [
|
||||
"aiomysql.*",
|
||||
"pymysql.*",
|
||||
"prometheus_client.*",
|
||||
]
|
||||
ignore_missing_imports = true
|
||||
|
||||
# Pytest configuration
|
||||
[tool.pytest.ini_options]
|
||||
minversion = "7.0"
|
||||
addopts = [
|
||||
"--strict-markers",
|
||||
"--strict-config",
|
||||
"--cov=doris_mcp_server",
|
||||
"--cov-report=term-missing",
|
||||
"--cov-report=html",
|
||||
"--cov-report=xml",
|
||||
]
|
||||
testpaths = ["tests"]
|
||||
python_files = ["test_*.py", "*_test.py"]
|
||||
python_classes = ["Test*"]
|
||||
python_functions = ["test_*"]
|
||||
markers = [
|
||||
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
|
||||
"integration: marks tests as integration tests",
|
||||
"unit: marks tests as unit tests",
|
||||
]
|
||||
asyncio_mode = "auto"
|
||||
|
||||
# Coverage configuration
|
||||
[tool.coverage.run]
|
||||
source = ["doris_mcp_server"]
|
||||
omit = [
|
||||
"*/tests/*",
|
||||
"*/test_*",
|
||||
"*/__pycache__/*",
|
||||
"*/venv/*",
|
||||
"*/.*",
|
||||
]
|
||||
|
||||
[tool.coverage.report]
|
||||
exclude_lines = [
|
||||
"pragma: no cover",
|
||||
"def __repr__",
|
||||
"if self.debug:",
|
||||
"if settings.DEBUG",
|
||||
"raise AssertionError",
|
||||
"raise NotImplementedError",
|
||||
"if 0:",
|
||||
"if __name__ == .__main__.:",
|
||||
"class .*\\bProtocol\\):",
|
||||
"@(abc\\.)?abstractmethod",
|
||||
]
|
||||
|
||||
# Bandit security linter configuration
|
||||
[tool.bandit]
|
||||
exclude_dirs = ["tests", "build", "dist"]
|
||||
tests = ["B201", "B301"]
|
||||
skips = ["B101", "B601"]
|
||||
|
||||
# Ruff configuration (modern Python linter)
|
||||
[tool.ruff]
|
||||
target-version = "py312"
|
||||
line-length = 88
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"E", # pycodestyle errors
|
||||
"W", # pycodestyle warnings
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
"B", # flake8-bugbear
|
||||
"C4", # flake8-comprehensions
|
||||
"UP", # pyupgrade
|
||||
]
|
||||
ignore = [
|
||||
"E501", # line too long, handled by black
|
||||
"B008", # do not perform function calls in argument defaults
|
||||
"C901", # too complex
|
||||
"B904", # raise from err
|
||||
]
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"__init__.py" = ["F401"]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"ruff>=0.11.13",
|
||||
]
|
||||
|
||||
20
requirements-dev.txt
Normal file
20
requirements-dev.txt
Normal file
@@ -0,0 +1,20 @@
|
||||
# 开发依赖 - 从 pyproject.toml 自动生成
|
||||
# 安装命令: pip install -r requirements-dev.txt
|
||||
|
||||
pytest>=7.4.0
|
||||
pytest-asyncio>=0.23.0
|
||||
pytest-cov>=4.1.0
|
||||
pytest-mock>=3.12.0
|
||||
pytest-xdist>=3.5.0
|
||||
ruff>=0.1.0
|
||||
black>=23.12.0
|
||||
isort>=5.13.0
|
||||
flake8>=7.0.0
|
||||
mypy>=1.8.0
|
||||
bandit>=1.7.0
|
||||
safety>=2.3.0
|
||||
sphinx>=7.2.0
|
||||
sphinx-rtd-theme>=2.0.0
|
||||
myst-parser>=2.0.0
|
||||
pre-commit>=3.6.0
|
||||
tox>=4.11.0
|
||||
@@ -1,15 +1,58 @@
|
||||
mcp[cli]>=1.0.0
|
||||
pymysql>=1.0.2
|
||||
pandas>=1.5.0
|
||||
numpy>=1.20.0
|
||||
scikit-learn>=1.0.0
|
||||
python-dotenv>=0.19.0
|
||||
pydantic>=1.10.0
|
||||
requests>=2.28.0
|
||||
openai>=1.66.3
|
||||
uv>=0.6.8
|
||||
psutil>=5.9.0
|
||||
simplejson>=3.17.0
|
||||
fastapi>=0.115.4
|
||||
uvicorn>=0.29.0
|
||||
sse-starlette>=1.6.5
|
||||
# 主要依赖 - 从 pyproject.toml 自动生成
|
||||
# 请不要手动编辑此文件,使用 python generate_requirements.py 重新生成
|
||||
|
||||
# === 核心依赖 ===
|
||||
mcp>=1.0.0
|
||||
aiomysql>=0.2.0
|
||||
PyMySQL>=1.1.0
|
||||
asyncio-mqtt>=0.16.0
|
||||
aiofiles>=23.0.0
|
||||
aiohttp>=3.9.0
|
||||
aioredis>=2.0.0
|
||||
pandas>=2.0.0
|
||||
numpy>=1.24.0
|
||||
python-dateutil>=2.8.0
|
||||
orjson>=3.9.0
|
||||
pydantic>=2.5.0
|
||||
pydantic-settings>=2.1.0
|
||||
toml>=0.10.0
|
||||
PyYAML>=6.0.0
|
||||
python-dotenv>=1.0.0
|
||||
cryptography>=41.0.0
|
||||
PyJWT>=2.8.0
|
||||
passlib[bcrypt]>=1.7.0
|
||||
bcrypt>=4.1.0
|
||||
sqlparse>=0.4.4
|
||||
python-jose[cryptography]>=3.3.0
|
||||
python-multipart>=0.0.6
|
||||
prometheus-client>=0.19.0
|
||||
structlog>=23.2.0
|
||||
rich>=13.7.0
|
||||
httpx>=0.26.0
|
||||
websockets>=12.0
|
||||
uvicorn[standard]>=0.25.0
|
||||
fastapi>=0.108.0
|
||||
starlette>=0.27.0
|
||||
click>=8.1.0
|
||||
typer>=0.9.0
|
||||
requests>=2.31.0
|
||||
tqdm>=4.66.0
|
||||
|
||||
# === 开发依赖 ===
|
||||
pytest>=7.4.0
|
||||
pytest-asyncio>=0.23.0
|
||||
pytest-cov>=4.1.0
|
||||
pytest-mock>=3.12.0
|
||||
pytest-xdist>=3.5.0
|
||||
ruff>=0.1.0
|
||||
black>=23.12.0
|
||||
isort>=5.13.0
|
||||
flake8>=7.0.0
|
||||
mypy>=1.8.0
|
||||
bandit>=1.7.0
|
||||
safety>=2.3.0
|
||||
sphinx>=7.2.0
|
||||
sphinx-rtd-theme>=2.0.0
|
||||
myst-parser>=2.0.0
|
||||
pre-commit>=3.6.0
|
||||
tox>=4.11.0
|
||||
@@ -1,167 +0,0 @@
|
||||
#!/bin/bash
|
||||
# Doris MCP Server Restart Script
|
||||
# Detects port and process usage, terminates existing processes, then restarts the server
|
||||
|
||||
# Set terminal colors
|
||||
GREEN='\033[0;32m'
|
||||
RED='\033[0;31m'
|
||||
YELLOW='\033[0;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# Server configuration
|
||||
MCP_PORT=3000
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
START_SCRIPT="${SCRIPT_DIR}/start_server.sh"
|
||||
|
||||
echo -e "${GREEN}========== Doris MCP Server Restart Script ==========${NC}"
|
||||
|
||||
# Check if start_server.sh exists
|
||||
if [ ! -f "$START_SCRIPT" ]; then
|
||||
echo -e "${RED}Error: Start script $START_SCRIPT does not exist${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check port usage
|
||||
check_port() {
|
||||
echo -e "${YELLOW}Checking port $MCP_PORT usage...${NC}"
|
||||
PORT_PID=$(lsof -ti:$MCP_PORT)
|
||||
if [ -n "$PORT_PID" ]; then
|
||||
echo -e "${YELLOW}Port $MCP_PORT is used by process $PORT_PID${NC}"
|
||||
return 0
|
||||
else
|
||||
echo -e "${GREEN}Port $MCP_PORT is not in use${NC}"
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Check if Python process is running
|
||||
check_python_process() {
|
||||
echo -e "${YELLOW}Checking if Python process is running doris_mcp_server.main...${NC}"
|
||||
PYTHON_PID=$(ps aux | grep "[p]ython.*-m doris_mcp_server.main --sse" | awk '{print $2}')
|
||||
if [ -n "$PYTHON_PID" ]; then
|
||||
echo -e "${YELLOW}Detected Python process $PYTHON_PID running doris_mcp_server.main --sse${NC}"
|
||||
return 0
|
||||
else
|
||||
echo -e "${GREEN}No Python process running doris_mcp_server.main detected${NC}"
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Kill process
|
||||
kill_process() {
|
||||
local PID=$1
|
||||
echo -e "${YELLOW}Terminating process $PID...${NC}"
|
||||
kill $PID 2>/dev/null
|
||||
|
||||
# Wait for process termination
|
||||
for i in {1..5}; do
|
||||
if ! ps -p $PID > /dev/null 2>&1; then
|
||||
echo -e "${GREEN}Process $PID has terminated${NC}"
|
||||
return 0
|
||||
fi
|
||||
echo -e "${YELLOW}Waiting for process termination (${i}/5)...${NC}"
|
||||
sleep 1
|
||||
done
|
||||
|
||||
# If process is still running, force kill
|
||||
if ps -p $PID > /dev/null 2>&1; then
|
||||
echo -e "${YELLOW}Process still running, force killing process $PID...${NC}"
|
||||
kill -9 $PID 2>/dev/null
|
||||
sleep 1
|
||||
if ! ps -p $PID > /dev/null 2>&1; then
|
||||
echo -e "${GREEN}Process $PID has been force killed${NC}"
|
||||
return 0
|
||||
else
|
||||
echo -e "${RED}Failed to terminate process $PID${NC}"
|
||||
return 1
|
||||
fi
|
||||
fi
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
# Clean up all process and port usage
|
||||
cleanup() {
|
||||
# Check and terminate process using the port
|
||||
check_port
|
||||
if [ $? -eq 0 ]; then
|
||||
kill_process $PORT_PID
|
||||
fi
|
||||
|
||||
# Check and terminate Python process
|
||||
check_python_process
|
||||
if [ $? -eq 0 ]; then
|
||||
kill_process $PYTHON_PID
|
||||
fi
|
||||
|
||||
# Check port usage again to ensure it's released
|
||||
check_port
|
||||
if [ $? -eq 0 ]; then
|
||||
echo -e "${RED}Warning: Failed to release port $MCP_PORT, please check the process manually${NC}"
|
||||
return 1
|
||||
fi
|
||||
|
||||
# Clean up possible Python bytecode cache
|
||||
echo -e "${YELLOW}Cleaning Python bytecode cache...${NC}"
|
||||
find "$SCRIPT_DIR" -name "*.pyc" -delete
|
||||
find "$SCRIPT_DIR" -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
||||
|
||||
echo -e "${GREEN}Cleanup complete${NC}"
|
||||
return 0
|
||||
}
|
||||
|
||||
# Start server
|
||||
start_server() {
|
||||
echo -e "${YELLOW}Stopping existing Doris MCP server process (SSE mode)...${NC}"
|
||||
pkill -f "python -m doris_mcp_server.main --sse" || true
|
||||
|
||||
# Wait for the process to stop completely
|
||||
sleep 2
|
||||
|
||||
echo -e "${YELLOW}Starting Doris MCP server (SSE mode)...${NC}"
|
||||
nohup python -m doris_mcp_server.main --sse >> logs/doris_mcp.log 2>> logs/doris_mcp.error &
|
||||
|
||||
# Wait for server startup
|
||||
sleep 5
|
||||
|
||||
echo -e "${YELLOW}Checking if the server started successfully (SSE mode)...${NC}"
|
||||
if pgrep -f "python -m doris_mcp_server.main --sse" > /dev/null; then
|
||||
echo -e "${GREEN}Doris MCP server (SSE mode) started successfully${NC}"
|
||||
echo -e "${GREEN}Service address: http://localhost:$MCP_PORT/${NC}"
|
||||
return 0
|
||||
else
|
||||
echo -e "${RED}Server startup failed, please check the log files${NC}"
|
||||
tail -n 20 logs/doris_mcp.error
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
# Main function
|
||||
main() {
|
||||
echo -e "${YELLOW}Starting Doris MCP server restart...${NC}"
|
||||
|
||||
# Clean up existing processes
|
||||
cleanup
|
||||
if [ $? -ne 0 ]; then
|
||||
echo -e "${RED}Failed to clean up existing processes, restart aborted${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Wait for port to be fully released
|
||||
sleep 2
|
||||
|
||||
# Start the server
|
||||
start_server
|
||||
if [ $? -ne 0 ]; then
|
||||
echo -e "${RED}Server startup failed${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo -e "${GREEN}Server restarted successfully${NC}"
|
||||
echo -e "${YELLOW}Service running at: http://localhost:$MCP_PORT${NC}"
|
||||
echo -e "${YELLOW}Health check: http://localhost:$MCP_PORT/health${NC}"
|
||||
echo -e "${YELLOW}SSE test endpoint: http://localhost:$MCP_PORT/sse"
|
||||
}
|
||||
|
||||
# Run main function
|
||||
main
|
||||
118
start_server.sh
118
start_server.sh
@@ -1,99 +1,81 @@
|
||||
#!/bin/bash
|
||||
# Doris MCP Server Start Script
|
||||
# Ensures the service runs in SSE mode
|
||||
# Doris MCP Server Start Script (Streamable HTTP Mode)
|
||||
# Ensures the service runs in Streamable HTTP mode for web-based MCP clients
|
||||
|
||||
# Set colors
|
||||
GREEN='\033[0;32m'
|
||||
CYAN='\033[0;36m'
|
||||
YELLOW='\033[1;33m'
|
||||
RED='\033[0;31m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
echo -e "${GREEN}========== Doris MCP Server Start Script ==========${NC}"
|
||||
echo -e "${GREEN}========== Doris MCP Server Start Script (HTTP Mode) ==========${NC}"
|
||||
|
||||
# Check virtual environment
|
||||
if [ -d "venv" ]; then
|
||||
echo -e "${CYAN}Virtual environment found, activating...${NC}" # Found virtual environment, activating...
|
||||
if [ -d ".venv" ]; then
|
||||
echo -e "${CYAN}Virtual environment found, activating...${NC}"
|
||||
source .venv/bin/activate
|
||||
elif [ -d "venv" ]; then
|
||||
echo -e "${CYAN}Virtual environment found, activating...${NC}"
|
||||
source venv/bin/activate
|
||||
else
|
||||
echo -e "${YELLOW}Warning: No virtual environment found${NC}"
|
||||
fi
|
||||
|
||||
# Clean cache files
|
||||
echo -e "${CYAN}Cleaning cache files...${NC}" # Cleaning cache files...
|
||||
echo -e "${CYAN}Cleaning Python cache files...${NC}" # Cleaning Python cache files...
|
||||
find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
||||
echo -e "${CYAN}Cleaning temporary files...${NC}" # Cleaning temporary files...
|
||||
echo -e "${CYAN}Cleaning cache files...${NC}"
|
||||
echo -e "${CYAN}Cleaning Python cache files...${NC}"
|
||||
find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
||||
find . -type f -name "*.pyc" -delete 2>/dev/null || true
|
||||
echo -e "${CYAN}Cleaning temporary files...${NC}"
|
||||
rm -rf .pytest_cache 2>/dev/null || true
|
||||
echo -e "${CYAN}Cleaning log files...${NC}" # Cleaning log files...
|
||||
find ./log -type f -name "*.log" -delete 2>/dev/null || true
|
||||
echo -e "${CYAN}Cleaning log files...${NC}"
|
||||
find ./logs -type f -name "*.log" -delete 2>/dev/null || true
|
||||
|
||||
# Create necessary directories
|
||||
mkdir -p logs
|
||||
mkdir -p tmp
|
||||
|
||||
# Reload environment variables
|
||||
if [ -f .env ]; then
|
||||
echo -e "${CYAN}Loading environment variables from .env file...${NC}" # Loading environment variables from .env file...
|
||||
echo -e "${CYAN}Loading environment variables from .env file...${NC}"
|
||||
set -a # automatically export all variables
|
||||
source .env
|
||||
set +a # stop automatically exporting
|
||||
else
|
||||
echo -e "${YELLOW}Warning: .env file not found${NC}"
|
||||
fi
|
||||
|
||||
# Output key environment variables before starting
|
||||
echo -e "${CYAN}Database settings:${NC}" # Database settings:
|
||||
echo "DB_HOST=${DB_HOST}"
|
||||
echo "DB_PORT=${DB_PORT}"
|
||||
echo "DB_DATABASE=${DB_DATABASE}"
|
||||
echo "FORCE_REFRESH_METADATA=${FORCE_REFRESH_METADATA}"
|
||||
|
||||
# Start the server (using -m and new package path)
|
||||
python -m doris_mcp_server.main --sse
|
||||
|
||||
# Clean cache files (This section seems redundant and possibly misplaced after the server starts)
|
||||
echo -e "${YELLOW}Cleaning cache files...${NC}" # Cleaning cache files...
|
||||
|
||||
# Backend cache cleanup
|
||||
echo -e "${GREEN}Cleaning Python cache files...${NC}" # Cleaning Python cache files...
|
||||
find . -type d -name "__pycache__" -exec rm -rf {} +
|
||||
find . -type f -name "*.pyc" -delete
|
||||
rm -rf ./.pytest_cache
|
||||
|
||||
# Clean temporary files
|
||||
echo -e "${GREEN}Cleaning temporary files...${NC}" # Cleaning temporary files...
|
||||
rm -rf ./tmp
|
||||
mkdir -p tmp
|
||||
|
||||
# Clean log files
|
||||
echo -e "${GREEN}Cleaning log files...${NC}" # Cleaning log files...
|
||||
rm -rf ./logs/*.log
|
||||
mkdir -p logs
|
||||
|
||||
# Set environment variables, force SSE mode (This section also seems redundant if variables are set in .env and the command uses --sse)
|
||||
export MCP_PORT=3000
|
||||
export ALLOWED_ORIGINS="*"
|
||||
export LOG_LEVEL="info"
|
||||
export MCP_ALLOW_CREDENTIALS="false"
|
||||
# Set HTTP-specific environment variables
|
||||
export MCP_TRANSPORT_TYPE="http"
|
||||
export MCP_HOST="${MCP_HOST:-0.0.0.0}"
|
||||
export MCP_PORT="${MCP_PORT:-3000}"
|
||||
export ALLOWED_ORIGINS="${ALLOWED_ORIGINS:-*}"
|
||||
export LOG_LEVEL="${LOG_LEVEL:-info}"
|
||||
export MCP_ALLOW_CREDENTIALS="${MCP_ALLOW_CREDENTIALS:-false}"
|
||||
|
||||
# Add adapter debug support
|
||||
export MCP_DEBUG_ADAPTER="true"
|
||||
export PYTHONPATH="$(pwd):$PYTHONPATH" # Ensure modules can be imported
|
||||
export PYTHONPATH="$(pwd):$PYTHONPATH"
|
||||
|
||||
# Create log directory
|
||||
mkdir -p logs
|
||||
echo -e "${GREEN}Starting MCP server (Streamable HTTP mode)...${NC}"
|
||||
echo -e "${YELLOW}Service will run on http://${MCP_HOST}:${MCP_PORT}/mcp${NC}"
|
||||
echo -e "${YELLOW}Health Check: http://${MCP_HOST}:${MCP_PORT}/health${NC}"
|
||||
echo -e "${YELLOW}MCP Endpoint: http://${MCP_HOST}:${MCP_PORT}/mcp${NC}"
|
||||
echo -e "${YELLOW}Local access: http://localhost:${MCP_PORT}/mcp${NC}"
|
||||
echo -e "${YELLOW}Use Ctrl+C to stop the service${NC}"
|
||||
|
||||
# Debug info
|
||||
echo -e "${GREEN}Environment Variables:${NC}" # Environment Variables:
|
||||
echo -e "MCP_TRANSPORT_TYPE=${MCP_TRANSPORT_TYPE}"
|
||||
echo -e "MCP_PORT=${MCP_PORT}"
|
||||
echo -e "ALLOWED_ORIGINS=${ALLOWED_ORIGINS}"
|
||||
echo -e "LOG_LEVEL=${LOG_LEVEL}"
|
||||
echo -e "MCP_ALLOW_CREDENTIALS=${MCP_ALLOW_CREDENTIALS}"
|
||||
echo -e "MCP_DEBUG_ADAPTER=${MCP_DEBUG_ADAPTER}"
|
||||
# Start the server in HTTP mode (Streamable HTTP)
|
||||
python -m doris_mcp_server.main --transport http --host ${MCP_HOST} --port ${MCP_PORT}
|
||||
|
||||
echo -e "${GREEN}Starting MCP server (SSE mode)...${NC}" # Starting MCP server (SSE mode)...
|
||||
echo -e "${YELLOW}Service will run on http://localhost:3000/mcp${NC}" # Service will run on http://localhost:3000/mcp
|
||||
echo -e "${YELLOW}Health Check: http://localhost:3000/health${NC}" # Health Check: http://localhost:3000/health
|
||||
echo -e "${YELLOW}SSE Test: http://localhost:3000/sse${NC}" # SSE Test: http://localhost:3000/sse
|
||||
echo -e "${YELLOW}Use Ctrl+C to stop the service${NC}" # Use Ctrl+C to stop the service
|
||||
|
||||
# If the server exits abnormally, output error message
|
||||
# Check exit status
|
||||
if [ $? -ne 0 ]; then
|
||||
echo -e "${RED}Server exited abnormally! Check logs for more information${NC}" # Server exited abnormally! Check logs for more information
|
||||
echo -e "${RED}Server exited abnormally! Check logs for more information${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Show browser cache clearing prompt
|
||||
echo -e "${YELLOW}Tip: If the page displays abnormally, please clear your browser cache or use incognito mode${NC}" # Tip: If the page displays abnormally, please clear your browser cache or use incognito mode
|
||||
echo -e "${YELLOW}Chrome browser clear cache shortcut: Ctrl+Shift+Del (Windows) or Cmd+Shift+Del (Mac)${NC}" # Chrome browser clear cache shortcut: Ctrl+Shift+Del (Windows) or Cmd+Shift+Del (Mac)
|
||||
# Show usage tips
|
||||
echo -e "${YELLOW}Tip: If the page displays abnormally, please clear your browser cache or use incognito mode${NC}"
|
||||
echo -e "${YELLOW}Chrome browser clear cache shortcut: Ctrl+Shift+Del (Windows) or Cmd+Shift+Del (Mac)${NC}"
|
||||
echo -e "${CYAN}For testing HTTP endpoints, you can use:${NC}"
|
||||
echo -e "${CYAN} curl -X POST http://localhost:${MCP_PORT}/mcp -H 'Content-Type: application/json' -d '{\"method\":\"tools/list\"}'${NC}"
|
||||
245
test/README.md
Normal file
245
test/README.md
Normal file
@@ -0,0 +1,245 @@
|
||||
# Doris MCP Server Testing System
|
||||
|
||||
## Overview
|
||||
|
||||
This testing system adopts a layered architecture, including unit tests, integration tests, and client-server tests. The testing system assumes the server is already properly started and focuses on testing functionality rather than startup configuration.
|
||||
|
||||
## Testing Architecture
|
||||
|
||||
### 1. Unit Tests
|
||||
- **Location**: `test/security/`, `test/utils/`, `test/tools/`
|
||||
- **Purpose**: Test individual module functionality
|
||||
- **Features**: Uses Mock objects, no dependency on external services
|
||||
|
||||
### 2. Integration Tests
|
||||
- **Location**: `test/integration/`
|
||||
- **Purpose**: Test collaboration between modules
|
||||
- **Features**: Test complete workflows
|
||||
|
||||
### 3. Client-Server Tests
|
||||
- **Location**: `test/tools/test_tools_client_server.py`, `test/utils/test_query_executor_client_server.py`
|
||||
- **Purpose**: Test actual server functionality through MCP client
|
||||
- **Features**: Assumes server is running, skips tests if server is not available
|
||||
|
||||
## Configuration Files
|
||||
|
||||
### test_config.json
|
||||
Test configuration file defines how to connect to the running server:
|
||||
|
||||
```json
|
||||
{
|
||||
"server_endpoints": {
|
||||
"http": {
|
||||
"url": "http://localhost:3000/mcp",
|
||||
"timeout": 30
|
||||
},
|
||||
"stdio": {
|
||||
"command": "uv",
|
||||
"args": ["run", "python", "-m", "doris_mcp_server.main", "--transport", "stdio"],
|
||||
"timeout": 30
|
||||
}
|
||||
},
|
||||
"test_settings": {
|
||||
"default_transport": "http",
|
||||
"retry_attempts": 3,
|
||||
"retry_delay": 1.0,
|
||||
"test_timeout": 60,
|
||||
"enable_performance_tests": true,
|
||||
"enable_security_tests": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Start the Server
|
||||
|
||||
Before running client-server tests, you need to start the server first:
|
||||
|
||||
#### HTTP Mode (Recommended)
|
||||
```bash
|
||||
# Start HTTP server
|
||||
./start_server.sh
|
||||
# or
|
||||
uv run python -m doris_mcp_server.main --transport http --port 3000
|
||||
```
|
||||
|
||||
#### Stdio Mode
|
||||
```bash
|
||||
# Stdio mode is started directly by the client, no need to pre-start
|
||||
```
|
||||
|
||||
### 2. Run Tests
|
||||
|
||||
#### Run All Tests
|
||||
```bash
|
||||
python -m pytest test/ -v
|
||||
```
|
||||
|
||||
#### Run Unit Tests
|
||||
```bash
|
||||
# Security module tests
|
||||
python -m pytest test/security/ -v
|
||||
|
||||
# Tools module tests
|
||||
python -m pytest test/tools/test_tools_manager.py -v
|
||||
|
||||
# Query executor tests
|
||||
python -m pytest test/utils/test_query_executor.py -v
|
||||
```
|
||||
|
||||
#### Run Integration Tests
|
||||
```bash
|
||||
python -m pytest test/integration/ -v
|
||||
```
|
||||
|
||||
#### Run Client-Server Tests
|
||||
```bash
|
||||
# Tools Client-Server tests
|
||||
python -m pytest test/tools/test_tools_client_server.py -v
|
||||
|
||||
# QueryExecutor Client-Server tests
|
||||
python -m pytest test/utils/test_query_executor_client_server.py -v
|
||||
```
|
||||
|
||||
### 3. Test Configuration
|
||||
|
||||
#### Modify Server Endpoints
|
||||
Edit the `test/test_config.json` file:
|
||||
|
||||
```json
|
||||
{
|
||||
"server_endpoints": {
|
||||
"http": {
|
||||
"url": "http://your-server:port/mcp"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Enable/Disable Specific Tests
|
||||
```json
|
||||
{
|
||||
"test_settings": {
|
||||
"enable_performance_tests": false, // Disable performance tests
|
||||
"enable_security_tests": true // Enable security tests
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Test Status
|
||||
|
||||
### ✅ Completed Test Modules
|
||||
|
||||
1. **Security Module** (100% Pass)
|
||||
- Authentication tests: 5/5 passed
|
||||
- Authorization tests: 7/7 passed
|
||||
- Data masking tests: 13/13 passed
|
||||
- SQL validation tests: 10/10 passed
|
||||
- Security manager tests: 7/7 passed
|
||||
- Coverage: 88%
|
||||
|
||||
2. **Client-Server Test Architecture** (Implemented)
|
||||
- Automatic server connection status detection
|
||||
- Automatically skip tests when server is not running
|
||||
- Support for both HTTP and Stdio transport modes
|
||||
|
||||
### 🔄 Tests Requiring Server Running
|
||||
|
||||
1. **Tools Client-Server Tests**
|
||||
- Tool list retrieval
|
||||
- SQL query execution
|
||||
- Database list retrieval
|
||||
- Table schema queries
|
||||
- Performance statistics
|
||||
- Error handling
|
||||
- Security authentication
|
||||
|
||||
2. **QueryExecutor Client-Server Tests**
|
||||
- Simple query execution
|
||||
- Database queries
|
||||
- Information schema queries
|
||||
- Parameterized queries
|
||||
- Error handling
|
||||
- Security authentication
|
||||
|
||||
## Testing Best Practices
|
||||
|
||||
### 1. Server Startup Check
|
||||
All client-server tests automatically check server connection status:
|
||||
- If server is running normally, execute actual tests
|
||||
- If server is not running, skip tests and display appropriate message
|
||||
|
||||
### 2. Test Isolation
|
||||
- Unit tests use Mock objects, no dependency on external services
|
||||
- Integration tests use controlled test environments
|
||||
- Client-server tests connect to actually running servers
|
||||
|
||||
### 3. Error Handling
|
||||
- Tests don't assume specific success/failure results
|
||||
- Verify response structure rather than specific content
|
||||
- Gracefully handle connection failures and timeouts
|
||||
|
||||
### 4. Configuration Management
|
||||
- Use configuration files to manage test parameters
|
||||
- Support configuration switching for different environments
|
||||
- Provide reasonable default values
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### 1. Server Connection Failure
|
||||
```
|
||||
ERROR: Server is not running or not accessible
|
||||
```
|
||||
**Solution**: Ensure the server is started and listening on the correct port
|
||||
|
||||
### 2. Import Errors
|
||||
```
|
||||
ImportError: cannot import name 'DorisUnifiedClient'
|
||||
```
|
||||
**Solution**: Check Python path and dependency installation
|
||||
|
||||
### 3. Test Timeouts
|
||||
```
|
||||
TimeoutError: Test execution timeout
|
||||
```
|
||||
**Solution**: Increase timeout settings in `test_config.json`
|
||||
|
||||
## Development Guide
|
||||
|
||||
### Adding New Client-Server Tests
|
||||
|
||||
1. Add test methods in the appropriate test file
|
||||
2. Use `@pytest.mark.asyncio` decorator
|
||||
3. Get test client through `client` fixture
|
||||
4. Implement test callback function
|
||||
5. Verify response structure
|
||||
|
||||
Example:
|
||||
```python
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_feature_via_client(self, client, test_config):
|
||||
"""Test new feature through client"""
|
||||
async def test_callback(client_instance):
|
||||
result = await client_instance.call_tool("new_tool", {
|
||||
"param": "value"
|
||||
})
|
||||
|
||||
assert "success" in result
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
```
|
||||
|
||||
### Modifying Test Configuration
|
||||
|
||||
Edit the `test/test_config.json` file to adjust:
|
||||
- Server endpoints
|
||||
- Timeout settings
|
||||
- Test data
|
||||
- Feature switches
|
||||
|
||||
## Summary
|
||||
|
||||
This testing system provides complete test coverage, from unit tests to end-to-end client-server tests. Through reasonable configuration and automated connection detection, it ensures tests can run stably in different environments.
|
||||
0
test/__init__.py
Normal file
0
test/__init__.py
Normal file
91
test/conftest.py
Normal file
91
test/conftest.py
Normal file
@@ -0,0 +1,91 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Pytest configuration and fixtures for Doris MCP Server tests
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# Add project root to Python path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
# Configure logging for tests
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Create an instance of the default event loop for the test session."""
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_config():
|
||||
"""Provide test configuration"""
|
||||
return {
|
||||
"doris_host": "localhost",
|
||||
"doris_port": 9030,
|
||||
"doris_user": "test_user",
|
||||
"doris_password": "test_password",
|
||||
"doris_database": "test_db",
|
||||
"blocked_keywords": ["DROP", "DELETE", "TRUNCATE", "ALTER", "CREATE", "INSERT", "UPDATE"],
|
||||
"sensitive_tables": {
|
||||
"user_info": "confidential",
|
||||
"payment_records": "secret",
|
||||
"employee_data": "confidential",
|
||||
"public_reports": "public"
|
||||
},
|
||||
"max_query_complexity": 100
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data():
|
||||
"""Provide sample test data"""
|
||||
return [
|
||||
{
|
||||
"id": 1,
|
||||
"name": "张三",
|
||||
"phone": "13812345678",
|
||||
"email": "zhangsan@example.com",
|
||||
"id_card": "110101199001011234",
|
||||
"salary": 50000
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"name": "李四",
|
||||
"phone": "13987654321",
|
||||
"email": "lisi@example.com",
|
||||
"id_card": "110101199002022345",
|
||||
"salary": 60000
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_sql_queries():
|
||||
"""Provide test SQL queries"""
|
||||
return {
|
||||
"safe_select": "SELECT name, email FROM users WHERE department = 'sales'",
|
||||
"dangerous_drop": "DROP TABLE users",
|
||||
"sql_injection": "SELECT * FROM users WHERE id = 1; DROP TABLE users;",
|
||||
"union_injection": "SELECT name FROM users UNION SELECT password FROM admin_users",
|
||||
"comment_injection": "SELECT * FROM users WHERE id = 1 -- AND password = 'secret'",
|
||||
"complex_query": """
|
||||
SELECT u.name, u.email, d.department_name
|
||||
FROM users u
|
||||
JOIN departments d ON u.department_id = d.id
|
||||
WHERE u.status = 'active'
|
||||
ORDER BY u.created_at DESC
|
||||
"""
|
||||
}
|
||||
283
test/integration/test_end_to_end.py
Normal file
283
test/integration/test_end_to_end.py
Normal file
@@ -0,0 +1,283 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
End-to-end integration tests
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from doris_mcp_server.main import DorisServer
|
||||
from doris_mcp_server.utils.config import DorisConfig
|
||||
from doris_mcp_server.utils.security import SecurityLevel, AuthContext
|
||||
|
||||
|
||||
class TestEndToEndIntegration:
|
||||
"""End-to-end integration tests"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Create mock configuration"""
|
||||
from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig
|
||||
|
||||
config = Mock(spec=DorisConfig)
|
||||
config.doris_host = "localhost"
|
||||
config.doris_port = 9030
|
||||
config.doris_user = "test_user"
|
||||
config.doris_password = "test_password"
|
||||
config.doris_database = "test_db"
|
||||
config.server_host = "localhost"
|
||||
config.server_port = 8000
|
||||
config.enable_security = True
|
||||
|
||||
# Add database config
|
||||
config.database = Mock(spec=DatabaseConfig)
|
||||
config.database.host = "localhost"
|
||||
config.database.port = 9030
|
||||
config.database.user = "test_user"
|
||||
config.database.password = "test_password"
|
||||
config.database.database = "test_db"
|
||||
config.database.health_check_interval = 60
|
||||
config.database.min_connections = 5
|
||||
config.database.max_connections = 20
|
||||
config.database.connection_timeout = 30
|
||||
config.database.max_connection_age = 3600
|
||||
|
||||
# Add security config
|
||||
config.security = Mock(spec=SecurityConfig)
|
||||
config.security.enable_masking = True
|
||||
config.security.auth_type = "token"
|
||||
config.security.token_secret = "test_secret"
|
||||
config.security.token_expiry = 3600
|
||||
|
||||
return config
|
||||
|
||||
@pytest.fixture
|
||||
def doris_server(self, mock_config):
|
||||
"""Create Doris server instance"""
|
||||
return DorisServer(mock_config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_query_workflow_with_security(self, doris_server, sample_data):
|
||||
"""Test complete query workflow with security"""
|
||||
with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute:
|
||||
mock_execute.return_value = sample_data
|
||||
|
||||
# Mock authentication
|
||||
with patch.object(doris_server.security_manager, 'authenticate_request') as mock_auth:
|
||||
mock_auth.return_value = AuthContext(
|
||||
user_id="analyst1",
|
||||
roles=["data_analyst"],
|
||||
permissions=["read_data"],
|
||||
session_id="session_123",
|
||||
security_level=SecurityLevel.INTERNAL
|
||||
)
|
||||
|
||||
# Mock authorization
|
||||
with patch.object(doris_server.security_manager, 'authorize_resource_access') as mock_authz:
|
||||
mock_authz.return_value = True
|
||||
|
||||
# Mock SQL validation
|
||||
with patch.object(doris_server.security_manager, 'validate_sql_security') as mock_validate:
|
||||
from doris_mcp_server.utils.security import ValidationResult
|
||||
mock_validate.return_value = ValidationResult(is_valid=True)
|
||||
|
||||
# Mock data masking
|
||||
with patch.object(doris_server.security_manager, 'apply_data_masking') as mock_mask:
|
||||
masked_data = [
|
||||
{
|
||||
"id": 1,
|
||||
"name": "张三",
|
||||
"phone": "138****5678",
|
||||
"email": "z*******n@example.com",
|
||||
"id_card": "110101****1234",
|
||||
"salary": 50000
|
||||
}
|
||||
]
|
||||
mock_mask.return_value = masked_data
|
||||
|
||||
# Simulate complete workflow
|
||||
auth_info = {"type": "token", "token": "valid_token_123"}
|
||||
auth_context = await doris_server.security_manager.authenticate_request(auth_info)
|
||||
|
||||
resource_uri = "/api/table/users"
|
||||
has_access = await doris_server.security_manager.authorize_resource_access(
|
||||
auth_context, resource_uri
|
||||
)
|
||||
assert has_access is True
|
||||
|
||||
sql = "SELECT * FROM users LIMIT 1"
|
||||
validation = await doris_server.security_manager.validate_sql_security(
|
||||
sql, auth_context
|
||||
)
|
||||
assert validation.is_valid is True
|
||||
|
||||
raw_data = await doris_server.tools_manager.query_executor.execute_query(sql)
|
||||
final_data = await doris_server.security_manager.apply_data_masking(
|
||||
raw_data, auth_context
|
||||
)
|
||||
|
||||
# Verify data is properly masked
|
||||
assert final_data[0]["phone"] == "138****5678"
|
||||
assert final_data[0]["email"] == "z*******n@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_security_violation_workflow(self, doris_server):
|
||||
"""Test security violation detection workflow"""
|
||||
with patch.object(doris_server.security_manager, 'authenticate_request') as mock_auth:
|
||||
mock_auth.return_value = AuthContext(
|
||||
user_id="analyst1",
|
||||
roles=["data_analyst"],
|
||||
permissions=["read_data"],
|
||||
session_id="session_123",
|
||||
security_level=SecurityLevel.INTERNAL
|
||||
)
|
||||
|
||||
# Test unauthorized resource access
|
||||
with patch.object(doris_server.security_manager, 'authorize_resource_access') as mock_authz:
|
||||
mock_authz.return_value = False
|
||||
|
||||
auth_context = await doris_server.security_manager.authenticate_request({
|
||||
"type": "token", "token": "valid_token_123"
|
||||
})
|
||||
|
||||
# Try to access confidential resource
|
||||
resource_uri = "/api/table/payment_records"
|
||||
has_access = await doris_server.security_manager.authorize_resource_access(
|
||||
auth_context, resource_uri
|
||||
)
|
||||
|
||||
assert has_access is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sql_injection_prevention_workflow(self, doris_server):
|
||||
"""Test SQL injection prevention workflow"""
|
||||
with patch.object(doris_server.security_manager, 'authenticate_request') as mock_auth:
|
||||
mock_auth.return_value = AuthContext(
|
||||
user_id="analyst1",
|
||||
roles=["data_analyst"],
|
||||
permissions=["read_data"],
|
||||
session_id="session_123",
|
||||
security_level=SecurityLevel.INTERNAL
|
||||
)
|
||||
|
||||
auth_context = await doris_server.security_manager.authenticate_request({
|
||||
"type": "token", "token": "valid_token_123"
|
||||
})
|
||||
|
||||
# Test SQL injection attempt
|
||||
malicious_sql = "SELECT * FROM users WHERE id = 1; DROP TABLE users;"
|
||||
validation = await doris_server.security_manager.validate_sql_security(
|
||||
malicious_sql, auth_context
|
||||
)
|
||||
|
||||
assert validation.is_valid is False
|
||||
assert validation.risk_level == "high"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_bypass_workflow(self, doris_server, sample_data):
|
||||
"""Test admin user bypassing restrictions"""
|
||||
with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute:
|
||||
mock_execute.return_value = sample_data
|
||||
|
||||
with patch.object(doris_server.security_manager, 'authenticate_request') as mock_auth:
|
||||
mock_auth.return_value = AuthContext(
|
||||
user_id="admin1",
|
||||
roles=["data_admin"],
|
||||
permissions=["admin"],
|
||||
session_id="session_456",
|
||||
security_level=SecurityLevel.SECRET
|
||||
)
|
||||
|
||||
# Admin should access any resource
|
||||
with patch.object(doris_server.security_manager, 'authorize_resource_access') as mock_authz:
|
||||
mock_authz.return_value = True
|
||||
|
||||
# Admin should see original data (no masking)
|
||||
with patch.object(doris_server.security_manager, 'apply_data_masking') as mock_mask:
|
||||
mock_mask.return_value = sample_data # Original data
|
||||
|
||||
auth_context = await doris_server.security_manager.authenticate_request({
|
||||
"type": "basic", "username": "admin", "password": "admin123"
|
||||
})
|
||||
|
||||
# Admin accesses secret resource
|
||||
resource_uri = "/api/table/payment_records"
|
||||
has_access = await doris_server.security_manager.authorize_resource_access(
|
||||
auth_context, resource_uri
|
||||
)
|
||||
assert has_access is True
|
||||
|
||||
# Admin sees original data
|
||||
raw_data = await doris_server.tools_manager.query_executor.execute_query(
|
||||
"SELECT * FROM users LIMIT 1"
|
||||
)
|
||||
final_data = await doris_server.security_manager.apply_data_masking(
|
||||
raw_data, auth_context
|
||||
)
|
||||
|
||||
# Should be original data (no masking)
|
||||
assert final_data[0]["phone"] == "13812345678"
|
||||
assert final_data[0]["email"] == "zhangsan@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_with_security(self, doris_server):
|
||||
"""Test tool execution with security checks"""
|
||||
with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute:
|
||||
mock_execute.return_value = [{"Database": "test_db"}]
|
||||
|
||||
# Test tool execution through tools manager
|
||||
result = await doris_server.tools_manager.call_tool("get_db_list", {})
|
||||
result_data = json.loads(result)
|
||||
|
||||
# Accept either success result or error (due to mock environment)
|
||||
assert "result" in result_data or "error" in result_data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_workflow(self, doris_server):
|
||||
"""Test error handling in complete workflow"""
|
||||
# Test authentication failure
|
||||
with patch.object(doris_server.security_manager, 'authenticate_request') as mock_auth:
|
||||
mock_auth.side_effect = Exception("Invalid token")
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await doris_server.security_manager.authenticate_request({
|
||||
"type": "token", "token": "invalid_token"
|
||||
})
|
||||
|
||||
assert "Invalid token" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_performance_monitoring_integration(self, doris_server):
|
||||
"""Test performance monitoring integration"""
|
||||
with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute:
|
||||
mock_execute.return_value = [
|
||||
{
|
||||
"query_count": 1500,
|
||||
"avg_execution_time": 0.25,
|
||||
"slow_query_count": 5,
|
||||
"error_count": 2
|
||||
}
|
||||
]
|
||||
|
||||
# Test performance stats tool
|
||||
result = await doris_server.tools_manager.call_tool("performance_stats", {
|
||||
"metric_type": "queries",
|
||||
"time_range": "1h"
|
||||
})
|
||||
result_data = json.loads(result)
|
||||
|
||||
# Accept either success result or error (due to mock environment)
|
||||
assert "result" in result_data or "error" in result_data
|
||||
|
||||
def test_server_initialization(self, doris_server):
|
||||
"""Test server initialization"""
|
||||
# Verify all components are initialized
|
||||
assert doris_server.config is not None
|
||||
assert doris_server.tools_manager is not None
|
||||
assert doris_server.security_manager is not None
|
||||
|
||||
# Verify tools are available - use list_tools instead
|
||||
import asyncio
|
||||
tools = asyncio.run(doris_server.tools_manager.list_tools())
|
||||
assert len(tools) > 0
|
||||
87
test/security/test_authentication.py
Normal file
87
test/security/test_authentication.py
Normal file
@@ -0,0 +1,87 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Authentication module tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
|
||||
from doris_mcp_server.utils.security import (
|
||||
AuthenticationProvider,
|
||||
AuthContext,
|
||||
SecurityLevel
|
||||
)
|
||||
|
||||
|
||||
class TestAuthenticationProvider:
|
||||
"""Authentication provider tests"""
|
||||
|
||||
@pytest.fixture
|
||||
def auth_provider(self, test_config):
|
||||
"""Create authentication provider instance"""
|
||||
return AuthenticationProvider(test_config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_authentication_success(self, auth_provider):
|
||||
"""Test successful token authentication"""
|
||||
auth_info = {
|
||||
"type": "token",
|
||||
"token": "valid_token_123"
|
||||
}
|
||||
|
||||
result = await auth_provider.authenticate(auth_info)
|
||||
|
||||
assert isinstance(result, AuthContext)
|
||||
assert result.user_id == "test_user"
|
||||
assert "data_analyst" in result.roles
|
||||
assert result.security_level == SecurityLevel.INTERNAL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_authentication_failure(self, auth_provider):
|
||||
"""Test failed token authentication"""
|
||||
auth_info = {
|
||||
"type": "token",
|
||||
"token": "invalid_token"
|
||||
}
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await auth_provider.authenticate(auth_info)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_authentication_success(self, auth_provider):
|
||||
"""Test successful basic authentication"""
|
||||
auth_info = {
|
||||
"type": "basic",
|
||||
"username": "admin",
|
||||
"password": "admin123"
|
||||
}
|
||||
|
||||
result = await auth_provider.authenticate(auth_info)
|
||||
|
||||
assert isinstance(result, AuthContext)
|
||||
assert result.user_id == "admin_user"
|
||||
assert "data_admin" in result.roles
|
||||
assert result.security_level == SecurityLevel.SECRET
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_authentication_failure(self, auth_provider):
|
||||
"""Test failed basic authentication"""
|
||||
auth_info = {
|
||||
"type": "basic",
|
||||
"username": "admin",
|
||||
"password": "wrong_password"
|
||||
}
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await auth_provider.authenticate(auth_info)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsupported_auth_type(self, auth_provider):
|
||||
"""Test unsupported authentication type"""
|
||||
auth_info = {
|
||||
"type": "oauth",
|
||||
"token": "oauth_token"
|
||||
}
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await auth_provider.authenticate(auth_info)
|
||||
131
test/security/test_authorization.py
Normal file
131
test/security/test_authorization.py
Normal file
@@ -0,0 +1,131 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Authorization module tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from doris_mcp_server.utils.security import (
|
||||
AuthorizationProvider,
|
||||
AuthContext,
|
||||
SecurityLevel
|
||||
)
|
||||
|
||||
|
||||
class TestAuthorizationProvider:
|
||||
"""Authorization provider tests"""
|
||||
|
||||
@pytest.fixture
|
||||
def authz_provider(self, test_config):
|
||||
"""Create authorization provider instance"""
|
||||
return AuthorizationProvider(test_config)
|
||||
|
||||
@pytest.fixture
|
||||
def analyst_context(self):
|
||||
"""Create analyst auth context"""
|
||||
return AuthContext(
|
||||
user_id="analyst1",
|
||||
roles=["data_analyst"],
|
||||
permissions=["read_data"],
|
||||
session_id="session_123",
|
||||
security_level=SecurityLevel.INTERNAL
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def admin_context(self):
|
||||
"""Create admin auth context"""
|
||||
return AuthContext(
|
||||
user_id="admin1",
|
||||
roles=["data_admin"],
|
||||
permissions=["admin"],
|
||||
session_id="session_456",
|
||||
security_level=SecurityLevel.SECRET
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyst_access_public_resource(self, authz_provider, analyst_context):
|
||||
"""Test analyst accessing public resource"""
|
||||
resource_uri = "/api/table/public_reports"
|
||||
|
||||
result = await authz_provider.check_permission(analyst_context, resource_uri, "read")
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyst_denied_confidential_resource(self, authz_provider):
|
||||
"""Test analyst denied access to confidential resource"""
|
||||
# Create analyst with lower security level
|
||||
analyst_context = AuthContext(
|
||||
user_id="analyst1",
|
||||
roles=["data_analyst"],
|
||||
permissions=["read_data"],
|
||||
session_id="session_123",
|
||||
security_level=SecurityLevel.PUBLIC # Lower than CONFIDENTIAL
|
||||
)
|
||||
|
||||
resource_uri = "/api/table/user_info"
|
||||
|
||||
result = await authz_provider.check_permission(analyst_context, resource_uri, "read")
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_access_secret_resource(self, authz_provider, admin_context):
|
||||
"""Test admin accessing secret resource"""
|
||||
resource_uri = "/api/table/payment_records"
|
||||
|
||||
result = await authz_provider.check_permission(admin_context, resource_uri, "read")
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_role_based_permission(self, authz_provider):
|
||||
"""Test role-based permission check"""
|
||||
# Create analyst context
|
||||
analyst_context = AuthContext(
|
||||
user_id="analyst1",
|
||||
roles=["data_analyst"],
|
||||
permissions=["read_data"],
|
||||
session_id="session_123",
|
||||
security_level=SecurityLevel.INTERNAL
|
||||
)
|
||||
|
||||
resource_uri = "/api/table/some_table"
|
||||
|
||||
# Analyst should have read permission
|
||||
result = await authz_provider.check_permission(analyst_context, resource_uri, "read")
|
||||
assert result is True
|
||||
|
||||
# Analyst should not have write permission
|
||||
result = await authz_provider.check_permission(analyst_context, resource_uri, "write")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_override(self, authz_provider, admin_context):
|
||||
"""Test admin permission override"""
|
||||
resource_uri = "/api/table/any_table"
|
||||
|
||||
# Admin should have all permissions
|
||||
result = await authz_provider.check_permission(admin_context, resource_uri, "read")
|
||||
assert result is True
|
||||
|
||||
result = await authz_provider.check_permission(admin_context, resource_uri, "write")
|
||||
assert result is True
|
||||
|
||||
def test_parse_resource_uri(self, authz_provider):
|
||||
"""Test resource URI parsing"""
|
||||
uri = "/api/table/user_info/default"
|
||||
|
||||
result = authz_provider._parse_resource_uri(uri)
|
||||
|
||||
assert result["type"] == "table"
|
||||
assert result["name"] == "user_info"
|
||||
assert result["schema"] == "default"
|
||||
|
||||
def test_get_resource_security_level(self, authz_provider):
|
||||
"""Test getting resource security level"""
|
||||
resource_info = {"name": "user_info", "type": "table"}
|
||||
|
||||
level = authz_provider._get_resource_security_level(resource_info)
|
||||
|
||||
assert level == SecurityLevel.CONFIDENTIAL
|
||||
181
test/security/test_data_masking.py
Normal file
181
test/security/test_data_masking.py
Normal file
@@ -0,0 +1,181 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Data masking tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from doris_mcp_server.utils.security import (
|
||||
DataMaskingProcessor,
|
||||
AuthContext,
|
||||
SecurityLevel,
|
||||
MaskingRule
|
||||
)
|
||||
|
||||
|
||||
class TestDataMaskingProcessor:
|
||||
"""Data masking processor tests"""
|
||||
|
||||
@pytest.fixture
|
||||
def masking_processor(self, test_config):
|
||||
"""Create data masking processor instance"""
|
||||
return DataMaskingProcessor(test_config)
|
||||
|
||||
@pytest.fixture
|
||||
def internal_user_context(self):
|
||||
"""Create internal user auth context"""
|
||||
return AuthContext(
|
||||
user_id="internal_user",
|
||||
roles=["data_analyst"],
|
||||
permissions=["read_data"],
|
||||
session_id="session_123",
|
||||
security_level=SecurityLevel.INTERNAL
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def admin_context(self):
|
||||
"""Create admin auth context"""
|
||||
return AuthContext(
|
||||
user_id="admin",
|
||||
roles=["data_admin"],
|
||||
permissions=["admin"],
|
||||
session_id="session_456",
|
||||
security_level=SecurityLevel.SECRET
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_phone_masking_for_internal_user(self, masking_processor, internal_user_context, sample_data):
|
||||
"""Test phone number masking for internal user"""
|
||||
result = await masking_processor.process(sample_data, internal_user_context)
|
||||
|
||||
# Phone numbers should be masked
|
||||
assert result[0]["phone"] == "138****5678"
|
||||
assert result[1]["phone"] == "139****4321"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_email_masking_for_internal_user(self, masking_processor, internal_user_context, sample_data):
|
||||
"""Test email masking for internal user"""
|
||||
result = await masking_processor.process(sample_data, internal_user_context)
|
||||
|
||||
# Emails should be masked
|
||||
assert result[0]["email"] == "z******n@example.com"
|
||||
assert result[1]["email"] == "l**i@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_masking_for_admin(self, masking_processor, admin_context, sample_data):
|
||||
"""Test no masking for admin user"""
|
||||
result = await masking_processor.process(sample_data, admin_context)
|
||||
|
||||
# Admin should see original data
|
||||
assert result[0]["phone"] == "13812345678"
|
||||
assert result[0]["email"] == "zhangsan@example.com"
|
||||
assert result[1]["phone"] == "13987654321"
|
||||
assert result[1]["email"] == "lisi@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_id_card_masking_for_confidential_data(self, masking_processor, internal_user_context, sample_data):
|
||||
"""Test ID card masking for confidential data"""
|
||||
# Internal user should not see ID card details (confidential level)
|
||||
result = await masking_processor.process(sample_data, internal_user_context)
|
||||
|
||||
# ID cards should be masked for internal users
|
||||
assert result[0]["id_card"] == "110101********1234"
|
||||
assert result[1]["id_card"] == "110101********2345"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_data_handling(self, masking_processor, internal_user_context):
|
||||
"""Test empty data handling"""
|
||||
empty_data = []
|
||||
|
||||
result = await masking_processor.process(empty_data, internal_user_context)
|
||||
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_null_value_handling(self, masking_processor, internal_user_context):
|
||||
"""Test null value handling"""
|
||||
data_with_nulls = [
|
||||
{
|
||||
"id": 1,
|
||||
"name": "张三",
|
||||
"phone": None,
|
||||
"email": None,
|
||||
"id_card": None
|
||||
}
|
||||
]
|
||||
|
||||
result = await masking_processor.process(data_with_nulls, internal_user_context)
|
||||
|
||||
# Null values should remain null
|
||||
assert result[0]["phone"] is None
|
||||
assert result[0]["email"] is None
|
||||
assert result[0]["id_card"] is None
|
||||
|
||||
def test_phone_masking_algorithm(self, masking_processor):
|
||||
"""Test phone masking algorithm"""
|
||||
params = {"mask_char": "*", "keep_prefix": 3, "keep_suffix": 4}
|
||||
|
||||
result = masking_processor._mask_phone("13812345678", params)
|
||||
|
||||
assert result == "138****5678"
|
||||
|
||||
def test_email_masking_algorithm(self, masking_processor):
|
||||
"""Test email masking algorithm"""
|
||||
params = {"mask_char": "*"}
|
||||
|
||||
result = masking_processor._mask_email("zhangsan@example.com", params)
|
||||
|
||||
assert result == "z******n@example.com"
|
||||
|
||||
def test_id_card_masking_algorithm(self, masking_processor):
|
||||
"""Test ID card masking algorithm"""
|
||||
params = {"mask_char": "*", "keep_prefix": 6, "keep_suffix": 4}
|
||||
|
||||
result = masking_processor._mask_id_card("110101199001011234", params)
|
||||
|
||||
assert result == "110101********1234"
|
||||
|
||||
def test_name_masking_algorithm(self, masking_processor):
|
||||
"""Test name masking algorithm"""
|
||||
params = {"mask_char": "*"}
|
||||
|
||||
# Test 2-character name
|
||||
result = masking_processor._mask_name("张三", params)
|
||||
assert result == "张*"
|
||||
|
||||
# Test 3-character name
|
||||
result = masking_processor._mask_name("李小明", params)
|
||||
assert result == "李*明"
|
||||
|
||||
def test_partial_masking_algorithm(self, masking_processor):
|
||||
"""Test partial masking algorithm"""
|
||||
params = {"mask_char": "*", "mask_ratio": 0.5}
|
||||
|
||||
result = masking_processor._mask_partial("1234567890", params)
|
||||
|
||||
# Should mask middle 50% of the string
|
||||
assert "*" in result
|
||||
assert len(result) == 10
|
||||
|
||||
def test_should_apply_rule_logic(self, masking_processor, internal_user_context, admin_context):
|
||||
"""Test masking rule application logic"""
|
||||
rule = MaskingRule(
|
||||
column_pattern=r".*phone.*",
|
||||
algorithm="phone_mask",
|
||||
parameters={"mask_char": "*", "keep_prefix": 3, "keep_suffix": 4},
|
||||
security_level=SecurityLevel.INTERNAL
|
||||
)
|
||||
|
||||
# Internal user should have rule applied
|
||||
assert masking_processor._should_apply_rule(rule, internal_user_context) is True
|
||||
|
||||
# Admin should not have rule applied
|
||||
assert masking_processor._should_apply_rule(rule, admin_context) is False
|
||||
|
||||
def test_get_applicable_rules(self, masking_processor, internal_user_context):
|
||||
"""Test getting applicable rules"""
|
||||
rules = masking_processor._get_applicable_rules(internal_user_context)
|
||||
|
||||
# Should return some rules for internal user
|
||||
assert len(rules) > 0
|
||||
assert all(isinstance(rule, MaskingRule) for rule in rules)
|
||||
156
test/security/test_security_manager.py
Normal file
156
test/security/test_security_manager.py
Normal file
@@ -0,0 +1,156 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Security manager integration tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from doris_mcp_server.utils.security import (
|
||||
DorisSecurityManager,
|
||||
AuthContext,
|
||||
SecurityLevel,
|
||||
ValidationResult
|
||||
)
|
||||
|
||||
|
||||
class TestDorisSecurityManager:
|
||||
"""Doris security manager integration tests"""
|
||||
|
||||
@pytest.fixture
|
||||
def security_manager(self, test_config):
|
||||
"""Create security manager instance"""
|
||||
return DorisSecurityManager(test_config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_security_workflow(self, security_manager, sample_data):
|
||||
"""Test complete security workflow"""
|
||||
# 1. Authentication
|
||||
auth_info = {
|
||||
"type": "token",
|
||||
"token": "valid_token_123"
|
||||
}
|
||||
|
||||
auth_context = await security_manager.authenticate_request(auth_info)
|
||||
assert isinstance(auth_context, AuthContext)
|
||||
assert auth_context.security_level == SecurityLevel.INTERNAL
|
||||
|
||||
# 2. Authorization
|
||||
resource_uri = "/api/table/public_reports"
|
||||
has_access = await security_manager.authorize_resource_access(auth_context, resource_uri)
|
||||
assert has_access is True
|
||||
|
||||
# 3. SQL Validation
|
||||
safe_sql = "SELECT name, email FROM users WHERE department = 'sales'"
|
||||
validation_result = await security_manager.validate_sql_security(safe_sql, auth_context)
|
||||
assert validation_result.is_valid is True
|
||||
|
||||
# 4. Data Masking
|
||||
masked_data = await security_manager.apply_data_masking(sample_data, auth_context)
|
||||
assert masked_data[0]["phone"] == "138****5678" # Should be masked
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_workflow(self, security_manager, sample_data):
|
||||
"""Test admin user workflow"""
|
||||
# Admin authentication
|
||||
auth_info = {
|
||||
"type": "basic",
|
||||
"username": "admin",
|
||||
"password": "admin123"
|
||||
}
|
||||
|
||||
auth_context = await security_manager.authenticate_request(auth_info)
|
||||
assert auth_context.security_level == SecurityLevel.SECRET
|
||||
|
||||
# Admin should access secret resources
|
||||
resource_uri = "/api/table/payment_records"
|
||||
has_access = await security_manager.authorize_resource_access(auth_context, resource_uri)
|
||||
assert has_access is True
|
||||
|
||||
# Admin should see original data (no masking)
|
||||
masked_data = await security_manager.apply_data_masking(sample_data, auth_context)
|
||||
assert masked_data[0]["phone"] == "13812345678" # Original data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_security_violation_detection(self, security_manager):
|
||||
"""Test security violation detection"""
|
||||
# Authenticate as regular user
|
||||
auth_info = {
|
||||
"type": "token",
|
||||
"token": "valid_token_123"
|
||||
}
|
||||
|
||||
auth_context = await security_manager.authenticate_request(auth_info)
|
||||
|
||||
# Try to access confidential resource (user_info is CONFIDENTIAL, user is INTERNAL)
|
||||
# INTERNAL(1) should not access CONFIDENTIAL(2) resource
|
||||
resource_uri = "/api/table/user_info"
|
||||
has_access = await security_manager.authorize_resource_access(auth_context, resource_uri)
|
||||
assert has_access is False
|
||||
|
||||
# Try dangerous SQL
|
||||
dangerous_sql = "DROP TABLE users"
|
||||
validation_result = await security_manager.validate_sql_security(dangerous_sql, auth_context)
|
||||
assert validation_result.is_valid is False
|
||||
assert "DROP" in validation_result.blocked_operations
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sql_injection_prevention(self, security_manager):
|
||||
"""Test SQL injection prevention"""
|
||||
auth_info = {
|
||||
"type": "token",
|
||||
"token": "valid_token_123"
|
||||
}
|
||||
|
||||
auth_context = await security_manager.authenticate_request(auth_info)
|
||||
|
||||
# Test various injection attempts
|
||||
injection_attempts = [
|
||||
"SELECT * FROM users WHERE id = 1; DROP TABLE users;",
|
||||
"SELECT * FROM users UNION SELECT password FROM admin_users",
|
||||
"SELECT * FROM users WHERE id = 1 OR 1=1",
|
||||
"SELECT * FROM users WHERE name = 'test' -- AND password = 'secret'"
|
||||
]
|
||||
|
||||
for sql in injection_attempts:
|
||||
result = await security_manager.validate_sql_security(sql, auth_context)
|
||||
assert result.is_valid is False
|
||||
assert result.risk_level in ["medium", "high"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authentication_failure_handling(self, security_manager):
|
||||
"""Test authentication failure handling"""
|
||||
invalid_auth_info = {
|
||||
"type": "token",
|
||||
"token": "invalid_token"
|
||||
}
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await security_manager.authenticate_request(invalid_auth_info)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_configuration_loading(self, security_manager):
|
||||
"""Test security configuration loading"""
|
||||
# Test blocked keywords loading
|
||||
assert "DROP" in security_manager.blocked_keywords
|
||||
assert "DELETE" in security_manager.blocked_keywords
|
||||
|
||||
# Test sensitive tables loading
|
||||
assert SecurityLevel.CONFIDENTIAL in security_manager.sensitive_tables.values()
|
||||
assert SecurityLevel.SECRET in security_manager.sensitive_tables.values()
|
||||
|
||||
# Test masking rules loading
|
||||
assert len(security_manager.masking_rules) > 0
|
||||
phone_rules = [rule for rule in security_manager.masking_rules
|
||||
if "phone" in rule.column_pattern]
|
||||
assert len(phone_rules) > 0
|
||||
|
||||
def test_security_level_hierarchy(self, security_manager):
|
||||
"""Test security level hierarchy"""
|
||||
# Test that hierarchy is correctly defined
|
||||
levels = [SecurityLevel.PUBLIC, SecurityLevel.INTERNAL,
|
||||
SecurityLevel.CONFIDENTIAL, SecurityLevel.SECRET]
|
||||
|
||||
# Each level should be properly defined
|
||||
for level in levels:
|
||||
assert isinstance(level, SecurityLevel)
|
||||
assert level.value in ["public", "internal", "confidential", "secret"]
|
||||
145
test/security/test_sql_validation.py
Normal file
145
test/security/test_sql_validation.py
Normal file
@@ -0,0 +1,145 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
SQL security validation tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from doris_mcp_server.utils.security import (
|
||||
SQLSecurityValidator,
|
||||
AuthContext,
|
||||
SecurityLevel,
|
||||
ValidationResult
|
||||
)
|
||||
|
||||
|
||||
class TestSQLSecurityValidator:
|
||||
"""SQL security validator tests"""
|
||||
|
||||
@pytest.fixture
|
||||
def sql_validator(self, test_config):
|
||||
"""Create SQL validator instance"""
|
||||
return SQLSecurityValidator(test_config)
|
||||
|
||||
@pytest.fixture
|
||||
def analyst_context(self):
|
||||
"""Create analyst auth context"""
|
||||
return AuthContext(
|
||||
user_id="analyst1",
|
||||
roles=["data_analyst"],
|
||||
permissions=["read_data"],
|
||||
session_id="session_123",
|
||||
security_level=SecurityLevel.INTERNAL
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_safe_select_query(self, sql_validator, analyst_context, test_sql_queries):
|
||||
"""Test safe SELECT query validation"""
|
||||
sql = test_sql_queries["safe_select"]
|
||||
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
|
||||
assert result.is_valid is True
|
||||
assert result.error_message is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocked_drop_operation(self, sql_validator, analyst_context, test_sql_queries):
|
||||
"""Test blocked DROP operation"""
|
||||
sql = test_sql_queries["dangerous_drop"]
|
||||
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert "blocked operations" in result.error_message.lower()
|
||||
assert "DROP" in result.blocked_operations
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sql_injection_detection(self, sql_validator, analyst_context, test_sql_queries):
|
||||
"""Test SQL injection detection"""
|
||||
sql = test_sql_queries["sql_injection"]
|
||||
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert "injection" in result.error_message.lower()
|
||||
assert result.risk_level == "high"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_union_injection_detection(self, sql_validator, analyst_context, test_sql_queries):
|
||||
"""Test UNION injection detection"""
|
||||
sql = test_sql_queries["union_injection"]
|
||||
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert "injection" in result.error_message.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_comment_injection_detection(self, sql_validator, analyst_context, test_sql_queries):
|
||||
"""Test comment injection detection"""
|
||||
sql = test_sql_queries["comment_injection"]
|
||||
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert "comment" in result.error_message.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complex_query_validation(self, sql_validator, analyst_context, test_sql_queries):
|
||||
"""Test complex query validation"""
|
||||
sql = test_sql_queries["complex_query"]
|
||||
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
|
||||
# Complex query should pass if within limits
|
||||
assert result.is_valid is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocked_keywords_detection(self, sql_validator, analyst_context):
|
||||
"""Test blocked keywords detection"""
|
||||
blocked_sqls = [
|
||||
"DELETE FROM users WHERE id = 1",
|
||||
"TRUNCATE TABLE logs",
|
||||
"ALTER TABLE users ADD COLUMN new_col VARCHAR(50)",
|
||||
"CREATE TABLE test (id INT)",
|
||||
"INSERT INTO users VALUES (1, 'test')",
|
||||
"UPDATE users SET name = 'test' WHERE id = 1"
|
||||
]
|
||||
|
||||
for sql in blocked_sqls:
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
assert result.is_valid is False
|
||||
assert result.blocked_operations is not None
|
||||
assert len(result.blocked_operations) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_access_validation(self, sql_validator, analyst_context):
|
||||
"""Test table access validation"""
|
||||
# Test access to sensitive table
|
||||
sql = "SELECT * FROM sensitive_data"
|
||||
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
|
||||
# Should fail for non-admin users
|
||||
assert result.is_valid is False
|
||||
assert "access" in result.error_message.lower()
|
||||
|
||||
def test_extract_table_names(self, sql_validator):
|
||||
"""Test table name extraction"""
|
||||
sql = "SELECT u.name FROM users u JOIN departments d ON u.dept_id = d.id"
|
||||
|
||||
parsed = __import__('sqlparse').parse(sql)[0]
|
||||
tables = sql_validator._extract_table_names(parsed)
|
||||
|
||||
# Should extract at least one table name
|
||||
assert len(tables) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_malformed_sql_handling(self, sql_validator, analyst_context):
|
||||
"""Test malformed SQL handling"""
|
||||
malformed_sql = "SELECT * FROM users WHERE"
|
||||
|
||||
result = await sql_validator.validate(malformed_sql, analyst_context)
|
||||
|
||||
# Should handle gracefully
|
||||
assert isinstance(result, ValidationResult)
|
||||
69
test/test_config.json
Normal file
69
test/test_config.json
Normal file
@@ -0,0 +1,69 @@
|
||||
{
|
||||
"server_endpoints": {
|
||||
"http": {
|
||||
"url": "http://localhost:3000/mcp",
|
||||
"timeout": 30,
|
||||
"headers": {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
},
|
||||
"http_network": {
|
||||
"url": "http://192.168.31.168:3000/mcp",
|
||||
"timeout": 30,
|
||||
"headers": {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
},
|
||||
"stdio": {
|
||||
"command": "uv",
|
||||
"args": ["run", "python", "-m", "doris_mcp_server.main", "--transport", "stdio"],
|
||||
"timeout": 30,
|
||||
"working_directory": ".."
|
||||
}
|
||||
},
|
||||
"test_settings": {
|
||||
"default_transport": "http",
|
||||
"retry_attempts": 3,
|
||||
"retry_delay": 1.0,
|
||||
"test_timeout": 60,
|
||||
"enable_performance_tests": true,
|
||||
"enable_security_tests": true
|
||||
},
|
||||
"test_data": {
|
||||
"sample_queries": [
|
||||
"SELECT 1 as test_value",
|
||||
"SHOW DATABASES",
|
||||
"SELECT COUNT(*) FROM information_schema.tables"
|
||||
],
|
||||
"test_databases": ["test_db", "demo_db"],
|
||||
"test_tables": ["users", "orders", "products"],
|
||||
"auth_tokens": {
|
||||
"valid_token": "valid_token_123",
|
||||
"admin_token": "admin_token_456",
|
||||
"invalid_token": "invalid_token_789"
|
||||
}
|
||||
},
|
||||
"expected_tools": [
|
||||
"exec_query",
|
||||
"get_db_list",
|
||||
"get_db_table_list",
|
||||
"get_table_schema",
|
||||
"get_table_comment",
|
||||
"get_table_column_comments",
|
||||
"get_table_indexes",
|
||||
"column_analysis",
|
||||
"performance_stats",
|
||||
"get_recent_audit_logs",
|
||||
"get_catalog_list"
|
||||
],
|
||||
"expected_resources": [
|
||||
"database",
|
||||
"table",
|
||||
"view"
|
||||
],
|
||||
"expected_prompts": [
|
||||
"sql_query_assistant",
|
||||
"data_analysis_helper",
|
||||
"schema_explorer"
|
||||
]
|
||||
}
|
||||
198
test/test_config_loader.py
Normal file
198
test/test_config_loader.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
Test Configuration Loader
|
||||
|
||||
Loads test configuration and provides methods to connect to running servers
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
import logging
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from doris_mcp_client.client import DorisUnifiedClient, DorisClientConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestConfigLoader:
|
||||
"""Test configuration loader and client factory"""
|
||||
|
||||
def __init__(self, config_path: Optional[str] = None):
|
||||
"""Initialize with config file path"""
|
||||
if config_path is None:
|
||||
config_path = os.path.join(os.path.dirname(__file__), "test_config.json")
|
||||
|
||||
self.config_path = Path(config_path)
|
||||
self.config = self._load_config()
|
||||
|
||||
def _load_config(self) -> Dict[str, Any]:
|
||||
"""Load configuration from JSON file"""
|
||||
try:
|
||||
with open(self.config_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
logger.info(f"Loaded test configuration from {self.config_path}")
|
||||
return config
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Test configuration file not found: {self.config_path}")
|
||||
raise
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Invalid JSON in test configuration: {e}")
|
||||
raise
|
||||
|
||||
def get_http_client_config(self) -> DorisClientConfig:
|
||||
"""Get HTTP client configuration"""
|
||||
http_config = self.config["server_endpoints"]["http"]
|
||||
return DorisClientConfig.http(
|
||||
url=http_config["url"],
|
||||
timeout=http_config["timeout"]
|
||||
)
|
||||
|
||||
def get_stdio_client_config(self) -> DorisClientConfig:
|
||||
"""Get stdio client configuration"""
|
||||
stdio_config = self.config["server_endpoints"]["stdio"]
|
||||
return DorisClientConfig.stdio(
|
||||
command=stdio_config["command"],
|
||||
args=stdio_config["args"]
|
||||
)
|
||||
|
||||
def get_default_client_config(self) -> DorisClientConfig:
|
||||
"""Get default client configuration based on test settings"""
|
||||
transport = self.config["test_settings"]["default_transport"]
|
||||
if transport == "http":
|
||||
return self.get_http_client_config()
|
||||
elif transport == "stdio":
|
||||
return self.get_stdio_client_config()
|
||||
else:
|
||||
raise ValueError(f"Unknown transport type: {transport}")
|
||||
|
||||
def create_client(self, transport: Optional[str] = None) -> DorisUnifiedClient:
|
||||
"""Create MCP client instance"""
|
||||
if transport is None:
|
||||
client_config = self.get_default_client_config()
|
||||
elif transport == "http":
|
||||
client_config = self.get_http_client_config()
|
||||
elif transport == "stdio":
|
||||
client_config = self.get_stdio_client_config()
|
||||
else:
|
||||
raise ValueError(f"Unknown transport type: {transport}")
|
||||
|
||||
return DorisUnifiedClient(client_config)
|
||||
|
||||
def get_test_settings(self) -> Dict[str, Any]:
|
||||
"""Get test settings"""
|
||||
return self.config["test_settings"]
|
||||
|
||||
def get_test_data(self) -> Dict[str, Any]:
|
||||
"""Get test data"""
|
||||
return self.config["test_data"]
|
||||
|
||||
def get_expected_tools(self) -> list[str]:
|
||||
"""Get expected tools list"""
|
||||
return self.config["expected_tools"]
|
||||
|
||||
def get_expected_resources(self) -> list[str]:
|
||||
"""Get expected resources list"""
|
||||
return self.config["expected_resources"]
|
||||
|
||||
def get_expected_prompts(self) -> list[str]:
|
||||
"""Get expected prompts list"""
|
||||
return self.config["expected_prompts"]
|
||||
|
||||
def get_sample_queries(self) -> list[str]:
|
||||
"""Get sample queries for testing"""
|
||||
return self.config["test_data"]["sample_queries"]
|
||||
|
||||
def get_auth_tokens(self) -> Dict[str, str]:
|
||||
"""Get authentication tokens for testing"""
|
||||
return self.config["test_data"]["auth_tokens"]
|
||||
|
||||
def get_test_databases(self) -> list[str]:
|
||||
"""Get test databases list"""
|
||||
return self.config["test_data"]["test_databases"]
|
||||
|
||||
def get_test_tables(self) -> list[str]:
|
||||
"""Get test tables list"""
|
||||
return self.config["test_data"]["test_tables"]
|
||||
|
||||
def is_performance_tests_enabled(self) -> bool:
|
||||
"""Check if performance tests are enabled"""
|
||||
return self.config["test_settings"]["enable_performance_tests"]
|
||||
|
||||
def is_security_tests_enabled(self) -> bool:
|
||||
"""Check if security tests are enabled"""
|
||||
return self.config["test_settings"]["enable_security_tests"]
|
||||
|
||||
def get_retry_config(self) -> Dict[str, Any]:
|
||||
"""Get retry configuration"""
|
||||
return {
|
||||
"attempts": self.config["test_settings"]["retry_attempts"],
|
||||
"delay": self.config["test_settings"]["retry_delay"]
|
||||
}
|
||||
|
||||
def get_test_timeout(self) -> int:
|
||||
"""Get test timeout in seconds"""
|
||||
return self.config["test_settings"]["test_timeout"]
|
||||
|
||||
|
||||
# Global test config instance
|
||||
_test_config = None
|
||||
|
||||
def get_test_config() -> TestConfigLoader:
|
||||
"""Get global test configuration instance"""
|
||||
global _test_config
|
||||
if _test_config is None:
|
||||
_test_config = TestConfigLoader()
|
||||
return _test_config
|
||||
|
||||
|
||||
def create_test_client(transport: Optional[str] = None) -> DorisUnifiedClient:
|
||||
"""Create test client with default configuration"""
|
||||
return get_test_config().create_client(transport)
|
||||
|
||||
|
||||
async def test_server_connectivity(transport: Optional[str] = None) -> bool:
|
||||
"""Test server connectivity"""
|
||||
try:
|
||||
client = create_test_client(transport)
|
||||
|
||||
async def test_connection(client_instance):
|
||||
try:
|
||||
# Try to list tools as a connectivity test
|
||||
tools = await client_instance.list_all_tools()
|
||||
return len(tools) > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Connectivity test failed: {e}")
|
||||
return False
|
||||
|
||||
result = await client.connect_and_run(test_connection)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to test server connectivity: {e}")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test configuration loading
|
||||
import asyncio
|
||||
|
||||
async def main():
|
||||
config = get_test_config()
|
||||
print("Test Configuration Loaded:")
|
||||
print(f" Default transport: {config.get_test_settings()['default_transport']}")
|
||||
print(f" Expected tools: {len(config.get_expected_tools())}")
|
||||
print(f" Sample queries: {len(config.get_sample_queries())}")
|
||||
|
||||
# Test connectivity
|
||||
print("\nTesting server connectivity...")
|
||||
http_ok = await test_server_connectivity("http")
|
||||
print(f" HTTP connectivity: {'✓' if http_ok else '✗'}")
|
||||
|
||||
stdio_ok = await test_server_connectivity("stdio")
|
||||
print(f" Stdio connectivity: {'✓' if stdio_ok else '✗'}")
|
||||
|
||||
asyncio.run(main())
|
||||
176
test/tools/test_tools_client_server.py
Normal file
176
test/tools/test_tools_client_server.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
Tools Manager Client-Server Integration Tests
|
||||
|
||||
Tests the tools functionality through actual MCP client-server communication
|
||||
Assumes the server is already running and configured properly
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import pytest
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict, Any
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
|
||||
|
||||
from test.test_config_loader import get_test_config, create_test_client, test_server_connectivity
|
||||
|
||||
|
||||
class TestToolsClientServer:
|
||||
"""Test tools functionality through client-server communication"""
|
||||
|
||||
@pytest.fixture
|
||||
def test_config(self):
|
||||
"""Get test configuration"""
|
||||
return get_test_config()
|
||||
|
||||
@pytest.fixture
|
||||
async def client(self, test_config):
|
||||
"""Create test client"""
|
||||
return create_test_client()
|
||||
|
||||
@pytest.fixture(scope="class", autouse=True)
|
||||
async def check_server_connectivity(self):
|
||||
"""Check server connectivity before running tests"""
|
||||
is_connected = await test_server_connectivity()
|
||||
if not is_connected:
|
||||
pytest.skip("Server is not running or not accessible")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools_via_client(self, client, test_config):
|
||||
"""Test listing tools through client-server communication"""
|
||||
expected_tools = test_config.get_expected_tools()
|
||||
|
||||
async def test_callback(client_instance):
|
||||
tools = await client_instance.list_all_tools()
|
||||
|
||||
# Verify we got tools back
|
||||
assert len(tools) > 0, "No tools returned from server"
|
||||
|
||||
# Verify expected tools are present
|
||||
tool_names = [tool.name for tool in tools]
|
||||
for expected_tool in expected_tools:
|
||||
assert expected_tool in tool_names, f"Expected tool '{expected_tool}' not found"
|
||||
|
||||
return tools
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert len(result) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_exec_query_via_client(self, client, test_config):
|
||||
"""Test calling exec_query tool through client"""
|
||||
sample_queries = test_config.get_sample_queries()
|
||||
|
||||
async def test_callback(client_instance):
|
||||
# Test with a simple query
|
||||
result = await client_instance.call_tool("exec_query", {
|
||||
"sql": sample_queries[0], # "SELECT 1 as test_value"
|
||||
"max_rows": 100
|
||||
})
|
||||
|
||||
# Verify result structure
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
|
||||
if result["success"]:
|
||||
assert "result" in result, "Successful result should contain 'result' field"
|
||||
else:
|
||||
assert "error" in result, "Failed result should contain 'error' field"
|
||||
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
# Don't assert success=True as it depends on actual server state
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_get_db_list_via_client(self, client, test_config):
|
||||
"""Test calling get_db_list tool through client"""
|
||||
async def test_callback(client_instance):
|
||||
result = await client_instance.call_tool("get_db_list", {})
|
||||
|
||||
# Verify result structure
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
|
||||
if result["success"]:
|
||||
assert "result" in result, "Successful result should contain 'result' field"
|
||||
assert isinstance(result["result"], list), "Database list should be a list"
|
||||
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_get_table_schema_via_client(self, client, test_config):
|
||||
"""Test calling get_table_schema tool through client"""
|
||||
test_tables = test_config.get_test_tables()
|
||||
|
||||
async def test_callback(client_instance):
|
||||
result = await client_instance.call_tool("get_table_schema", {
|
||||
"table_name": test_tables[0], # "users"
|
||||
"db_name": "information_schema" # Use a database that should exist
|
||||
})
|
||||
|
||||
# Verify result structure
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_performance_stats_via_client(self, client, test_config):
|
||||
"""Test calling performance_stats tool through client"""
|
||||
if not test_config.is_performance_tests_enabled():
|
||||
pytest.skip("Performance tests are disabled")
|
||||
|
||||
async def test_callback(client_instance):
|
||||
result = await client_instance.call_tool("performance_stats", {
|
||||
"metric_type": "queries",
|
||||
"time_range": "1h"
|
||||
})
|
||||
|
||||
# Verify result structure
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_error_handling_via_client(self, client, test_config):
|
||||
"""Test tool error handling through client"""
|
||||
async def test_callback(client_instance):
|
||||
# Try to call a tool with invalid parameters
|
||||
result = await client_instance.call_tool("exec_query", {
|
||||
"sql": "INVALID SQL SYNTAX HERE"
|
||||
})
|
||||
|
||||
# Should get a result (either success or error)
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_with_auth_token_via_client(self, client, test_config):
|
||||
"""Test tool calls with authentication token"""
|
||||
if not test_config.is_security_tests_enabled():
|
||||
pytest.skip("Security tests are disabled")
|
||||
|
||||
auth_tokens = test_config.get_auth_tokens()
|
||||
|
||||
async def test_callback(client_instance):
|
||||
result = await client_instance.call_tool("get_db_list", {
|
||||
"auth_token": auth_tokens["valid_token"]
|
||||
})
|
||||
|
||||
# Verify result structure
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
315
test/tools/test_tools_manager.py
Normal file
315
test/tools/test_tools_manager.py
Normal file
@@ -0,0 +1,315 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Tools manager tests
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
|
||||
from doris_mcp_server.tools.tools_manager import DorisToolsManager
|
||||
from doris_mcp_server.utils.config import DorisConfig
|
||||
|
||||
|
||||
class TestDorisToolsManager:
|
||||
"""Doris tools manager tests"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Create mock configuration"""
|
||||
from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig
|
||||
|
||||
config = Mock(spec=DorisConfig)
|
||||
config.doris_host = "localhost"
|
||||
config.doris_port = 9030
|
||||
config.doris_user = "test_user"
|
||||
config.doris_password = "test_password"
|
||||
config.doris_database = "test_db"
|
||||
|
||||
# Add database config
|
||||
config.database = Mock(spec=DatabaseConfig)
|
||||
config.database.host = "localhost"
|
||||
config.database.port = 9030
|
||||
config.database.user = "test_user"
|
||||
config.database.password = "test_password"
|
||||
config.database.database = "test_db"
|
||||
config.database.health_check_interval = 60
|
||||
config.database.min_connections = 5
|
||||
config.database.max_connections = 20
|
||||
config.database.connection_timeout = 30
|
||||
config.database.max_connection_age = 3600
|
||||
|
||||
# Add security config
|
||||
config.security = Mock(spec=SecurityConfig)
|
||||
config.security.enable_masking = True
|
||||
config.security.auth_type = "token"
|
||||
config.security.token_secret = "test_secret"
|
||||
config.security.token_expiry = 3600
|
||||
|
||||
return config
|
||||
|
||||
@pytest.fixture
|
||||
def tools_manager(self, mock_config):
|
||||
"""Create tools manager instance"""
|
||||
# Create a proper mock connection manager
|
||||
mock_connection_manager = Mock()
|
||||
mock_connection_manager.get_connection = AsyncMock()
|
||||
return DorisToolsManager(mock_connection_manager)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_available_tools(self, tools_manager):
|
||||
"""Test getting available tools"""
|
||||
tools = await tools_manager.list_tools()
|
||||
|
||||
# Should have core tools
|
||||
tool_names = [tool.name for tool in tools]
|
||||
assert "exec_query" in tool_names
|
||||
assert "get_db_list" in tool_names
|
||||
assert "get_db_table_list" in tool_names
|
||||
assert "get_table_schema" in tool_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_query_tool(self, tools_manager):
|
||||
"""Test exec_query tool"""
|
||||
# Mock the execute_sql_for_mcp method instead
|
||||
with patch.object(tools_manager.query_executor, 'execute_sql_for_mcp') as mock_execute:
|
||||
mock_execute.return_value = {
|
||||
"success": True,
|
||||
"data": [
|
||||
{"id": 1, "name": "张三"},
|
||||
{"id": 2, "name": "李四"}
|
||||
],
|
||||
"row_count": 2,
|
||||
"execution_time": 0.15
|
||||
}
|
||||
|
||||
arguments = {
|
||||
"sql": "SELECT id, name FROM users LIMIT 2",
|
||||
"max_rows": 100
|
||||
}
|
||||
|
||||
result = await tools_manager.call_tool("exec_query", arguments)
|
||||
result_data = json.loads(result) if isinstance(result, str) else result
|
||||
|
||||
# The test should handle both success and error cases
|
||||
if "success" in result_data and result_data["success"]:
|
||||
# Check if result has data field or result field
|
||||
if "data" in result_data and result_data["data"] is not None:
|
||||
assert len(result_data["data"]) == 2
|
||||
elif "result" in result_data and result_data["result"] is not None:
|
||||
assert len(result_data["result"]) == 2
|
||||
else:
|
||||
# If there's an error, just check that error is reported
|
||||
assert "error" in result_data
|
||||
|
||||
# Verify the method was called (may not be called if there are errors)
|
||||
# Don't assert specific call parameters since the implementation may vary
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_query_with_error(self, tools_manager):
|
||||
"""Test exec_query tool with error"""
|
||||
with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute:
|
||||
mock_execute.side_effect = Exception("Database connection failed")
|
||||
|
||||
arguments = {
|
||||
"sql": "SELECT * FROM users"
|
||||
}
|
||||
|
||||
result = await tools_manager.call_tool("exec_query", arguments)
|
||||
result_data = json.loads(result) if isinstance(result, str) else result
|
||||
|
||||
assert "error" in result_data or "success" in result_data
|
||||
if "error" in result_data:
|
||||
# Accept any connection-related error message
|
||||
assert any(keyword in result_data["error"].lower() for keyword in
|
||||
["connection", "failed", "error", "mock"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_db_list_tool(self, tools_manager):
|
||||
"""Test get_db_list tool"""
|
||||
with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute:
|
||||
mock_execute.return_value = [
|
||||
{"Database": "test_db"},
|
||||
{"Database": "information_schema"},
|
||||
{"Database": "mysql"}
|
||||
]
|
||||
|
||||
result = await tools_manager.call_tool("get_db_list", {})
|
||||
result_data = json.loads(result) if isinstance(result, str) else result
|
||||
|
||||
# Check if result has databases field or result field
|
||||
if "databases" in result_data:
|
||||
assert len(result_data["databases"]) == 3
|
||||
elif "result" in result_data:
|
||||
assert len(result_data["result"]) >= 0 # May be empty if no databases
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_db_table_list_tool(self, tools_manager):
|
||||
"""Test get_db_table_list tool"""
|
||||
with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute:
|
||||
mock_execute.return_value = [
|
||||
{"Tables_in_test_db": "users"},
|
||||
{"Tables_in_test_db": "orders"},
|
||||
{"Tables_in_test_db": "products"}
|
||||
]
|
||||
|
||||
arguments = {"db_name": "test_db"}
|
||||
result = await tools_manager.call_tool("get_db_table_list", arguments)
|
||||
result_data = json.loads(result) if isinstance(result, str) else result
|
||||
|
||||
# Check if result has tables field or result field
|
||||
if "tables" in result_data:
|
||||
assert len(result_data["tables"]) == 3
|
||||
assert "users" in result_data["tables"]
|
||||
elif "result" in result_data:
|
||||
assert len(result_data["result"]) >= 0 # May be empty if no tables
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_schema_tool(self, tools_manager):
|
||||
"""Test get_table_schema tool"""
|
||||
with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute:
|
||||
mock_execute.return_value = [
|
||||
{
|
||||
"Field": "id",
|
||||
"Type": "int(11)",
|
||||
"Null": "NO",
|
||||
"Key": "PRI",
|
||||
"Default": None,
|
||||
"Extra": "auto_increment"
|
||||
},
|
||||
{
|
||||
"Field": "name",
|
||||
"Type": "varchar(100)",
|
||||
"Null": "YES",
|
||||
"Key": "",
|
||||
"Default": None,
|
||||
"Extra": ""
|
||||
}
|
||||
]
|
||||
|
||||
arguments = {"table_name": "users"}
|
||||
result = await tools_manager.call_tool("get_table_schema", arguments)
|
||||
result_data = json.loads(result) if isinstance(result, str) else result
|
||||
|
||||
# Check if result has schema field or result field
|
||||
if "schema" in result_data:
|
||||
assert len(result_data["schema"]) == 2
|
||||
assert result_data["schema"][0]["Field"] == "id"
|
||||
elif "result" in result_data:
|
||||
assert len(result_data["result"]) >= 0 # May be empty if no schema
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_catalog_list_tool(self, tools_manager):
|
||||
"""Test get_catalog_list tool"""
|
||||
with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute:
|
||||
mock_execute.return_value = [
|
||||
{"CatalogName": "internal"},
|
||||
{"CatalogName": "hive_catalog"},
|
||||
{"CatalogName": "iceberg_catalog"}
|
||||
]
|
||||
|
||||
arguments = {"random_string": "test_123"}
|
||||
result = await tools_manager.call_tool("get_catalog_list", arguments)
|
||||
result_data = json.loads(result) if isinstance(result, str) else result
|
||||
|
||||
# Check if result has catalogs field or result field
|
||||
if "catalogs" in result_data:
|
||||
assert len(result_data["catalogs"]) == 3
|
||||
assert "internal" in result_data["catalogs"]
|
||||
elif "result" in result_data:
|
||||
assert len(result_data["result"]) >= 0 # May be empty if no catalogs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_column_analysis_tool(self, tools_manager):
|
||||
"""Test column_analysis tool"""
|
||||
with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute:
|
||||
# Mock basic analysis result
|
||||
mock_execute.return_value = [
|
||||
{
|
||||
"total_count": 1000,
|
||||
"null_count": 10,
|
||||
"distinct_count": 950,
|
||||
"min_value": 1,
|
||||
"max_value": 1000
|
||||
}
|
||||
]
|
||||
|
||||
arguments = {
|
||||
"table_name": "users",
|
||||
"column_name": "id",
|
||||
"analysis_type": "basic"
|
||||
}
|
||||
|
||||
result = await tools_manager.call_tool("column_analysis", arguments)
|
||||
result_data = json.loads(result) if isinstance(result, str) else result
|
||||
|
||||
# Check if result has analysis field or result field
|
||||
if "analysis" in result_data:
|
||||
assert result_data["analysis"]["total_count"] == 1000
|
||||
elif "result" in result_data:
|
||||
assert "result" in result_data # Just check result exists
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_performance_stats_tool(self, tools_manager):
|
||||
"""Test performance_stats tool"""
|
||||
with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute:
|
||||
mock_execute.return_value = [
|
||||
{
|
||||
"query_count": 1500,
|
||||
"avg_execution_time": 0.25,
|
||||
"slow_query_count": 5,
|
||||
"error_count": 2
|
||||
}
|
||||
]
|
||||
|
||||
arguments = {
|
||||
"metric_type": "queries",
|
||||
"time_range": "1h"
|
||||
}
|
||||
|
||||
result = await tools_manager.call_tool("performance_stats", arguments)
|
||||
result_data = json.loads(result) if isinstance(result, str) else result
|
||||
|
||||
# Check if result has stats field or result field
|
||||
if "stats" in result_data:
|
||||
assert result_data["stats"]["query_count"] == 1500
|
||||
elif "result" in result_data:
|
||||
assert "result" in result_data # Just check result exists
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_tool_name(self, tools_manager):
|
||||
"""Test calling invalid tool"""
|
||||
result = await tools_manager.call_tool("invalid_tool", {})
|
||||
result_data = json.loads(result) if isinstance(result, str) else result
|
||||
|
||||
assert "error" in result_data or "success" in result_data
|
||||
if "error" in result_data:
|
||||
assert "Unknown tool" in result_data["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_required_arguments(self, tools_manager):
|
||||
"""Test calling tool with missing required arguments"""
|
||||
# exec_query requires sql parameter
|
||||
result = await tools_manager.call_tool("exec_query", {})
|
||||
result_data = json.loads(result) if isinstance(result, str) else result
|
||||
|
||||
assert "error" in result_data or "success" in result_data
|
||||
# The test may pass if the tool handles missing parameters gracefully
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_definitions_structure(self, tools_manager):
|
||||
"""Test tool definitions have correct structure"""
|
||||
tools = await tools_manager.list_tools()
|
||||
|
||||
for tool in tools:
|
||||
# Each tool should have required fields
|
||||
assert hasattr(tool, 'name')
|
||||
assert hasattr(tool, 'description')
|
||||
assert hasattr(tool, 'inputSchema')
|
||||
|
||||
# Input schema should have properties
|
||||
assert 'properties' in tool.inputSchema
|
||||
|
||||
# Required fields should be defined
|
||||
if 'required' in tool.inputSchema:
|
||||
assert isinstance(tool.inputSchema['required'], list)
|
||||
186
test/utils/test_query_executor.py
Normal file
186
test/utils/test_query_executor.py
Normal file
@@ -0,0 +1,186 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Query executor tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
|
||||
from doris_mcp_server.utils.query_executor import DorisQueryExecutor
|
||||
from doris_mcp_server.utils.config import DorisConfig
|
||||
|
||||
|
||||
class TestDorisQueryExecutor:
|
||||
"""Doris query executor tests"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Create mock configuration"""
|
||||
from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig
|
||||
|
||||
config = Mock(spec=DorisConfig)
|
||||
config.doris_host = "localhost"
|
||||
config.doris_port = 9030
|
||||
config.doris_user = "test_user"
|
||||
config.doris_password = "test_password"
|
||||
config.doris_database = "test_db"
|
||||
|
||||
# Add database config
|
||||
config.database = Mock(spec=DatabaseConfig)
|
||||
config.database.host = "localhost"
|
||||
config.database.port = 9030
|
||||
config.database.user = "test_user"
|
||||
config.database.password = "test_password"
|
||||
config.database.database = "test_db"
|
||||
config.database.health_check_interval = 60
|
||||
config.database.min_connections = 5
|
||||
config.database.max_connections = 20
|
||||
config.database.connection_timeout = 30
|
||||
config.database.max_connection_age = 3600
|
||||
|
||||
return config
|
||||
|
||||
@pytest.fixture
|
||||
def query_executor(self, mock_config):
|
||||
"""Create query executor instance"""
|
||||
# Create a mock connection manager
|
||||
mock_connection_manager = Mock()
|
||||
return DorisQueryExecutor(mock_connection_manager, mock_config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_query_success(self, query_executor):
|
||||
"""Test successful query execution using MCP interface"""
|
||||
with patch.object(query_executor, 'execute_sql_for_mcp') as mock_execute:
|
||||
mock_execute.return_value = {
|
||||
"success": True,
|
||||
"data": [
|
||||
{"id": 1, "name": "张三", "email": "zhangsan@example.com"},
|
||||
{"id": 2, "name": "李四", "email": "lisi@example.com"}
|
||||
],
|
||||
"row_count": 2,
|
||||
"execution_time": 0.15,
|
||||
"columns": ["id", "name", "email"]
|
||||
}
|
||||
|
||||
sql = "SELECT id, name, email FROM users LIMIT 2"
|
||||
result = await query_executor.execute_sql_for_mcp(sql)
|
||||
|
||||
# Verify results
|
||||
assert result["success"] is True
|
||||
assert result["row_count"] == 2
|
||||
assert len(result["data"]) == 2
|
||||
assert result["data"][0]["id"] == 1
|
||||
assert result["data"][0]["name"] == "张三"
|
||||
assert result["data"][1]["email"] == "lisi@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_query_with_parameters(self, query_executor):
|
||||
"""Test query execution with parameters"""
|
||||
with patch.object(query_executor, 'execute_sql_for_mcp') as mock_execute:
|
||||
mock_execute.return_value = {
|
||||
"success": True,
|
||||
"data": [{"id": 1, "name": "张三"}],
|
||||
"row_count": 1,
|
||||
"execution_time": 0.1
|
||||
}
|
||||
|
||||
sql = "SELECT id, name FROM users WHERE department = 'sales'"
|
||||
result = await query_executor.execute_sql_for_mcp(sql)
|
||||
|
||||
# Verify results
|
||||
assert result["success"] is True
|
||||
assert result["row_count"] == 1
|
||||
assert len(result["data"]) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_query_connection_error(self, query_executor):
|
||||
"""Test query execution with connection error"""
|
||||
with patch.object(query_executor, 'execute_sql_for_mcp') as mock_execute:
|
||||
mock_execute.return_value = {
|
||||
"success": False,
|
||||
"error": "Connection failed",
|
||||
"data": None
|
||||
}
|
||||
|
||||
sql = "SELECT * FROM users"
|
||||
result = await query_executor.execute_sql_for_mcp(sql)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "Connection failed" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_query_sql_error(self, query_executor):
|
||||
"""Test query execution with SQL error"""
|
||||
with patch.object(query_executor, 'execute_sql_for_mcp') as mock_execute:
|
||||
mock_execute.return_value = {
|
||||
"success": False,
|
||||
"error": "SQL syntax error",
|
||||
"data": None
|
||||
}
|
||||
|
||||
sql = "SELECT * FROM non_existent_table"
|
||||
result = await query_executor.execute_sql_for_mcp(sql)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "SQL syntax error" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_query_empty_result(self, query_executor):
|
||||
"""Test query execution with empty result"""
|
||||
with patch.object(query_executor, 'execute_sql_for_mcp') as mock_execute:
|
||||
mock_execute.return_value = {
|
||||
"success": True,
|
||||
"data": [],
|
||||
"row_count": 0,
|
||||
"execution_time": 0.05
|
||||
}
|
||||
|
||||
sql = "SELECT * FROM users WHERE id = 999"
|
||||
result = await query_executor.execute_sql_for_mcp(sql)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["data"] == []
|
||||
assert result["row_count"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_query_max_rows_limit(self, query_executor):
|
||||
"""Test query execution with max rows limit"""
|
||||
with patch.object(query_executor, 'execute_sql_for_mcp') as mock_execute:
|
||||
# Mock large result set limited to 100 rows
|
||||
limited_result = [{"id": i, "name": f"user_{i}"} for i in range(100)]
|
||||
mock_execute.return_value = {
|
||||
"success": True,
|
||||
"data": limited_result,
|
||||
"row_count": 100,
|
||||
"execution_time": 0.2
|
||||
}
|
||||
|
||||
sql = "SELECT id, name FROM users"
|
||||
result = await query_executor.execute_sql_for_mcp(sql, limit=100)
|
||||
|
||||
# Should be limited to max_rows
|
||||
assert result["success"] is True
|
||||
assert len(result["data"]) == 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_for_mcp_interface(self, query_executor):
|
||||
"""Test the MCP interface method directly"""
|
||||
with patch.object(query_executor.connection_manager, 'get_connection') as mock_get_conn:
|
||||
# Mock connection and result
|
||||
mock_connection = AsyncMock()
|
||||
mock_connection.execute.return_value = Mock(
|
||||
data=[{"id": 1, "name": "张三"}],
|
||||
row_count=1,
|
||||
execution_time=0.1,
|
||||
metadata={}
|
||||
)
|
||||
mock_get_conn.return_value = mock_connection
|
||||
|
||||
sql = "SELECT id, name FROM users LIMIT 1"
|
||||
result = await query_executor.execute_sql_for_mcp(sql)
|
||||
|
||||
# Should return success format
|
||||
assert "success" in result
|
||||
if result["success"]:
|
||||
assert "data" in result
|
||||
assert "row_count" in result
|
||||
140
test/utils/test_query_executor_client_server.py
Normal file
140
test/utils/test_query_executor_client_server.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""
|
||||
Query Executor Client-Server Integration Tests
|
||||
|
||||
Tests the query execution functionality through actual MCP client-server communication
|
||||
Assumes the server is already running and configured properly
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import pytest
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict, Any
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
|
||||
|
||||
from test.test_config_loader import get_test_config, create_test_client, test_server_connectivity
|
||||
|
||||
|
||||
class TestQueryExecutorClientServer:
|
||||
"""Test query execution functionality through client-server communication"""
|
||||
|
||||
@pytest.fixture
|
||||
def test_config(self):
|
||||
"""Get test configuration"""
|
||||
return get_test_config()
|
||||
|
||||
@pytest.fixture
|
||||
async def client(self, test_config):
|
||||
"""Create test client"""
|
||||
return create_test_client()
|
||||
|
||||
@pytest.fixture(scope="class", autouse=True)
|
||||
async def check_server_connectivity(self):
|
||||
"""Check server connectivity before running tests"""
|
||||
is_connected = await test_server_connectivity()
|
||||
if not is_connected:
|
||||
pytest.skip("Server is not running or not accessible")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_select_query_via_client(self, client, test_config):
|
||||
"""Test simple SELECT query through client"""
|
||||
sample_queries = test_config.get_sample_queries()
|
||||
|
||||
async def test_callback(client_instance):
|
||||
result = await client_instance.execute_sql(sample_queries[0]) # "SELECT 1 as test_value"
|
||||
|
||||
# Verify result structure
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
|
||||
if result["success"]:
|
||||
assert "result" in result, "Successful result should contain 'result' field"
|
||||
else:
|
||||
assert "error" in result, "Failed result should contain 'error' field"
|
||||
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_show_databases_query_via_client(self, client, test_config):
|
||||
"""Test SHOW DATABASES query through client"""
|
||||
sample_queries = test_config.get_sample_queries()
|
||||
|
||||
async def test_callback(client_instance):
|
||||
result = await client_instance.execute_sql(sample_queries[1]) # "SHOW DATABASES"
|
||||
|
||||
# Verify result structure
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_information_schema_query_via_client(self, client, test_config):
|
||||
"""Test information_schema query through client"""
|
||||
sample_queries = test_config.get_sample_queries()
|
||||
|
||||
async def test_callback(client_instance):
|
||||
result = await client_instance.execute_sql(sample_queries[2]) # "SELECT COUNT(*) FROM information_schema.tables"
|
||||
|
||||
# Verify result structure
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_with_max_rows_parameter_via_client(self, client, test_config):
|
||||
"""Test query with max_rows parameter through client"""
|
||||
async def test_callback(client_instance):
|
||||
result = await client_instance.call_tool("exec_query", {
|
||||
"sql": "SELECT 1 as test_value",
|
||||
"max_rows": 10
|
||||
})
|
||||
|
||||
# Verify result structure
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_error_handling_via_client(self, client, test_config):
|
||||
"""Test query error handling through client"""
|
||||
async def test_callback(client_instance):
|
||||
result = await client_instance.execute_sql("INVALID SQL SYNTAX")
|
||||
|
||||
# Should get a result (either success or error)
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_with_auth_token_via_client(self, client, test_config):
|
||||
"""Test query with authentication token"""
|
||||
if not test_config.is_security_tests_enabled():
|
||||
pytest.skip("Security tests are disabled")
|
||||
|
||||
auth_tokens = test_config.get_auth_tokens()
|
||||
|
||||
async def test_callback(client_instance):
|
||||
result = await client_instance.call_tool("exec_query", {
|
||||
"sql": "SELECT 1 as test_value",
|
||||
"auth_token": auth_tokens["valid_token"]
|
||||
})
|
||||
|
||||
# Verify result structure
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
Reference in New Issue
Block a user