Compare commits
2 Commits
infrastruc
...
0.3.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
82d2ef6f40 | ||
|
|
44dc2ddc05 |
65
Dockerfile
65
Dockerfile
@@ -1,65 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# Use Python 3.12 as base image
|
||||
FROM python:3.12-slim
|
||||
|
||||
# Set working directory
|
||||
WORKDIR /app
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONPATH=/app
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
curl \
|
||||
gcc \
|
||||
g++ \
|
||||
pkg-config \
|
||||
default-libmysqlclient-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements file
|
||||
COPY requirements.txt .
|
||||
|
||||
# Install Python dependencies
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# Copy application code
|
||||
COPY . .
|
||||
|
||||
# Create necessary directories
|
||||
RUN mkdir -p /app/logs /app/config /app/data
|
||||
|
||||
# Set permissions
|
||||
RUN chmod +x /app/start.sh
|
||||
|
||||
# Create non-root user
|
||||
RUN groupadd -r doris && useradd -r -g doris doris
|
||||
RUN chown -R doris:doris /app
|
||||
USER doris
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
|
||||
CMD curl -f http://localhost:3000/health || exit 1
|
||||
|
||||
# Expose ports
|
||||
EXPOSE 3000 3001 3002
|
||||
|
||||
# Start command
|
||||
CMD ["/app/start.sh"]
|
||||
135
Makefile
135
Makefile
@@ -1,135 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# Doris MCP Server Makefile
|
||||
# Provides convenient commands using UV
|
||||
|
||||
.PHONY: help install sync dev test lint format build clean check start-stdio start-sse
|
||||
|
||||
# Default target
|
||||
help:
|
||||
@echo "Available commands:"
|
||||
@echo " install - Install dependencies using UV"
|
||||
@echo " sync - Sync dependencies and create virtual environment"
|
||||
@echo " dev - Install development dependencies"
|
||||
@echo " test - Run tests"
|
||||
@echo " lint - Run linting tools"
|
||||
@echo " format - Format code with black and isort"
|
||||
@echo " build - Build the package"
|
||||
@echo " clean - Clean build artifacts"
|
||||
@echo " check - Run all checks (format, lint, test)"
|
||||
@echo " start-stdio - Start server in stdio mode"
|
||||
@echo " start-sse - Start server in SSE mode"
|
||||
|
||||
# Install dependencies
|
||||
install:
|
||||
uv sync
|
||||
|
||||
# Sync dependencies with development extras
|
||||
sync:
|
||||
uv sync
|
||||
|
||||
# Install development dependencies
|
||||
dev:
|
||||
uv sync --dev
|
||||
|
||||
# Run tests
|
||||
test:
|
||||
uv run pytest
|
||||
|
||||
# Run linting tools
|
||||
lint:
|
||||
uv run ruff check doris_mcp_server/
|
||||
uv run mypy doris_mcp_server/
|
||||
|
||||
# Format code
|
||||
format:
|
||||
uv run ruff format doris_mcp_server/
|
||||
uv run ruff check --fix doris_mcp_server/
|
||||
|
||||
# Build the package
|
||||
build:
|
||||
uv build
|
||||
|
||||
# Clean build artifacts
|
||||
clean:
|
||||
rm -rf build/
|
||||
rm -rf dist/
|
||||
rm -rf *.egg-info/
|
||||
find . -type d -name __pycache__ -exec rm -rf {} +
|
||||
find . -type d -name .pytest_cache -exec rm -rf {} +
|
||||
find . -type d -name .mypy_cache -exec rm -rf {} +
|
||||
|
||||
# Run all checks
|
||||
check: format lint test
|
||||
|
||||
# Start server in stdio mode
|
||||
start-stdio:
|
||||
uv run python -m doris_mcp_server.main --transport stdio
|
||||
|
||||
# Start server in SSE mode
|
||||
start-sse:
|
||||
uv run python -m doris_mcp_server.main --transport sse --host 0.0.0.0 --port 8080
|
||||
|
||||
# Start server with custom database settings
|
||||
start-dev:
|
||||
uv run python -m doris_mcp_server.main \
|
||||
--transport stdio \
|
||||
--db-host localhost \
|
||||
--db-port 9030 \
|
||||
--db-user root \
|
||||
--log-level DEBUG
|
||||
|
||||
# Run a single test file
|
||||
test-file:
|
||||
uv run pytest $(FILE) -v
|
||||
|
||||
# Install and run in one command
|
||||
run: install start-stdio
|
||||
|
||||
# Development setup
|
||||
setup: dev
|
||||
@echo "✅ Development environment is ready!"
|
||||
@echo "Run 'make start-stdio' to start the server"
|
||||
|
||||
# Add dependencies
|
||||
add:
|
||||
uv add $(PACKAGE)
|
||||
|
||||
# Add development dependencies
|
||||
add-dev:
|
||||
uv add --dev $(PACKAGE)
|
||||
|
||||
# Show dependency tree
|
||||
deps:
|
||||
uv tree
|
||||
|
||||
# Lock dependencies
|
||||
lock:
|
||||
uv lock
|
||||
|
||||
# Check for outdated dependencies
|
||||
outdated:
|
||||
uv tree --outdated
|
||||
|
||||
# Export requirements.txt
|
||||
export-requirements:
|
||||
uv export --no-hashes > requirements.txt
|
||||
|
||||
# Show UV version and info
|
||||
info:
|
||||
uv --version
|
||||
uv python list
|
||||
146
README.md
146
README.md
@@ -58,7 +58,82 @@ Doris MCP (Model Context Protocol) Server is a backend service built with Python
|
||||
* Python 3.12+
|
||||
* Database connection details (e.g., Doris Host, Port, User, Password, Database)
|
||||
|
||||
## Quick Start
|
||||
## 🚀 Quick Start
|
||||
|
||||
### Installation from PyPI
|
||||
|
||||
```bash
|
||||
# Install the latest version
|
||||
pip install mcp-doris-server
|
||||
|
||||
# Install specific version
|
||||
pip install mcp-doris-server==0.3
|
||||
```
|
||||
|
||||
> **💡 Command Compatibility**: After installation, both `doris-mcp-server` and `mcp-doris-server` commands are available for backward compatibility. You can use either command interchangeably.
|
||||
|
||||
### Start Streamable HTTP Mode (Web Service)
|
||||
|
||||
```bash
|
||||
# Full configuration with database connection
|
||||
doris-mcp-server \
|
||||
--transport http \
|
||||
--host 0.0.0.0 \
|
||||
--port 3000 \
|
||||
--db-host 127.0.0.1 \
|
||||
--db-port 9030 \
|
||||
--db-user root \
|
||||
--db-password your_password
|
||||
```
|
||||
|
||||
### Start Stdio Mode (for Cursor and other MCP clients)
|
||||
|
||||
```bash
|
||||
# For direct integration with MCP clients like Cursor
|
||||
doris-mcp-server --transport stdio
|
||||
```
|
||||
|
||||
### Verify Installation
|
||||
|
||||
```bash
|
||||
# Check installation
|
||||
doris-mcp-server --help
|
||||
|
||||
# Test HTTP mode (in another terminal)
|
||||
curl http://localhost:3000/health
|
||||
```
|
||||
|
||||
### Environment Variables (Optional)
|
||||
|
||||
Instead of command-line arguments, you can use environment variables:
|
||||
|
||||
```bash
|
||||
export DORIS_HOST="127.0.0.1"
|
||||
export DORIS_PORT="9030"
|
||||
export DORIS_USER="root"
|
||||
export DORIS_PASSWORD="your_password"
|
||||
|
||||
# Then start with simplified command
|
||||
doris-mcp-server --transport http --host 0.0.0.0 --port 3000
|
||||
```
|
||||
|
||||
### Command Line Arguments
|
||||
|
||||
The `doris-mcp-server` command supports the following arguments:
|
||||
|
||||
| Argument | Description | Default | Required |
|
||||
|:---------|:------------|:--------|:---------|
|
||||
| `--transport` | Transport mode: `http` or `stdio` | `http` | No |
|
||||
| `--host` | HTTP server host (HTTP mode only) | `0.0.0.0` | No |
|
||||
| `--port` | HTTP server port (HTTP mode only) | `3000` | No |
|
||||
| `--db-host` | Doris database host | `localhost` | No |
|
||||
| `--db-port` | Doris database port | `9030` | No |
|
||||
| `--db-user` | Doris database username | `root` | No |
|
||||
| `--db-password` | Doris database password | - | Yes (unless in env) |
|
||||
|
||||
## Development Setup
|
||||
|
||||
For developers who want to build from source:
|
||||
|
||||
### 1. Clone the Repository
|
||||
|
||||
@@ -481,9 +556,36 @@ You can connect Cursor to this MCP server using Stdio mode (recommended) or Stre
|
||||
|
||||
Stdio mode allows Cursor to manage the server process directly. Configuration is done within Cursor's MCP Server settings file (typically `~/.cursor/mcp.json` or similar).
|
||||
|
||||
### Using uv (Recommended)
|
||||
### Method 1: Using PyPI Installation (Recommended)
|
||||
|
||||
If you have `uv` installed, you can run the server directly:
|
||||
Install the package from PyPI and configure Cursor to use it:
|
||||
|
||||
```bash
|
||||
pip install mcp-doris-server
|
||||
```
|
||||
|
||||
**Configure Cursor:** Add an entry like the following to your Cursor MCP configuration:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"doris-stdio": {
|
||||
"command": "doris-mcp-server",
|
||||
"args": ["--transport", "stdio"],
|
||||
"env": {
|
||||
"DORIS_HOST": "127.0.0.1",
|
||||
"DORIS_PORT": "9030",
|
||||
"DORIS_USER": "root",
|
||||
"DORIS_PASSWORD": "your_db_password"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Method 2: Using uv (Development)
|
||||
|
||||
If you have `uv` installed and want to run from source:
|
||||
|
||||
```bash
|
||||
uv run --project /path/to/doris-mcp-server doris-mcp-server
|
||||
@@ -491,32 +593,24 @@ uv run --project /path/to/doris-mcp-server doris-mcp-server
|
||||
|
||||
**Note:** Replace `/path/to/doris-mcp-server` with the actual absolute path to your project directory.
|
||||
|
||||
1. **Configure Cursor:** Add an entry like the following to your Cursor MCP configuration:
|
||||
**Configure Cursor:** Add an entry like the following to your Cursor MCP configuration:
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"doris-stdio": {
|
||||
"command": "uv",
|
||||
"args": ["run", "--project", "/path/to/your/doris-mcp-server", "doris-mcp-server"],
|
||||
"env": {
|
||||
"DORIS_HOST": "127.0.0.1",
|
||||
"DORIS_PORT": "9030",
|
||||
"DORIS_USER": "root",
|
||||
"DORIS_PASSWORD": "your_db_password",
|
||||
"DORIS_DATABASE": "your_default_db",
|
||||
"LOG_LEVEL": "INFO"
|
||||
}
|
||||
}
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"doris-stdio": {
|
||||
"command": "uv",
|
||||
"args": ["run", "--project", "/path/to/your/doris-mcp-server", "doris-mcp-server"],
|
||||
"env": {
|
||||
"DORIS_HOST": "127.0.0.1",
|
||||
"DORIS_PORT": "9030",
|
||||
"DORIS_USER": "root",
|
||||
"DORIS_PASSWORD": "your_db_password"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
2. **Key Points:**
|
||||
* Replace `/path/to/your/doris-mcp-server` with the actual absolute path to the project's root directory on your system.
|
||||
* The `--project` argument is crucial for `uv` to find the `pyproject.toml` and run the correct command.
|
||||
* Database connection details are set directly in the `env` block. Cursor will pass these to the server process.
|
||||
* No `.env` file is needed for this mode when configured via Cursor.
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Streamable HTTP Mode (v0.3.0+)
|
||||
|
||||
|
||||
@@ -1,218 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
# Doris MCP Server
|
||||
doris-mcp-server:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
container_name: doris-mcp-server
|
||||
ports:
|
||||
- "3000:3000" # MCP service port
|
||||
- "3001:3001" # Monitoring metrics port
|
||||
- "3002:3002" # Health check port
|
||||
environment:
|
||||
# Database configuration
|
||||
- DORIS_HOST=doris-fe
|
||||
- DORIS_PORT=9030
|
||||
- DORIS_USER=root
|
||||
- DORIS_PASSWORD=doris123
|
||||
- DORIS_DATABASE=test_db
|
||||
|
||||
# Connection pool configuration
|
||||
- DORIS_MIN_CONNECTIONS=5
|
||||
- DORIS_MAX_CONNECTIONS=20
|
||||
|
||||
# Security configuration
|
||||
- AUTH_TYPE=token
|
||||
- TOKEN_SECRET=your_secret_key_here
|
||||
- MAX_RESULT_ROWS=10000
|
||||
|
||||
# Performance configuration
|
||||
- ENABLE_QUERY_CACHE=true
|
||||
- MAX_CONCURRENT_QUERIES=50
|
||||
|
||||
# Logging configuration
|
||||
- LOG_LEVEL=INFO
|
||||
- LOG_FILE_PATH=/app/logs/doris-mcp-server.log
|
||||
|
||||
# Monitoring configuration
|
||||
- ENABLE_METRICS=true
|
||||
- METRICS_PORT=8081
|
||||
volumes:
|
||||
- ./logs:/app/logs
|
||||
- ./config:/app/config
|
||||
depends_on:
|
||||
- doris-fe
|
||||
- doris-be
|
||||
networks:
|
||||
- doris-network
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8082/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 40s
|
||||
|
||||
# Apache Doris Frontend
|
||||
doris-fe:
|
||||
image: apache/doris:2.0.3-fe-x86_64
|
||||
container_name: doris-fe
|
||||
ports:
|
||||
- "8030:8030" # FE HTTP port
|
||||
- "9030:9030" # FE MySQL port
|
||||
environment:
|
||||
- FE_SERVERS=fe1:doris-fe:9010
|
||||
- FE_ID=1
|
||||
volumes:
|
||||
- doris-fe-data:/opt/apache-doris/fe/doris-meta
|
||||
- doris-fe-log:/opt/apache-doris/fe/log
|
||||
- ./doris-config/fe.conf:/opt/apache-doris/fe/conf/fe.conf
|
||||
networks:
|
||||
- doris-network
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8030/api/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 5
|
||||
start_period: 60s
|
||||
|
||||
# Apache Doris Backend
|
||||
doris-be:
|
||||
image: apache/doris:2.0.3-be-x86_64
|
||||
container_name: doris-be
|
||||
ports:
|
||||
- "8040:8040" # BE HTTP port
|
||||
- "9060:9060" # BE heartbeat port
|
||||
environment:
|
||||
- FE_SERVERS=doris-fe:9010
|
||||
- BE_ADDR=doris-be:9050
|
||||
volumes:
|
||||
- doris-be-data:/opt/apache-doris/be/storage
|
||||
- doris-be-log:/opt/apache-doris/be/log
|
||||
- ./doris-config/be.conf:/opt/apache-doris/be/conf/be.conf
|
||||
depends_on:
|
||||
- doris-fe
|
||||
networks:
|
||||
- doris-network
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8040/api/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 5
|
||||
start_period: 60s
|
||||
|
||||
# Redis cache (optional)
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
container_name: doris-redis
|
||||
ports:
|
||||
- "6379:6379"
|
||||
command: redis-server --appendonly yes --requirepass redis123
|
||||
volumes:
|
||||
- redis-data:/data
|
||||
networks:
|
||||
- doris-network
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "--raw", "incr", "ping"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
|
||||
# Prometheus monitoring
|
||||
prometheus:
|
||||
image: prom/prometheus:latest
|
||||
container_name: doris-prometheus
|
||||
ports:
|
||||
- "9090:9090"
|
||||
volumes:
|
||||
- ./monitoring/prometheus.yml:/etc/prometheus/prometheus.yml
|
||||
- prometheus-data:/prometheus
|
||||
command:
|
||||
- '--config.file=/etc/prometheus/prometheus.yml'
|
||||
- '--storage.tsdb.path=/prometheus'
|
||||
- '--web.console.libraries=/etc/prometheus/console_libraries'
|
||||
- '--web.console.templates=/etc/prometheus/consoles'
|
||||
- '--storage.tsdb.retention.time=200h'
|
||||
- '--web.enable-lifecycle'
|
||||
networks:
|
||||
- doris-network
|
||||
restart: unless-stopped
|
||||
|
||||
# Grafana visualization
|
||||
grafana:
|
||||
image: grafana/grafana:latest
|
||||
container_name: doris-grafana
|
||||
ports:
|
||||
- "3000:3000"
|
||||
environment:
|
||||
- GF_SECURITY_ADMIN_PASSWORD=admin123
|
||||
volumes:
|
||||
- grafana-data:/var/lib/grafana
|
||||
- ./monitoring/grafana/dashboards:/etc/grafana/provisioning/dashboards
|
||||
- ./monitoring/grafana/datasources:/etc/grafana/provisioning/datasources
|
||||
depends_on:
|
||||
- prometheus
|
||||
networks:
|
||||
- doris-network
|
||||
restart: unless-stopped
|
||||
|
||||
# Nginx load balancer
|
||||
nginx:
|
||||
image: nginx:alpine
|
||||
container_name: doris-nginx
|
||||
ports:
|
||||
- "80:80"
|
||||
- "443:443"
|
||||
volumes:
|
||||
- ./nginx/nginx.conf:/etc/nginx/nginx.conf
|
||||
- ./nginx/ssl:/etc/nginx/ssl
|
||||
- ./nginx/logs:/var/log/nginx
|
||||
depends_on:
|
||||
- doris-mcp-server
|
||||
networks:
|
||||
- doris-network
|
||||
restart: unless-stopped
|
||||
|
||||
volumes:
|
||||
doris-fe-data:
|
||||
driver: local
|
||||
doris-fe-log:
|
||||
driver: local
|
||||
doris-be-data:
|
||||
driver: local
|
||||
doris-be-log:
|
||||
driver: local
|
||||
redis-data:
|
||||
driver: local
|
||||
prometheus-data:
|
||||
driver: local
|
||||
grafana-data:
|
||||
driver: local
|
||||
|
||||
networks:
|
||||
doris-network:
|
||||
driver: bridge
|
||||
ipam:
|
||||
config:
|
||||
- subnet: 172.20.0.0/16
|
||||
@@ -1,340 +0,0 @@
|
||||
<!--
|
||||
Licensed to the Apache Software Foundation (ASF) under one
|
||||
or more contributor license agreements. See the NOTICE file
|
||||
distributed with this work for additional information
|
||||
regarding copyright ownership. The ASF licenses this file
|
||||
to you under the Apache License, Version 2.0 (the
|
||||
"License"); you may not use this file except in compliance
|
||||
with the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing,
|
||||
software distributed under the License is distributed on an
|
||||
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations
|
||||
under the License.
|
||||
-->
|
||||
# Doris Unified MCP Client
|
||||
|
||||
This is a unified Doris MCP client that supports both **stdio** and **Streamable HTTP** transport modes, providing complete MCP protocol support.
|
||||
|
||||
## 🚀 Features
|
||||
|
||||
- ✅ **Dual Mode Support**: Both stdio and HTTP transport methods
|
||||
- ✅ **Complete MCP Support**: Resources, Tools, and Prompts primitives
|
||||
- ✅ **Unified API**: Same interface for different transport modes
|
||||
- ✅ **Asynchronous Design**: High-performance async client based on asyncio
|
||||
- ✅ **Enterprise Features**: Connection pooling, error handling, logging
|
||||
- ✅ **Convenience Methods**: High-level wrappers for common database operations
|
||||
|
||||
## 📦 Install Dependencies
|
||||
|
||||
```bash
|
||||
pip install mcp
|
||||
```
|
||||
|
||||
## 🎯 Quick Start
|
||||
|
||||
### 1. stdio Mode
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from client import create_stdio_client
|
||||
|
||||
async def main():
|
||||
# Create stdio client
|
||||
client = await create_stdio_client(
|
||||
"python",
|
||||
["-m", "doris_mcp_server.main", "--transport", "stdio"]
|
||||
)
|
||||
|
||||
async def test_client(client):
|
||||
# Get database list
|
||||
db_result = await client.get_database_list()
|
||||
print(f"Databases: {db_result}")
|
||||
|
||||
# Execute SQL query
|
||||
query_result = await client.execute_sql("SELECT 1 as test")
|
||||
print(f"Query result: {query_result}")
|
||||
|
||||
await client.connect_and_run(test_client)
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
### 2. HTTP Mode
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from unified_client import create_http_client
|
||||
|
||||
async def main():
|
||||
# Create HTTP client
|
||||
client = await create_http_client("http://localhost:3000/mcp")
|
||||
|
||||
async def test_client(client):
|
||||
# Get all tools
|
||||
tools = await client.list_all_tools()
|
||||
print(f"Available tools: {len(tools)}")
|
||||
|
||||
# Execute query
|
||||
result = await client.execute_sql(
|
||||
"SELECT COUNT(*) FROM internal.ssb.lineorder LIMIT 1"
|
||||
)
|
||||
print(f"Query result: {result}")
|
||||
|
||||
await client.connect_and_run(test_client)
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
## 🔧 API Reference
|
||||
|
||||
### Client Creation
|
||||
|
||||
```python
|
||||
# stdio mode
|
||||
client = await create_stdio_client(command, args)
|
||||
|
||||
# HTTP mode
|
||||
client = await create_http_client(server_url, timeout=60)
|
||||
```
|
||||
|
||||
### Basic Operations
|
||||
|
||||
```python
|
||||
async def test_client(client):
|
||||
# Get server capabilities
|
||||
tools = await client.list_all_tools()
|
||||
resources = await client.list_all_resources()
|
||||
prompts = await client.list_all_prompts()
|
||||
|
||||
# Call tool
|
||||
result = await client.call_tool("tool_name", {"param": "value"})
|
||||
|
||||
# Read resource
|
||||
content = await client.read_resource("resource://uri")
|
||||
|
||||
# Get prompt
|
||||
prompt = await client.get_prompt("prompt_name", {"param": "value"})
|
||||
```
|
||||
|
||||
### Advanced Database Operations
|
||||
|
||||
```python
|
||||
async def database_operations(client):
|
||||
# Execute SQL query
|
||||
result = await client.execute_sql("SELECT * FROM table LIMIT 10")
|
||||
|
||||
# Get database list
|
||||
databases = await client.get_database_list()
|
||||
|
||||
# Get table schema
|
||||
schema = await client.get_table_schema("table_name", "db_name")
|
||||
|
||||
# Column data analysis
|
||||
analysis = await client.analyze_column("table", "column", "basic")
|
||||
```
|
||||
|
||||
## 🧪 Testing
|
||||
|
||||
### Run Test Suite
|
||||
|
||||
```bash
|
||||
# Interactive testing
|
||||
python test_unified_client.py
|
||||
|
||||
# Test stdio mode
|
||||
python test_unified_client.py stdio
|
||||
|
||||
# Test HTTP mode
|
||||
python test_unified_client.py http
|
||||
|
||||
# Test both modes
|
||||
python test_unified_client.py both
|
||||
|
||||
# Performance benchmark
|
||||
python test_unified_client.py benchmark
|
||||
```
|
||||
|
||||
### Test Output Example
|
||||
|
||||
```
|
||||
🎯 Doris Unified Client Test Suite
|
||||
============================================================
|
||||
|
||||
🚀 Testing HTTP Mode
|
||||
==================================================
|
||||
📋 Getting server capabilities...
|
||||
✅ Found 11 tools
|
||||
✅ Found 0 resources
|
||||
✅ Found 0 prompts
|
||||
|
||||
🔧 Available tools:
|
||||
1. get_db_list: Get database list
|
||||
2. get_table_list: Get table list for specified database
|
||||
3. get_table_schema: Get table structure information
|
||||
4. exec_query: Execute SQL query
|
||||
5. column_analysis: Analyze column data distribution and statistics
|
||||
...
|
||||
|
||||
🧪 Testing basic functionality...
|
||||
1️⃣ Getting database list...
|
||||
✅ Success: 3 databases
|
||||
2️⃣ Executing simple query...
|
||||
✅ Query successful
|
||||
3️⃣ Executing SSB data query...
|
||||
✅ SSB query successful
|
||||
4️⃣ Getting table structure...
|
||||
✅ Table structure retrieved successfully
|
||||
5️⃣ Column data analysis...
|
||||
✅ Column analysis successful
|
||||
|
||||
✅ HTTP mode testing completed!
|
||||
```
|
||||
|
||||
## 🏗️ Architecture Design
|
||||
|
||||
### Unified Client Architecture
|
||||
|
||||
```
|
||||
DorisUnifiedClient
|
||||
├── DorisResourceClient # Resource management
|
||||
├── DorisToolsClient # Tool invocation
|
||||
├── DorisPromptClient # Prompt management
|
||||
└── Transport Layer
|
||||
├── stdio mode # Standard input/output
|
||||
└── HTTP mode # Streamable HTTP
|
||||
```
|
||||
|
||||
### Key Features
|
||||
|
||||
1. **Unified Interface**: Same API for different transport modes
|
||||
2. **Async Context**: Proper resource management and connection cleanup
|
||||
3. **Error Handling**: Comprehensive exception handling and error recovery
|
||||
4. **Performance Optimization**: Connection reuse and request caching
|
||||
|
||||
## 📚 Usage Examples
|
||||
|
||||
### Complete Example
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from client import DorisUnifiedClient, DorisClientConfig
|
||||
|
||||
async def comprehensive_example():
|
||||
# Create configuration
|
||||
config = DorisClientConfig.stdio(
|
||||
"python",
|
||||
["-m", "doris_mcp_server.main"]
|
||||
)
|
||||
|
||||
client = DorisUnifiedClient(config)
|
||||
|
||||
async def demo_operations(client):
|
||||
print("🔍 Discovering server capabilities...")
|
||||
|
||||
# List all available tools
|
||||
tools = await client.list_all_tools()
|
||||
print(f"Available tools: {[tool.name for tool in tools]}")
|
||||
|
||||
# Get database list
|
||||
print("\n📊 Getting database information...")
|
||||
db_result = await client.get_database_list()
|
||||
print(f"Databases: {db_result}")
|
||||
|
||||
# Execute queries
|
||||
print("\n🔍 Executing queries...")
|
||||
|
||||
# Simple query
|
||||
result1 = await client.execute_sql("SELECT 1 as test_column")
|
||||
print(f"Simple query result: {result1}")
|
||||
|
||||
# Get table schema
|
||||
schema_result = await client.get_table_schema("lineorder", "ssb")
|
||||
print(f"Table schema: {schema_result}")
|
||||
|
||||
# Column analysis
|
||||
analysis_result = await client.analyze_column(
|
||||
"lineorder", "lo_orderkey", "basic"
|
||||
)
|
||||
print(f"Column analysis: {analysis_result}")
|
||||
|
||||
await client.connect_and_run(demo_operations)
|
||||
|
||||
# Run the example
|
||||
asyncio.run(comprehensive_example())
|
||||
```
|
||||
|
||||
### Error Handling
|
||||
|
||||
```python
|
||||
async def error_handling_example(client):
|
||||
try:
|
||||
# This might fail
|
||||
result = await client.execute_sql("INVALID SQL")
|
||||
except Exception as e:
|
||||
print(f"SQL execution failed: {e}")
|
||||
|
||||
# Check result status
|
||||
result = await client.get_database_list()
|
||||
if result.get("success", True):
|
||||
print("Operation successful")
|
||||
else:
|
||||
print(f"Operation failed: {result.get('error')}")
|
||||
```
|
||||
|
||||
## 🔧 Configuration
|
||||
|
||||
### Client Configuration Options
|
||||
|
||||
```python
|
||||
# stdio mode with custom arguments
|
||||
config = DorisClientConfig(
|
||||
transport="stdio",
|
||||
server_command="python",
|
||||
server_args=["-m", "doris_mcp_server.main", "--debug"],
|
||||
timeout=30
|
||||
)
|
||||
|
||||
# HTTP mode with custom timeout
|
||||
config = DorisClientConfig(
|
||||
transport="http",
|
||||
server_url="http://localhost:8080/mcp",
|
||||
timeout=60
|
||||
)
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
|
||||
```bash
|
||||
# Set default server URL
|
||||
export DORIS_MCP_SERVER_URL="http://localhost:8080"
|
||||
|
||||
# Set default timeout
|
||||
export DORIS_MCP_TIMEOUT=60
|
||||
|
||||
# Enable debug logging
|
||||
export DORIS_MCP_DEBUG=true
|
||||
```
|
||||
|
||||
## 🚀 Performance Tips
|
||||
|
||||
1. **Connection Reuse**: Use the same client instance for multiple operations
|
||||
2. **Batch Operations**: Group related queries together
|
||||
3. **Async Context**: Always use proper async context management
|
||||
4. **Error Recovery**: Implement retry logic for transient failures
|
||||
|
||||
## 🤝 Contributing
|
||||
|
||||
1. Fork the repository
|
||||
2. Create a feature branch
|
||||
3. Make your changes
|
||||
4. Add tests
|
||||
5. Submit a pull request
|
||||
|
||||
## 📄 License
|
||||
|
||||
This project is licensed under the Apache 2.0 License.
|
||||
@@ -1,25 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Doris MCP Client Package
|
||||
|
||||
Unified MCP client supporting both stdio and HTTP transport modes
|
||||
"""
|
||||
|
||||
from .client import DorisUnifiedClient, DorisClientConfig
|
||||
|
||||
__all__ = ["DorisUnifiedClient", "DorisClientConfig"]
|
||||
@@ -1,513 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Unified Doris MCP Client - Supports both stdio and Streamable HTTP modes
|
||||
|
||||
Combines the correct HTTP implementation from http_client.py and the complete architecture from client.py
|
||||
Provides complete support for the three major primitives: Resources, Tools, and Prompts
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Callable
|
||||
from datetime import timedelta
|
||||
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from mcp import StdioServerParameters
|
||||
from mcp.types import (
|
||||
Prompt,
|
||||
Resource,
|
||||
Tool,
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DorisClientConfig:
|
||||
"""Doris client configuration class"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: str = "stdio",
|
||||
server_command: str | None = None,
|
||||
server_args: list[str] | None = None,
|
||||
server_url: str | None = None,
|
||||
timeout: int = 60,
|
||||
):
|
||||
self.transport = transport
|
||||
self.server_command = server_command
|
||||
self.server_args = server_args or []
|
||||
self.server_url = server_url
|
||||
self.timeout = timeout
|
||||
|
||||
@classmethod
|
||||
def stdio(cls, command: str, args: list[str] = None) -> "DorisClientConfig":
|
||||
"""Create stdio connection configuration"""
|
||||
return cls(
|
||||
transport="stdio",
|
||||
server_command=command,
|
||||
server_args=args or []
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def http(cls, url: str, timeout: int = 60) -> "DorisClientConfig":
|
||||
"""Create HTTP connection configuration"""
|
||||
return cls(
|
||||
transport="http",
|
||||
server_url=url,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
|
||||
class DorisResourceClient:
|
||||
"""Doris resource client - Handles Resources related operations"""
|
||||
|
||||
def __init__(self, session: ClientSession):
|
||||
self.session = session
|
||||
self.logger = logging.getLogger(f"{__name__}.DorisResourceClient")
|
||||
self._resources_cache = None
|
||||
|
||||
async def list_resources(self) -> list[Resource]:
|
||||
"""Get list of all available resources"""
|
||||
try:
|
||||
self.logger.info("Getting resource list")
|
||||
response = await self.session.list_resources()
|
||||
resources = response.resources if hasattr(response, "resources") else []
|
||||
self._resources_cache = resources
|
||||
self.logger.info(f"Retrieved {len(resources)} resources")
|
||||
return resources
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to get resource list: {e}")
|
||||
return []
|
||||
|
||||
async def read_resource(self, uri: str) -> str | None:
|
||||
"""Read specified resource content"""
|
||||
try:
|
||||
self.logger.info(f"Reading resource: {uri}")
|
||||
response = await self.session.read_resource(uri)
|
||||
|
||||
if hasattr(response, "contents") and response.contents:
|
||||
# Merge all content
|
||||
content_parts = []
|
||||
for content in response.contents:
|
||||
if hasattr(content, "text"):
|
||||
content_parts.append(content.text)
|
||||
content = "\n".join(content_parts)
|
||||
self.logger.info(f"Successfully read resource content: {len(content)} characters")
|
||||
return content
|
||||
elif hasattr(response, "content"):
|
||||
return str(response.content)
|
||||
else:
|
||||
self.logger.warning(f"Resource {uri} returned no content")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to read resource {uri}: {e}")
|
||||
return None
|
||||
|
||||
async def filter_resources_by_type(self, resource_type: str) -> list[Resource]:
|
||||
"""Filter resources by type"""
|
||||
if not self._resources_cache:
|
||||
await self.list_resources()
|
||||
|
||||
if resource_type == "table":
|
||||
return [r for r in self._resources_cache if "table" in r.uri]
|
||||
elif resource_type == "view":
|
||||
return [r for r in self._resources_cache if "view" in r.uri]
|
||||
elif resource_type == "database":
|
||||
return [
|
||||
r for r in self._resources_cache
|
||||
if "database" in r.uri and "table" not in r.uri
|
||||
]
|
||||
else:
|
||||
return self._resources_cache
|
||||
|
||||
|
||||
class DorisToolsClient:
|
||||
"""Doris tools client - Handles Tools related operations"""
|
||||
|
||||
def __init__(self, session: ClientSession):
|
||||
self.session = session
|
||||
self.logger = logging.getLogger(f"{__name__}.DorisToolsClient")
|
||||
self._tools_cache = None
|
||||
|
||||
async def list_tools(self) -> list[Tool]:
|
||||
"""Get list of all available tools"""
|
||||
try:
|
||||
self.logger.info("Getting tool list")
|
||||
response = await self.session.list_tools()
|
||||
tools = response.tools if hasattr(response, "tools") else []
|
||||
self._tools_cache = tools
|
||||
self.logger.info(f"Retrieved {len(tools)} tools")
|
||||
return tools
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to get tool list: {e}")
|
||||
return []
|
||||
|
||||
async def call_tool(self, name: str, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Call specified tool"""
|
||||
try:
|
||||
self.logger.info(f"Calling tool: {name}")
|
||||
self.logger.debug(f"Tool arguments: {arguments}")
|
||||
|
||||
response = await self.session.call_tool(name, arguments)
|
||||
|
||||
if hasattr(response, "content") and response.content:
|
||||
# Parse response content
|
||||
result_text = ""
|
||||
for content in response.content:
|
||||
if hasattr(content, "text"):
|
||||
result_text += content.text
|
||||
|
||||
# Try to parse as JSON
|
||||
try:
|
||||
result = json.loads(result_text)
|
||||
self.logger.info(f"Tool call successful: {name}")
|
||||
return result
|
||||
except json.JSONDecodeError:
|
||||
# If not JSON format, return text directly
|
||||
return {"success": True, "data": result_text}
|
||||
|
||||
self.logger.warning(f"Tool {name} returned no content")
|
||||
return {"success": False, "error": "No response content"}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Tool call failed {name}: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def get_tool_by_name(self, name: str) -> Tool | None:
|
||||
"""Get tool definition by name"""
|
||||
if not self._tools_cache:
|
||||
await self.list_tools()
|
||||
|
||||
for tool in self._tools_cache:
|
||||
if tool.name == name:
|
||||
return tool
|
||||
return None
|
||||
|
||||
async def get_tools_by_category(self, category: str) -> list[Tool]:
|
||||
"""Filter tools by category"""
|
||||
if not self._tools_cache:
|
||||
await self.list_tools()
|
||||
|
||||
category_lower = category.lower()
|
||||
return [
|
||||
tool for tool in self._tools_cache
|
||||
if category_lower in tool.description.lower()
|
||||
or category_lower in tool.name.lower()
|
||||
]
|
||||
|
||||
|
||||
class DorisPromptClient:
|
||||
"""Doris prompt client - Handles Prompts related operations"""
|
||||
|
||||
def __init__(self, session: ClientSession):
|
||||
self.session = session
|
||||
self.logger = logging.getLogger(f"{__name__}.DorisPromptClient")
|
||||
self._prompts_cache = None
|
||||
|
||||
async def list_prompts(self) -> list[Prompt]:
|
||||
"""Get list of all available prompts"""
|
||||
try:
|
||||
self.logger.info("Getting prompt list")
|
||||
response = await self.session.list_prompts()
|
||||
prompts = response.prompts if hasattr(response, "prompts") else []
|
||||
self._prompts_cache = prompts
|
||||
self.logger.info(f"Retrieved {len(prompts)} prompts")
|
||||
return prompts
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to get prompt list: {e}")
|
||||
return []
|
||||
|
||||
async def get_prompt(self, name: str, arguments: dict[str, Any]) -> str | None:
|
||||
"""Get specified prompt content"""
|
||||
try:
|
||||
self.logger.info(f"Getting prompt: {name}")
|
||||
self.logger.debug(f"Prompt arguments: {arguments}")
|
||||
|
||||
response = await self.session.get_prompt(name, arguments)
|
||||
|
||||
if hasattr(response, "messages") and response.messages:
|
||||
# Merge all message content
|
||||
content_parts = []
|
||||
for message in response.messages:
|
||||
if hasattr(message, "content"):
|
||||
if hasattr(message.content, "text"):
|
||||
content_parts.append(message.content.text)
|
||||
else:
|
||||
content_parts.append(str(message.content))
|
||||
|
||||
content = "\n".join(content_parts)
|
||||
self.logger.info(f"Successfully retrieved prompt content: {len(content)} characters")
|
||||
return content
|
||||
|
||||
self.logger.warning(f"Prompt {name} returned no content")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to get prompt {name}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class DorisUnifiedClient:
|
||||
"""Unified Doris MCP client - Provides complete MCP functionality"""
|
||||
|
||||
def __init__(self, config: DorisClientConfig):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(f"{__name__}.DorisUnifiedClient")
|
||||
self.session = None
|
||||
self.resources = None
|
||||
self.tools = None
|
||||
self.prompts = None
|
||||
|
||||
async def connect_and_run(self, callback_func: Callable):
|
||||
"""Connect to server and execute callback function"""
|
||||
if self.config.transport == "stdio":
|
||||
await self._run_stdio_mode(callback_func)
|
||||
elif self.config.transport == "http":
|
||||
await self._run_http_mode(callback_func)
|
||||
else:
|
||||
raise ValueError(f"Unsupported transport type: {self.config.transport}")
|
||||
|
||||
async def _run_stdio_mode(self, callback_func: Callable):
|
||||
"""Run in stdio mode"""
|
||||
try:
|
||||
self.logger.info(f"Starting stdio client: {self.config.server_command}")
|
||||
|
||||
server_params = StdioServerParameters(
|
||||
command=self.config.server_command,
|
||||
args=self.config.server_args,
|
||||
)
|
||||
|
||||
async with stdio_client(server_params) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
self.session = session
|
||||
self._init_sub_clients()
|
||||
|
||||
# Initialize server
|
||||
await session.initialize()
|
||||
self.logger.info("Server initialized successfully")
|
||||
|
||||
# Execute callback function
|
||||
await callback_func(self)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"stdio mode execution failed: {e}")
|
||||
raise
|
||||
|
||||
async def _run_http_mode(self, callback_func: Callable):
|
||||
"""Run in HTTP mode"""
|
||||
try:
|
||||
self.logger.info(f"Starting HTTP client: {self.config.server_url}")
|
||||
|
||||
async with streamablehttp_client(
|
||||
self.config.server_url,
|
||||
timeout=timedelta(seconds=self.config.timeout)
|
||||
) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
self.session = session
|
||||
self._init_sub_clients()
|
||||
|
||||
# Initialize server
|
||||
await session.initialize()
|
||||
self.logger.info("Server initialized successfully")
|
||||
|
||||
# Execute callback function
|
||||
await callback_func(self)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"HTTP mode execution failed: {e}")
|
||||
raise
|
||||
|
||||
def _init_sub_clients(self):
|
||||
"""Initialize sub-clients"""
|
||||
self.resources = DorisResourceClient(self.session)
|
||||
self.tools = DorisToolsClient(self.session)
|
||||
self.prompts = DorisPromptClient(self.session)
|
||||
|
||||
# Convenience methods
|
||||
async def list_all_resources(self) -> list[Resource]:
|
||||
"""Get all resources"""
|
||||
return await self.resources.list_resources()
|
||||
|
||||
async def list_all_tools(self) -> list[Tool]:
|
||||
"""Get all tools"""
|
||||
return await self.tools.list_tools()
|
||||
|
||||
async def list_all_prompts(self) -> list[Prompt]:
|
||||
"""Get all prompts"""
|
||||
return await self.prompts.list_prompts()
|
||||
|
||||
async def call_tool(self, name: str, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Call tool"""
|
||||
return await self.tools.call_tool(name, arguments)
|
||||
|
||||
async def read_resource(self, uri: str) -> str | None:
|
||||
"""Read resource"""
|
||||
return await self.resources.read_resource(uri)
|
||||
|
||||
async def get_prompt(self, name: str, arguments: dict[str, Any]) -> str | None:
|
||||
"""Get prompt"""
|
||||
return await self.prompts.get_prompt(name, arguments)
|
||||
|
||||
# Smart tool finding methods
|
||||
async def _find_tool_by_pattern(self, patterns: list[str]) -> str | None:
|
||||
"""Find tool by name pattern"""
|
||||
tools = await self.list_all_tools()
|
||||
for pattern in patterns:
|
||||
for tool in tools:
|
||||
if pattern in tool.name:
|
||||
return tool.name
|
||||
return None
|
||||
|
||||
async def _find_tool_by_function(self, function_keywords: list[str]) -> str | None:
|
||||
"""Find tool by function keywords"""
|
||||
tools = await self.list_all_tools()
|
||||
for tool in tools:
|
||||
tool_desc = tool.description.lower()
|
||||
tool_name = tool.name.lower()
|
||||
for keyword in function_keywords:
|
||||
if keyword.lower() in tool_desc or keyword.lower() in tool_name:
|
||||
return tool.name
|
||||
return None
|
||||
|
||||
# High-level business methods
|
||||
async def execute_sql(self, sql: str, **kwargs) -> dict[str, Any]:
|
||||
"""Execute SQL query"""
|
||||
tool_name = await self._find_tool_by_pattern(["exec_query", "execute", "query"])
|
||||
if not tool_name:
|
||||
return {"success": False, "error": "SQL execution tool not found"}
|
||||
|
||||
arguments = {"sql": sql, **kwargs}
|
||||
return await self.call_tool(tool_name, arguments)
|
||||
|
||||
async def get_table_schema(self, table_name: str, db_name: str = None, **kwargs) -> dict[str, Any]:
|
||||
"""Get table schema"""
|
||||
tool_name = await self._find_tool_by_pattern(["get_table_schema", "table_schema", "schema"])
|
||||
if not tool_name:
|
||||
return {"success": False, "error": "Table schema tool not found"}
|
||||
|
||||
arguments = {"table_name": table_name}
|
||||
if db_name:
|
||||
arguments["db_name"] = db_name
|
||||
arguments.update(kwargs)
|
||||
|
||||
return await self.call_tool(tool_name, arguments)
|
||||
|
||||
async def get_database_list(self, **kwargs) -> dict[str, Any]:
|
||||
"""Get database list"""
|
||||
tool_name = await self._find_tool_by_pattern(["get_db_list", "database_list", "db_list"])
|
||||
if not tool_name:
|
||||
return {"success": False, "error": "Database list tool not found"}
|
||||
|
||||
return await self.call_tool(tool_name, kwargs)
|
||||
|
||||
async def analyze_column(self, table_name: str, column_name: str, analysis_type: str = "basic", **kwargs) -> dict[str, Any]:
|
||||
"""Analyze column"""
|
||||
tool_name = await self._find_tool_by_pattern(["column_analysis", "analyze_column", "column"])
|
||||
if not tool_name:
|
||||
return {"success": False, "error": "Column analysis tool not found"}
|
||||
|
||||
arguments = {
|
||||
"table_name": table_name,
|
||||
"column_name": column_name,
|
||||
"analysis_type": analysis_type,
|
||||
**kwargs
|
||||
}
|
||||
return await self.call_tool(tool_name, arguments)
|
||||
|
||||
async def call_tool_by_function(self, function_description: str, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Call tool by function description"""
|
||||
# Try to find appropriate tool based on function description
|
||||
function_keywords = function_description.lower().split()
|
||||
tool_name = await self._find_tool_by_function(function_keywords)
|
||||
|
||||
if not tool_name:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"No tool found for function: {function_description}"
|
||||
}
|
||||
|
||||
return await self.call_tool(tool_name, arguments)
|
||||
|
||||
|
||||
# Convenience factory functions
|
||||
async def create_stdio_client(command: str, args: list[str] = None) -> DorisUnifiedClient:
|
||||
"""Create stdio client"""
|
||||
config = DorisClientConfig.stdio(command, args)
|
||||
return DorisUnifiedClient(config)
|
||||
|
||||
|
||||
async def create_http_client(server_url: str, timeout: int = 60) -> DorisUnifiedClient:
|
||||
"""Create HTTP client"""
|
||||
config = DorisClientConfig.http(server_url, timeout)
|
||||
return DorisUnifiedClient(config)
|
||||
|
||||
|
||||
# Example usage
|
||||
async def example_stdio():
|
||||
"""stdio mode example"""
|
||||
client = await create_stdio_client("python", ["doris_mcp_server/main.py"])
|
||||
|
||||
async def test_client(client: DorisUnifiedClient):
|
||||
# Get server capabilities
|
||||
resources = await client.list_all_resources()
|
||||
tools = await client.list_all_tools()
|
||||
prompts = await client.list_all_prompts()
|
||||
|
||||
print(f"Resources: {len(resources)}")
|
||||
print(f"Tools: {len(tools)}")
|
||||
print(f"Prompts: {len(prompts)}")
|
||||
|
||||
# Test SQL execution
|
||||
result = await client.execute_sql("SELECT 1 as test")
|
||||
print(f"SQL execution result: {result}")
|
||||
|
||||
await client.connect_and_run(test_client)
|
||||
|
||||
|
||||
async def example_http():
|
||||
"""HTTP mode example"""
|
||||
client = await create_http_client("http://localhost:8080")
|
||||
|
||||
async def test_client(client: DorisUnifiedClient):
|
||||
# Get server capabilities
|
||||
resources = await client.list_all_resources()
|
||||
tools = await client.list_all_tools()
|
||||
|
||||
print(f"Resources: {len(resources)}")
|
||||
print(f"Tools: {len(tools)}")
|
||||
|
||||
# Test database list
|
||||
result = await client.get_database_list()
|
||||
print(f"Database list: {result}")
|
||||
|
||||
await client.connect_and_run(test_client)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run stdio example
|
||||
asyncio.run(example_stdio())
|
||||
|
||||
# Run HTTP example
|
||||
# asyncio.run(example_http())
|
||||
@@ -42,7 +42,7 @@ classifiers = [
|
||||
|
||||
dependencies = [
|
||||
# Core MCP dependencies
|
||||
"mcp>=1.0.0",
|
||||
"mcp>=1.8.0",
|
||||
# Database drivers
|
||||
"aiomysql>=0.2.0",
|
||||
"PyMySQL>=1.1.0",
|
||||
@@ -147,8 +147,10 @@ Issues = "https://github.com/apache/doris-mcp-server/issues"
|
||||
Changelog = "https://github.com/apache/doris-mcp-server/blob/main/CHANGELOG.md"
|
||||
|
||||
[project.scripts]
|
||||
mcp-doris-server = "doris_mcp_server.main:main_sync"
|
||||
doris-mcp-server = "doris_mcp_server.main:main_sync"
|
||||
doris-mcp-client = "doris_mcp_server.client:main"
|
||||
mcp-doris-client = "doris_mcp_server.client:main"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["doris_mcp_server"]
|
||||
@@ -163,7 +165,7 @@ include = [
|
||||
# Black configuration
|
||||
[tool.black]
|
||||
line-length = 88
|
||||
target-version = ['py310', 'py311', 'py312']
|
||||
target-version = ['py312']
|
||||
include = '\.pyi?$'
|
||||
extend-exclude = '''
|
||||
/(
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# 开发依赖 - 从 pyproject.toml 自动生成
|
||||
# 安装命令: pip install -r requirements-dev.txt
|
||||
|
||||
pytest>=7.4.0
|
||||
pytest-asyncio>=0.23.0
|
||||
pytest-cov>=4.1.0
|
||||
pytest-mock>=3.12.0
|
||||
pytest-xdist>=3.5.0
|
||||
ruff>=0.1.0
|
||||
black>=23.12.0
|
||||
isort>=5.13.0
|
||||
flake8>=7.0.0
|
||||
mypy>=1.8.0
|
||||
bandit>=1.7.0
|
||||
safety>=2.3.0
|
||||
sphinx>=7.2.0
|
||||
sphinx-rtd-theme>=2.0.0
|
||||
myst-parser>=2.0.0
|
||||
pre-commit>=3.6.0
|
||||
tox>=4.11.0
|
||||
263
test/README.md
263
test/README.md
@@ -1,263 +0,0 @@
|
||||
<!--
|
||||
Licensed to the Apache Software Foundation (ASF) under one
|
||||
or more contributor license agreements. See the NOTICE file
|
||||
distributed with this work for additional information
|
||||
regarding copyright ownership. The ASF licenses this file
|
||||
to you under the Apache License, Version 2.0 (the
|
||||
"License"); you may not use this file except in compliance
|
||||
with the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing,
|
||||
software distributed under the License is distributed on an
|
||||
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations
|
||||
under the License.
|
||||
-->
|
||||
# Doris MCP Server Testing System
|
||||
|
||||
## Overview
|
||||
|
||||
This testing system adopts a layered architecture, including unit tests, integration tests, and client-server tests. The testing system assumes the server is already properly started and focuses on testing functionality rather than startup configuration.
|
||||
|
||||
## Testing Architecture
|
||||
|
||||
### 1. Unit Tests
|
||||
- **Location**: `test/security/`, `test/utils/`, `test/tools/`
|
||||
- **Purpose**: Test individual module functionality
|
||||
- **Features**: Uses Mock objects, no dependency on external services
|
||||
|
||||
### 2. Integration Tests
|
||||
- **Location**: `test/integration/`
|
||||
- **Purpose**: Test collaboration between modules
|
||||
- **Features**: Test complete workflows
|
||||
|
||||
### 3. Client-Server Tests
|
||||
- **Location**: `test/tools/test_tools_client_server.py`, `test/utils/test_query_executor_client_server.py`
|
||||
- **Purpose**: Test actual server functionality through MCP client
|
||||
- **Features**: Assumes server is running, skips tests if server is not available
|
||||
|
||||
## Configuration Files
|
||||
|
||||
### test_config.json
|
||||
Test configuration file defines how to connect to the running server:
|
||||
|
||||
```json
|
||||
{
|
||||
"server_endpoints": {
|
||||
"http": {
|
||||
"url": "http://localhost:3000/mcp",
|
||||
"timeout": 30
|
||||
},
|
||||
"stdio": {
|
||||
"command": "uv",
|
||||
"args": ["run", "python", "-m", "doris_mcp_server.main", "--transport", "stdio"],
|
||||
"timeout": 30
|
||||
}
|
||||
},
|
||||
"test_settings": {
|
||||
"default_transport": "http",
|
||||
"retry_attempts": 3,
|
||||
"retry_delay": 1.0,
|
||||
"test_timeout": 60,
|
||||
"enable_performance_tests": true,
|
||||
"enable_security_tests": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
### 1. Start the Server
|
||||
|
||||
Before running client-server tests, you need to start the server first:
|
||||
|
||||
#### HTTP Mode (Recommended)
|
||||
```bash
|
||||
# Start HTTP server
|
||||
./start_server.sh
|
||||
# or
|
||||
uv run python -m doris_mcp_server.main --transport http --port 3000
|
||||
```
|
||||
|
||||
#### Stdio Mode
|
||||
```bash
|
||||
# Stdio mode is started directly by the client, no need to pre-start
|
||||
```
|
||||
|
||||
### 2. Run Tests
|
||||
|
||||
#### Run All Tests
|
||||
```bash
|
||||
python -m pytest test/ -v
|
||||
```
|
||||
|
||||
#### Run Unit Tests
|
||||
```bash
|
||||
# Security module tests
|
||||
python -m pytest test/security/ -v
|
||||
|
||||
# Tools module tests
|
||||
python -m pytest test/tools/test_tools_manager.py -v
|
||||
|
||||
# Query executor tests
|
||||
python -m pytest test/utils/test_query_executor.py -v
|
||||
```
|
||||
|
||||
#### Run Integration Tests
|
||||
```bash
|
||||
python -m pytest test/integration/ -v
|
||||
```
|
||||
|
||||
#### Run Client-Server Tests
|
||||
```bash
|
||||
# Tools Client-Server tests
|
||||
python -m pytest test/tools/test_tools_client_server.py -v
|
||||
|
||||
# QueryExecutor Client-Server tests
|
||||
python -m pytest test/utils/test_query_executor_client_server.py -v
|
||||
```
|
||||
|
||||
### 3. Test Configuration
|
||||
|
||||
#### Modify Server Endpoints
|
||||
Edit the `test/test_config.json` file:
|
||||
|
||||
```json
|
||||
{
|
||||
"server_endpoints": {
|
||||
"http": {
|
||||
"url": "http://your-server:port/mcp"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### Enable/Disable Specific Tests
|
||||
```json
|
||||
{
|
||||
"test_settings": {
|
||||
"enable_performance_tests": false, // Disable performance tests
|
||||
"enable_security_tests": true // Enable security tests
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Test Status
|
||||
|
||||
### ✅ Completed Test Modules
|
||||
|
||||
1. **Security Module** (100% Pass)
|
||||
- Authentication tests: 5/5 passed
|
||||
- Authorization tests: 7/7 passed
|
||||
- Data masking tests: 13/13 passed
|
||||
- SQL validation tests: 10/10 passed
|
||||
- Security manager tests: 7/7 passed
|
||||
- Coverage: 88%
|
||||
|
||||
2. **Client-Server Test Architecture** (Implemented)
|
||||
- Automatic server connection status detection
|
||||
- Automatically skip tests when server is not running
|
||||
- Support for both HTTP and Stdio transport modes
|
||||
|
||||
### 🔄 Tests Requiring Server Running
|
||||
|
||||
1. **Tools Client-Server Tests**
|
||||
- Tool list retrieval
|
||||
- SQL query execution
|
||||
- Database list retrieval
|
||||
- Table schema queries
|
||||
- Performance statistics
|
||||
- Error handling
|
||||
- Security authentication
|
||||
|
||||
2. **QueryExecutor Client-Server Tests**
|
||||
- Simple query execution
|
||||
- Database queries
|
||||
- Information schema queries
|
||||
- Parameterized queries
|
||||
- Error handling
|
||||
- Security authentication
|
||||
|
||||
## Testing Best Practices
|
||||
|
||||
### 1. Server Startup Check
|
||||
All client-server tests automatically check server connection status:
|
||||
- If server is running normally, execute actual tests
|
||||
- If server is not running, skip tests and display appropriate message
|
||||
|
||||
### 2. Test Isolation
|
||||
- Unit tests use Mock objects, no dependency on external services
|
||||
- Integration tests use controlled test environments
|
||||
- Client-server tests connect to actually running servers
|
||||
|
||||
### 3. Error Handling
|
||||
- Tests don't assume specific success/failure results
|
||||
- Verify response structure rather than specific content
|
||||
- Gracefully handle connection failures and timeouts
|
||||
|
||||
### 4. Configuration Management
|
||||
- Use configuration files to manage test parameters
|
||||
- Support configuration switching for different environments
|
||||
- Provide reasonable default values
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### 1. Server Connection Failure
|
||||
```
|
||||
ERROR: Server is not running or not accessible
|
||||
```
|
||||
**Solution**: Ensure the server is started and listening on the correct port
|
||||
|
||||
### 2. Import Errors
|
||||
```
|
||||
ImportError: cannot import name 'DorisUnifiedClient'
|
||||
```
|
||||
**Solution**: Check Python path and dependency installation
|
||||
|
||||
### 3. Test Timeouts
|
||||
```
|
||||
TimeoutError: Test execution timeout
|
||||
```
|
||||
**Solution**: Increase timeout settings in `test_config.json`
|
||||
|
||||
## Development Guide
|
||||
|
||||
### Adding New Client-Server Tests
|
||||
|
||||
1. Add test methods in the appropriate test file
|
||||
2. Use `@pytest.mark.asyncio` decorator
|
||||
3. Get test client through `client` fixture
|
||||
4. Implement test callback function
|
||||
5. Verify response structure
|
||||
|
||||
Example:
|
||||
```python
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_feature_via_client(self, client, test_config):
|
||||
"""Test new feature through client"""
|
||||
async def test_callback(client_instance):
|
||||
result = await client_instance.call_tool("new_tool", {
|
||||
"param": "value"
|
||||
})
|
||||
|
||||
assert "success" in result
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
```
|
||||
|
||||
### Modifying Test Configuration
|
||||
|
||||
Edit the `test/test_config.json` file to adjust:
|
||||
- Server endpoints
|
||||
- Timeout settings
|
||||
- Test data
|
||||
- Feature switches
|
||||
|
||||
## Summary
|
||||
|
||||
This testing system provides complete test coverage, from unit tests to end-to-end client-server tests. Through reasonable configuration and automated connection detection, it ensures tests can run stably in different environments.
|
||||
@@ -1,16 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
107
test/conftest.py
107
test/conftest.py
@@ -1,107 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Pytest configuration and fixtures for Doris MCP Server tests
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
# Add project root to Python path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
# Configure logging for tests
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Create an instance of the default event loop for the test session."""
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_config():
|
||||
"""Provide test configuration"""
|
||||
return {
|
||||
"doris_host": "localhost",
|
||||
"doris_port": 9030,
|
||||
"doris_user": "test_user",
|
||||
"doris_password": "test_password",
|
||||
"doris_database": "test_db",
|
||||
"blocked_keywords": ["DROP", "DELETE", "TRUNCATE", "ALTER", "CREATE", "INSERT", "UPDATE"],
|
||||
"sensitive_tables": {
|
||||
"user_info": "confidential",
|
||||
"payment_records": "secret",
|
||||
"employee_data": "confidential",
|
||||
"public_reports": "public"
|
||||
},
|
||||
"max_query_complexity": 100
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_data():
|
||||
"""Provide sample test data"""
|
||||
return [
|
||||
{
|
||||
"id": 1,
|
||||
"name": "张三",
|
||||
"phone": "13812345678",
|
||||
"email": "zhangsan@example.com",
|
||||
"id_card": "110101199001011234",
|
||||
"salary": 50000
|
||||
},
|
||||
{
|
||||
"id": 2,
|
||||
"name": "李四",
|
||||
"phone": "13987654321",
|
||||
"email": "lisi@example.com",
|
||||
"id_card": "110101199002022345",
|
||||
"salary": 60000
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_sql_queries():
|
||||
"""Provide test SQL queries"""
|
||||
return {
|
||||
"safe_select": "SELECT name, email FROM users WHERE department = 'sales'",
|
||||
"dangerous_drop": "DROP TABLE users",
|
||||
"sql_injection": "SELECT * FROM users WHERE id = 1; DROP TABLE users;",
|
||||
"union_injection": "SELECT name FROM users UNION SELECT password FROM admin_users",
|
||||
"comment_injection": "SELECT * FROM users WHERE id = 1 -- AND password = 'secret'",
|
||||
"complex_query": """
|
||||
SELECT u.name, u.email, d.department_name
|
||||
FROM users u
|
||||
JOIN departments d ON u.department_id = d.id
|
||||
WHERE u.status = 'active'
|
||||
ORDER BY u.created_at DESC
|
||||
"""
|
||||
}
|
||||
@@ -1,299 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
End-to-end integration tests
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from doris_mcp_server.main import DorisServer
|
||||
from doris_mcp_server.utils.config import DorisConfig
|
||||
from doris_mcp_server.utils.security import SecurityLevel, AuthContext
|
||||
|
||||
|
||||
class TestEndToEndIntegration:
|
||||
"""End-to-end integration tests"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Create mock configuration"""
|
||||
from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig
|
||||
|
||||
config = Mock(spec=DorisConfig)
|
||||
config.doris_host = "localhost"
|
||||
config.doris_port = 9030
|
||||
config.doris_user = "test_user"
|
||||
config.doris_password = "test_password"
|
||||
config.doris_database = "test_db"
|
||||
config.server_host = "localhost"
|
||||
config.server_port = 8000
|
||||
config.enable_security = True
|
||||
|
||||
# Add database config
|
||||
config.database = Mock(spec=DatabaseConfig)
|
||||
config.database.host = "localhost"
|
||||
config.database.port = 9030
|
||||
config.database.user = "test_user"
|
||||
config.database.password = "test_password"
|
||||
config.database.database = "test_db"
|
||||
config.database.health_check_interval = 60
|
||||
config.database.min_connections = 5
|
||||
config.database.max_connections = 20
|
||||
config.database.connection_timeout = 30
|
||||
config.database.max_connection_age = 3600
|
||||
|
||||
# Add security config
|
||||
config.security = Mock(spec=SecurityConfig)
|
||||
config.security.enable_masking = True
|
||||
config.security.auth_type = "token"
|
||||
config.security.token_secret = "test_secret"
|
||||
config.security.token_expiry = 3600
|
||||
|
||||
return config
|
||||
|
||||
@pytest.fixture
|
||||
def doris_server(self, mock_config):
|
||||
"""Create Doris server instance"""
|
||||
return DorisServer(mock_config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_query_workflow_with_security(self, doris_server, sample_data):
|
||||
"""Test complete query workflow with security"""
|
||||
with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute:
|
||||
mock_execute.return_value = sample_data
|
||||
|
||||
# Mock authentication
|
||||
with patch.object(doris_server.security_manager, 'authenticate_request') as mock_auth:
|
||||
mock_auth.return_value = AuthContext(
|
||||
user_id="analyst1",
|
||||
roles=["data_analyst"],
|
||||
permissions=["read_data"],
|
||||
session_id="session_123",
|
||||
security_level=SecurityLevel.INTERNAL
|
||||
)
|
||||
|
||||
# Mock authorization
|
||||
with patch.object(doris_server.security_manager, 'authorize_resource_access') as mock_authz:
|
||||
mock_authz.return_value = True
|
||||
|
||||
# Mock SQL validation
|
||||
with patch.object(doris_server.security_manager, 'validate_sql_security') as mock_validate:
|
||||
from doris_mcp_server.utils.security import ValidationResult
|
||||
mock_validate.return_value = ValidationResult(is_valid=True)
|
||||
|
||||
# Mock data masking
|
||||
with patch.object(doris_server.security_manager, 'apply_data_masking') as mock_mask:
|
||||
masked_data = [
|
||||
{
|
||||
"id": 1,
|
||||
"name": "张三",
|
||||
"phone": "138****5678",
|
||||
"email": "z*******n@example.com",
|
||||
"id_card": "110101****1234",
|
||||
"salary": 50000
|
||||
}
|
||||
]
|
||||
mock_mask.return_value = masked_data
|
||||
|
||||
# Simulate complete workflow
|
||||
auth_info = {"type": "token", "token": "valid_token_123"}
|
||||
auth_context = await doris_server.security_manager.authenticate_request(auth_info)
|
||||
|
||||
resource_uri = "/api/table/users"
|
||||
has_access = await doris_server.security_manager.authorize_resource_access(
|
||||
auth_context, resource_uri
|
||||
)
|
||||
assert has_access is True
|
||||
|
||||
sql = "SELECT * FROM users LIMIT 1"
|
||||
validation = await doris_server.security_manager.validate_sql_security(
|
||||
sql, auth_context
|
||||
)
|
||||
assert validation.is_valid is True
|
||||
|
||||
raw_data = await doris_server.tools_manager.query_executor.execute_query(sql)
|
||||
final_data = await doris_server.security_manager.apply_data_masking(
|
||||
raw_data, auth_context
|
||||
)
|
||||
|
||||
# Verify data is properly masked
|
||||
assert final_data[0]["phone"] == "138****5678"
|
||||
assert final_data[0]["email"] == "z*******n@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_security_violation_workflow(self, doris_server):
|
||||
"""Test security violation detection workflow"""
|
||||
with patch.object(doris_server.security_manager, 'authenticate_request') as mock_auth:
|
||||
mock_auth.return_value = AuthContext(
|
||||
user_id="analyst1",
|
||||
roles=["data_analyst"],
|
||||
permissions=["read_data"],
|
||||
session_id="session_123",
|
||||
security_level=SecurityLevel.INTERNAL
|
||||
)
|
||||
|
||||
# Test unauthorized resource access
|
||||
with patch.object(doris_server.security_manager, 'authorize_resource_access') as mock_authz:
|
||||
mock_authz.return_value = False
|
||||
|
||||
auth_context = await doris_server.security_manager.authenticate_request({
|
||||
"type": "token", "token": "valid_token_123"
|
||||
})
|
||||
|
||||
# Try to access confidential resource
|
||||
resource_uri = "/api/table/payment_records"
|
||||
has_access = await doris_server.security_manager.authorize_resource_access(
|
||||
auth_context, resource_uri
|
||||
)
|
||||
|
||||
assert has_access is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sql_injection_prevention_workflow(self, doris_server):
|
||||
"""Test SQL injection prevention workflow"""
|
||||
with patch.object(doris_server.security_manager, 'authenticate_request') as mock_auth:
|
||||
mock_auth.return_value = AuthContext(
|
||||
user_id="analyst1",
|
||||
roles=["data_analyst"],
|
||||
permissions=["read_data"],
|
||||
session_id="session_123",
|
||||
security_level=SecurityLevel.INTERNAL
|
||||
)
|
||||
|
||||
auth_context = await doris_server.security_manager.authenticate_request({
|
||||
"type": "token", "token": "valid_token_123"
|
||||
})
|
||||
|
||||
# Test SQL injection attempt
|
||||
malicious_sql = "SELECT * FROM users WHERE id = 1; DROP TABLE users;"
|
||||
validation = await doris_server.security_manager.validate_sql_security(
|
||||
malicious_sql, auth_context
|
||||
)
|
||||
|
||||
assert validation.is_valid is False
|
||||
assert validation.risk_level == "high"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_bypass_workflow(self, doris_server, sample_data):
|
||||
"""Test admin user bypassing restrictions"""
|
||||
with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute:
|
||||
mock_execute.return_value = sample_data
|
||||
|
||||
with patch.object(doris_server.security_manager, 'authenticate_request') as mock_auth:
|
||||
mock_auth.return_value = AuthContext(
|
||||
user_id="admin1",
|
||||
roles=["data_admin"],
|
||||
permissions=["admin"],
|
||||
session_id="session_456",
|
||||
security_level=SecurityLevel.SECRET
|
||||
)
|
||||
|
||||
# Admin should access any resource
|
||||
with patch.object(doris_server.security_manager, 'authorize_resource_access') as mock_authz:
|
||||
mock_authz.return_value = True
|
||||
|
||||
# Admin should see original data (no masking)
|
||||
with patch.object(doris_server.security_manager, 'apply_data_masking') as mock_mask:
|
||||
mock_mask.return_value = sample_data # Original data
|
||||
|
||||
auth_context = await doris_server.security_manager.authenticate_request({
|
||||
"type": "basic", "username": "admin", "password": "admin123"
|
||||
})
|
||||
|
||||
# Admin accesses secret resource
|
||||
resource_uri = "/api/table/payment_records"
|
||||
has_access = await doris_server.security_manager.authorize_resource_access(
|
||||
auth_context, resource_uri
|
||||
)
|
||||
assert has_access is True
|
||||
|
||||
# Admin sees original data
|
||||
raw_data = await doris_server.tools_manager.query_executor.execute_query(
|
||||
"SELECT * FROM users LIMIT 1"
|
||||
)
|
||||
final_data = await doris_server.security_manager.apply_data_masking(
|
||||
raw_data, auth_context
|
||||
)
|
||||
|
||||
# Should be original data (no masking)
|
||||
assert final_data[0]["phone"] == "13812345678"
|
||||
assert final_data[0]["email"] == "zhangsan@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_execution_with_security(self, doris_server):
|
||||
"""Test tool execution with security checks"""
|
||||
with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute:
|
||||
mock_execute.return_value = [{"Database": "test_db"}]
|
||||
|
||||
# Test tool execution through tools manager
|
||||
result = await doris_server.tools_manager.call_tool("get_db_list", {})
|
||||
result_data = json.loads(result)
|
||||
|
||||
# Accept either success result or error (due to mock environment)
|
||||
assert "result" in result_data or "error" in result_data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling_workflow(self, doris_server):
|
||||
"""Test error handling in complete workflow"""
|
||||
# Test authentication failure
|
||||
with patch.object(doris_server.security_manager, 'authenticate_request') as mock_auth:
|
||||
mock_auth.side_effect = Exception("Invalid token")
|
||||
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
await doris_server.security_manager.authenticate_request({
|
||||
"type": "token", "token": "invalid_token"
|
||||
})
|
||||
|
||||
assert "Invalid token" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_performance_monitoring_integration(self, doris_server):
|
||||
"""Test performance monitoring integration"""
|
||||
with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute:
|
||||
mock_execute.return_value = [
|
||||
{
|
||||
"query_count": 1500,
|
||||
"avg_execution_time": 0.25,
|
||||
"slow_query_count": 5,
|
||||
"error_count": 2
|
||||
}
|
||||
]
|
||||
|
||||
# Test performance stats tool
|
||||
result = await doris_server.tools_manager.call_tool("performance_stats", {
|
||||
"metric_type": "queries",
|
||||
"time_range": "1h"
|
||||
})
|
||||
result_data = json.loads(result)
|
||||
|
||||
# Accept either success result or error (due to mock environment)
|
||||
assert "result" in result_data or "error" in result_data
|
||||
|
||||
def test_server_initialization(self, doris_server):
|
||||
"""Test server initialization"""
|
||||
# Verify all components are initialized
|
||||
assert doris_server.config is not None
|
||||
assert doris_server.tools_manager is not None
|
||||
assert doris_server.security_manager is not None
|
||||
|
||||
# Verify tools are available - use list_tools instead
|
||||
import asyncio
|
||||
tools = asyncio.run(doris_server.tools_manager.list_tools())
|
||||
assert len(tools) > 0
|
||||
@@ -1,103 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Authentication module tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
|
||||
from doris_mcp_server.utils.security import (
|
||||
AuthenticationProvider,
|
||||
AuthContext,
|
||||
SecurityLevel
|
||||
)
|
||||
|
||||
|
||||
class TestAuthenticationProvider:
|
||||
"""Authentication provider tests"""
|
||||
|
||||
@pytest.fixture
|
||||
def auth_provider(self, test_config):
|
||||
"""Create authentication provider instance"""
|
||||
return AuthenticationProvider(test_config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_authentication_success(self, auth_provider):
|
||||
"""Test successful token authentication"""
|
||||
auth_info = {
|
||||
"type": "token",
|
||||
"token": "valid_token_123"
|
||||
}
|
||||
|
||||
result = await auth_provider.authenticate(auth_info)
|
||||
|
||||
assert isinstance(result, AuthContext)
|
||||
assert result.user_id == "test_user"
|
||||
assert "data_analyst" in result.roles
|
||||
assert result.security_level == SecurityLevel.INTERNAL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_authentication_failure(self, auth_provider):
|
||||
"""Test failed token authentication"""
|
||||
auth_info = {
|
||||
"type": "token",
|
||||
"token": "invalid_token"
|
||||
}
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await auth_provider.authenticate(auth_info)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_authentication_success(self, auth_provider):
|
||||
"""Test successful basic authentication"""
|
||||
auth_info = {
|
||||
"type": "basic",
|
||||
"username": "admin",
|
||||
"password": "admin123"
|
||||
}
|
||||
|
||||
result = await auth_provider.authenticate(auth_info)
|
||||
|
||||
assert isinstance(result, AuthContext)
|
||||
assert result.user_id == "admin_user"
|
||||
assert "data_admin" in result.roles
|
||||
assert result.security_level == SecurityLevel.SECRET
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_authentication_failure(self, auth_provider):
|
||||
"""Test failed basic authentication"""
|
||||
auth_info = {
|
||||
"type": "basic",
|
||||
"username": "admin",
|
||||
"password": "wrong_password"
|
||||
}
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await auth_provider.authenticate(auth_info)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsupported_auth_type(self, auth_provider):
|
||||
"""Test unsupported authentication type"""
|
||||
auth_info = {
|
||||
"type": "oauth",
|
||||
"token": "oauth_token"
|
||||
}
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await auth_provider.authenticate(auth_info)
|
||||
@@ -1,147 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Authorization module tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from doris_mcp_server.utils.security import (
|
||||
AuthorizationProvider,
|
||||
AuthContext,
|
||||
SecurityLevel
|
||||
)
|
||||
|
||||
|
||||
class TestAuthorizationProvider:
|
||||
"""Authorization provider tests"""
|
||||
|
||||
@pytest.fixture
|
||||
def authz_provider(self, test_config):
|
||||
"""Create authorization provider instance"""
|
||||
return AuthorizationProvider(test_config)
|
||||
|
||||
@pytest.fixture
|
||||
def analyst_context(self):
|
||||
"""Create analyst auth context"""
|
||||
return AuthContext(
|
||||
user_id="analyst1",
|
||||
roles=["data_analyst"],
|
||||
permissions=["read_data"],
|
||||
session_id="session_123",
|
||||
security_level=SecurityLevel.INTERNAL
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def admin_context(self):
|
||||
"""Create admin auth context"""
|
||||
return AuthContext(
|
||||
user_id="admin1",
|
||||
roles=["data_admin"],
|
||||
permissions=["admin"],
|
||||
session_id="session_456",
|
||||
security_level=SecurityLevel.SECRET
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyst_access_public_resource(self, authz_provider, analyst_context):
|
||||
"""Test analyst accessing public resource"""
|
||||
resource_uri = "/api/table/public_reports"
|
||||
|
||||
result = await authz_provider.check_permission(analyst_context, resource_uri, "read")
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyst_denied_confidential_resource(self, authz_provider):
|
||||
"""Test analyst denied access to confidential resource"""
|
||||
# Create analyst with lower security level
|
||||
analyst_context = AuthContext(
|
||||
user_id="analyst1",
|
||||
roles=["data_analyst"],
|
||||
permissions=["read_data"],
|
||||
session_id="session_123",
|
||||
security_level=SecurityLevel.PUBLIC # Lower than CONFIDENTIAL
|
||||
)
|
||||
|
||||
resource_uri = "/api/table/user_info"
|
||||
|
||||
result = await authz_provider.check_permission(analyst_context, resource_uri, "read")
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_access_secret_resource(self, authz_provider, admin_context):
|
||||
"""Test admin accessing secret resource"""
|
||||
resource_uri = "/api/table/payment_records"
|
||||
|
||||
result = await authz_provider.check_permission(admin_context, resource_uri, "read")
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_role_based_permission(self, authz_provider):
|
||||
"""Test role-based permission check"""
|
||||
# Create analyst context
|
||||
analyst_context = AuthContext(
|
||||
user_id="analyst1",
|
||||
roles=["data_analyst"],
|
||||
permissions=["read_data"],
|
||||
session_id="session_123",
|
||||
security_level=SecurityLevel.INTERNAL
|
||||
)
|
||||
|
||||
resource_uri = "/api/table/some_table"
|
||||
|
||||
# Analyst should have read permission
|
||||
result = await authz_provider.check_permission(analyst_context, resource_uri, "read")
|
||||
assert result is True
|
||||
|
||||
# Analyst should not have write permission
|
||||
result = await authz_provider.check_permission(analyst_context, resource_uri, "write")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_override(self, authz_provider, admin_context):
|
||||
"""Test admin permission override"""
|
||||
resource_uri = "/api/table/any_table"
|
||||
|
||||
# Admin should have all permissions
|
||||
result = await authz_provider.check_permission(admin_context, resource_uri, "read")
|
||||
assert result is True
|
||||
|
||||
result = await authz_provider.check_permission(admin_context, resource_uri, "write")
|
||||
assert result is True
|
||||
|
||||
def test_parse_resource_uri(self, authz_provider):
|
||||
"""Test resource URI parsing"""
|
||||
uri = "/api/table/user_info/default"
|
||||
|
||||
result = authz_provider._parse_resource_uri(uri)
|
||||
|
||||
assert result["type"] == "table"
|
||||
assert result["name"] == "user_info"
|
||||
assert result["schema"] == "default"
|
||||
|
||||
def test_get_resource_security_level(self, authz_provider):
|
||||
"""Test getting resource security level"""
|
||||
resource_info = {"name": "user_info", "type": "table"}
|
||||
|
||||
level = authz_provider._get_resource_security_level(resource_info)
|
||||
|
||||
assert level == SecurityLevel.CONFIDENTIAL
|
||||
@@ -1,197 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Data masking tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from doris_mcp_server.utils.security import (
|
||||
DataMaskingProcessor,
|
||||
AuthContext,
|
||||
SecurityLevel,
|
||||
MaskingRule
|
||||
)
|
||||
|
||||
|
||||
class TestDataMaskingProcessor:
|
||||
"""Data masking processor tests"""
|
||||
|
||||
@pytest.fixture
|
||||
def masking_processor(self, test_config):
|
||||
"""Create data masking processor instance"""
|
||||
return DataMaskingProcessor(test_config)
|
||||
|
||||
@pytest.fixture
|
||||
def internal_user_context(self):
|
||||
"""Create internal user auth context"""
|
||||
return AuthContext(
|
||||
user_id="internal_user",
|
||||
roles=["data_analyst"],
|
||||
permissions=["read_data"],
|
||||
session_id="session_123",
|
||||
security_level=SecurityLevel.INTERNAL
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def admin_context(self):
|
||||
"""Create admin auth context"""
|
||||
return AuthContext(
|
||||
user_id="admin",
|
||||
roles=["data_admin"],
|
||||
permissions=["admin"],
|
||||
session_id="session_456",
|
||||
security_level=SecurityLevel.SECRET
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_phone_masking_for_internal_user(self, masking_processor, internal_user_context, sample_data):
|
||||
"""Test phone number masking for internal user"""
|
||||
result = await masking_processor.process(sample_data, internal_user_context)
|
||||
|
||||
# Phone numbers should be masked
|
||||
assert result[0]["phone"] == "138****5678"
|
||||
assert result[1]["phone"] == "139****4321"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_email_masking_for_internal_user(self, masking_processor, internal_user_context, sample_data):
|
||||
"""Test email masking for internal user"""
|
||||
result = await masking_processor.process(sample_data, internal_user_context)
|
||||
|
||||
# Emails should be masked
|
||||
assert result[0]["email"] == "z******n@example.com"
|
||||
assert result[1]["email"] == "l**i@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_masking_for_admin(self, masking_processor, admin_context, sample_data):
|
||||
"""Test no masking for admin user"""
|
||||
result = await masking_processor.process(sample_data, admin_context)
|
||||
|
||||
# Admin should see original data
|
||||
assert result[0]["phone"] == "13812345678"
|
||||
assert result[0]["email"] == "zhangsan@example.com"
|
||||
assert result[1]["phone"] == "13987654321"
|
||||
assert result[1]["email"] == "lisi@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_id_card_masking_for_confidential_data(self, masking_processor, internal_user_context, sample_data):
|
||||
"""Test ID card masking for confidential data"""
|
||||
# Internal user should not see ID card details (confidential level)
|
||||
result = await masking_processor.process(sample_data, internal_user_context)
|
||||
|
||||
# ID cards should be masked for internal users
|
||||
assert result[0]["id_card"] == "110101********1234"
|
||||
assert result[1]["id_card"] == "110101********2345"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_data_handling(self, masking_processor, internal_user_context):
|
||||
"""Test empty data handling"""
|
||||
empty_data = []
|
||||
|
||||
result = await masking_processor.process(empty_data, internal_user_context)
|
||||
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_null_value_handling(self, masking_processor, internal_user_context):
|
||||
"""Test null value handling"""
|
||||
data_with_nulls = [
|
||||
{
|
||||
"id": 1,
|
||||
"name": "张三",
|
||||
"phone": None,
|
||||
"email": None,
|
||||
"id_card": None
|
||||
}
|
||||
]
|
||||
|
||||
result = await masking_processor.process(data_with_nulls, internal_user_context)
|
||||
|
||||
# Null values should remain null
|
||||
assert result[0]["phone"] is None
|
||||
assert result[0]["email"] is None
|
||||
assert result[0]["id_card"] is None
|
||||
|
||||
def test_phone_masking_algorithm(self, masking_processor):
|
||||
"""Test phone masking algorithm"""
|
||||
params = {"mask_char": "*", "keep_prefix": 3, "keep_suffix": 4}
|
||||
|
||||
result = masking_processor._mask_phone("13812345678", params)
|
||||
|
||||
assert result == "138****5678"
|
||||
|
||||
def test_email_masking_algorithm(self, masking_processor):
|
||||
"""Test email masking algorithm"""
|
||||
params = {"mask_char": "*"}
|
||||
|
||||
result = masking_processor._mask_email("zhangsan@example.com", params)
|
||||
|
||||
assert result == "z******n@example.com"
|
||||
|
||||
def test_id_card_masking_algorithm(self, masking_processor):
|
||||
"""Test ID card masking algorithm"""
|
||||
params = {"mask_char": "*", "keep_prefix": 6, "keep_suffix": 4}
|
||||
|
||||
result = masking_processor._mask_id_card("110101199001011234", params)
|
||||
|
||||
assert result == "110101********1234"
|
||||
|
||||
def test_name_masking_algorithm(self, masking_processor):
|
||||
"""Test name masking algorithm"""
|
||||
params = {"mask_char": "*"}
|
||||
|
||||
# Test 2-character name
|
||||
result = masking_processor._mask_name("张三", params)
|
||||
assert result == "张*"
|
||||
|
||||
# Test 3-character name
|
||||
result = masking_processor._mask_name("李小明", params)
|
||||
assert result == "李*明"
|
||||
|
||||
def test_partial_masking_algorithm(self, masking_processor):
|
||||
"""Test partial masking algorithm"""
|
||||
params = {"mask_char": "*", "mask_ratio": 0.5}
|
||||
|
||||
result = masking_processor._mask_partial("1234567890", params)
|
||||
|
||||
# Should mask middle 50% of the string
|
||||
assert "*" in result
|
||||
assert len(result) == 10
|
||||
|
||||
def test_should_apply_rule_logic(self, masking_processor, internal_user_context, admin_context):
|
||||
"""Test masking rule application logic"""
|
||||
rule = MaskingRule(
|
||||
column_pattern=r".*phone.*",
|
||||
algorithm="phone_mask",
|
||||
parameters={"mask_char": "*", "keep_prefix": 3, "keep_suffix": 4},
|
||||
security_level=SecurityLevel.INTERNAL
|
||||
)
|
||||
|
||||
# Internal user should have rule applied
|
||||
assert masking_processor._should_apply_rule(rule, internal_user_context) is True
|
||||
|
||||
# Admin should not have rule applied
|
||||
assert masking_processor._should_apply_rule(rule, admin_context) is False
|
||||
|
||||
def test_get_applicable_rules(self, masking_processor, internal_user_context):
|
||||
"""Test getting applicable rules"""
|
||||
rules = masking_processor._get_applicable_rules(internal_user_context)
|
||||
|
||||
# Should return some rules for internal user
|
||||
assert len(rules) > 0
|
||||
assert all(isinstance(rule, MaskingRule) for rule in rules)
|
||||
@@ -1,172 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Security manager integration tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from doris_mcp_server.utils.security import (
|
||||
DorisSecurityManager,
|
||||
AuthContext,
|
||||
SecurityLevel,
|
||||
ValidationResult
|
||||
)
|
||||
|
||||
|
||||
class TestDorisSecurityManager:
|
||||
"""Doris security manager integration tests"""
|
||||
|
||||
@pytest.fixture
|
||||
def security_manager(self, test_config):
|
||||
"""Create security manager instance"""
|
||||
return DorisSecurityManager(test_config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_security_workflow(self, security_manager, sample_data):
|
||||
"""Test complete security workflow"""
|
||||
# 1. Authentication
|
||||
auth_info = {
|
||||
"type": "token",
|
||||
"token": "valid_token_123"
|
||||
}
|
||||
|
||||
auth_context = await security_manager.authenticate_request(auth_info)
|
||||
assert isinstance(auth_context, AuthContext)
|
||||
assert auth_context.security_level == SecurityLevel.INTERNAL
|
||||
|
||||
# 2. Authorization
|
||||
resource_uri = "/api/table/public_reports"
|
||||
has_access = await security_manager.authorize_resource_access(auth_context, resource_uri)
|
||||
assert has_access is True
|
||||
|
||||
# 3. SQL Validation
|
||||
safe_sql = "SELECT name, email FROM users WHERE department = 'sales'"
|
||||
validation_result = await security_manager.validate_sql_security(safe_sql, auth_context)
|
||||
assert validation_result.is_valid is True
|
||||
|
||||
# 4. Data Masking
|
||||
masked_data = await security_manager.apply_data_masking(sample_data, auth_context)
|
||||
assert masked_data[0]["phone"] == "138****5678" # Should be masked
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_workflow(self, security_manager, sample_data):
|
||||
"""Test admin user workflow"""
|
||||
# Admin authentication
|
||||
auth_info = {
|
||||
"type": "basic",
|
||||
"username": "admin",
|
||||
"password": "admin123"
|
||||
}
|
||||
|
||||
auth_context = await security_manager.authenticate_request(auth_info)
|
||||
assert auth_context.security_level == SecurityLevel.SECRET
|
||||
|
||||
# Admin should access secret resources
|
||||
resource_uri = "/api/table/payment_records"
|
||||
has_access = await security_manager.authorize_resource_access(auth_context, resource_uri)
|
||||
assert has_access is True
|
||||
|
||||
# Admin should see original data (no masking)
|
||||
masked_data = await security_manager.apply_data_masking(sample_data, auth_context)
|
||||
assert masked_data[0]["phone"] == "13812345678" # Original data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_security_violation_detection(self, security_manager):
|
||||
"""Test security violation detection"""
|
||||
# Authenticate as regular user
|
||||
auth_info = {
|
||||
"type": "token",
|
||||
"token": "valid_token_123"
|
||||
}
|
||||
|
||||
auth_context = await security_manager.authenticate_request(auth_info)
|
||||
|
||||
# Try to access confidential resource (user_info is CONFIDENTIAL, user is INTERNAL)
|
||||
# INTERNAL(1) should not access CONFIDENTIAL(2) resource
|
||||
resource_uri = "/api/table/user_info"
|
||||
has_access = await security_manager.authorize_resource_access(auth_context, resource_uri)
|
||||
assert has_access is False
|
||||
|
||||
# Try dangerous SQL
|
||||
dangerous_sql = "DROP TABLE users"
|
||||
validation_result = await security_manager.validate_sql_security(dangerous_sql, auth_context)
|
||||
assert validation_result.is_valid is False
|
||||
assert "DROP" in validation_result.blocked_operations
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sql_injection_prevention(self, security_manager):
|
||||
"""Test SQL injection prevention"""
|
||||
auth_info = {
|
||||
"type": "token",
|
||||
"token": "valid_token_123"
|
||||
}
|
||||
|
||||
auth_context = await security_manager.authenticate_request(auth_info)
|
||||
|
||||
# Test various injection attempts
|
||||
injection_attempts = [
|
||||
"SELECT * FROM users WHERE id = 1; DROP TABLE users;",
|
||||
"SELECT * FROM users UNION SELECT password FROM admin_users",
|
||||
"SELECT * FROM users WHERE id = 1 OR 1=1",
|
||||
"SELECT * FROM users WHERE name = 'test' -- AND password = 'secret'"
|
||||
]
|
||||
|
||||
for sql in injection_attempts:
|
||||
result = await security_manager.validate_sql_security(sql, auth_context)
|
||||
assert result.is_valid is False
|
||||
assert result.risk_level in ["medium", "high"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authentication_failure_handling(self, security_manager):
|
||||
"""Test authentication failure handling"""
|
||||
invalid_auth_info = {
|
||||
"type": "token",
|
||||
"token": "invalid_token"
|
||||
}
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await security_manager.authenticate_request(invalid_auth_info)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_configuration_loading(self, security_manager):
|
||||
"""Test security configuration loading"""
|
||||
# Test blocked keywords loading
|
||||
assert "DROP" in security_manager.blocked_keywords
|
||||
assert "DELETE" in security_manager.blocked_keywords
|
||||
|
||||
# Test sensitive tables loading
|
||||
assert SecurityLevel.CONFIDENTIAL in security_manager.sensitive_tables.values()
|
||||
assert SecurityLevel.SECRET in security_manager.sensitive_tables.values()
|
||||
|
||||
# Test masking rules loading
|
||||
assert len(security_manager.masking_rules) > 0
|
||||
phone_rules = [rule for rule in security_manager.masking_rules
|
||||
if "phone" in rule.column_pattern]
|
||||
assert len(phone_rules) > 0
|
||||
|
||||
def test_security_level_hierarchy(self, security_manager):
|
||||
"""Test security level hierarchy"""
|
||||
# Test that hierarchy is correctly defined
|
||||
levels = [SecurityLevel.PUBLIC, SecurityLevel.INTERNAL,
|
||||
SecurityLevel.CONFIDENTIAL, SecurityLevel.SECRET]
|
||||
|
||||
# Each level should be properly defined
|
||||
for level in levels:
|
||||
assert isinstance(level, SecurityLevel)
|
||||
assert level.value in ["public", "internal", "confidential", "secret"]
|
||||
@@ -1,161 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
SQL security validation tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from doris_mcp_server.utils.security import (
|
||||
SQLSecurityValidator,
|
||||
AuthContext,
|
||||
SecurityLevel,
|
||||
ValidationResult
|
||||
)
|
||||
|
||||
|
||||
class TestSQLSecurityValidator:
|
||||
"""SQL security validator tests"""
|
||||
|
||||
@pytest.fixture
|
||||
def sql_validator(self, test_config):
|
||||
"""Create SQL validator instance"""
|
||||
return SQLSecurityValidator(test_config)
|
||||
|
||||
@pytest.fixture
|
||||
def analyst_context(self):
|
||||
"""Create analyst auth context"""
|
||||
return AuthContext(
|
||||
user_id="analyst1",
|
||||
roles=["data_analyst"],
|
||||
permissions=["read_data"],
|
||||
session_id="session_123",
|
||||
security_level=SecurityLevel.INTERNAL
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_safe_select_query(self, sql_validator, analyst_context, test_sql_queries):
|
||||
"""Test safe SELECT query validation"""
|
||||
sql = test_sql_queries["safe_select"]
|
||||
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
|
||||
assert result.is_valid is True
|
||||
assert result.error_message is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocked_drop_operation(self, sql_validator, analyst_context, test_sql_queries):
|
||||
"""Test blocked DROP operation"""
|
||||
sql = test_sql_queries["dangerous_drop"]
|
||||
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert "blocked operations" in result.error_message.lower()
|
||||
assert "DROP" in result.blocked_operations
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sql_injection_detection(self, sql_validator, analyst_context, test_sql_queries):
|
||||
"""Test SQL injection detection"""
|
||||
sql = test_sql_queries["sql_injection"]
|
||||
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert "injection" in result.error_message.lower()
|
||||
assert result.risk_level == "high"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_union_injection_detection(self, sql_validator, analyst_context, test_sql_queries):
|
||||
"""Test UNION injection detection"""
|
||||
sql = test_sql_queries["union_injection"]
|
||||
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert "injection" in result.error_message.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_comment_injection_detection(self, sql_validator, analyst_context, test_sql_queries):
|
||||
"""Test comment injection detection"""
|
||||
sql = test_sql_queries["comment_injection"]
|
||||
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert "comment" in result.error_message.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complex_query_validation(self, sql_validator, analyst_context, test_sql_queries):
|
||||
"""Test complex query validation"""
|
||||
sql = test_sql_queries["complex_query"]
|
||||
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
|
||||
# Complex query should pass if within limits
|
||||
assert result.is_valid is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocked_keywords_detection(self, sql_validator, analyst_context):
|
||||
"""Test blocked keywords detection"""
|
||||
blocked_sqls = [
|
||||
"DELETE FROM users WHERE id = 1",
|
||||
"TRUNCATE TABLE logs",
|
||||
"ALTER TABLE users ADD COLUMN new_col VARCHAR(50)",
|
||||
"CREATE TABLE test (id INT)",
|
||||
"INSERT INTO users VALUES (1, 'test')",
|
||||
"UPDATE users SET name = 'test' WHERE id = 1"
|
||||
]
|
||||
|
||||
for sql in blocked_sqls:
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
assert result.is_valid is False
|
||||
assert result.blocked_operations is not None
|
||||
assert len(result.blocked_operations) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_access_validation(self, sql_validator, analyst_context):
|
||||
"""Test table access validation"""
|
||||
# Test access to sensitive table
|
||||
sql = "SELECT * FROM sensitive_data"
|
||||
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
|
||||
# Should fail for non-admin users
|
||||
assert result.is_valid is False
|
||||
assert "access" in result.error_message.lower()
|
||||
|
||||
def test_extract_table_names(self, sql_validator):
|
||||
"""Test table name extraction"""
|
||||
sql = "SELECT u.name FROM users u JOIN departments d ON u.dept_id = d.id"
|
||||
|
||||
parsed = __import__('sqlparse').parse(sql)[0]
|
||||
tables = sql_validator._extract_table_names(parsed)
|
||||
|
||||
# Should extract at least one table name
|
||||
assert len(tables) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_malformed_sql_handling(self, sql_validator, analyst_context):
|
||||
"""Test malformed SQL handling"""
|
||||
malformed_sql = "SELECT * FROM users WHERE"
|
||||
|
||||
result = await sql_validator.validate(malformed_sql, analyst_context)
|
||||
|
||||
# Should handle gracefully
|
||||
assert isinstance(result, ValidationResult)
|
||||
@@ -1,69 +0,0 @@
|
||||
{
|
||||
"server_endpoints": {
|
||||
"http": {
|
||||
"url": "http://localhost:3000/mcp",
|
||||
"timeout": 30,
|
||||
"headers": {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
},
|
||||
"http_network": {
|
||||
"url": "http://192.168.31.168:3000/mcp",
|
||||
"timeout": 30,
|
||||
"headers": {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
},
|
||||
"stdio": {
|
||||
"command": "uv",
|
||||
"args": ["run", "python", "-m", "doris_mcp_server.main", "--transport", "stdio"],
|
||||
"timeout": 30,
|
||||
"working_directory": ".."
|
||||
}
|
||||
},
|
||||
"test_settings": {
|
||||
"default_transport": "http",
|
||||
"retry_attempts": 3,
|
||||
"retry_delay": 1.0,
|
||||
"test_timeout": 60,
|
||||
"enable_performance_tests": true,
|
||||
"enable_security_tests": true
|
||||
},
|
||||
"test_data": {
|
||||
"sample_queries": [
|
||||
"SELECT 1 as test_value",
|
||||
"SHOW DATABASES",
|
||||
"SELECT COUNT(*) FROM information_schema.tables"
|
||||
],
|
||||
"test_databases": ["test_db", "demo_db"],
|
||||
"test_tables": ["users", "orders", "products"],
|
||||
"auth_tokens": {
|
||||
"valid_token": "valid_token_123",
|
||||
"admin_token": "admin_token_456",
|
||||
"invalid_token": "invalid_token_789"
|
||||
}
|
||||
},
|
||||
"expected_tools": [
|
||||
"exec_query",
|
||||
"get_db_list",
|
||||
"get_db_table_list",
|
||||
"get_table_schema",
|
||||
"get_table_comment",
|
||||
"get_table_column_comments",
|
||||
"get_table_indexes",
|
||||
"column_analysis",
|
||||
"performance_stats",
|
||||
"get_recent_audit_logs",
|
||||
"get_catalog_list"
|
||||
],
|
||||
"expected_resources": [
|
||||
"database",
|
||||
"table",
|
||||
"view"
|
||||
],
|
||||
"expected_prompts": [
|
||||
"sql_query_assistant",
|
||||
"data_analysis_helper",
|
||||
"schema_explorer"
|
||||
]
|
||||
}
|
||||
@@ -1,214 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Test Configuration Loader
|
||||
|
||||
Loads test configuration and provides methods to connect to running servers
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
import logging
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
||||
|
||||
from doris_mcp_client.client import DorisUnifiedClient, DorisClientConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestConfigLoader:
|
||||
"""Test configuration loader and client factory"""
|
||||
|
||||
def __init__(self, config_path: Optional[str] = None):
|
||||
"""Initialize with config file path"""
|
||||
if config_path is None:
|
||||
config_path = os.path.join(os.path.dirname(__file__), "test_config.json")
|
||||
|
||||
self.config_path = Path(config_path)
|
||||
self.config = self._load_config()
|
||||
|
||||
def _load_config(self) -> Dict[str, Any]:
|
||||
"""Load configuration from JSON file"""
|
||||
try:
|
||||
with open(self.config_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
logger.info(f"Loaded test configuration from {self.config_path}")
|
||||
return config
|
||||
except FileNotFoundError:
|
||||
logger.error(f"Test configuration file not found: {self.config_path}")
|
||||
raise
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Invalid JSON in test configuration: {e}")
|
||||
raise
|
||||
|
||||
def get_http_client_config(self) -> DorisClientConfig:
|
||||
"""Get HTTP client configuration"""
|
||||
http_config = self.config["server_endpoints"]["http"]
|
||||
return DorisClientConfig.http(
|
||||
url=http_config["url"],
|
||||
timeout=http_config["timeout"]
|
||||
)
|
||||
|
||||
def get_stdio_client_config(self) -> DorisClientConfig:
|
||||
"""Get stdio client configuration"""
|
||||
stdio_config = self.config["server_endpoints"]["stdio"]
|
||||
return DorisClientConfig.stdio(
|
||||
command=stdio_config["command"],
|
||||
args=stdio_config["args"]
|
||||
)
|
||||
|
||||
def get_default_client_config(self) -> DorisClientConfig:
|
||||
"""Get default client configuration based on test settings"""
|
||||
transport = self.config["test_settings"]["default_transport"]
|
||||
if transport == "http":
|
||||
return self.get_http_client_config()
|
||||
elif transport == "stdio":
|
||||
return self.get_stdio_client_config()
|
||||
else:
|
||||
raise ValueError(f"Unknown transport type: {transport}")
|
||||
|
||||
def create_client(self, transport: Optional[str] = None) -> DorisUnifiedClient:
|
||||
"""Create MCP client instance"""
|
||||
if transport is None:
|
||||
client_config = self.get_default_client_config()
|
||||
elif transport == "http":
|
||||
client_config = self.get_http_client_config()
|
||||
elif transport == "stdio":
|
||||
client_config = self.get_stdio_client_config()
|
||||
else:
|
||||
raise ValueError(f"Unknown transport type: {transport}")
|
||||
|
||||
return DorisUnifiedClient(client_config)
|
||||
|
||||
def get_test_settings(self) -> Dict[str, Any]:
|
||||
"""Get test settings"""
|
||||
return self.config["test_settings"]
|
||||
|
||||
def get_test_data(self) -> Dict[str, Any]:
|
||||
"""Get test data"""
|
||||
return self.config["test_data"]
|
||||
|
||||
def get_expected_tools(self) -> list[str]:
|
||||
"""Get expected tools list"""
|
||||
return self.config["expected_tools"]
|
||||
|
||||
def get_expected_resources(self) -> list[str]:
|
||||
"""Get expected resources list"""
|
||||
return self.config["expected_resources"]
|
||||
|
||||
def get_expected_prompts(self) -> list[str]:
|
||||
"""Get expected prompts list"""
|
||||
return self.config["expected_prompts"]
|
||||
|
||||
def get_sample_queries(self) -> list[str]:
|
||||
"""Get sample queries for testing"""
|
||||
return self.config["test_data"]["sample_queries"]
|
||||
|
||||
def get_auth_tokens(self) -> Dict[str, str]:
|
||||
"""Get authentication tokens for testing"""
|
||||
return self.config["test_data"]["auth_tokens"]
|
||||
|
||||
def get_test_databases(self) -> list[str]:
|
||||
"""Get test databases list"""
|
||||
return self.config["test_data"]["test_databases"]
|
||||
|
||||
def get_test_tables(self) -> list[str]:
|
||||
"""Get test tables list"""
|
||||
return self.config["test_data"]["test_tables"]
|
||||
|
||||
def is_performance_tests_enabled(self) -> bool:
|
||||
"""Check if performance tests are enabled"""
|
||||
return self.config["test_settings"]["enable_performance_tests"]
|
||||
|
||||
def is_security_tests_enabled(self) -> bool:
|
||||
"""Check if security tests are enabled"""
|
||||
return self.config["test_settings"]["enable_security_tests"]
|
||||
|
||||
def get_retry_config(self) -> Dict[str, Any]:
|
||||
"""Get retry configuration"""
|
||||
return {
|
||||
"attempts": self.config["test_settings"]["retry_attempts"],
|
||||
"delay": self.config["test_settings"]["retry_delay"]
|
||||
}
|
||||
|
||||
def get_test_timeout(self) -> int:
|
||||
"""Get test timeout in seconds"""
|
||||
return self.config["test_settings"]["test_timeout"]
|
||||
|
||||
|
||||
# Global test config instance
|
||||
_test_config = None
|
||||
|
||||
def get_test_config() -> TestConfigLoader:
|
||||
"""Get global test configuration instance"""
|
||||
global _test_config
|
||||
if _test_config is None:
|
||||
_test_config = TestConfigLoader()
|
||||
return _test_config
|
||||
|
||||
|
||||
def create_test_client(transport: Optional[str] = None) -> DorisUnifiedClient:
|
||||
"""Create test client with default configuration"""
|
||||
return get_test_config().create_client(transport)
|
||||
|
||||
|
||||
async def test_server_connectivity(transport: Optional[str] = None) -> bool:
|
||||
"""Test server connectivity"""
|
||||
try:
|
||||
client = create_test_client(transport)
|
||||
|
||||
async def test_connection(client_instance):
|
||||
try:
|
||||
# Try to list tools as a connectivity test
|
||||
tools = await client_instance.list_all_tools()
|
||||
return len(tools) > 0
|
||||
except Exception as e:
|
||||
logger.error(f"Connectivity test failed: {e}")
|
||||
return False
|
||||
|
||||
result = await client.connect_and_run(test_connection)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to test server connectivity: {e}")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Test configuration loading
|
||||
import asyncio
|
||||
|
||||
async def main():
|
||||
config = get_test_config()
|
||||
print("Test Configuration Loaded:")
|
||||
print(f" Default transport: {config.get_test_settings()['default_transport']}")
|
||||
print(f" Expected tools: {len(config.get_expected_tools())}")
|
||||
print(f" Sample queries: {len(config.get_sample_queries())}")
|
||||
|
||||
# Test connectivity
|
||||
print("\nTesting server connectivity...")
|
||||
http_ok = await test_server_connectivity("http")
|
||||
print(f" HTTP connectivity: {'✓' if http_ok else '✗'}")
|
||||
|
||||
stdio_ok = await test_server_connectivity("stdio")
|
||||
print(f" Stdio connectivity: {'✓' if stdio_ok else '✗'}")
|
||||
|
||||
asyncio.run(main())
|
||||
@@ -1,192 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Tools Manager Client-Server Integration Tests
|
||||
|
||||
Tests the tools functionality through actual MCP client-server communication
|
||||
Assumes the server is already running and configured properly
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import pytest
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict, Any
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
|
||||
|
||||
from test.test_config_loader import get_test_config, create_test_client, test_server_connectivity
|
||||
|
||||
|
||||
class TestToolsClientServer:
|
||||
"""Test tools functionality through client-server communication"""
|
||||
|
||||
@pytest.fixture
|
||||
def test_config(self):
|
||||
"""Get test configuration"""
|
||||
return get_test_config()
|
||||
|
||||
@pytest.fixture
|
||||
async def client(self, test_config):
|
||||
"""Create test client"""
|
||||
return create_test_client()
|
||||
|
||||
@pytest.fixture(scope="class", autouse=True)
|
||||
async def check_server_connectivity(self):
|
||||
"""Check server connectivity before running tests"""
|
||||
is_connected = await test_server_connectivity()
|
||||
if not is_connected:
|
||||
pytest.skip("Server is not running or not accessible")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools_via_client(self, client, test_config):
|
||||
"""Test listing tools through client-server communication"""
|
||||
expected_tools = test_config.get_expected_tools()
|
||||
|
||||
async def test_callback(client_instance):
|
||||
tools = await client_instance.list_all_tools()
|
||||
|
||||
# Verify we got tools back
|
||||
assert len(tools) > 0, "No tools returned from server"
|
||||
|
||||
# Verify expected tools are present
|
||||
tool_names = [tool.name for tool in tools]
|
||||
for expected_tool in expected_tools:
|
||||
assert expected_tool in tool_names, f"Expected tool '{expected_tool}' not found"
|
||||
|
||||
return tools
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert len(result) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_exec_query_via_client(self, client, test_config):
|
||||
"""Test calling exec_query tool through client"""
|
||||
sample_queries = test_config.get_sample_queries()
|
||||
|
||||
async def test_callback(client_instance):
|
||||
# Test with a simple query
|
||||
result = await client_instance.call_tool("exec_query", {
|
||||
"sql": sample_queries[0], # "SELECT 1 as test_value"
|
||||
"max_rows": 100
|
||||
})
|
||||
|
||||
# Verify result structure
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
|
||||
if result["success"]:
|
||||
assert "result" in result, "Successful result should contain 'result' field"
|
||||
else:
|
||||
assert "error" in result, "Failed result should contain 'error' field"
|
||||
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
# Don't assert success=True as it depends on actual server state
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_get_db_list_via_client(self, client, test_config):
|
||||
"""Test calling get_db_list tool through client"""
|
||||
async def test_callback(client_instance):
|
||||
result = await client_instance.call_tool("get_db_list", {})
|
||||
|
||||
# Verify result structure
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
|
||||
if result["success"]:
|
||||
assert "result" in result, "Successful result should contain 'result' field"
|
||||
assert isinstance(result["result"], list), "Database list should be a list"
|
||||
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_get_table_schema_via_client(self, client, test_config):
|
||||
"""Test calling get_table_schema tool through client"""
|
||||
test_tables = test_config.get_test_tables()
|
||||
|
||||
async def test_callback(client_instance):
|
||||
result = await client_instance.call_tool("get_table_schema", {
|
||||
"table_name": test_tables[0], # "users"
|
||||
"db_name": "information_schema" # Use a database that should exist
|
||||
})
|
||||
|
||||
# Verify result structure
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_tool_performance_stats_via_client(self, client, test_config):
|
||||
"""Test calling performance_stats tool through client"""
|
||||
if not test_config.is_performance_tests_enabled():
|
||||
pytest.skip("Performance tests are disabled")
|
||||
|
||||
async def test_callback(client_instance):
|
||||
result = await client_instance.call_tool("performance_stats", {
|
||||
"metric_type": "queries",
|
||||
"time_range": "1h"
|
||||
})
|
||||
|
||||
# Verify result structure
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_error_handling_via_client(self, client, test_config):
|
||||
"""Test tool error handling through client"""
|
||||
async def test_callback(client_instance):
|
||||
# Try to call a tool with invalid parameters
|
||||
result = await client_instance.call_tool("exec_query", {
|
||||
"sql": "INVALID SQL SYNTAX HERE"
|
||||
})
|
||||
|
||||
# Should get a result (either success or error)
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_with_auth_token_via_client(self, client, test_config):
|
||||
"""Test tool calls with authentication token"""
|
||||
if not test_config.is_security_tests_enabled():
|
||||
pytest.skip("Security tests are disabled")
|
||||
|
||||
auth_tokens = test_config.get_auth_tokens()
|
||||
|
||||
async def test_callback(client_instance):
|
||||
result = await client_instance.call_tool("get_db_list", {
|
||||
"auth_token": auth_tokens["valid_token"]
|
||||
})
|
||||
|
||||
# Verify result structure
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
@@ -1,331 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Tools manager tests
|
||||
"""
|
||||
|
||||
import json
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
|
||||
from doris_mcp_server.tools.tools_manager import DorisToolsManager
|
||||
from doris_mcp_server.utils.config import DorisConfig
|
||||
|
||||
|
||||
class TestDorisToolsManager:
|
||||
"""Doris tools manager tests"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Create mock configuration"""
|
||||
from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig
|
||||
|
||||
config = Mock(spec=DorisConfig)
|
||||
config.doris_host = "localhost"
|
||||
config.doris_port = 9030
|
||||
config.doris_user = "test_user"
|
||||
config.doris_password = "test_password"
|
||||
config.doris_database = "test_db"
|
||||
|
||||
# Add database config
|
||||
config.database = Mock(spec=DatabaseConfig)
|
||||
config.database.host = "localhost"
|
||||
config.database.port = 9030
|
||||
config.database.user = "test_user"
|
||||
config.database.password = "test_password"
|
||||
config.database.database = "test_db"
|
||||
config.database.health_check_interval = 60
|
||||
config.database.min_connections = 5
|
||||
config.database.max_connections = 20
|
||||
config.database.connection_timeout = 30
|
||||
config.database.max_connection_age = 3600
|
||||
|
||||
# Add security config
|
||||
config.security = Mock(spec=SecurityConfig)
|
||||
config.security.enable_masking = True
|
||||
config.security.auth_type = "token"
|
||||
config.security.token_secret = "test_secret"
|
||||
config.security.token_expiry = 3600
|
||||
|
||||
return config
|
||||
|
||||
@pytest.fixture
|
||||
def tools_manager(self, mock_config):
|
||||
"""Create tools manager instance"""
|
||||
# Create a proper mock connection manager
|
||||
mock_connection_manager = Mock()
|
||||
mock_connection_manager.get_connection = AsyncMock()
|
||||
return DorisToolsManager(mock_connection_manager)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_available_tools(self, tools_manager):
|
||||
"""Test getting available tools"""
|
||||
tools = await tools_manager.list_tools()
|
||||
|
||||
# Should have core tools
|
||||
tool_names = [tool.name for tool in tools]
|
||||
assert "exec_query" in tool_names
|
||||
assert "get_db_list" in tool_names
|
||||
assert "get_db_table_list" in tool_names
|
||||
assert "get_table_schema" in tool_names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_query_tool(self, tools_manager):
|
||||
"""Test exec_query tool"""
|
||||
# Mock the execute_sql_for_mcp method instead
|
||||
with patch.object(tools_manager.query_executor, 'execute_sql_for_mcp') as mock_execute:
|
||||
mock_execute.return_value = {
|
||||
"success": True,
|
||||
"data": [
|
||||
{"id": 1, "name": "张三"},
|
||||
{"id": 2, "name": "李四"}
|
||||
],
|
||||
"row_count": 2,
|
||||
"execution_time": 0.15
|
||||
}
|
||||
|
||||
arguments = {
|
||||
"sql": "SELECT id, name FROM users LIMIT 2",
|
||||
"max_rows": 100
|
||||
}
|
||||
|
||||
result = await tools_manager.call_tool("exec_query", arguments)
|
||||
result_data = json.loads(result) if isinstance(result, str) else result
|
||||
|
||||
# The test should handle both success and error cases
|
||||
if "success" in result_data and result_data["success"]:
|
||||
# Check if result has data field or result field
|
||||
if "data" in result_data and result_data["data"] is not None:
|
||||
assert len(result_data["data"]) == 2
|
||||
elif "result" in result_data and result_data["result"] is not None:
|
||||
assert len(result_data["result"]) == 2
|
||||
else:
|
||||
# If there's an error, just check that error is reported
|
||||
assert "error" in result_data
|
||||
|
||||
# Verify the method was called (may not be called if there are errors)
|
||||
# Don't assert specific call parameters since the implementation may vary
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_query_with_error(self, tools_manager):
|
||||
"""Test exec_query tool with error"""
|
||||
with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute:
|
||||
mock_execute.side_effect = Exception("Database connection failed")
|
||||
|
||||
arguments = {
|
||||
"sql": "SELECT * FROM users"
|
||||
}
|
||||
|
||||
result = await tools_manager.call_tool("exec_query", arguments)
|
||||
result_data = json.loads(result) if isinstance(result, str) else result
|
||||
|
||||
assert "error" in result_data or "success" in result_data
|
||||
if "error" in result_data:
|
||||
# Accept any connection-related error message
|
||||
assert any(keyword in result_data["error"].lower() for keyword in
|
||||
["connection", "failed", "error", "mock"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_db_list_tool(self, tools_manager):
|
||||
"""Test get_db_list tool"""
|
||||
with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute:
|
||||
mock_execute.return_value = [
|
||||
{"Database": "test_db"},
|
||||
{"Database": "information_schema"},
|
||||
{"Database": "mysql"}
|
||||
]
|
||||
|
||||
result = await tools_manager.call_tool("get_db_list", {})
|
||||
result_data = json.loads(result) if isinstance(result, str) else result
|
||||
|
||||
# Check if result has databases field or result field
|
||||
if "databases" in result_data:
|
||||
assert len(result_data["databases"]) == 3
|
||||
elif "result" in result_data:
|
||||
assert len(result_data["result"]) >= 0 # May be empty if no databases
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_db_table_list_tool(self, tools_manager):
|
||||
"""Test get_db_table_list tool"""
|
||||
with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute:
|
||||
mock_execute.return_value = [
|
||||
{"Tables_in_test_db": "users"},
|
||||
{"Tables_in_test_db": "orders"},
|
||||
{"Tables_in_test_db": "products"}
|
||||
]
|
||||
|
||||
arguments = {"db_name": "test_db"}
|
||||
result = await tools_manager.call_tool("get_db_table_list", arguments)
|
||||
result_data = json.loads(result) if isinstance(result, str) else result
|
||||
|
||||
# Check if result has tables field or result field
|
||||
if "tables" in result_data:
|
||||
assert len(result_data["tables"]) == 3
|
||||
assert "users" in result_data["tables"]
|
||||
elif "result" in result_data:
|
||||
assert len(result_data["result"]) >= 0 # May be empty if no tables
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_schema_tool(self, tools_manager):
|
||||
"""Test get_table_schema tool"""
|
||||
with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute:
|
||||
mock_execute.return_value = [
|
||||
{
|
||||
"Field": "id",
|
||||
"Type": "int(11)",
|
||||
"Null": "NO",
|
||||
"Key": "PRI",
|
||||
"Default": None,
|
||||
"Extra": "auto_increment"
|
||||
},
|
||||
{
|
||||
"Field": "name",
|
||||
"Type": "varchar(100)",
|
||||
"Null": "YES",
|
||||
"Key": "",
|
||||
"Default": None,
|
||||
"Extra": ""
|
||||
}
|
||||
]
|
||||
|
||||
arguments = {"table_name": "users"}
|
||||
result = await tools_manager.call_tool("get_table_schema", arguments)
|
||||
result_data = json.loads(result) if isinstance(result, str) else result
|
||||
|
||||
# Check if result has schema field or result field
|
||||
if "schema" in result_data:
|
||||
assert len(result_data["schema"]) == 2
|
||||
assert result_data["schema"][0]["Field"] == "id"
|
||||
elif "result" in result_data:
|
||||
assert len(result_data["result"]) >= 0 # May be empty if no schema
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_catalog_list_tool(self, tools_manager):
|
||||
"""Test get_catalog_list tool"""
|
||||
with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute:
|
||||
mock_execute.return_value = [
|
||||
{"CatalogName": "internal"},
|
||||
{"CatalogName": "hive_catalog"},
|
||||
{"CatalogName": "iceberg_catalog"}
|
||||
]
|
||||
|
||||
arguments = {"random_string": "test_123"}
|
||||
result = await tools_manager.call_tool("get_catalog_list", arguments)
|
||||
result_data = json.loads(result) if isinstance(result, str) else result
|
||||
|
||||
# Check if result has catalogs field or result field
|
||||
if "catalogs" in result_data:
|
||||
assert len(result_data["catalogs"]) == 3
|
||||
assert "internal" in result_data["catalogs"]
|
||||
elif "result" in result_data:
|
||||
assert len(result_data["result"]) >= 0 # May be empty if no catalogs
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_column_analysis_tool(self, tools_manager):
|
||||
"""Test column_analysis tool"""
|
||||
with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute:
|
||||
# Mock basic analysis result
|
||||
mock_execute.return_value = [
|
||||
{
|
||||
"total_count": 1000,
|
||||
"null_count": 10,
|
||||
"distinct_count": 950,
|
||||
"min_value": 1,
|
||||
"max_value": 1000
|
||||
}
|
||||
]
|
||||
|
||||
arguments = {
|
||||
"table_name": "users",
|
||||
"column_name": "id",
|
||||
"analysis_type": "basic"
|
||||
}
|
||||
|
||||
result = await tools_manager.call_tool("column_analysis", arguments)
|
||||
result_data = json.loads(result) if isinstance(result, str) else result
|
||||
|
||||
# Check if result has analysis field or result field
|
||||
if "analysis" in result_data:
|
||||
assert result_data["analysis"]["total_count"] == 1000
|
||||
elif "result" in result_data:
|
||||
assert "result" in result_data # Just check result exists
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_performance_stats_tool(self, tools_manager):
|
||||
"""Test performance_stats tool"""
|
||||
with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute:
|
||||
mock_execute.return_value = [
|
||||
{
|
||||
"query_count": 1500,
|
||||
"avg_execution_time": 0.25,
|
||||
"slow_query_count": 5,
|
||||
"error_count": 2
|
||||
}
|
||||
]
|
||||
|
||||
arguments = {
|
||||
"metric_type": "queries",
|
||||
"time_range": "1h"
|
||||
}
|
||||
|
||||
result = await tools_manager.call_tool("performance_stats", arguments)
|
||||
result_data = json.loads(result) if isinstance(result, str) else result
|
||||
|
||||
# Check if result has stats field or result field
|
||||
if "stats" in result_data:
|
||||
assert result_data["stats"]["query_count"] == 1500
|
||||
elif "result" in result_data:
|
||||
assert "result" in result_data # Just check result exists
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_tool_name(self, tools_manager):
|
||||
"""Test calling invalid tool"""
|
||||
result = await tools_manager.call_tool("invalid_tool", {})
|
||||
result_data = json.loads(result) if isinstance(result, str) else result
|
||||
|
||||
assert "error" in result_data or "success" in result_data
|
||||
if "error" in result_data:
|
||||
assert "Unknown tool" in result_data["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_required_arguments(self, tools_manager):
|
||||
"""Test calling tool with missing required arguments"""
|
||||
# exec_query requires sql parameter
|
||||
result = await tools_manager.call_tool("exec_query", {})
|
||||
result_data = json.loads(result) if isinstance(result, str) else result
|
||||
|
||||
assert "error" in result_data or "success" in result_data
|
||||
# The test may pass if the tool handles missing parameters gracefully
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_definitions_structure(self, tools_manager):
|
||||
"""Test tool definitions have correct structure"""
|
||||
tools = await tools_manager.list_tools()
|
||||
|
||||
for tool in tools:
|
||||
# Each tool should have required fields
|
||||
assert hasattr(tool, 'name')
|
||||
assert hasattr(tool, 'description')
|
||||
assert hasattr(tool, 'inputSchema')
|
||||
|
||||
# Input schema should have properties
|
||||
assert 'properties' in tool.inputSchema
|
||||
|
||||
# Required fields should be defined
|
||||
if 'required' in tool.inputSchema:
|
||||
assert isinstance(tool.inputSchema['required'], list)
|
||||
@@ -1,202 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Query executor tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, patch
|
||||
|
||||
from doris_mcp_server.utils.query_executor import DorisQueryExecutor
|
||||
from doris_mcp_server.utils.config import DorisConfig
|
||||
|
||||
|
||||
class TestDorisQueryExecutor:
|
||||
"""Doris query executor tests"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
"""Create mock configuration"""
|
||||
from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig
|
||||
|
||||
config = Mock(spec=DorisConfig)
|
||||
config.doris_host = "localhost"
|
||||
config.doris_port = 9030
|
||||
config.doris_user = "test_user"
|
||||
config.doris_password = "test_password"
|
||||
config.doris_database = "test_db"
|
||||
|
||||
# Add database config
|
||||
config.database = Mock(spec=DatabaseConfig)
|
||||
config.database.host = "localhost"
|
||||
config.database.port = 9030
|
||||
config.database.user = "test_user"
|
||||
config.database.password = "test_password"
|
||||
config.database.database = "test_db"
|
||||
config.database.health_check_interval = 60
|
||||
config.database.min_connections = 5
|
||||
config.database.max_connections = 20
|
||||
config.database.connection_timeout = 30
|
||||
config.database.max_connection_age = 3600
|
||||
|
||||
return config
|
||||
|
||||
@pytest.fixture
|
||||
def query_executor(self, mock_config):
|
||||
"""Create query executor instance"""
|
||||
# Create a mock connection manager
|
||||
mock_connection_manager = Mock()
|
||||
return DorisQueryExecutor(mock_connection_manager, mock_config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_query_success(self, query_executor):
|
||||
"""Test successful query execution using MCP interface"""
|
||||
with patch.object(query_executor, 'execute_sql_for_mcp') as mock_execute:
|
||||
mock_execute.return_value = {
|
||||
"success": True,
|
||||
"data": [
|
||||
{"id": 1, "name": "张三", "email": "zhangsan@example.com"},
|
||||
{"id": 2, "name": "李四", "email": "lisi@example.com"}
|
||||
],
|
||||
"row_count": 2,
|
||||
"execution_time": 0.15,
|
||||
"columns": ["id", "name", "email"]
|
||||
}
|
||||
|
||||
sql = "SELECT id, name, email FROM users LIMIT 2"
|
||||
result = await query_executor.execute_sql_for_mcp(sql)
|
||||
|
||||
# Verify results
|
||||
assert result["success"] is True
|
||||
assert result["row_count"] == 2
|
||||
assert len(result["data"]) == 2
|
||||
assert result["data"][0]["id"] == 1
|
||||
assert result["data"][0]["name"] == "张三"
|
||||
assert result["data"][1]["email"] == "lisi@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_query_with_parameters(self, query_executor):
|
||||
"""Test query execution with parameters"""
|
||||
with patch.object(query_executor, 'execute_sql_for_mcp') as mock_execute:
|
||||
mock_execute.return_value = {
|
||||
"success": True,
|
||||
"data": [{"id": 1, "name": "张三"}],
|
||||
"row_count": 1,
|
||||
"execution_time": 0.1
|
||||
}
|
||||
|
||||
sql = "SELECT id, name FROM users WHERE department = 'sales'"
|
||||
result = await query_executor.execute_sql_for_mcp(sql)
|
||||
|
||||
# Verify results
|
||||
assert result["success"] is True
|
||||
assert result["row_count"] == 1
|
||||
assert len(result["data"]) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_query_connection_error(self, query_executor):
|
||||
"""Test query execution with connection error"""
|
||||
with patch.object(query_executor, 'execute_sql_for_mcp') as mock_execute:
|
||||
mock_execute.return_value = {
|
||||
"success": False,
|
||||
"error": "Connection failed",
|
||||
"data": None
|
||||
}
|
||||
|
||||
sql = "SELECT * FROM users"
|
||||
result = await query_executor.execute_sql_for_mcp(sql)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "Connection failed" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_query_sql_error(self, query_executor):
|
||||
"""Test query execution with SQL error"""
|
||||
with patch.object(query_executor, 'execute_sql_for_mcp') as mock_execute:
|
||||
mock_execute.return_value = {
|
||||
"success": False,
|
||||
"error": "SQL syntax error",
|
||||
"data": None
|
||||
}
|
||||
|
||||
sql = "SELECT * FROM non_existent_table"
|
||||
result = await query_executor.execute_sql_for_mcp(sql)
|
||||
|
||||
assert result["success"] is False
|
||||
assert "SQL syntax error" in result["error"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_query_empty_result(self, query_executor):
|
||||
"""Test query execution with empty result"""
|
||||
with patch.object(query_executor, 'execute_sql_for_mcp') as mock_execute:
|
||||
mock_execute.return_value = {
|
||||
"success": True,
|
||||
"data": [],
|
||||
"row_count": 0,
|
||||
"execution_time": 0.05
|
||||
}
|
||||
|
||||
sql = "SELECT * FROM users WHERE id = 999"
|
||||
result = await query_executor.execute_sql_for_mcp(sql)
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["data"] == []
|
||||
assert result["row_count"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_query_max_rows_limit(self, query_executor):
|
||||
"""Test query execution with max rows limit"""
|
||||
with patch.object(query_executor, 'execute_sql_for_mcp') as mock_execute:
|
||||
# Mock large result set limited to 100 rows
|
||||
limited_result = [{"id": i, "name": f"user_{i}"} for i in range(100)]
|
||||
mock_execute.return_value = {
|
||||
"success": True,
|
||||
"data": limited_result,
|
||||
"row_count": 100,
|
||||
"execution_time": 0.2
|
||||
}
|
||||
|
||||
sql = "SELECT id, name FROM users"
|
||||
result = await query_executor.execute_sql_for_mcp(sql, limit=100)
|
||||
|
||||
# Should be limited to max_rows
|
||||
assert result["success"] is True
|
||||
assert len(result["data"]) == 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_for_mcp_interface(self, query_executor):
|
||||
"""Test the MCP interface method directly"""
|
||||
with patch.object(query_executor.connection_manager, 'get_connection') as mock_get_conn:
|
||||
# Mock connection and result
|
||||
mock_connection = AsyncMock()
|
||||
mock_connection.execute.return_value = Mock(
|
||||
data=[{"id": 1, "name": "张三"}],
|
||||
row_count=1,
|
||||
execution_time=0.1,
|
||||
metadata={}
|
||||
)
|
||||
mock_get_conn.return_value = mock_connection
|
||||
|
||||
sql = "SELECT id, name FROM users LIMIT 1"
|
||||
result = await query_executor.execute_sql_for_mcp(sql)
|
||||
|
||||
# Should return success format
|
||||
assert "success" in result
|
||||
if result["success"]:
|
||||
assert "data" in result
|
||||
assert "row_count" in result
|
||||
@@ -1,156 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Query Executor Client-Server Integration Tests
|
||||
|
||||
Tests the query execution functionality through actual MCP client-server communication
|
||||
Assumes the server is already running and configured properly
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import pytest
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict, Any
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
|
||||
|
||||
from test.test_config_loader import get_test_config, create_test_client, test_server_connectivity
|
||||
|
||||
|
||||
class TestQueryExecutorClientServer:
|
||||
"""Test query execution functionality through client-server communication"""
|
||||
|
||||
@pytest.fixture
|
||||
def test_config(self):
|
||||
"""Get test configuration"""
|
||||
return get_test_config()
|
||||
|
||||
@pytest.fixture
|
||||
async def client(self, test_config):
|
||||
"""Create test client"""
|
||||
return create_test_client()
|
||||
|
||||
@pytest.fixture(scope="class", autouse=True)
|
||||
async def check_server_connectivity(self):
|
||||
"""Check server connectivity before running tests"""
|
||||
is_connected = await test_server_connectivity()
|
||||
if not is_connected:
|
||||
pytest.skip("Server is not running or not accessible")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_select_query_via_client(self, client, test_config):
|
||||
"""Test simple SELECT query through client"""
|
||||
sample_queries = test_config.get_sample_queries()
|
||||
|
||||
async def test_callback(client_instance):
|
||||
result = await client_instance.execute_sql(sample_queries[0]) # "SELECT 1 as test_value"
|
||||
|
||||
# Verify result structure
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
|
||||
if result["success"]:
|
||||
assert "result" in result, "Successful result should contain 'result' field"
|
||||
else:
|
||||
assert "error" in result, "Failed result should contain 'error' field"
|
||||
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_show_databases_query_via_client(self, client, test_config):
|
||||
"""Test SHOW DATABASES query through client"""
|
||||
sample_queries = test_config.get_sample_queries()
|
||||
|
||||
async def test_callback(client_instance):
|
||||
result = await client_instance.execute_sql(sample_queries[1]) # "SHOW DATABASES"
|
||||
|
||||
# Verify result structure
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_information_schema_query_via_client(self, client, test_config):
|
||||
"""Test information_schema query through client"""
|
||||
sample_queries = test_config.get_sample_queries()
|
||||
|
||||
async def test_callback(client_instance):
|
||||
result = await client_instance.execute_sql(sample_queries[2]) # "SELECT COUNT(*) FROM information_schema.tables"
|
||||
|
||||
# Verify result structure
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_with_max_rows_parameter_via_client(self, client, test_config):
|
||||
"""Test query with max_rows parameter through client"""
|
||||
async def test_callback(client_instance):
|
||||
result = await client_instance.call_tool("exec_query", {
|
||||
"sql": "SELECT 1 as test_value",
|
||||
"max_rows": 10
|
||||
})
|
||||
|
||||
# Verify result structure
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_error_handling_via_client(self, client, test_config):
|
||||
"""Test query error handling through client"""
|
||||
async def test_callback(client_instance):
|
||||
result = await client_instance.execute_sql("INVALID SQL SYNTAX")
|
||||
|
||||
# Should get a result (either success or error)
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_query_with_auth_token_via_client(self, client, test_config):
|
||||
"""Test query with authentication token"""
|
||||
if not test_config.is_security_tests_enabled():
|
||||
pytest.skip("Security tests are disabled")
|
||||
|
||||
auth_tokens = test_config.get_auth_tokens()
|
||||
|
||||
async def test_callback(client_instance):
|
||||
result = await client_instance.call_tool("exec_query", {
|
||||
"sql": "SELECT 1 as test_value",
|
||||
"auth_token": auth_tokens["valid_token"]
|
||||
})
|
||||
|
||||
# Verify result structure
|
||||
assert "success" in result, "Result should contain 'success' field"
|
||||
return result
|
||||
|
||||
result = await client.connect_and_run(test_callback)
|
||||
assert "success" in result
|
||||
344
uv.lock
generated
344
uv.lock
generated
@@ -1,19 +1,3 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
version = 1
|
||||
revision = 1
|
||||
requires-python = ">=3.12"
|
||||
@@ -534,170 +518,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/8f/d7/9322c609343d929e75e7e5e6255e614fcc67572cfd083959cdef3b7aad79/docutils-0.21.2-py3-none-any.whl", hash = "sha256:dafca5b9e384f0e419294eb4d2ff9fa826435bf15f15b7bd45723e8ad76811b2", size = 587408 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "doris-mcp-server"
|
||||
version = "0.3.0"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "aiofiles" },
|
||||
{ name = "aiohttp" },
|
||||
{ name = "aiomysql" },
|
||||
{ name = "aioredis" },
|
||||
{ name = "asyncio-mqtt" },
|
||||
{ name = "bcrypt" },
|
||||
{ name = "click" },
|
||||
{ name = "cryptography" },
|
||||
{ name = "fastapi" },
|
||||
{ name = "httpx" },
|
||||
{ name = "mcp" },
|
||||
{ name = "numpy" },
|
||||
{ name = "orjson" },
|
||||
{ name = "pandas" },
|
||||
{ name = "passlib", extra = ["bcrypt"] },
|
||||
{ name = "prometheus-client" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "pydantic-settings" },
|
||||
{ name = "pyjwt" },
|
||||
{ name = "pymysql" },
|
||||
{ name = "pytest" },
|
||||
{ name = "pytest-asyncio" },
|
||||
{ name = "pytest-cov" },
|
||||
{ name = "python-dateutil" },
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "python-jose", extra = ["cryptography"] },
|
||||
{ name = "python-multipart" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "requests" },
|
||||
{ name = "rich" },
|
||||
{ name = "sqlparse" },
|
||||
{ name = "starlette" },
|
||||
{ name = "structlog" },
|
||||
{ name = "toml" },
|
||||
{ name = "tqdm" },
|
||||
{ name = "typer" },
|
||||
{ name = "uvicorn", extra = ["standard"] },
|
||||
{ name = "websockets" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
dev = [
|
||||
{ name = "bandit" },
|
||||
{ name = "black" },
|
||||
{ name = "flake8" },
|
||||
{ name = "isort" },
|
||||
{ name = "mypy" },
|
||||
{ name = "myst-parser" },
|
||||
{ name = "pre-commit" },
|
||||
{ name = "pytest" },
|
||||
{ name = "pytest-asyncio" },
|
||||
{ name = "pytest-cov" },
|
||||
{ name = "pytest-mock" },
|
||||
{ name = "pytest-xdist" },
|
||||
{ name = "ruff" },
|
||||
{ name = "safety" },
|
||||
{ name = "sphinx" },
|
||||
{ name = "sphinx-rtd-theme" },
|
||||
{ name = "tox" },
|
||||
]
|
||||
docs = [
|
||||
{ name = "myst-parser" },
|
||||
{ name = "sphinx" },
|
||||
{ name = "sphinx-autoapi" },
|
||||
{ name = "sphinx-rtd-theme" },
|
||||
]
|
||||
monitoring = [
|
||||
{ name = "grafana-client" },
|
||||
{ name = "jaeger-client" },
|
||||
{ name = "opentelemetry-api" },
|
||||
{ name = "opentelemetry-sdk" },
|
||||
{ name = "prometheus-client" },
|
||||
]
|
||||
performance = [
|
||||
{ name = "cchardet" },
|
||||
{ name = "orjson" },
|
||||
{ name = "uvloop" },
|
||||
]
|
||||
|
||||
[package.dev-dependencies]
|
||||
dev = [
|
||||
{ name = "ruff" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "aiofiles", specifier = ">=23.0.0" },
|
||||
{ name = "aiohttp", specifier = ">=3.9.0" },
|
||||
{ name = "aiomysql", specifier = ">=0.2.0" },
|
||||
{ name = "aioredis", specifier = ">=2.0.0" },
|
||||
{ name = "asyncio-mqtt", specifier = ">=0.16.0" },
|
||||
{ name = "bandit", marker = "extra == 'dev'", specifier = ">=1.7.0" },
|
||||
{ name = "bcrypt", specifier = ">=4.1.0" },
|
||||
{ name = "black", marker = "extra == 'dev'", specifier = ">=23.12.0" },
|
||||
{ name = "cchardet", marker = "extra == 'performance'", specifier = ">=2.1.0" },
|
||||
{ name = "click", specifier = ">=8.1.0" },
|
||||
{ name = "cryptography", specifier = ">=41.0.0" },
|
||||
{ name = "fastapi", specifier = ">=0.108.0" },
|
||||
{ name = "flake8", marker = "extra == 'dev'", specifier = ">=7.0.0" },
|
||||
{ name = "grafana-client", marker = "extra == 'monitoring'", specifier = ">=3.5.0" },
|
||||
{ name = "httpx", specifier = ">=0.26.0" },
|
||||
{ name = "isort", marker = "extra == 'dev'", specifier = ">=5.13.0" },
|
||||
{ name = "jaeger-client", marker = "extra == 'monitoring'", specifier = ">=4.8.0" },
|
||||
{ name = "mcp", specifier = ">=1.0.0" },
|
||||
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8.0" },
|
||||
{ name = "myst-parser", marker = "extra == 'dev'", specifier = ">=2.0.0" },
|
||||
{ name = "myst-parser", marker = "extra == 'docs'", specifier = ">=2.0.0" },
|
||||
{ name = "numpy", specifier = ">=1.24.0" },
|
||||
{ name = "opentelemetry-api", marker = "extra == 'monitoring'", specifier = ">=1.21.0" },
|
||||
{ name = "opentelemetry-sdk", marker = "extra == 'monitoring'", specifier = ">=1.21.0" },
|
||||
{ name = "orjson", specifier = ">=3.9.0" },
|
||||
{ name = "orjson", marker = "extra == 'performance'", specifier = ">=3.9.0" },
|
||||
{ name = "pandas", specifier = ">=2.0.0" },
|
||||
{ name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.0" },
|
||||
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.6.0" },
|
||||
{ name = "prometheus-client", specifier = ">=0.19.0" },
|
||||
{ name = "prometheus-client", marker = "extra == 'monitoring'", specifier = ">=0.19.0" },
|
||||
{ name = "pydantic", specifier = ">=2.5.0" },
|
||||
{ name = "pydantic-settings", specifier = ">=2.1.0" },
|
||||
{ name = "pyjwt", specifier = ">=2.8.0" },
|
||||
{ name = "pymysql", specifier = ">=1.1.0" },
|
||||
{ name = "pytest", specifier = ">=8.4.0" },
|
||||
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=7.4.0" },
|
||||
{ name = "pytest-asyncio", specifier = ">=1.0.0" },
|
||||
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.23.0" },
|
||||
{ name = "pytest-cov", specifier = ">=6.1.1" },
|
||||
{ name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.1.0" },
|
||||
{ name = "pytest-mock", marker = "extra == 'dev'", specifier = ">=3.12.0" },
|
||||
{ name = "pytest-xdist", marker = "extra == 'dev'", specifier = ">=3.5.0" },
|
||||
{ name = "python-dateutil", specifier = ">=2.8.0" },
|
||||
{ name = "python-dotenv", specifier = ">=1.0.0" },
|
||||
{ name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" },
|
||||
{ name = "python-multipart", specifier = ">=0.0.6" },
|
||||
{ name = "pyyaml", specifier = ">=6.0.0" },
|
||||
{ name = "requests", specifier = ">=2.31.0" },
|
||||
{ name = "rich", specifier = ">=13.7.0" },
|
||||
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.1.0" },
|
||||
{ name = "safety", marker = "extra == 'dev'", specifier = ">=2.3.0" },
|
||||
{ name = "sphinx", marker = "extra == 'dev'", specifier = ">=7.2.0" },
|
||||
{ name = "sphinx", marker = "extra == 'docs'", specifier = ">=7.2.0" },
|
||||
{ name = "sphinx-autoapi", marker = "extra == 'docs'", specifier = ">=3.0.0" },
|
||||
{ name = "sphinx-rtd-theme", marker = "extra == 'dev'", specifier = ">=2.0.0" },
|
||||
{ name = "sphinx-rtd-theme", marker = "extra == 'docs'", specifier = ">=2.0.0" },
|
||||
{ name = "sqlparse", specifier = ">=0.4.4" },
|
||||
{ name = "starlette", specifier = ">=0.27.0" },
|
||||
{ name = "structlog", specifier = ">=23.2.0" },
|
||||
{ name = "toml", specifier = ">=0.10.0" },
|
||||
{ name = "tox", marker = "extra == 'dev'", specifier = ">=4.11.0" },
|
||||
{ name = "tqdm", specifier = ">=4.66.0" },
|
||||
{ name = "typer", specifier = ">=0.9.0" },
|
||||
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.25.0" },
|
||||
{ name = "uvloop", marker = "extra == 'performance'", specifier = ">=0.19.0" },
|
||||
{ name = "websockets", specifier = ">=12.0" },
|
||||
]
|
||||
provides-extras = ["dev", "docs", "performance", "monitoring"]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [{ name = "ruff", specifier = ">=0.11.13" }]
|
||||
|
||||
[[package]]
|
||||
name = "dparse"
|
||||
version = "0.6.4"
|
||||
@@ -1126,6 +946,170 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/79/45/823ad05504bea55cb0feb7470387f151252127ad5c72f8882e8fe6cf5c0e/mcp-1.9.3-py3-none-any.whl", hash = "sha256:69b0136d1ac9927402ed4cf221d4b8ff875e7132b0b06edd446448766f34f9b9", size = 131063 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mcp-doris-server"
|
||||
version = "0.3.0"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "aiofiles" },
|
||||
{ name = "aiohttp" },
|
||||
{ name = "aiomysql" },
|
||||
{ name = "aioredis" },
|
||||
{ name = "asyncio-mqtt" },
|
||||
{ name = "bcrypt" },
|
||||
{ name = "click" },
|
||||
{ name = "cryptography" },
|
||||
{ name = "fastapi" },
|
||||
{ name = "httpx" },
|
||||
{ name = "mcp" },
|
||||
{ name = "numpy" },
|
||||
{ name = "orjson" },
|
||||
{ name = "pandas" },
|
||||
{ name = "passlib", extra = ["bcrypt"] },
|
||||
{ name = "prometheus-client" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "pydantic-settings" },
|
||||
{ name = "pyjwt" },
|
||||
{ name = "pymysql" },
|
||||
{ name = "pytest" },
|
||||
{ name = "pytest-asyncio" },
|
||||
{ name = "pytest-cov" },
|
||||
{ name = "python-dateutil" },
|
||||
{ name = "python-dotenv" },
|
||||
{ name = "python-jose", extra = ["cryptography"] },
|
||||
{ name = "python-multipart" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "requests" },
|
||||
{ name = "rich" },
|
||||
{ name = "sqlparse" },
|
||||
{ name = "starlette" },
|
||||
{ name = "structlog" },
|
||||
{ name = "toml" },
|
||||
{ name = "tqdm" },
|
||||
{ name = "typer" },
|
||||
{ name = "uvicorn", extra = ["standard"] },
|
||||
{ name = "websockets" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
dev = [
|
||||
{ name = "bandit" },
|
||||
{ name = "black" },
|
||||
{ name = "flake8" },
|
||||
{ name = "isort" },
|
||||
{ name = "mypy" },
|
||||
{ name = "myst-parser" },
|
||||
{ name = "pre-commit" },
|
||||
{ name = "pytest" },
|
||||
{ name = "pytest-asyncio" },
|
||||
{ name = "pytest-cov" },
|
||||
{ name = "pytest-mock" },
|
||||
{ name = "pytest-xdist" },
|
||||
{ name = "ruff" },
|
||||
{ name = "safety" },
|
||||
{ name = "sphinx" },
|
||||
{ name = "sphinx-rtd-theme" },
|
||||
{ name = "tox" },
|
||||
]
|
||||
docs = [
|
||||
{ name = "myst-parser" },
|
||||
{ name = "sphinx" },
|
||||
{ name = "sphinx-autoapi" },
|
||||
{ name = "sphinx-rtd-theme" },
|
||||
]
|
||||
monitoring = [
|
||||
{ name = "grafana-client" },
|
||||
{ name = "jaeger-client" },
|
||||
{ name = "opentelemetry-api" },
|
||||
{ name = "opentelemetry-sdk" },
|
||||
{ name = "prometheus-client" },
|
||||
]
|
||||
performance = [
|
||||
{ name = "cchardet" },
|
||||
{ name = "orjson" },
|
||||
{ name = "uvloop" },
|
||||
]
|
||||
|
||||
[package.dev-dependencies]
|
||||
dev = [
|
||||
{ name = "ruff" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "aiofiles", specifier = ">=23.0.0" },
|
||||
{ name = "aiohttp", specifier = ">=3.9.0" },
|
||||
{ name = "aiomysql", specifier = ">=0.2.0" },
|
||||
{ name = "aioredis", specifier = ">=2.0.0" },
|
||||
{ name = "asyncio-mqtt", specifier = ">=0.16.0" },
|
||||
{ name = "bandit", marker = "extra == 'dev'", specifier = ">=1.7.0" },
|
||||
{ name = "bcrypt", specifier = ">=4.1.0" },
|
||||
{ name = "black", marker = "extra == 'dev'", specifier = ">=23.12.0" },
|
||||
{ name = "cchardet", marker = "extra == 'performance'", specifier = ">=2.1.0" },
|
||||
{ name = "click", specifier = ">=8.1.0" },
|
||||
{ name = "cryptography", specifier = ">=41.0.0" },
|
||||
{ name = "fastapi", specifier = ">=0.108.0" },
|
||||
{ name = "flake8", marker = "extra == 'dev'", specifier = ">=7.0.0" },
|
||||
{ name = "grafana-client", marker = "extra == 'monitoring'", specifier = ">=3.5.0" },
|
||||
{ name = "httpx", specifier = ">=0.26.0" },
|
||||
{ name = "isort", marker = "extra == 'dev'", specifier = ">=5.13.0" },
|
||||
{ name = "jaeger-client", marker = "extra == 'monitoring'", specifier = ">=4.8.0" },
|
||||
{ name = "mcp", specifier = ">=1.0.0" },
|
||||
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.8.0" },
|
||||
{ name = "myst-parser", marker = "extra == 'dev'", specifier = ">=2.0.0" },
|
||||
{ name = "myst-parser", marker = "extra == 'docs'", specifier = ">=2.0.0" },
|
||||
{ name = "numpy", specifier = ">=1.24.0" },
|
||||
{ name = "opentelemetry-api", marker = "extra == 'monitoring'", specifier = ">=1.21.0" },
|
||||
{ name = "opentelemetry-sdk", marker = "extra == 'monitoring'", specifier = ">=1.21.0" },
|
||||
{ name = "orjson", specifier = ">=3.9.0" },
|
||||
{ name = "orjson", marker = "extra == 'performance'", specifier = ">=3.9.0" },
|
||||
{ name = "pandas", specifier = ">=2.0.0" },
|
||||
{ name = "passlib", extras = ["bcrypt"], specifier = ">=1.7.0" },
|
||||
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=3.6.0" },
|
||||
{ name = "prometheus-client", specifier = ">=0.19.0" },
|
||||
{ name = "prometheus-client", marker = "extra == 'monitoring'", specifier = ">=0.19.0" },
|
||||
{ name = "pydantic", specifier = ">=2.5.0" },
|
||||
{ name = "pydantic-settings", specifier = ">=2.1.0" },
|
||||
{ name = "pyjwt", specifier = ">=2.8.0" },
|
||||
{ name = "pymysql", specifier = ">=1.1.0" },
|
||||
{ name = "pytest", specifier = ">=8.4.0" },
|
||||
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=7.4.0" },
|
||||
{ name = "pytest-asyncio", specifier = ">=1.0.0" },
|
||||
{ name = "pytest-asyncio", marker = "extra == 'dev'", specifier = ">=0.23.0" },
|
||||
{ name = "pytest-cov", specifier = ">=6.1.1" },
|
||||
{ name = "pytest-cov", marker = "extra == 'dev'", specifier = ">=4.1.0" },
|
||||
{ name = "pytest-mock", marker = "extra == 'dev'", specifier = ">=3.12.0" },
|
||||
{ name = "pytest-xdist", marker = "extra == 'dev'", specifier = ">=3.5.0" },
|
||||
{ name = "python-dateutil", specifier = ">=2.8.0" },
|
||||
{ name = "python-dotenv", specifier = ">=1.0.0" },
|
||||
{ name = "python-jose", extras = ["cryptography"], specifier = ">=3.3.0" },
|
||||
{ name = "python-multipart", specifier = ">=0.0.6" },
|
||||
{ name = "pyyaml", specifier = ">=6.0.0" },
|
||||
{ name = "requests", specifier = ">=2.31.0" },
|
||||
{ name = "rich", specifier = ">=13.7.0" },
|
||||
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.1.0" },
|
||||
{ name = "safety", marker = "extra == 'dev'", specifier = ">=2.3.0" },
|
||||
{ name = "sphinx", marker = "extra == 'dev'", specifier = ">=7.2.0" },
|
||||
{ name = "sphinx", marker = "extra == 'docs'", specifier = ">=7.2.0" },
|
||||
{ name = "sphinx-autoapi", marker = "extra == 'docs'", specifier = ">=3.0.0" },
|
||||
{ name = "sphinx-rtd-theme", marker = "extra == 'dev'", specifier = ">=2.0.0" },
|
||||
{ name = "sphinx-rtd-theme", marker = "extra == 'docs'", specifier = ">=2.0.0" },
|
||||
{ name = "sqlparse", specifier = ">=0.4.4" },
|
||||
{ name = "starlette", specifier = ">=0.27.0" },
|
||||
{ name = "structlog", specifier = ">=23.2.0" },
|
||||
{ name = "toml", specifier = ">=0.10.0" },
|
||||
{ name = "tox", marker = "extra == 'dev'", specifier = ">=4.11.0" },
|
||||
{ name = "tqdm", specifier = ">=4.66.0" },
|
||||
{ name = "typer", specifier = ">=0.9.0" },
|
||||
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.25.0" },
|
||||
{ name = "uvloop", marker = "extra == 'performance'", specifier = ">=0.19.0" },
|
||||
{ name = "websockets", specifier = ">=12.0" },
|
||||
]
|
||||
provides-extras = ["dev", "docs", "performance", "monitoring"]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [{ name = "ruff", specifier = ">=0.11.13" }]
|
||||
|
||||
[[package]]
|
||||
name = "mdit-py-plugins"
|
||||
version = "0.4.2"
|
||||
|
||||
Reference in New Issue
Block a user