diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index 296c7bd..0000000 --- a/Dockerfile +++ /dev/null @@ -1,65 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# Use Python 3.12 as base image -FROM python:3.12-slim - -# Set working directory -WORKDIR /app - -# Set environment variables -ENV PYTHONPATH=/app -ENV PYTHONUNBUFFERED=1 -ENV DEBIAN_FRONTEND=noninteractive - -# Install system dependencies -RUN apt-get update && apt-get install -y \ - curl \ - gcc \ - g++ \ - pkg-config \ - default-libmysqlclient-dev \ - && rm -rf /var/lib/apt/lists/* - -# Copy requirements file -COPY requirements.txt . - -# Install Python dependencies -RUN pip install --no-cache-dir -r requirements.txt - -# Copy application code -COPY . . - -# Create necessary directories -RUN mkdir -p /app/logs /app/config /app/data - -# Set permissions -RUN chmod +x /app/start.sh - -# Create non-root user -RUN groupadd -r doris && useradd -r -g doris doris -RUN chown -R doris:doris /app -USER doris - -# Health check -HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \ - CMD curl -f http://localhost:3000/health || exit 1 - -# Expose ports -EXPOSE 3000 3001 3002 - -# Start command -CMD ["/app/start.sh"] \ No newline at end of file diff --git a/Makefile b/Makefile deleted file mode 100644 index c7f4f94..0000000 --- a/Makefile +++ /dev/null @@ -1,135 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# Doris MCP Server Makefile -# Provides convenient commands using UV - -.PHONY: help install sync dev test lint format build clean check start-stdio start-sse - -# Default target -help: - @echo "Available commands:" - @echo " install - Install dependencies using UV" - @echo " sync - Sync dependencies and create virtual environment" - @echo " dev - Install development dependencies" - @echo " test - Run tests" - @echo " lint - Run linting tools" - @echo " format - Format code with black and isort" - @echo " build - Build the package" - @echo " clean - Clean build artifacts" - @echo " check - Run all checks (format, lint, test)" - @echo " start-stdio - Start server in stdio mode" - @echo " start-sse - Start server in SSE mode" - -# Install dependencies -install: - uv sync - -# Sync dependencies with development extras -sync: - uv sync - -# Install development dependencies -dev: - uv sync --dev - -# Run tests -test: - uv run pytest - -# Run linting tools -lint: - uv run ruff check doris_mcp_server/ - uv run mypy doris_mcp_server/ - -# Format code -format: - uv run ruff format doris_mcp_server/ - uv run ruff check --fix doris_mcp_server/ - -# Build the package -build: - uv build - -# Clean build artifacts -clean: - rm -rf build/ - rm -rf dist/ - rm -rf *.egg-info/ - find . -type d -name __pycache__ -exec rm -rf {} + - find . -type d -name .pytest_cache -exec rm -rf {} + - find . -type d -name .mypy_cache -exec rm -rf {} + - -# Run all checks -check: format lint test - -# Start server in stdio mode -start-stdio: - uv run python -m doris_mcp_server.main --transport stdio - -# Start server in SSE mode -start-sse: - uv run python -m doris_mcp_server.main --transport sse --host 0.0.0.0 --port 8080 - -# Start server with custom database settings -start-dev: - uv run python -m doris_mcp_server.main \ - --transport stdio \ - --db-host localhost \ - --db-port 9030 \ - --db-user root \ - --log-level DEBUG - -# Run a single test file -test-file: - uv run pytest $(FILE) -v - -# Install and run in one command -run: install start-stdio - -# Development setup -setup: dev - @echo "โœ… Development environment is ready!" - @echo "Run 'make start-stdio' to start the server" - -# Add dependencies -add: - uv add $(PACKAGE) - -# Add development dependencies -add-dev: - uv add --dev $(PACKAGE) - -# Show dependency tree -deps: - uv tree - -# Lock dependencies -lock: - uv lock - -# Check for outdated dependencies -outdated: - uv tree --outdated - -# Export requirements.txt -export-requirements: - uv export --no-hashes > requirements.txt - -# Show UV version and info -info: - uv --version - uv python list \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml deleted file mode 100644 index 002cfdf..0000000 --- a/docker-compose.yml +++ /dev/null @@ -1,218 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -version: '3.8' - -services: - # Doris MCP Server - doris-mcp-server: - build: - context: . - dockerfile: Dockerfile - container_name: doris-mcp-server - ports: - - "3000:3000" # MCP service port - - "3001:3001" # Monitoring metrics port - - "3002:3002" # Health check port - environment: - # Database configuration - - DORIS_HOST=doris-fe - - DORIS_PORT=9030 - - DORIS_USER=root - - DORIS_PASSWORD=doris123 - - DORIS_DATABASE=test_db - - # Connection pool configuration - - DORIS_MIN_CONNECTIONS=5 - - DORIS_MAX_CONNECTIONS=20 - - # Security configuration - - AUTH_TYPE=token - - TOKEN_SECRET=your_secret_key_here - - MAX_RESULT_ROWS=10000 - - # Performance configuration - - ENABLE_QUERY_CACHE=true - - MAX_CONCURRENT_QUERIES=50 - - # Logging configuration - - LOG_LEVEL=INFO - - LOG_FILE_PATH=/app/logs/doris-mcp-server.log - - # Monitoring configuration - - ENABLE_METRICS=true - - METRICS_PORT=8081 - volumes: - - ./logs:/app/logs - - ./config:/app/config - depends_on: - - doris-fe - - doris-be - networks: - - doris-network - restart: unless-stopped - healthcheck: - test: ["CMD", "curl", "-f", "http://localhost:8082/health"] - interval: 30s - timeout: 10s - retries: 3 - start_period: 40s - - # Apache Doris Frontend - doris-fe: - image: apache/doris:2.0.3-fe-x86_64 - container_name: doris-fe - ports: - - "8030:8030" # FE HTTP port - - "9030:9030" # FE MySQL port - environment: - - FE_SERVERS=fe1:doris-fe:9010 - - FE_ID=1 - volumes: - - doris-fe-data:/opt/apache-doris/fe/doris-meta - - doris-fe-log:/opt/apache-doris/fe/log - - ./doris-config/fe.conf:/opt/apache-doris/fe/conf/fe.conf - networks: - - doris-network - restart: unless-stopped - healthcheck: - test: ["CMD", "curl", "-f", "http://localhost:8030/api/health"] - interval: 30s - timeout: 10s - retries: 5 - start_period: 60s - - # Apache Doris Backend - doris-be: - image: apache/doris:2.0.3-be-x86_64 - container_name: doris-be - ports: - - "8040:8040" # BE HTTP port - - "9060:9060" # BE heartbeat port - environment: - - FE_SERVERS=doris-fe:9010 - - BE_ADDR=doris-be:9050 - volumes: - - doris-be-data:/opt/apache-doris/be/storage - - doris-be-log:/opt/apache-doris/be/log - - ./doris-config/be.conf:/opt/apache-doris/be/conf/be.conf - depends_on: - - doris-fe - networks: - - doris-network - restart: unless-stopped - healthcheck: - test: ["CMD", "curl", "-f", "http://localhost:8040/api/health"] - interval: 30s - timeout: 10s - retries: 5 - start_period: 60s - - # Redis cache (optional) - redis: - image: redis:7-alpine - container_name: doris-redis - ports: - - "6379:6379" - command: redis-server --appendonly yes --requirepass redis123 - volumes: - - redis-data:/data - networks: - - doris-network - restart: unless-stopped - healthcheck: - test: ["CMD", "redis-cli", "--raw", "incr", "ping"] - interval: 30s - timeout: 10s - retries: 3 - - # Prometheus monitoring - prometheus: - image: prom/prometheus:latest - container_name: doris-prometheus - ports: - - "9090:9090" - volumes: - - ./monitoring/prometheus.yml:/etc/prometheus/prometheus.yml - - prometheus-data:/prometheus - command: - - '--config.file=/etc/prometheus/prometheus.yml' - - '--storage.tsdb.path=/prometheus' - - '--web.console.libraries=/etc/prometheus/console_libraries' - - '--web.console.templates=/etc/prometheus/consoles' - - '--storage.tsdb.retention.time=200h' - - '--web.enable-lifecycle' - networks: - - doris-network - restart: unless-stopped - - # Grafana visualization - grafana: - image: grafana/grafana:latest - container_name: doris-grafana - ports: - - "3000:3000" - environment: - - GF_SECURITY_ADMIN_PASSWORD=admin123 - volumes: - - grafana-data:/var/lib/grafana - - ./monitoring/grafana/dashboards:/etc/grafana/provisioning/dashboards - - ./monitoring/grafana/datasources:/etc/grafana/provisioning/datasources - depends_on: - - prometheus - networks: - - doris-network - restart: unless-stopped - - # Nginx load balancer - nginx: - image: nginx:alpine - container_name: doris-nginx - ports: - - "80:80" - - "443:443" - volumes: - - ./nginx/nginx.conf:/etc/nginx/nginx.conf - - ./nginx/ssl:/etc/nginx/ssl - - ./nginx/logs:/var/log/nginx - depends_on: - - doris-mcp-server - networks: - - doris-network - restart: unless-stopped - -volumes: - doris-fe-data: - driver: local - doris-fe-log: - driver: local - doris-be-data: - driver: local - doris-be-log: - driver: local - redis-data: - driver: local - prometheus-data: - driver: local - grafana-data: - driver: local - -networks: - doris-network: - driver: bridge - ipam: - config: - - subnet: 172.20.0.0/16 \ No newline at end of file diff --git a/doris_mcp_client/README.md b/doris_mcp_client/README.md deleted file mode 100644 index 4df3b20..0000000 --- a/doris_mcp_client/README.md +++ /dev/null @@ -1,340 +0,0 @@ - -# Doris Unified MCP Client - -This is a unified Doris MCP client that supports both **stdio** and **Streamable HTTP** transport modes, providing complete MCP protocol support. - -## ๐Ÿš€ Features - -- โœ… **Dual Mode Support**: Both stdio and HTTP transport methods -- โœ… **Complete MCP Support**: Resources, Tools, and Prompts primitives -- โœ… **Unified API**: Same interface for different transport modes -- โœ… **Asynchronous Design**: High-performance async client based on asyncio -- โœ… **Enterprise Features**: Connection pooling, error handling, logging -- โœ… **Convenience Methods**: High-level wrappers for common database operations - -## ๐Ÿ“ฆ Install Dependencies - -```bash -pip install mcp -``` - -## ๐ŸŽฏ Quick Start - -### 1. stdio Mode - -```python -import asyncio -from client import create_stdio_client - -async def main(): - # Create stdio client - client = await create_stdio_client( - "python", - ["-m", "doris_mcp_server.main", "--transport", "stdio"] - ) - - async def test_client(client): - # Get database list - db_result = await client.get_database_list() - print(f"Databases: {db_result}") - - # Execute SQL query - query_result = await client.execute_sql("SELECT 1 as test") - print(f"Query result: {query_result}") - - await client.connect_and_run(test_client) - -asyncio.run(main()) -``` - -### 2. HTTP Mode - -```python -import asyncio -from unified_client import create_http_client - -async def main(): - # Create HTTP client - client = await create_http_client("http://localhost:3000/mcp") - - async def test_client(client): - # Get all tools - tools = await client.list_all_tools() - print(f"Available tools: {len(tools)}") - - # Execute query - result = await client.execute_sql( - "SELECT COUNT(*) FROM internal.ssb.lineorder LIMIT 1" - ) - print(f"Query result: {result}") - - await client.connect_and_run(test_client) - -asyncio.run(main()) -``` - -## ๐Ÿ”ง API Reference - -### Client Creation - -```python -# stdio mode -client = await create_stdio_client(command, args) - -# HTTP mode -client = await create_http_client(server_url, timeout=60) -``` - -### Basic Operations - -```python -async def test_client(client): - # Get server capabilities - tools = await client.list_all_tools() - resources = await client.list_all_resources() - prompts = await client.list_all_prompts() - - # Call tool - result = await client.call_tool("tool_name", {"param": "value"}) - - # Read resource - content = await client.read_resource("resource://uri") - - # Get prompt - prompt = await client.get_prompt("prompt_name", {"param": "value"}) -``` - -### Advanced Database Operations - -```python -async def database_operations(client): - # Execute SQL query - result = await client.execute_sql("SELECT * FROM table LIMIT 10") - - # Get database list - databases = await client.get_database_list() - - # Get table schema - schema = await client.get_table_schema("table_name", "db_name") - - # Column data analysis - analysis = await client.analyze_column("table", "column", "basic") -``` - -## ๐Ÿงช Testing - -### Run Test Suite - -```bash -# Interactive testing -python test_unified_client.py - -# Test stdio mode -python test_unified_client.py stdio - -# Test HTTP mode -python test_unified_client.py http - -# Test both modes -python test_unified_client.py both - -# Performance benchmark -python test_unified_client.py benchmark -``` - -### Test Output Example - -``` -๐ŸŽฏ Doris Unified Client Test Suite -============================================================ - -๐Ÿš€ Testing HTTP Mode -================================================== -๐Ÿ“‹ Getting server capabilities... -โœ… Found 11 tools -โœ… Found 0 resources -โœ… Found 0 prompts - -๐Ÿ”ง Available tools: - 1. get_db_list: Get database list - 2. get_table_list: Get table list for specified database - 3. get_table_schema: Get table structure information - 4. exec_query: Execute SQL query - 5. column_analysis: Analyze column data distribution and statistics - ... - -๐Ÿงช Testing basic functionality... -1๏ธโƒฃ Getting database list... - โœ… Success: 3 databases -2๏ธโƒฃ Executing simple query... - โœ… Query successful -3๏ธโƒฃ Executing SSB data query... - โœ… SSB query successful -4๏ธโƒฃ Getting table structure... - โœ… Table structure retrieved successfully -5๏ธโƒฃ Column data analysis... - โœ… Column analysis successful - -โœ… HTTP mode testing completed! -``` - -## ๐Ÿ—๏ธ Architecture Design - -### Unified Client Architecture - -``` -DorisUnifiedClient -โ”œโ”€โ”€ DorisResourceClient # Resource management -โ”œโ”€โ”€ DorisToolsClient # Tool invocation -โ”œโ”€โ”€ DorisPromptClient # Prompt management -โ””โ”€โ”€ Transport Layer - โ”œโ”€โ”€ stdio mode # Standard input/output - โ””โ”€โ”€ HTTP mode # Streamable HTTP -``` - -### Key Features - -1. **Unified Interface**: Same API for different transport modes -2. **Async Context**: Proper resource management and connection cleanup -3. **Error Handling**: Comprehensive exception handling and error recovery -4. **Performance Optimization**: Connection reuse and request caching - -## ๐Ÿ“š Usage Examples - -### Complete Example - -```python -import asyncio -from client import DorisUnifiedClient, DorisClientConfig - -async def comprehensive_example(): - # Create configuration - config = DorisClientConfig.stdio( - "python", - ["-m", "doris_mcp_server.main"] - ) - - client = DorisUnifiedClient(config) - - async def demo_operations(client): - print("๐Ÿ” Discovering server capabilities...") - - # List all available tools - tools = await client.list_all_tools() - print(f"Available tools: {[tool.name for tool in tools]}") - - # Get database list - print("\n๐Ÿ“Š Getting database information...") - db_result = await client.get_database_list() - print(f"Databases: {db_result}") - - # Execute queries - print("\n๐Ÿ” Executing queries...") - - # Simple query - result1 = await client.execute_sql("SELECT 1 as test_column") - print(f"Simple query result: {result1}") - - # Get table schema - schema_result = await client.get_table_schema("lineorder", "ssb") - print(f"Table schema: {schema_result}") - - # Column analysis - analysis_result = await client.analyze_column( - "lineorder", "lo_orderkey", "basic" - ) - print(f"Column analysis: {analysis_result}") - - await client.connect_and_run(demo_operations) - -# Run the example -asyncio.run(comprehensive_example()) -``` - -### Error Handling - -```python -async def error_handling_example(client): - try: - # This might fail - result = await client.execute_sql("INVALID SQL") - except Exception as e: - print(f"SQL execution failed: {e}") - - # Check result status - result = await client.get_database_list() - if result.get("success", True): - print("Operation successful") - else: - print(f"Operation failed: {result.get('error')}") -``` - -## ๐Ÿ”ง Configuration - -### Client Configuration Options - -```python -# stdio mode with custom arguments -config = DorisClientConfig( - transport="stdio", - server_command="python", - server_args=["-m", "doris_mcp_server.main", "--debug"], - timeout=30 -) - -# HTTP mode with custom timeout -config = DorisClientConfig( - transport="http", - server_url="http://localhost:8080/mcp", - timeout=60 -) -``` - -### Environment Variables - -```bash -# Set default server URL -export DORIS_MCP_SERVER_URL="http://localhost:8080" - -# Set default timeout -export DORIS_MCP_TIMEOUT=60 - -# Enable debug logging -export DORIS_MCP_DEBUG=true -``` - -## ๐Ÿš€ Performance Tips - -1. **Connection Reuse**: Use the same client instance for multiple operations -2. **Batch Operations**: Group related queries together -3. **Async Context**: Always use proper async context management -4. **Error Recovery**: Implement retry logic for transient failures - -## ๐Ÿค Contributing - -1. Fork the repository -2. Create a feature branch -3. Make your changes -4. Add tests -5. Submit a pull request - -## ๐Ÿ“„ License - -This project is licensed under the Apache 2.0 License. \ No newline at end of file diff --git a/doris_mcp_client/__init__.py b/doris_mcp_client/__init__.py deleted file mode 100644 index e12733f..0000000 --- a/doris_mcp_client/__init__.py +++ /dev/null @@ -1,25 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Doris MCP Client Package - -Unified MCP client supporting both stdio and HTTP transport modes -""" - -from .client import DorisUnifiedClient, DorisClientConfig - -__all__ = ["DorisUnifiedClient", "DorisClientConfig"] \ No newline at end of file diff --git a/doris_mcp_client/client.py b/doris_mcp_client/client.py deleted file mode 100644 index a7e4020..0000000 --- a/doris_mcp_client/client.py +++ /dev/null @@ -1,513 +0,0 @@ -#!/usr/bin/env python3 -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Unified Doris MCP Client - Supports both stdio and Streamable HTTP modes - -Combines the correct HTTP implementation from http_client.py and the complete architecture from client.py -Provides complete support for the three major primitives: Resources, Tools, and Prompts -""" - -import asyncio -import json -import logging -from typing import Any, Callable -from datetime import timedelta - -from mcp.client.session import ClientSession -from mcp.client.stdio import stdio_client -from mcp.client.streamable_http import streamablehttp_client -from mcp import StdioServerParameters -from mcp.types import ( - Prompt, - Resource, - Tool, -) - -# Configure logging -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -class DorisClientConfig: - """Doris client configuration class""" - - def __init__( - self, - transport: str = "stdio", - server_command: str | None = None, - server_args: list[str] | None = None, - server_url: str | None = None, - timeout: int = 60, - ): - self.transport = transport - self.server_command = server_command - self.server_args = server_args or [] - self.server_url = server_url - self.timeout = timeout - - @classmethod - def stdio(cls, command: str, args: list[str] = None) -> "DorisClientConfig": - """Create stdio connection configuration""" - return cls( - transport="stdio", - server_command=command, - server_args=args or [] - ) - - @classmethod - def http(cls, url: str, timeout: int = 60) -> "DorisClientConfig": - """Create HTTP connection configuration""" - return cls( - transport="http", - server_url=url, - timeout=timeout - ) - - -class DorisResourceClient: - """Doris resource client - Handles Resources related operations""" - - def __init__(self, session: ClientSession): - self.session = session - self.logger = logging.getLogger(f"{__name__}.DorisResourceClient") - self._resources_cache = None - - async def list_resources(self) -> list[Resource]: - """Get list of all available resources""" - try: - self.logger.info("Getting resource list") - response = await self.session.list_resources() - resources = response.resources if hasattr(response, "resources") else [] - self._resources_cache = resources - self.logger.info(f"Retrieved {len(resources)} resources") - return resources - except Exception as e: - self.logger.error(f"Failed to get resource list: {e}") - return [] - - async def read_resource(self, uri: str) -> str | None: - """Read specified resource content""" - try: - self.logger.info(f"Reading resource: {uri}") - response = await self.session.read_resource(uri) - - if hasattr(response, "contents") and response.contents: - # Merge all content - content_parts = [] - for content in response.contents: - if hasattr(content, "text"): - content_parts.append(content.text) - content = "\n".join(content_parts) - self.logger.info(f"Successfully read resource content: {len(content)} characters") - return content - elif hasattr(response, "content"): - return str(response.content) - else: - self.logger.warning(f"Resource {uri} returned no content") - return None - - except Exception as e: - self.logger.error(f"Failed to read resource {uri}: {e}") - return None - - async def filter_resources_by_type(self, resource_type: str) -> list[Resource]: - """Filter resources by type""" - if not self._resources_cache: - await self.list_resources() - - if resource_type == "table": - return [r for r in self._resources_cache if "table" in r.uri] - elif resource_type == "view": - return [r for r in self._resources_cache if "view" in r.uri] - elif resource_type == "database": - return [ - r for r in self._resources_cache - if "database" in r.uri and "table" not in r.uri - ] - else: - return self._resources_cache - - -class DorisToolsClient: - """Doris tools client - Handles Tools related operations""" - - def __init__(self, session: ClientSession): - self.session = session - self.logger = logging.getLogger(f"{__name__}.DorisToolsClient") - self._tools_cache = None - - async def list_tools(self) -> list[Tool]: - """Get list of all available tools""" - try: - self.logger.info("Getting tool list") - response = await self.session.list_tools() - tools = response.tools if hasattr(response, "tools") else [] - self._tools_cache = tools - self.logger.info(f"Retrieved {len(tools)} tools") - return tools - except Exception as e: - self.logger.error(f"Failed to get tool list: {e}") - return [] - - async def call_tool(self, name: str, arguments: dict[str, Any]) -> dict[str, Any]: - """Call specified tool""" - try: - self.logger.info(f"Calling tool: {name}") - self.logger.debug(f"Tool arguments: {arguments}") - - response = await self.session.call_tool(name, arguments) - - if hasattr(response, "content") and response.content: - # Parse response content - result_text = "" - for content in response.content: - if hasattr(content, "text"): - result_text += content.text - - # Try to parse as JSON - try: - result = json.loads(result_text) - self.logger.info(f"Tool call successful: {name}") - return result - except json.JSONDecodeError: - # If not JSON format, return text directly - return {"success": True, "data": result_text} - - self.logger.warning(f"Tool {name} returned no content") - return {"success": False, "error": "No response content"} - - except Exception as e: - self.logger.error(f"Tool call failed {name}: {e}") - return {"success": False, "error": str(e)} - - async def get_tool_by_name(self, name: str) -> Tool | None: - """Get tool definition by name""" - if not self._tools_cache: - await self.list_tools() - - for tool in self._tools_cache: - if tool.name == name: - return tool - return None - - async def get_tools_by_category(self, category: str) -> list[Tool]: - """Filter tools by category""" - if not self._tools_cache: - await self.list_tools() - - category_lower = category.lower() - return [ - tool for tool in self._tools_cache - if category_lower in tool.description.lower() - or category_lower in tool.name.lower() - ] - - -class DorisPromptClient: - """Doris prompt client - Handles Prompts related operations""" - - def __init__(self, session: ClientSession): - self.session = session - self.logger = logging.getLogger(f"{__name__}.DorisPromptClient") - self._prompts_cache = None - - async def list_prompts(self) -> list[Prompt]: - """Get list of all available prompts""" - try: - self.logger.info("Getting prompt list") - response = await self.session.list_prompts() - prompts = response.prompts if hasattr(response, "prompts") else [] - self._prompts_cache = prompts - self.logger.info(f"Retrieved {len(prompts)} prompts") - return prompts - except Exception as e: - self.logger.error(f"Failed to get prompt list: {e}") - return [] - - async def get_prompt(self, name: str, arguments: dict[str, Any]) -> str | None: - """Get specified prompt content""" - try: - self.logger.info(f"Getting prompt: {name}") - self.logger.debug(f"Prompt arguments: {arguments}") - - response = await self.session.get_prompt(name, arguments) - - if hasattr(response, "messages") and response.messages: - # Merge all message content - content_parts = [] - for message in response.messages: - if hasattr(message, "content"): - if hasattr(message.content, "text"): - content_parts.append(message.content.text) - else: - content_parts.append(str(message.content)) - - content = "\n".join(content_parts) - self.logger.info(f"Successfully retrieved prompt content: {len(content)} characters") - return content - - self.logger.warning(f"Prompt {name} returned no content") - return None - - except Exception as e: - self.logger.error(f"Failed to get prompt {name}: {e}") - return None - - -class DorisUnifiedClient: - """Unified Doris MCP client - Provides complete MCP functionality""" - - def __init__(self, config: DorisClientConfig): - self.config = config - self.logger = logging.getLogger(f"{__name__}.DorisUnifiedClient") - self.session = None - self.resources = None - self.tools = None - self.prompts = None - - async def connect_and_run(self, callback_func: Callable): - """Connect to server and execute callback function""" - if self.config.transport == "stdio": - await self._run_stdio_mode(callback_func) - elif self.config.transport == "http": - await self._run_http_mode(callback_func) - else: - raise ValueError(f"Unsupported transport type: {self.config.transport}") - - async def _run_stdio_mode(self, callback_func: Callable): - """Run in stdio mode""" - try: - self.logger.info(f"Starting stdio client: {self.config.server_command}") - - server_params = StdioServerParameters( - command=self.config.server_command, - args=self.config.server_args, - ) - - async with stdio_client(server_params) as (read, write): - async with ClientSession(read, write) as session: - self.session = session - self._init_sub_clients() - - # Initialize server - await session.initialize() - self.logger.info("Server initialized successfully") - - # Execute callback function - await callback_func(self) - - except Exception as e: - self.logger.error(f"stdio mode execution failed: {e}") - raise - - async def _run_http_mode(self, callback_func: Callable): - """Run in HTTP mode""" - try: - self.logger.info(f"Starting HTTP client: {self.config.server_url}") - - async with streamablehttp_client( - self.config.server_url, - timeout=timedelta(seconds=self.config.timeout) - ) as (read, write): - async with ClientSession(read, write) as session: - self.session = session - self._init_sub_clients() - - # Initialize server - await session.initialize() - self.logger.info("Server initialized successfully") - - # Execute callback function - await callback_func(self) - - except Exception as e: - self.logger.error(f"HTTP mode execution failed: {e}") - raise - - def _init_sub_clients(self): - """Initialize sub-clients""" - self.resources = DorisResourceClient(self.session) - self.tools = DorisToolsClient(self.session) - self.prompts = DorisPromptClient(self.session) - - # Convenience methods - async def list_all_resources(self) -> list[Resource]: - """Get all resources""" - return await self.resources.list_resources() - - async def list_all_tools(self) -> list[Tool]: - """Get all tools""" - return await self.tools.list_tools() - - async def list_all_prompts(self) -> list[Prompt]: - """Get all prompts""" - return await self.prompts.list_prompts() - - async def call_tool(self, name: str, arguments: dict[str, Any]) -> dict[str, Any]: - """Call tool""" - return await self.tools.call_tool(name, arguments) - - async def read_resource(self, uri: str) -> str | None: - """Read resource""" - return await self.resources.read_resource(uri) - - async def get_prompt(self, name: str, arguments: dict[str, Any]) -> str | None: - """Get prompt""" - return await self.prompts.get_prompt(name, arguments) - - # Smart tool finding methods - async def _find_tool_by_pattern(self, patterns: list[str]) -> str | None: - """Find tool by name pattern""" - tools = await self.list_all_tools() - for pattern in patterns: - for tool in tools: - if pattern in tool.name: - return tool.name - return None - - async def _find_tool_by_function(self, function_keywords: list[str]) -> str | None: - """Find tool by function keywords""" - tools = await self.list_all_tools() - for tool in tools: - tool_desc = tool.description.lower() - tool_name = tool.name.lower() - for keyword in function_keywords: - if keyword.lower() in tool_desc or keyword.lower() in tool_name: - return tool.name - return None - - # High-level business methods - async def execute_sql(self, sql: str, **kwargs) -> dict[str, Any]: - """Execute SQL query""" - tool_name = await self._find_tool_by_pattern(["exec_query", "execute", "query"]) - if not tool_name: - return {"success": False, "error": "SQL execution tool not found"} - - arguments = {"sql": sql, **kwargs} - return await self.call_tool(tool_name, arguments) - - async def get_table_schema(self, table_name: str, db_name: str = None, **kwargs) -> dict[str, Any]: - """Get table schema""" - tool_name = await self._find_tool_by_pattern(["get_table_schema", "table_schema", "schema"]) - if not tool_name: - return {"success": False, "error": "Table schema tool not found"} - - arguments = {"table_name": table_name} - if db_name: - arguments["db_name"] = db_name - arguments.update(kwargs) - - return await self.call_tool(tool_name, arguments) - - async def get_database_list(self, **kwargs) -> dict[str, Any]: - """Get database list""" - tool_name = await self._find_tool_by_pattern(["get_db_list", "database_list", "db_list"]) - if not tool_name: - return {"success": False, "error": "Database list tool not found"} - - return await self.call_tool(tool_name, kwargs) - - async def analyze_column(self, table_name: str, column_name: str, analysis_type: str = "basic", **kwargs) -> dict[str, Any]: - """Analyze column""" - tool_name = await self._find_tool_by_pattern(["column_analysis", "analyze_column", "column"]) - if not tool_name: - return {"success": False, "error": "Column analysis tool not found"} - - arguments = { - "table_name": table_name, - "column_name": column_name, - "analysis_type": analysis_type, - **kwargs - } - return await self.call_tool(tool_name, arguments) - - async def call_tool_by_function(self, function_description: str, arguments: dict[str, Any]) -> dict[str, Any]: - """Call tool by function description""" - # Try to find appropriate tool based on function description - function_keywords = function_description.lower().split() - tool_name = await self._find_tool_by_function(function_keywords) - - if not tool_name: - return { - "success": False, - "error": f"No tool found for function: {function_description}" - } - - return await self.call_tool(tool_name, arguments) - - -# Convenience factory functions -async def create_stdio_client(command: str, args: list[str] = None) -> DorisUnifiedClient: - """Create stdio client""" - config = DorisClientConfig.stdio(command, args) - return DorisUnifiedClient(config) - - -async def create_http_client(server_url: str, timeout: int = 60) -> DorisUnifiedClient: - """Create HTTP client""" - config = DorisClientConfig.http(server_url, timeout) - return DorisUnifiedClient(config) - - -# Example usage -async def example_stdio(): - """stdio mode example""" - client = await create_stdio_client("python", ["doris_mcp_server/main.py"]) - - async def test_client(client: DorisUnifiedClient): - # Get server capabilities - resources = await client.list_all_resources() - tools = await client.list_all_tools() - prompts = await client.list_all_prompts() - - print(f"Resources: {len(resources)}") - print(f"Tools: {len(tools)}") - print(f"Prompts: {len(prompts)}") - - # Test SQL execution - result = await client.execute_sql("SELECT 1 as test") - print(f"SQL execution result: {result}") - - await client.connect_and_run(test_client) - - -async def example_http(): - """HTTP mode example""" - client = await create_http_client("http://localhost:8080") - - async def test_client(client: DorisUnifiedClient): - # Get server capabilities - resources = await client.list_all_resources() - tools = await client.list_all_tools() - - print(f"Resources: {len(resources)}") - print(f"Tools: {len(tools)}") - - # Test database list - result = await client.get_database_list() - print(f"Database list: {result}") - - await client.connect_and_run(test_client) - - -if __name__ == "__main__": - # Run stdio example - asyncio.run(example_stdio()) - - # Run HTTP example - # asyncio.run(example_http()) \ No newline at end of file diff --git a/requirements-dev.txt b/requirements-dev.txt deleted file mode 100644 index ca0df72..0000000 --- a/requirements-dev.txt +++ /dev/null @@ -1,36 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# ๅผ€ๅ‘ไพ่ต– - ไปŽ pyproject.toml ่‡ชๅŠจ็”Ÿๆˆ -# ๅฎ‰่ฃ…ๅ‘ฝไปค: pip install -r requirements-dev.txt - -pytest>=7.4.0 -pytest-asyncio>=0.23.0 -pytest-cov>=4.1.0 -pytest-mock>=3.12.0 -pytest-xdist>=3.5.0 -ruff>=0.1.0 -black>=23.12.0 -isort>=5.13.0 -flake8>=7.0.0 -mypy>=1.8.0 -bandit>=1.7.0 -safety>=2.3.0 -sphinx>=7.2.0 -sphinx-rtd-theme>=2.0.0 -myst-parser>=2.0.0 -pre-commit>=3.6.0 -tox>=4.11.0 \ No newline at end of file diff --git a/test/README.md b/test/README.md deleted file mode 100644 index c8f75e2..0000000 --- a/test/README.md +++ /dev/null @@ -1,263 +0,0 @@ - -# Doris MCP Server Testing System - -## Overview - -This testing system adopts a layered architecture, including unit tests, integration tests, and client-server tests. The testing system assumes the server is already properly started and focuses on testing functionality rather than startup configuration. - -## Testing Architecture - -### 1. Unit Tests -- **Location**: `test/security/`, `test/utils/`, `test/tools/` -- **Purpose**: Test individual module functionality -- **Features**: Uses Mock objects, no dependency on external services - -### 2. Integration Tests -- **Location**: `test/integration/` -- **Purpose**: Test collaboration between modules -- **Features**: Test complete workflows - -### 3. Client-Server Tests -- **Location**: `test/tools/test_tools_client_server.py`, `test/utils/test_query_executor_client_server.py` -- **Purpose**: Test actual server functionality through MCP client -- **Features**: Assumes server is running, skips tests if server is not available - -## Configuration Files - -### test_config.json -Test configuration file defines how to connect to the running server: - -```json -{ - "server_endpoints": { - "http": { - "url": "http://localhost:3000/mcp", - "timeout": 30 - }, - "stdio": { - "command": "uv", - "args": ["run", "python", "-m", "doris_mcp_server.main", "--transport", "stdio"], - "timeout": 30 - } - }, - "test_settings": { - "default_transport": "http", - "retry_attempts": 3, - "retry_delay": 1.0, - "test_timeout": 60, - "enable_performance_tests": true, - "enable_security_tests": true - } -} -``` - -## Usage - -### 1. Start the Server - -Before running client-server tests, you need to start the server first: - -#### HTTP Mode (Recommended) -```bash -# Start HTTP server -./start_server.sh -# or -uv run python -m doris_mcp_server.main --transport http --port 3000 -``` - -#### Stdio Mode -```bash -# Stdio mode is started directly by the client, no need to pre-start -``` - -### 2. Run Tests - -#### Run All Tests -```bash -python -m pytest test/ -v -``` - -#### Run Unit Tests -```bash -# Security module tests -python -m pytest test/security/ -v - -# Tools module tests -python -m pytest test/tools/test_tools_manager.py -v - -# Query executor tests -python -m pytest test/utils/test_query_executor.py -v -``` - -#### Run Integration Tests -```bash -python -m pytest test/integration/ -v -``` - -#### Run Client-Server Tests -```bash -# Tools Client-Server tests -python -m pytest test/tools/test_tools_client_server.py -v - -# QueryExecutor Client-Server tests -python -m pytest test/utils/test_query_executor_client_server.py -v -``` - -### 3. Test Configuration - -#### Modify Server Endpoints -Edit the `test/test_config.json` file: - -```json -{ - "server_endpoints": { - "http": { - "url": "http://your-server:port/mcp" - } - } -} -``` - -#### Enable/Disable Specific Tests -```json -{ - "test_settings": { - "enable_performance_tests": false, // Disable performance tests - "enable_security_tests": true // Enable security tests - } -} -``` - -## Test Status - -### โœ… Completed Test Modules - -1. **Security Module** (100% Pass) - - Authentication tests: 5/5 passed - - Authorization tests: 7/7 passed - - Data masking tests: 13/13 passed - - SQL validation tests: 10/10 passed - - Security manager tests: 7/7 passed - - Coverage: 88% - -2. **Client-Server Test Architecture** (Implemented) - - Automatic server connection status detection - - Automatically skip tests when server is not running - - Support for both HTTP and Stdio transport modes - -### ๐Ÿ”„ Tests Requiring Server Running - -1. **Tools Client-Server Tests** - - Tool list retrieval - - SQL query execution - - Database list retrieval - - Table schema queries - - Performance statistics - - Error handling - - Security authentication - -2. **QueryExecutor Client-Server Tests** - - Simple query execution - - Database queries - - Information schema queries - - Parameterized queries - - Error handling - - Security authentication - -## Testing Best Practices - -### 1. Server Startup Check -All client-server tests automatically check server connection status: -- If server is running normally, execute actual tests -- If server is not running, skip tests and display appropriate message - -### 2. Test Isolation -- Unit tests use Mock objects, no dependency on external services -- Integration tests use controlled test environments -- Client-server tests connect to actually running servers - -### 3. Error Handling -- Tests don't assume specific success/failure results -- Verify response structure rather than specific content -- Gracefully handle connection failures and timeouts - -### 4. Configuration Management -- Use configuration files to manage test parameters -- Support configuration switching for different environments -- Provide reasonable default values - -## Troubleshooting - -### 1. Server Connection Failure -``` -ERROR: Server is not running or not accessible -``` -**Solution**: Ensure the server is started and listening on the correct port - -### 2. Import Errors -``` -ImportError: cannot import name 'DorisUnifiedClient' -``` -**Solution**: Check Python path and dependency installation - -### 3. Test Timeouts -``` -TimeoutError: Test execution timeout -``` -**Solution**: Increase timeout settings in `test_config.json` - -## Development Guide - -### Adding New Client-Server Tests - -1. Add test methods in the appropriate test file -2. Use `@pytest.mark.asyncio` decorator -3. Get test client through `client` fixture -4. Implement test callback function -5. Verify response structure - -Example: -```python -@pytest.mark.asyncio -async def test_new_feature_via_client(self, client, test_config): - """Test new feature through client""" - async def test_callback(client_instance): - result = await client_instance.call_tool("new_tool", { - "param": "value" - }) - - assert "success" in result - return result - - result = await client.connect_and_run(test_callback) - assert "success" in result -``` - -### Modifying Test Configuration - -Edit the `test/test_config.json` file to adjust: -- Server endpoints -- Timeout settings -- Test data -- Feature switches - -## Summary - -This testing system provides complete test coverage, from unit tests to end-to-end client-server tests. Through reasonable configuration and automated connection detection, it ensures tests can run stably in different environments. \ No newline at end of file diff --git a/test/__init__.py b/test/__init__.py deleted file mode 100644 index d216be4..0000000 --- a/test/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. \ No newline at end of file diff --git a/test/conftest.py b/test/conftest.py deleted file mode 100644 index c0c8bc0..0000000 --- a/test/conftest.py +++ /dev/null @@ -1,107 +0,0 @@ -#!/usr/bin/env python3 -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Pytest configuration and fixtures for Doris MCP Server tests -""" - -import asyncio -import logging -import sys -from pathlib import Path - -import pytest - -# Add project root to Python path -project_root = Path(__file__).parent.parent -sys.path.insert(0, str(project_root)) - -# Configure logging for tests -logging.basicConfig( - level=logging.INFO, - format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' -) - - -@pytest.fixture(scope="session") -def event_loop(): - """Create an instance of the default event loop for the test session.""" - loop = asyncio.get_event_loop_policy().new_event_loop() - yield loop - loop.close() - - -@pytest.fixture -def test_config(): - """Provide test configuration""" - return { - "doris_host": "localhost", - "doris_port": 9030, - "doris_user": "test_user", - "doris_password": "test_password", - "doris_database": "test_db", - "blocked_keywords": ["DROP", "DELETE", "TRUNCATE", "ALTER", "CREATE", "INSERT", "UPDATE"], - "sensitive_tables": { - "user_info": "confidential", - "payment_records": "secret", - "employee_data": "confidential", - "public_reports": "public" - }, - "max_query_complexity": 100 - } - - -@pytest.fixture -def sample_data(): - """Provide sample test data""" - return [ - { - "id": 1, - "name": "ๅผ ไธ‰", - "phone": "13812345678", - "email": "zhangsan@example.com", - "id_card": "110101199001011234", - "salary": 50000 - }, - { - "id": 2, - "name": "ๆŽๅ››", - "phone": "13987654321", - "email": "lisi@example.com", - "id_card": "110101199002022345", - "salary": 60000 - } - ] - - -@pytest.fixture -def test_sql_queries(): - """Provide test SQL queries""" - return { - "safe_select": "SELECT name, email FROM users WHERE department = 'sales'", - "dangerous_drop": "DROP TABLE users", - "sql_injection": "SELECT * FROM users WHERE id = 1; DROP TABLE users;", - "union_injection": "SELECT name FROM users UNION SELECT password FROM admin_users", - "comment_injection": "SELECT * FROM users WHERE id = 1 -- AND password = 'secret'", - "complex_query": """ - SELECT u.name, u.email, d.department_name - FROM users u - JOIN departments d ON u.department_id = d.id - WHERE u.status = 'active' - ORDER BY u.created_at DESC - """ - } diff --git a/test/integration/test_end_to_end.py b/test/integration/test_end_to_end.py deleted file mode 100644 index 9f9c5ab..0000000 --- a/test/integration/test_end_to_end.py +++ /dev/null @@ -1,299 +0,0 @@ -#!/usr/bin/env python3 -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -End-to-end integration tests -""" - -import json -import pytest -from unittest.mock import Mock, patch - -from doris_mcp_server.main import DorisServer -from doris_mcp_server.utils.config import DorisConfig -from doris_mcp_server.utils.security import SecurityLevel, AuthContext - - -class TestEndToEndIntegration: - """End-to-end integration tests""" - - @pytest.fixture - def mock_config(self): - """Create mock configuration""" - from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig - - config = Mock(spec=DorisConfig) - config.doris_host = "localhost" - config.doris_port = 9030 - config.doris_user = "test_user" - config.doris_password = "test_password" - config.doris_database = "test_db" - config.server_host = "localhost" - config.server_port = 8000 - config.enable_security = True - - # Add database config - config.database = Mock(spec=DatabaseConfig) - config.database.host = "localhost" - config.database.port = 9030 - config.database.user = "test_user" - config.database.password = "test_password" - config.database.database = "test_db" - config.database.health_check_interval = 60 - config.database.min_connections = 5 - config.database.max_connections = 20 - config.database.connection_timeout = 30 - config.database.max_connection_age = 3600 - - # Add security config - config.security = Mock(spec=SecurityConfig) - config.security.enable_masking = True - config.security.auth_type = "token" - config.security.token_secret = "test_secret" - config.security.token_expiry = 3600 - - return config - - @pytest.fixture - def doris_server(self, mock_config): - """Create Doris server instance""" - return DorisServer(mock_config) - - @pytest.mark.asyncio - async def test_complete_query_workflow_with_security(self, doris_server, sample_data): - """Test complete query workflow with security""" - with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute: - mock_execute.return_value = sample_data - - # Mock authentication - with patch.object(doris_server.security_manager, 'authenticate_request') as mock_auth: - mock_auth.return_value = AuthContext( - user_id="analyst1", - roles=["data_analyst"], - permissions=["read_data"], - session_id="session_123", - security_level=SecurityLevel.INTERNAL - ) - - # Mock authorization - with patch.object(doris_server.security_manager, 'authorize_resource_access') as mock_authz: - mock_authz.return_value = True - - # Mock SQL validation - with patch.object(doris_server.security_manager, 'validate_sql_security') as mock_validate: - from doris_mcp_server.utils.security import ValidationResult - mock_validate.return_value = ValidationResult(is_valid=True) - - # Mock data masking - with patch.object(doris_server.security_manager, 'apply_data_masking') as mock_mask: - masked_data = [ - { - "id": 1, - "name": "ๅผ ไธ‰", - "phone": "138****5678", - "email": "z*******n@example.com", - "id_card": "110101****1234", - "salary": 50000 - } - ] - mock_mask.return_value = masked_data - - # Simulate complete workflow - auth_info = {"type": "token", "token": "valid_token_123"} - auth_context = await doris_server.security_manager.authenticate_request(auth_info) - - resource_uri = "/api/table/users" - has_access = await doris_server.security_manager.authorize_resource_access( - auth_context, resource_uri - ) - assert has_access is True - - sql = "SELECT * FROM users LIMIT 1" - validation = await doris_server.security_manager.validate_sql_security( - sql, auth_context - ) - assert validation.is_valid is True - - raw_data = await doris_server.tools_manager.query_executor.execute_query(sql) - final_data = await doris_server.security_manager.apply_data_masking( - raw_data, auth_context - ) - - # Verify data is properly masked - assert final_data[0]["phone"] == "138****5678" - assert final_data[0]["email"] == "z*******n@example.com" - - @pytest.mark.asyncio - async def test_security_violation_workflow(self, doris_server): - """Test security violation detection workflow""" - with patch.object(doris_server.security_manager, 'authenticate_request') as mock_auth: - mock_auth.return_value = AuthContext( - user_id="analyst1", - roles=["data_analyst"], - permissions=["read_data"], - session_id="session_123", - security_level=SecurityLevel.INTERNAL - ) - - # Test unauthorized resource access - with patch.object(doris_server.security_manager, 'authorize_resource_access') as mock_authz: - mock_authz.return_value = False - - auth_context = await doris_server.security_manager.authenticate_request({ - "type": "token", "token": "valid_token_123" - }) - - # Try to access confidential resource - resource_uri = "/api/table/payment_records" - has_access = await doris_server.security_manager.authorize_resource_access( - auth_context, resource_uri - ) - - assert has_access is False - - @pytest.mark.asyncio - async def test_sql_injection_prevention_workflow(self, doris_server): - """Test SQL injection prevention workflow""" - with patch.object(doris_server.security_manager, 'authenticate_request') as mock_auth: - mock_auth.return_value = AuthContext( - user_id="analyst1", - roles=["data_analyst"], - permissions=["read_data"], - session_id="session_123", - security_level=SecurityLevel.INTERNAL - ) - - auth_context = await doris_server.security_manager.authenticate_request({ - "type": "token", "token": "valid_token_123" - }) - - # Test SQL injection attempt - malicious_sql = "SELECT * FROM users WHERE id = 1; DROP TABLE users;" - validation = await doris_server.security_manager.validate_sql_security( - malicious_sql, auth_context - ) - - assert validation.is_valid is False - assert validation.risk_level == "high" - - @pytest.mark.asyncio - async def test_admin_bypass_workflow(self, doris_server, sample_data): - """Test admin user bypassing restrictions""" - with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute: - mock_execute.return_value = sample_data - - with patch.object(doris_server.security_manager, 'authenticate_request') as mock_auth: - mock_auth.return_value = AuthContext( - user_id="admin1", - roles=["data_admin"], - permissions=["admin"], - session_id="session_456", - security_level=SecurityLevel.SECRET - ) - - # Admin should access any resource - with patch.object(doris_server.security_manager, 'authorize_resource_access') as mock_authz: - mock_authz.return_value = True - - # Admin should see original data (no masking) - with patch.object(doris_server.security_manager, 'apply_data_masking') as mock_mask: - mock_mask.return_value = sample_data # Original data - - auth_context = await doris_server.security_manager.authenticate_request({ - "type": "basic", "username": "admin", "password": "admin123" - }) - - # Admin accesses secret resource - resource_uri = "/api/table/payment_records" - has_access = await doris_server.security_manager.authorize_resource_access( - auth_context, resource_uri - ) - assert has_access is True - - # Admin sees original data - raw_data = await doris_server.tools_manager.query_executor.execute_query( - "SELECT * FROM users LIMIT 1" - ) - final_data = await doris_server.security_manager.apply_data_masking( - raw_data, auth_context - ) - - # Should be original data (no masking) - assert final_data[0]["phone"] == "13812345678" - assert final_data[0]["email"] == "zhangsan@example.com" - - @pytest.mark.asyncio - async def test_tool_execution_with_security(self, doris_server): - """Test tool execution with security checks""" - with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute: - mock_execute.return_value = [{"Database": "test_db"}] - - # Test tool execution through tools manager - result = await doris_server.tools_manager.call_tool("get_db_list", {}) - result_data = json.loads(result) - - # Accept either success result or error (due to mock environment) - assert "result" in result_data or "error" in result_data - - @pytest.mark.asyncio - async def test_error_handling_workflow(self, doris_server): - """Test error handling in complete workflow""" - # Test authentication failure - with patch.object(doris_server.security_manager, 'authenticate_request') as mock_auth: - mock_auth.side_effect = Exception("Invalid token") - - with pytest.raises(Exception) as exc_info: - await doris_server.security_manager.authenticate_request({ - "type": "token", "token": "invalid_token" - }) - - assert "Invalid token" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_performance_monitoring_integration(self, doris_server): - """Test performance monitoring integration""" - with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute: - mock_execute.return_value = [ - { - "query_count": 1500, - "avg_execution_time": 0.25, - "slow_query_count": 5, - "error_count": 2 - } - ] - - # Test performance stats tool - result = await doris_server.tools_manager.call_tool("performance_stats", { - "metric_type": "queries", - "time_range": "1h" - }) - result_data = json.loads(result) - - # Accept either success result or error (due to mock environment) - assert "result" in result_data or "error" in result_data - - def test_server_initialization(self, doris_server): - """Test server initialization""" - # Verify all components are initialized - assert doris_server.config is not None - assert doris_server.tools_manager is not None - assert doris_server.security_manager is not None - - # Verify tools are available - use list_tools instead - import asyncio - tools = asyncio.run(doris_server.tools_manager.list_tools()) - assert len(tools) > 0 \ No newline at end of file diff --git a/test/security/test_authentication.py b/test/security/test_authentication.py deleted file mode 100644 index 50aea98..0000000 --- a/test/security/test_authentication.py +++ /dev/null @@ -1,103 +0,0 @@ -#!/usr/bin/env python3 -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Authentication module tests -""" - -import pytest -from datetime import datetime - -from doris_mcp_server.utils.security import ( - AuthenticationProvider, - AuthContext, - SecurityLevel -) - - -class TestAuthenticationProvider: - """Authentication provider tests""" - - @pytest.fixture - def auth_provider(self, test_config): - """Create authentication provider instance""" - return AuthenticationProvider(test_config) - - @pytest.mark.asyncio - async def test_token_authentication_success(self, auth_provider): - """Test successful token authentication""" - auth_info = { - "type": "token", - "token": "valid_token_123" - } - - result = await auth_provider.authenticate(auth_info) - - assert isinstance(result, AuthContext) - assert result.user_id == "test_user" - assert "data_analyst" in result.roles - assert result.security_level == SecurityLevel.INTERNAL - - @pytest.mark.asyncio - async def test_token_authentication_failure(self, auth_provider): - """Test failed token authentication""" - auth_info = { - "type": "token", - "token": "invalid_token" - } - - with pytest.raises(Exception): - await auth_provider.authenticate(auth_info) - - @pytest.mark.asyncio - async def test_basic_authentication_success(self, auth_provider): - """Test successful basic authentication""" - auth_info = { - "type": "basic", - "username": "admin", - "password": "admin123" - } - - result = await auth_provider.authenticate(auth_info) - - assert isinstance(result, AuthContext) - assert result.user_id == "admin_user" - assert "data_admin" in result.roles - assert result.security_level == SecurityLevel.SECRET - - @pytest.mark.asyncio - async def test_basic_authentication_failure(self, auth_provider): - """Test failed basic authentication""" - auth_info = { - "type": "basic", - "username": "admin", - "password": "wrong_password" - } - - with pytest.raises(Exception): - await auth_provider.authenticate(auth_info) - - @pytest.mark.asyncio - async def test_unsupported_auth_type(self, auth_provider): - """Test unsupported authentication type""" - auth_info = { - "type": "oauth", - "token": "oauth_token" - } - - with pytest.raises(Exception): - await auth_provider.authenticate(auth_info) \ No newline at end of file diff --git a/test/security/test_authorization.py b/test/security/test_authorization.py deleted file mode 100644 index e24b711..0000000 --- a/test/security/test_authorization.py +++ /dev/null @@ -1,147 +0,0 @@ -#!/usr/bin/env python3 -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Authorization module tests -""" - -import pytest - -from doris_mcp_server.utils.security import ( - AuthorizationProvider, - AuthContext, - SecurityLevel -) - - -class TestAuthorizationProvider: - """Authorization provider tests""" - - @pytest.fixture - def authz_provider(self, test_config): - """Create authorization provider instance""" - return AuthorizationProvider(test_config) - - @pytest.fixture - def analyst_context(self): - """Create analyst auth context""" - return AuthContext( - user_id="analyst1", - roles=["data_analyst"], - permissions=["read_data"], - session_id="session_123", - security_level=SecurityLevel.INTERNAL - ) - - @pytest.fixture - def admin_context(self): - """Create admin auth context""" - return AuthContext( - user_id="admin1", - roles=["data_admin"], - permissions=["admin"], - session_id="session_456", - security_level=SecurityLevel.SECRET - ) - - @pytest.mark.asyncio - async def test_analyst_access_public_resource(self, authz_provider, analyst_context): - """Test analyst accessing public resource""" - resource_uri = "/api/table/public_reports" - - result = await authz_provider.check_permission(analyst_context, resource_uri, "read") - - assert result is True - - @pytest.mark.asyncio - async def test_analyst_denied_confidential_resource(self, authz_provider): - """Test analyst denied access to confidential resource""" - # Create analyst with lower security level - analyst_context = AuthContext( - user_id="analyst1", - roles=["data_analyst"], - permissions=["read_data"], - session_id="session_123", - security_level=SecurityLevel.PUBLIC # Lower than CONFIDENTIAL - ) - - resource_uri = "/api/table/user_info" - - result = await authz_provider.check_permission(analyst_context, resource_uri, "read") - - assert result is False - - @pytest.mark.asyncio - async def test_admin_access_secret_resource(self, authz_provider, admin_context): - """Test admin accessing secret resource""" - resource_uri = "/api/table/payment_records" - - result = await authz_provider.check_permission(admin_context, resource_uri, "read") - - assert result is True - - @pytest.mark.asyncio - async def test_role_based_permission(self, authz_provider): - """Test role-based permission check""" - # Create analyst context - analyst_context = AuthContext( - user_id="analyst1", - roles=["data_analyst"], - permissions=["read_data"], - session_id="session_123", - security_level=SecurityLevel.INTERNAL - ) - - resource_uri = "/api/table/some_table" - - # Analyst should have read permission - result = await authz_provider.check_permission(analyst_context, resource_uri, "read") - assert result is True - - # Analyst should not have write permission - result = await authz_provider.check_permission(analyst_context, resource_uri, "write") - assert result is False - - @pytest.mark.asyncio - async def test_admin_override(self, authz_provider, admin_context): - """Test admin permission override""" - resource_uri = "/api/table/any_table" - - # Admin should have all permissions - result = await authz_provider.check_permission(admin_context, resource_uri, "read") - assert result is True - - result = await authz_provider.check_permission(admin_context, resource_uri, "write") - assert result is True - - def test_parse_resource_uri(self, authz_provider): - """Test resource URI parsing""" - uri = "/api/table/user_info/default" - - result = authz_provider._parse_resource_uri(uri) - - assert result["type"] == "table" - assert result["name"] == "user_info" - assert result["schema"] == "default" - - def test_get_resource_security_level(self, authz_provider): - """Test getting resource security level""" - resource_info = {"name": "user_info", "type": "table"} - - level = authz_provider._get_resource_security_level(resource_info) - - assert level == SecurityLevel.CONFIDENTIAL \ No newline at end of file diff --git a/test/security/test_data_masking.py b/test/security/test_data_masking.py deleted file mode 100644 index 1647cd9..0000000 --- a/test/security/test_data_masking.py +++ /dev/null @@ -1,197 +0,0 @@ -#!/usr/bin/env python3 -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Data masking tests -""" - -import pytest - -from doris_mcp_server.utils.security import ( - DataMaskingProcessor, - AuthContext, - SecurityLevel, - MaskingRule -) - - -class TestDataMaskingProcessor: - """Data masking processor tests""" - - @pytest.fixture - def masking_processor(self, test_config): - """Create data masking processor instance""" - return DataMaskingProcessor(test_config) - - @pytest.fixture - def internal_user_context(self): - """Create internal user auth context""" - return AuthContext( - user_id="internal_user", - roles=["data_analyst"], - permissions=["read_data"], - session_id="session_123", - security_level=SecurityLevel.INTERNAL - ) - - @pytest.fixture - def admin_context(self): - """Create admin auth context""" - return AuthContext( - user_id="admin", - roles=["data_admin"], - permissions=["admin"], - session_id="session_456", - security_level=SecurityLevel.SECRET - ) - - @pytest.mark.asyncio - async def test_phone_masking_for_internal_user(self, masking_processor, internal_user_context, sample_data): - """Test phone number masking for internal user""" - result = await masking_processor.process(sample_data, internal_user_context) - - # Phone numbers should be masked - assert result[0]["phone"] == "138****5678" - assert result[1]["phone"] == "139****4321" - - @pytest.mark.asyncio - async def test_email_masking_for_internal_user(self, masking_processor, internal_user_context, sample_data): - """Test email masking for internal user""" - result = await masking_processor.process(sample_data, internal_user_context) - - # Emails should be masked - assert result[0]["email"] == "z******n@example.com" - assert result[1]["email"] == "l**i@example.com" - - @pytest.mark.asyncio - async def test_no_masking_for_admin(self, masking_processor, admin_context, sample_data): - """Test no masking for admin user""" - result = await masking_processor.process(sample_data, admin_context) - - # Admin should see original data - assert result[0]["phone"] == "13812345678" - assert result[0]["email"] == "zhangsan@example.com" - assert result[1]["phone"] == "13987654321" - assert result[1]["email"] == "lisi@example.com" - - @pytest.mark.asyncio - async def test_id_card_masking_for_confidential_data(self, masking_processor, internal_user_context, sample_data): - """Test ID card masking for confidential data""" - # Internal user should not see ID card details (confidential level) - result = await masking_processor.process(sample_data, internal_user_context) - - # ID cards should be masked for internal users - assert result[0]["id_card"] == "110101********1234" - assert result[1]["id_card"] == "110101********2345" - - @pytest.mark.asyncio - async def test_empty_data_handling(self, masking_processor, internal_user_context): - """Test empty data handling""" - empty_data = [] - - result = await masking_processor.process(empty_data, internal_user_context) - - assert result == [] - - @pytest.mark.asyncio - async def test_null_value_handling(self, masking_processor, internal_user_context): - """Test null value handling""" - data_with_nulls = [ - { - "id": 1, - "name": "ๅผ ไธ‰", - "phone": None, - "email": None, - "id_card": None - } - ] - - result = await masking_processor.process(data_with_nulls, internal_user_context) - - # Null values should remain null - assert result[0]["phone"] is None - assert result[0]["email"] is None - assert result[0]["id_card"] is None - - def test_phone_masking_algorithm(self, masking_processor): - """Test phone masking algorithm""" - params = {"mask_char": "*", "keep_prefix": 3, "keep_suffix": 4} - - result = masking_processor._mask_phone("13812345678", params) - - assert result == "138****5678" - - def test_email_masking_algorithm(self, masking_processor): - """Test email masking algorithm""" - params = {"mask_char": "*"} - - result = masking_processor._mask_email("zhangsan@example.com", params) - - assert result == "z******n@example.com" - - def test_id_card_masking_algorithm(self, masking_processor): - """Test ID card masking algorithm""" - params = {"mask_char": "*", "keep_prefix": 6, "keep_suffix": 4} - - result = masking_processor._mask_id_card("110101199001011234", params) - - assert result == "110101********1234" - - def test_name_masking_algorithm(self, masking_processor): - """Test name masking algorithm""" - params = {"mask_char": "*"} - - # Test 2-character name - result = masking_processor._mask_name("ๅผ ไธ‰", params) - assert result == "ๅผ *" - - # Test 3-character name - result = masking_processor._mask_name("ๆŽๅฐๆ˜Ž", params) - assert result == "ๆŽ*ๆ˜Ž" - - def test_partial_masking_algorithm(self, masking_processor): - """Test partial masking algorithm""" - params = {"mask_char": "*", "mask_ratio": 0.5} - - result = masking_processor._mask_partial("1234567890", params) - - # Should mask middle 50% of the string - assert "*" in result - assert len(result) == 10 - - def test_should_apply_rule_logic(self, masking_processor, internal_user_context, admin_context): - """Test masking rule application logic""" - rule = MaskingRule( - column_pattern=r".*phone.*", - algorithm="phone_mask", - parameters={"mask_char": "*", "keep_prefix": 3, "keep_suffix": 4}, - security_level=SecurityLevel.INTERNAL - ) - - # Internal user should have rule applied - assert masking_processor._should_apply_rule(rule, internal_user_context) is True - - # Admin should not have rule applied - assert masking_processor._should_apply_rule(rule, admin_context) is False - - def test_get_applicable_rules(self, masking_processor, internal_user_context): - """Test getting applicable rules""" - rules = masking_processor._get_applicable_rules(internal_user_context) - - # Should return some rules for internal user - assert len(rules) > 0 - assert all(isinstance(rule, MaskingRule) for rule in rules) \ No newline at end of file diff --git a/test/security/test_security_manager.py b/test/security/test_security_manager.py deleted file mode 100644 index d7598ae..0000000 --- a/test/security/test_security_manager.py +++ /dev/null @@ -1,172 +0,0 @@ -#!/usr/bin/env python3 -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Security manager integration tests -""" - -import pytest - -from doris_mcp_server.utils.security import ( - DorisSecurityManager, - AuthContext, - SecurityLevel, - ValidationResult -) - - -class TestDorisSecurityManager: - """Doris security manager integration tests""" - - @pytest.fixture - def security_manager(self, test_config): - """Create security manager instance""" - return DorisSecurityManager(test_config) - - @pytest.mark.asyncio - async def test_complete_security_workflow(self, security_manager, sample_data): - """Test complete security workflow""" - # 1. Authentication - auth_info = { - "type": "token", - "token": "valid_token_123" - } - - auth_context = await security_manager.authenticate_request(auth_info) - assert isinstance(auth_context, AuthContext) - assert auth_context.security_level == SecurityLevel.INTERNAL - - # 2. Authorization - resource_uri = "/api/table/public_reports" - has_access = await security_manager.authorize_resource_access(auth_context, resource_uri) - assert has_access is True - - # 3. SQL Validation - safe_sql = "SELECT name, email FROM users WHERE department = 'sales'" - validation_result = await security_manager.validate_sql_security(safe_sql, auth_context) - assert validation_result.is_valid is True - - # 4. Data Masking - masked_data = await security_manager.apply_data_masking(sample_data, auth_context) - assert masked_data[0]["phone"] == "138****5678" # Should be masked - - @pytest.mark.asyncio - async def test_admin_workflow(self, security_manager, sample_data): - """Test admin user workflow""" - # Admin authentication - auth_info = { - "type": "basic", - "username": "admin", - "password": "admin123" - } - - auth_context = await security_manager.authenticate_request(auth_info) - assert auth_context.security_level == SecurityLevel.SECRET - - # Admin should access secret resources - resource_uri = "/api/table/payment_records" - has_access = await security_manager.authorize_resource_access(auth_context, resource_uri) - assert has_access is True - - # Admin should see original data (no masking) - masked_data = await security_manager.apply_data_masking(sample_data, auth_context) - assert masked_data[0]["phone"] == "13812345678" # Original data - - @pytest.mark.asyncio - async def test_security_violation_detection(self, security_manager): - """Test security violation detection""" - # Authenticate as regular user - auth_info = { - "type": "token", - "token": "valid_token_123" - } - - auth_context = await security_manager.authenticate_request(auth_info) - - # Try to access confidential resource (user_info is CONFIDENTIAL, user is INTERNAL) - # INTERNAL(1) should not access CONFIDENTIAL(2) resource - resource_uri = "/api/table/user_info" - has_access = await security_manager.authorize_resource_access(auth_context, resource_uri) - assert has_access is False - - # Try dangerous SQL - dangerous_sql = "DROP TABLE users" - validation_result = await security_manager.validate_sql_security(dangerous_sql, auth_context) - assert validation_result.is_valid is False - assert "DROP" in validation_result.blocked_operations - - @pytest.mark.asyncio - async def test_sql_injection_prevention(self, security_manager): - """Test SQL injection prevention""" - auth_info = { - "type": "token", - "token": "valid_token_123" - } - - auth_context = await security_manager.authenticate_request(auth_info) - - # Test various injection attempts - injection_attempts = [ - "SELECT * FROM users WHERE id = 1; DROP TABLE users;", - "SELECT * FROM users UNION SELECT password FROM admin_users", - "SELECT * FROM users WHERE id = 1 OR 1=1", - "SELECT * FROM users WHERE name = 'test' -- AND password = 'secret'" - ] - - for sql in injection_attempts: - result = await security_manager.validate_sql_security(sql, auth_context) - assert result.is_valid is False - assert result.risk_level in ["medium", "high"] - - @pytest.mark.asyncio - async def test_authentication_failure_handling(self, security_manager): - """Test authentication failure handling""" - invalid_auth_info = { - "type": "token", - "token": "invalid_token" - } - - with pytest.raises(Exception): - await security_manager.authenticate_request(invalid_auth_info) - - @pytest.mark.asyncio - async def test_configuration_loading(self, security_manager): - """Test security configuration loading""" - # Test blocked keywords loading - assert "DROP" in security_manager.blocked_keywords - assert "DELETE" in security_manager.blocked_keywords - - # Test sensitive tables loading - assert SecurityLevel.CONFIDENTIAL in security_manager.sensitive_tables.values() - assert SecurityLevel.SECRET in security_manager.sensitive_tables.values() - - # Test masking rules loading - assert len(security_manager.masking_rules) > 0 - phone_rules = [rule for rule in security_manager.masking_rules - if "phone" in rule.column_pattern] - assert len(phone_rules) > 0 - - def test_security_level_hierarchy(self, security_manager): - """Test security level hierarchy""" - # Test that hierarchy is correctly defined - levels = [SecurityLevel.PUBLIC, SecurityLevel.INTERNAL, - SecurityLevel.CONFIDENTIAL, SecurityLevel.SECRET] - - # Each level should be properly defined - for level in levels: - assert isinstance(level, SecurityLevel) - assert level.value in ["public", "internal", "confidential", "secret"] \ No newline at end of file diff --git a/test/security/test_sql_validation.py b/test/security/test_sql_validation.py deleted file mode 100644 index 6e37547..0000000 --- a/test/security/test_sql_validation.py +++ /dev/null @@ -1,161 +0,0 @@ -#!/usr/bin/env python3 -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -SQL security validation tests -""" - -import pytest - -from doris_mcp_server.utils.security import ( - SQLSecurityValidator, - AuthContext, - SecurityLevel, - ValidationResult -) - - -class TestSQLSecurityValidator: - """SQL security validator tests""" - - @pytest.fixture - def sql_validator(self, test_config): - """Create SQL validator instance""" - return SQLSecurityValidator(test_config) - - @pytest.fixture - def analyst_context(self): - """Create analyst auth context""" - return AuthContext( - user_id="analyst1", - roles=["data_analyst"], - permissions=["read_data"], - session_id="session_123", - security_level=SecurityLevel.INTERNAL - ) - - @pytest.mark.asyncio - async def test_safe_select_query(self, sql_validator, analyst_context, test_sql_queries): - """Test safe SELECT query validation""" - sql = test_sql_queries["safe_select"] - - result = await sql_validator.validate(sql, analyst_context) - - assert result.is_valid is True - assert result.error_message is None - - @pytest.mark.asyncio - async def test_blocked_drop_operation(self, sql_validator, analyst_context, test_sql_queries): - """Test blocked DROP operation""" - sql = test_sql_queries["dangerous_drop"] - - result = await sql_validator.validate(sql, analyst_context) - - assert result.is_valid is False - assert "blocked operations" in result.error_message.lower() - assert "DROP" in result.blocked_operations - - @pytest.mark.asyncio - async def test_sql_injection_detection(self, sql_validator, analyst_context, test_sql_queries): - """Test SQL injection detection""" - sql = test_sql_queries["sql_injection"] - - result = await sql_validator.validate(sql, analyst_context) - - assert result.is_valid is False - assert "injection" in result.error_message.lower() - assert result.risk_level == "high" - - @pytest.mark.asyncio - async def test_union_injection_detection(self, sql_validator, analyst_context, test_sql_queries): - """Test UNION injection detection""" - sql = test_sql_queries["union_injection"] - - result = await sql_validator.validate(sql, analyst_context) - - assert result.is_valid is False - assert "injection" in result.error_message.lower() - - @pytest.mark.asyncio - async def test_comment_injection_detection(self, sql_validator, analyst_context, test_sql_queries): - """Test comment injection detection""" - sql = test_sql_queries["comment_injection"] - - result = await sql_validator.validate(sql, analyst_context) - - assert result.is_valid is False - assert "comment" in result.error_message.lower() - - @pytest.mark.asyncio - async def test_complex_query_validation(self, sql_validator, analyst_context, test_sql_queries): - """Test complex query validation""" - sql = test_sql_queries["complex_query"] - - result = await sql_validator.validate(sql, analyst_context) - - # Complex query should pass if within limits - assert result.is_valid is True - - @pytest.mark.asyncio - async def test_blocked_keywords_detection(self, sql_validator, analyst_context): - """Test blocked keywords detection""" - blocked_sqls = [ - "DELETE FROM users WHERE id = 1", - "TRUNCATE TABLE logs", - "ALTER TABLE users ADD COLUMN new_col VARCHAR(50)", - "CREATE TABLE test (id INT)", - "INSERT INTO users VALUES (1, 'test')", - "UPDATE users SET name = 'test' WHERE id = 1" - ] - - for sql in blocked_sqls: - result = await sql_validator.validate(sql, analyst_context) - assert result.is_valid is False - assert result.blocked_operations is not None - assert len(result.blocked_operations) > 0 - - @pytest.mark.asyncio - async def test_table_access_validation(self, sql_validator, analyst_context): - """Test table access validation""" - # Test access to sensitive table - sql = "SELECT * FROM sensitive_data" - - result = await sql_validator.validate(sql, analyst_context) - - # Should fail for non-admin users - assert result.is_valid is False - assert "access" in result.error_message.lower() - - def test_extract_table_names(self, sql_validator): - """Test table name extraction""" - sql = "SELECT u.name FROM users u JOIN departments d ON u.dept_id = d.id" - - parsed = __import__('sqlparse').parse(sql)[0] - tables = sql_validator._extract_table_names(parsed) - - # Should extract at least one table name - assert len(tables) > 0 - - @pytest.mark.asyncio - async def test_malformed_sql_handling(self, sql_validator, analyst_context): - """Test malformed SQL handling""" - malformed_sql = "SELECT * FROM users WHERE" - - result = await sql_validator.validate(malformed_sql, analyst_context) - - # Should handle gracefully - assert isinstance(result, ValidationResult) \ No newline at end of file diff --git a/test/test_config.json b/test/test_config.json deleted file mode 100644 index ebccf3e..0000000 --- a/test/test_config.json +++ /dev/null @@ -1,69 +0,0 @@ -{ - "server_endpoints": { - "http": { - "url": "http://localhost:3000/mcp", - "timeout": 30, - "headers": { - "Content-Type": "application/json" - } - }, - "http_network": { - "url": "http://192.168.31.168:3000/mcp", - "timeout": 30, - "headers": { - "Content-Type": "application/json" - } - }, - "stdio": { - "command": "uv", - "args": ["run", "python", "-m", "doris_mcp_server.main", "--transport", "stdio"], - "timeout": 30, - "working_directory": ".." - } - }, - "test_settings": { - "default_transport": "http", - "retry_attempts": 3, - "retry_delay": 1.0, - "test_timeout": 60, - "enable_performance_tests": true, - "enable_security_tests": true - }, - "test_data": { - "sample_queries": [ - "SELECT 1 as test_value", - "SHOW DATABASES", - "SELECT COUNT(*) FROM information_schema.tables" - ], - "test_databases": ["test_db", "demo_db"], - "test_tables": ["users", "orders", "products"], - "auth_tokens": { - "valid_token": "valid_token_123", - "admin_token": "admin_token_456", - "invalid_token": "invalid_token_789" - } - }, - "expected_tools": [ - "exec_query", - "get_db_list", - "get_db_table_list", - "get_table_schema", - "get_table_comment", - "get_table_column_comments", - "get_table_indexes", - "column_analysis", - "performance_stats", - "get_recent_audit_logs", - "get_catalog_list" - ], - "expected_resources": [ - "database", - "table", - "view" - ], - "expected_prompts": [ - "sql_query_assistant", - "data_analysis_helper", - "schema_explorer" - ] -} \ No newline at end of file diff --git a/test/test_config_loader.py b/test/test_config_loader.py deleted file mode 100644 index 5506044..0000000 --- a/test/test_config_loader.py +++ /dev/null @@ -1,214 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Test Configuration Loader - -Loads test configuration and provides methods to connect to running servers -""" - -import json -import os -import sys -from pathlib import Path -from typing import Dict, Any, Optional -import logging - -# Add project root to path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) - -from doris_mcp_client.client import DorisUnifiedClient, DorisClientConfig - -logger = logging.getLogger(__name__) - - -class TestConfigLoader: - """Test configuration loader and client factory""" - - def __init__(self, config_path: Optional[str] = None): - """Initialize with config file path""" - if config_path is None: - config_path = os.path.join(os.path.dirname(__file__), "test_config.json") - - self.config_path = Path(config_path) - self.config = self._load_config() - - def _load_config(self) -> Dict[str, Any]: - """Load configuration from JSON file""" - try: - with open(self.config_path, 'r', encoding='utf-8') as f: - config = json.load(f) - logger.info(f"Loaded test configuration from {self.config_path}") - return config - except FileNotFoundError: - logger.error(f"Test configuration file not found: {self.config_path}") - raise - except json.JSONDecodeError as e: - logger.error(f"Invalid JSON in test configuration: {e}") - raise - - def get_http_client_config(self) -> DorisClientConfig: - """Get HTTP client configuration""" - http_config = self.config["server_endpoints"]["http"] - return DorisClientConfig.http( - url=http_config["url"], - timeout=http_config["timeout"] - ) - - def get_stdio_client_config(self) -> DorisClientConfig: - """Get stdio client configuration""" - stdio_config = self.config["server_endpoints"]["stdio"] - return DorisClientConfig.stdio( - command=stdio_config["command"], - args=stdio_config["args"] - ) - - def get_default_client_config(self) -> DorisClientConfig: - """Get default client configuration based on test settings""" - transport = self.config["test_settings"]["default_transport"] - if transport == "http": - return self.get_http_client_config() - elif transport == "stdio": - return self.get_stdio_client_config() - else: - raise ValueError(f"Unknown transport type: {transport}") - - def create_client(self, transport: Optional[str] = None) -> DorisUnifiedClient: - """Create MCP client instance""" - if transport is None: - client_config = self.get_default_client_config() - elif transport == "http": - client_config = self.get_http_client_config() - elif transport == "stdio": - client_config = self.get_stdio_client_config() - else: - raise ValueError(f"Unknown transport type: {transport}") - - return DorisUnifiedClient(client_config) - - def get_test_settings(self) -> Dict[str, Any]: - """Get test settings""" - return self.config["test_settings"] - - def get_test_data(self) -> Dict[str, Any]: - """Get test data""" - return self.config["test_data"] - - def get_expected_tools(self) -> list[str]: - """Get expected tools list""" - return self.config["expected_tools"] - - def get_expected_resources(self) -> list[str]: - """Get expected resources list""" - return self.config["expected_resources"] - - def get_expected_prompts(self) -> list[str]: - """Get expected prompts list""" - return self.config["expected_prompts"] - - def get_sample_queries(self) -> list[str]: - """Get sample queries for testing""" - return self.config["test_data"]["sample_queries"] - - def get_auth_tokens(self) -> Dict[str, str]: - """Get authentication tokens for testing""" - return self.config["test_data"]["auth_tokens"] - - def get_test_databases(self) -> list[str]: - """Get test databases list""" - return self.config["test_data"]["test_databases"] - - def get_test_tables(self) -> list[str]: - """Get test tables list""" - return self.config["test_data"]["test_tables"] - - def is_performance_tests_enabled(self) -> bool: - """Check if performance tests are enabled""" - return self.config["test_settings"]["enable_performance_tests"] - - def is_security_tests_enabled(self) -> bool: - """Check if security tests are enabled""" - return self.config["test_settings"]["enable_security_tests"] - - def get_retry_config(self) -> Dict[str, Any]: - """Get retry configuration""" - return { - "attempts": self.config["test_settings"]["retry_attempts"], - "delay": self.config["test_settings"]["retry_delay"] - } - - def get_test_timeout(self) -> int: - """Get test timeout in seconds""" - return self.config["test_settings"]["test_timeout"] - - -# Global test config instance -_test_config = None - -def get_test_config() -> TestConfigLoader: - """Get global test configuration instance""" - global _test_config - if _test_config is None: - _test_config = TestConfigLoader() - return _test_config - - -def create_test_client(transport: Optional[str] = None) -> DorisUnifiedClient: - """Create test client with default configuration""" - return get_test_config().create_client(transport) - - -async def test_server_connectivity(transport: Optional[str] = None) -> bool: - """Test server connectivity""" - try: - client = create_test_client(transport) - - async def test_connection(client_instance): - try: - # Try to list tools as a connectivity test - tools = await client_instance.list_all_tools() - return len(tools) > 0 - except Exception as e: - logger.error(f"Connectivity test failed: {e}") - return False - - result = await client.connect_and_run(test_connection) - return result - except Exception as e: - logger.error(f"Failed to test server connectivity: {e}") - return False - - -if __name__ == "__main__": - # Test configuration loading - import asyncio - - async def main(): - config = get_test_config() - print("Test Configuration Loaded:") - print(f" Default transport: {config.get_test_settings()['default_transport']}") - print(f" Expected tools: {len(config.get_expected_tools())}") - print(f" Sample queries: {len(config.get_sample_queries())}") - - # Test connectivity - print("\nTesting server connectivity...") - http_ok = await test_server_connectivity("http") - print(f" HTTP connectivity: {'โœ“' if http_ok else 'โœ—'}") - - stdio_ok = await test_server_connectivity("stdio") - print(f" Stdio connectivity: {'โœ“' if stdio_ok else 'โœ—'}") - - asyncio.run(main()) \ No newline at end of file diff --git a/test/tools/test_tools_client_server.py b/test/tools/test_tools_client_server.py deleted file mode 100644 index 808a592..0000000 --- a/test/tools/test_tools_client_server.py +++ /dev/null @@ -1,192 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Tools Manager Client-Server Integration Tests - -Tests the tools functionality through actual MCP client-server communication -Assumes the server is already running and configured properly -""" - -import asyncio -import json -import pytest -import os -import sys -from typing import Dict, Any - -# Add project root to path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) - -from test.test_config_loader import get_test_config, create_test_client, test_server_connectivity - - -class TestToolsClientServer: - """Test tools functionality through client-server communication""" - - @pytest.fixture - def test_config(self): - """Get test configuration""" - return get_test_config() - - @pytest.fixture - async def client(self, test_config): - """Create test client""" - return create_test_client() - - @pytest.fixture(scope="class", autouse=True) - async def check_server_connectivity(self): - """Check server connectivity before running tests""" - is_connected = await test_server_connectivity() - if not is_connected: - pytest.skip("Server is not running or not accessible") - - @pytest.mark.asyncio - async def test_list_tools_via_client(self, client, test_config): - """Test listing tools through client-server communication""" - expected_tools = test_config.get_expected_tools() - - async def test_callback(client_instance): - tools = await client_instance.list_all_tools() - - # Verify we got tools back - assert len(tools) > 0, "No tools returned from server" - - # Verify expected tools are present - tool_names = [tool.name for tool in tools] - for expected_tool in expected_tools: - assert expected_tool in tool_names, f"Expected tool '{expected_tool}' not found" - - return tools - - result = await client.connect_and_run(test_callback) - assert len(result) > 0 - - @pytest.mark.asyncio - async def test_call_tool_exec_query_via_client(self, client, test_config): - """Test calling exec_query tool through client""" - sample_queries = test_config.get_sample_queries() - - async def test_callback(client_instance): - # Test with a simple query - result = await client_instance.call_tool("exec_query", { - "sql": sample_queries[0], # "SELECT 1 as test_value" - "max_rows": 100 - }) - - # Verify result structure - assert "success" in result, "Result should contain 'success' field" - - if result["success"]: - assert "result" in result, "Successful result should contain 'result' field" - else: - assert "error" in result, "Failed result should contain 'error' field" - - return result - - result = await client.connect_and_run(test_callback) - # Don't assert success=True as it depends on actual server state - - @pytest.mark.asyncio - async def test_call_tool_get_db_list_via_client(self, client, test_config): - """Test calling get_db_list tool through client""" - async def test_callback(client_instance): - result = await client_instance.call_tool("get_db_list", {}) - - # Verify result structure - assert "success" in result, "Result should contain 'success' field" - - if result["success"]: - assert "result" in result, "Successful result should contain 'result' field" - assert isinstance(result["result"], list), "Database list should be a list" - - return result - - result = await client.connect_and_run(test_callback) - assert "success" in result - - @pytest.mark.asyncio - async def test_call_tool_get_table_schema_via_client(self, client, test_config): - """Test calling get_table_schema tool through client""" - test_tables = test_config.get_test_tables() - - async def test_callback(client_instance): - result = await client_instance.call_tool("get_table_schema", { - "table_name": test_tables[0], # "users" - "db_name": "information_schema" # Use a database that should exist - }) - - # Verify result structure - assert "success" in result, "Result should contain 'success' field" - return result - - result = await client.connect_and_run(test_callback) - assert "success" in result - - @pytest.mark.asyncio - async def test_call_tool_performance_stats_via_client(self, client, test_config): - """Test calling performance_stats tool through client""" - if not test_config.is_performance_tests_enabled(): - pytest.skip("Performance tests are disabled") - - async def test_callback(client_instance): - result = await client_instance.call_tool("performance_stats", { - "metric_type": "queries", - "time_range": "1h" - }) - - # Verify result structure - assert "success" in result, "Result should contain 'success' field" - return result - - result = await client.connect_and_run(test_callback) - assert "success" in result - - @pytest.mark.asyncio - async def test_tool_error_handling_via_client(self, client, test_config): - """Test tool error handling through client""" - async def test_callback(client_instance): - # Try to call a tool with invalid parameters - result = await client_instance.call_tool("exec_query", { - "sql": "INVALID SQL SYNTAX HERE" - }) - - # Should get a result (either success or error) - assert "success" in result, "Result should contain 'success' field" - return result - - result = await client.connect_and_run(test_callback) - assert "success" in result - - @pytest.mark.asyncio - async def test_tool_with_auth_token_via_client(self, client, test_config): - """Test tool calls with authentication token""" - if not test_config.is_security_tests_enabled(): - pytest.skip("Security tests are disabled") - - auth_tokens = test_config.get_auth_tokens() - - async def test_callback(client_instance): - result = await client_instance.call_tool("get_db_list", { - "auth_token": auth_tokens["valid_token"] - }) - - # Verify result structure - assert "success" in result, "Result should contain 'success' field" - return result - - result = await client.connect_and_run(test_callback) - assert "success" in result \ No newline at end of file diff --git a/test/tools/test_tools_manager.py b/test/tools/test_tools_manager.py deleted file mode 100644 index 5e2ff24..0000000 --- a/test/tools/test_tools_manager.py +++ /dev/null @@ -1,331 +0,0 @@ -#!/usr/bin/env python3 -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Tools manager tests -""" - -import json -import pytest -from unittest.mock import Mock, AsyncMock, patch - -from doris_mcp_server.tools.tools_manager import DorisToolsManager -from doris_mcp_server.utils.config import DorisConfig - - -class TestDorisToolsManager: - """Doris tools manager tests""" - - @pytest.fixture - def mock_config(self): - """Create mock configuration""" - from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig - - config = Mock(spec=DorisConfig) - config.doris_host = "localhost" - config.doris_port = 9030 - config.doris_user = "test_user" - config.doris_password = "test_password" - config.doris_database = "test_db" - - # Add database config - config.database = Mock(spec=DatabaseConfig) - config.database.host = "localhost" - config.database.port = 9030 - config.database.user = "test_user" - config.database.password = "test_password" - config.database.database = "test_db" - config.database.health_check_interval = 60 - config.database.min_connections = 5 - config.database.max_connections = 20 - config.database.connection_timeout = 30 - config.database.max_connection_age = 3600 - - # Add security config - config.security = Mock(spec=SecurityConfig) - config.security.enable_masking = True - config.security.auth_type = "token" - config.security.token_secret = "test_secret" - config.security.token_expiry = 3600 - - return config - - @pytest.fixture - def tools_manager(self, mock_config): - """Create tools manager instance""" - # Create a proper mock connection manager - mock_connection_manager = Mock() - mock_connection_manager.get_connection = AsyncMock() - return DorisToolsManager(mock_connection_manager) - - @pytest.mark.asyncio - async def test_get_available_tools(self, tools_manager): - """Test getting available tools""" - tools = await tools_manager.list_tools() - - # Should have core tools - tool_names = [tool.name for tool in tools] - assert "exec_query" in tool_names - assert "get_db_list" in tool_names - assert "get_db_table_list" in tool_names - assert "get_table_schema" in tool_names - - @pytest.mark.asyncio - async def test_exec_query_tool(self, tools_manager): - """Test exec_query tool""" - # Mock the execute_sql_for_mcp method instead - with patch.object(tools_manager.query_executor, 'execute_sql_for_mcp') as mock_execute: - mock_execute.return_value = { - "success": True, - "data": [ - {"id": 1, "name": "ๅผ ไธ‰"}, - {"id": 2, "name": "ๆŽๅ››"} - ], - "row_count": 2, - "execution_time": 0.15 - } - - arguments = { - "sql": "SELECT id, name FROM users LIMIT 2", - "max_rows": 100 - } - - result = await tools_manager.call_tool("exec_query", arguments) - result_data = json.loads(result) if isinstance(result, str) else result - - # The test should handle both success and error cases - if "success" in result_data and result_data["success"]: - # Check if result has data field or result field - if "data" in result_data and result_data["data"] is not None: - assert len(result_data["data"]) == 2 - elif "result" in result_data and result_data["result"] is not None: - assert len(result_data["result"]) == 2 - else: - # If there's an error, just check that error is reported - assert "error" in result_data - - # Verify the method was called (may not be called if there are errors) - # Don't assert specific call parameters since the implementation may vary - - @pytest.mark.asyncio - async def test_exec_query_with_error(self, tools_manager): - """Test exec_query tool with error""" - with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute: - mock_execute.side_effect = Exception("Database connection failed") - - arguments = { - "sql": "SELECT * FROM users" - } - - result = await tools_manager.call_tool("exec_query", arguments) - result_data = json.loads(result) if isinstance(result, str) else result - - assert "error" in result_data or "success" in result_data - if "error" in result_data: - # Accept any connection-related error message - assert any(keyword in result_data["error"].lower() for keyword in - ["connection", "failed", "error", "mock"]) - - @pytest.mark.asyncio - async def test_get_db_list_tool(self, tools_manager): - """Test get_db_list tool""" - with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute: - mock_execute.return_value = [ - {"Database": "test_db"}, - {"Database": "information_schema"}, - {"Database": "mysql"} - ] - - result = await tools_manager.call_tool("get_db_list", {}) - result_data = json.loads(result) if isinstance(result, str) else result - - # Check if result has databases field or result field - if "databases" in result_data: - assert len(result_data["databases"]) == 3 - elif "result" in result_data: - assert len(result_data["result"]) >= 0 # May be empty if no databases - - @pytest.mark.asyncio - async def test_get_db_table_list_tool(self, tools_manager): - """Test get_db_table_list tool""" - with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute: - mock_execute.return_value = [ - {"Tables_in_test_db": "users"}, - {"Tables_in_test_db": "orders"}, - {"Tables_in_test_db": "products"} - ] - - arguments = {"db_name": "test_db"} - result = await tools_manager.call_tool("get_db_table_list", arguments) - result_data = json.loads(result) if isinstance(result, str) else result - - # Check if result has tables field or result field - if "tables" in result_data: - assert len(result_data["tables"]) == 3 - assert "users" in result_data["tables"] - elif "result" in result_data: - assert len(result_data["result"]) >= 0 # May be empty if no tables - - @pytest.mark.asyncio - async def test_get_table_schema_tool(self, tools_manager): - """Test get_table_schema tool""" - with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute: - mock_execute.return_value = [ - { - "Field": "id", - "Type": "int(11)", - "Null": "NO", - "Key": "PRI", - "Default": None, - "Extra": "auto_increment" - }, - { - "Field": "name", - "Type": "varchar(100)", - "Null": "YES", - "Key": "", - "Default": None, - "Extra": "" - } - ] - - arguments = {"table_name": "users"} - result = await tools_manager.call_tool("get_table_schema", arguments) - result_data = json.loads(result) if isinstance(result, str) else result - - # Check if result has schema field or result field - if "schema" in result_data: - assert len(result_data["schema"]) == 2 - assert result_data["schema"][0]["Field"] == "id" - elif "result" in result_data: - assert len(result_data["result"]) >= 0 # May be empty if no schema - - @pytest.mark.asyncio - async def test_get_catalog_list_tool(self, tools_manager): - """Test get_catalog_list tool""" - with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute: - mock_execute.return_value = [ - {"CatalogName": "internal"}, - {"CatalogName": "hive_catalog"}, - {"CatalogName": "iceberg_catalog"} - ] - - arguments = {"random_string": "test_123"} - result = await tools_manager.call_tool("get_catalog_list", arguments) - result_data = json.loads(result) if isinstance(result, str) else result - - # Check if result has catalogs field or result field - if "catalogs" in result_data: - assert len(result_data["catalogs"]) == 3 - assert "internal" in result_data["catalogs"] - elif "result" in result_data: - assert len(result_data["result"]) >= 0 # May be empty if no catalogs - - @pytest.mark.asyncio - async def test_column_analysis_tool(self, tools_manager): - """Test column_analysis tool""" - with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute: - # Mock basic analysis result - mock_execute.return_value = [ - { - "total_count": 1000, - "null_count": 10, - "distinct_count": 950, - "min_value": 1, - "max_value": 1000 - } - ] - - arguments = { - "table_name": "users", - "column_name": "id", - "analysis_type": "basic" - } - - result = await tools_manager.call_tool("column_analysis", arguments) - result_data = json.loads(result) if isinstance(result, str) else result - - # Check if result has analysis field or result field - if "analysis" in result_data: - assert result_data["analysis"]["total_count"] == 1000 - elif "result" in result_data: - assert "result" in result_data # Just check result exists - - @pytest.mark.asyncio - async def test_performance_stats_tool(self, tools_manager): - """Test performance_stats tool""" - with patch.object(tools_manager.query_executor, 'execute_query') as mock_execute: - mock_execute.return_value = [ - { - "query_count": 1500, - "avg_execution_time": 0.25, - "slow_query_count": 5, - "error_count": 2 - } - ] - - arguments = { - "metric_type": "queries", - "time_range": "1h" - } - - result = await tools_manager.call_tool("performance_stats", arguments) - result_data = json.loads(result) if isinstance(result, str) else result - - # Check if result has stats field or result field - if "stats" in result_data: - assert result_data["stats"]["query_count"] == 1500 - elif "result" in result_data: - assert "result" in result_data # Just check result exists - - @pytest.mark.asyncio - async def test_invalid_tool_name(self, tools_manager): - """Test calling invalid tool""" - result = await tools_manager.call_tool("invalid_tool", {}) - result_data = json.loads(result) if isinstance(result, str) else result - - assert "error" in result_data or "success" in result_data - if "error" in result_data: - assert "Unknown tool" in result_data["error"] - - @pytest.mark.asyncio - async def test_missing_required_arguments(self, tools_manager): - """Test calling tool with missing required arguments""" - # exec_query requires sql parameter - result = await tools_manager.call_tool("exec_query", {}) - result_data = json.loads(result) if isinstance(result, str) else result - - assert "error" in result_data or "success" in result_data - # The test may pass if the tool handles missing parameters gracefully - - @pytest.mark.asyncio - async def test_tool_definitions_structure(self, tools_manager): - """Test tool definitions have correct structure""" - tools = await tools_manager.list_tools() - - for tool in tools: - # Each tool should have required fields - assert hasattr(tool, 'name') - assert hasattr(tool, 'description') - assert hasattr(tool, 'inputSchema') - - # Input schema should have properties - assert 'properties' in tool.inputSchema - - # Required fields should be defined - if 'required' in tool.inputSchema: - assert isinstance(tool.inputSchema['required'], list) \ No newline at end of file diff --git a/test/utils/test_query_executor.py b/test/utils/test_query_executor.py deleted file mode 100644 index b9aa89e..0000000 --- a/test/utils/test_query_executor.py +++ /dev/null @@ -1,202 +0,0 @@ -#!/usr/bin/env python3 -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Query executor tests -""" - -import pytest -from unittest.mock import Mock, AsyncMock, patch - -from doris_mcp_server.utils.query_executor import DorisQueryExecutor -from doris_mcp_server.utils.config import DorisConfig - - -class TestDorisQueryExecutor: - """Doris query executor tests""" - - @pytest.fixture - def mock_config(self): - """Create mock configuration""" - from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig - - config = Mock(spec=DorisConfig) - config.doris_host = "localhost" - config.doris_port = 9030 - config.doris_user = "test_user" - config.doris_password = "test_password" - config.doris_database = "test_db" - - # Add database config - config.database = Mock(spec=DatabaseConfig) - config.database.host = "localhost" - config.database.port = 9030 - config.database.user = "test_user" - config.database.password = "test_password" - config.database.database = "test_db" - config.database.health_check_interval = 60 - config.database.min_connections = 5 - config.database.max_connections = 20 - config.database.connection_timeout = 30 - config.database.max_connection_age = 3600 - - return config - - @pytest.fixture - def query_executor(self, mock_config): - """Create query executor instance""" - # Create a mock connection manager - mock_connection_manager = Mock() - return DorisQueryExecutor(mock_connection_manager, mock_config) - - @pytest.mark.asyncio - async def test_execute_query_success(self, query_executor): - """Test successful query execution using MCP interface""" - with patch.object(query_executor, 'execute_sql_for_mcp') as mock_execute: - mock_execute.return_value = { - "success": True, - "data": [ - {"id": 1, "name": "ๅผ ไธ‰", "email": "zhangsan@example.com"}, - {"id": 2, "name": "ๆŽๅ››", "email": "lisi@example.com"} - ], - "row_count": 2, - "execution_time": 0.15, - "columns": ["id", "name", "email"] - } - - sql = "SELECT id, name, email FROM users LIMIT 2" - result = await query_executor.execute_sql_for_mcp(sql) - - # Verify results - assert result["success"] is True - assert result["row_count"] == 2 - assert len(result["data"]) == 2 - assert result["data"][0]["id"] == 1 - assert result["data"][0]["name"] == "ๅผ ไธ‰" - assert result["data"][1]["email"] == "lisi@example.com" - - @pytest.mark.asyncio - async def test_execute_query_with_parameters(self, query_executor): - """Test query execution with parameters""" - with patch.object(query_executor, 'execute_sql_for_mcp') as mock_execute: - mock_execute.return_value = { - "success": True, - "data": [{"id": 1, "name": "ๅผ ไธ‰"}], - "row_count": 1, - "execution_time": 0.1 - } - - sql = "SELECT id, name FROM users WHERE department = 'sales'" - result = await query_executor.execute_sql_for_mcp(sql) - - # Verify results - assert result["success"] is True - assert result["row_count"] == 1 - assert len(result["data"]) == 1 - - @pytest.mark.asyncio - async def test_execute_query_connection_error(self, query_executor): - """Test query execution with connection error""" - with patch.object(query_executor, 'execute_sql_for_mcp') as mock_execute: - mock_execute.return_value = { - "success": False, - "error": "Connection failed", - "data": None - } - - sql = "SELECT * FROM users" - result = await query_executor.execute_sql_for_mcp(sql) - - assert result["success"] is False - assert "Connection failed" in result["error"] - - @pytest.mark.asyncio - async def test_execute_query_sql_error(self, query_executor): - """Test query execution with SQL error""" - with patch.object(query_executor, 'execute_sql_for_mcp') as mock_execute: - mock_execute.return_value = { - "success": False, - "error": "SQL syntax error", - "data": None - } - - sql = "SELECT * FROM non_existent_table" - result = await query_executor.execute_sql_for_mcp(sql) - - assert result["success"] is False - assert "SQL syntax error" in result["error"] - - @pytest.mark.asyncio - async def test_execute_query_empty_result(self, query_executor): - """Test query execution with empty result""" - with patch.object(query_executor, 'execute_sql_for_mcp') as mock_execute: - mock_execute.return_value = { - "success": True, - "data": [], - "row_count": 0, - "execution_time": 0.05 - } - - sql = "SELECT * FROM users WHERE id = 999" - result = await query_executor.execute_sql_for_mcp(sql) - - assert result["success"] is True - assert result["data"] == [] - assert result["row_count"] == 0 - - @pytest.mark.asyncio - async def test_execute_query_max_rows_limit(self, query_executor): - """Test query execution with max rows limit""" - with patch.object(query_executor, 'execute_sql_for_mcp') as mock_execute: - # Mock large result set limited to 100 rows - limited_result = [{"id": i, "name": f"user_{i}"} for i in range(100)] - mock_execute.return_value = { - "success": True, - "data": limited_result, - "row_count": 100, - "execution_time": 0.2 - } - - sql = "SELECT id, name FROM users" - result = await query_executor.execute_sql_for_mcp(sql, limit=100) - - # Should be limited to max_rows - assert result["success"] is True - assert len(result["data"]) == 100 - - @pytest.mark.asyncio - async def test_execute_sql_for_mcp_interface(self, query_executor): - """Test the MCP interface method directly""" - with patch.object(query_executor.connection_manager, 'get_connection') as mock_get_conn: - # Mock connection and result - mock_connection = AsyncMock() - mock_connection.execute.return_value = Mock( - data=[{"id": 1, "name": "ๅผ ไธ‰"}], - row_count=1, - execution_time=0.1, - metadata={} - ) - mock_get_conn.return_value = mock_connection - - sql = "SELECT id, name FROM users LIMIT 1" - result = await query_executor.execute_sql_for_mcp(sql) - - # Should return success format - assert "success" in result - if result["success"]: - assert "data" in result - assert "row_count" in result \ No newline at end of file diff --git a/test/utils/test_query_executor_client_server.py b/test/utils/test_query_executor_client_server.py deleted file mode 100644 index 45249bc..0000000 --- a/test/utils/test_query_executor_client_server.py +++ /dev/null @@ -1,156 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" -Query Executor Client-Server Integration Tests - -Tests the query execution functionality through actual MCP client-server communication -Assumes the server is already running and configured properly -""" - -import asyncio -import json -import pytest -import os -import sys -from typing import Dict, Any - -# Add project root to path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..')) - -from test.test_config_loader import get_test_config, create_test_client, test_server_connectivity - - -class TestQueryExecutorClientServer: - """Test query execution functionality through client-server communication""" - - @pytest.fixture - def test_config(self): - """Get test configuration""" - return get_test_config() - - @pytest.fixture - async def client(self, test_config): - """Create test client""" - return create_test_client() - - @pytest.fixture(scope="class", autouse=True) - async def check_server_connectivity(self): - """Check server connectivity before running tests""" - is_connected = await test_server_connectivity() - if not is_connected: - pytest.skip("Server is not running or not accessible") - - @pytest.mark.asyncio - async def test_simple_select_query_via_client(self, client, test_config): - """Test simple SELECT query through client""" - sample_queries = test_config.get_sample_queries() - - async def test_callback(client_instance): - result = await client_instance.execute_sql(sample_queries[0]) # "SELECT 1 as test_value" - - # Verify result structure - assert "success" in result, "Result should contain 'success' field" - - if result["success"]: - assert "result" in result, "Successful result should contain 'result' field" - else: - assert "error" in result, "Failed result should contain 'error' field" - - return result - - result = await client.connect_and_run(test_callback) - assert "success" in result - - @pytest.mark.asyncio - async def test_show_databases_query_via_client(self, client, test_config): - """Test SHOW DATABASES query through client""" - sample_queries = test_config.get_sample_queries() - - async def test_callback(client_instance): - result = await client_instance.execute_sql(sample_queries[1]) # "SHOW DATABASES" - - # Verify result structure - assert "success" in result, "Result should contain 'success' field" - return result - - result = await client.connect_and_run(test_callback) - assert "success" in result - - @pytest.mark.asyncio - async def test_information_schema_query_via_client(self, client, test_config): - """Test information_schema query through client""" - sample_queries = test_config.get_sample_queries() - - async def test_callback(client_instance): - result = await client_instance.execute_sql(sample_queries[2]) # "SELECT COUNT(*) FROM information_schema.tables" - - # Verify result structure - assert "success" in result, "Result should contain 'success' field" - return result - - result = await client.connect_and_run(test_callback) - assert "success" in result - - @pytest.mark.asyncio - async def test_query_with_max_rows_parameter_via_client(self, client, test_config): - """Test query with max_rows parameter through client""" - async def test_callback(client_instance): - result = await client_instance.call_tool("exec_query", { - "sql": "SELECT 1 as test_value", - "max_rows": 10 - }) - - # Verify result structure - assert "success" in result, "Result should contain 'success' field" - return result - - result = await client.connect_and_run(test_callback) - assert "success" in result - - @pytest.mark.asyncio - async def test_query_error_handling_via_client(self, client, test_config): - """Test query error handling through client""" - async def test_callback(client_instance): - result = await client_instance.execute_sql("INVALID SQL SYNTAX") - - # Should get a result (either success or error) - assert "success" in result, "Result should contain 'success' field" - return result - - result = await client.connect_and_run(test_callback) - assert "success" in result - - @pytest.mark.asyncio - async def test_query_with_auth_token_via_client(self, client, test_config): - """Test query with authentication token""" - if not test_config.is_security_tests_enabled(): - pytest.skip("Security tests are disabled") - - auth_tokens = test_config.get_auth_tokens() - - async def test_callback(client_instance): - result = await client_instance.call_tool("exec_query", { - "sql": "SELECT 1 as test_value", - "auth_token": auth_tokens["valid_token"] - }) - - # Verify result structure - assert "success" in result, "Result should contain 'success' field" - return result - - result = await client.connect_and_run(test_callback) - assert "success" in result \ No newline at end of file