0.3.0 Release Version

This commit is contained in:
FreeOnePlus
2025-06-08 18:44:40 +08:00
parent d9fed06c92
commit 4c913743c7
54 changed files with 12649 additions and 4667 deletions

View File

@@ -1,196 +1,515 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Apache Doris MCP Server Main Entry - Primarily handles SSE mode
Apache Doris MCP Server - Enterprise Database Service Implementation
Stdio mode is handled by doris_mcp_server.mcp_core:run_stdio.
Based on Apache Doris official MCP Server architecture design, providing complete MCP protocol support
Supports independent encapsulation implementation of Resources, Tools, and Prompts
Supports both stdio and streamable HTTP startup modes
"""
import os
import sys
import argparse
import asyncio
import json
import logging
from contextlib import asynccontextmanager
from collections.abc import AsyncIterator
from dataclasses import dataclass
from typing import Dict, Any
import uvicorn
from uvicorn import Config, Server
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from dotenv import load_dotenv
from typing import Any
# Add project root to path
PROJECT_ROOT = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
sys.path.insert(0, PROJECT_ROOT)
from mcp.server import Server
from mcp.server.models import InitializationOptions
# SSE related imports
from mcp.server.fastmcp import FastMCP
from doris_mcp_server.sse_server import DorisMCPSseServer
from doris_mcp_server.streamable_server import DorisMCPStreamableServer
# Stdio related imports (only needed for tools now, maybe move tool init?)
# from mcp.server.stdio import stdio_server -> No longer used here
# Config and Tool Initializer
from doris_mcp_server.config import load_config # LOG_LEVEL might not be needed here directly
from doris_mcp_server.tools.tool_initializer import register_mcp_tools
# Load environment variables (load early for all modes)
load_dotenv(override=True)
# Get logger
logger = logging.getLogger("doris-mcp-main") # Changed logger name slightly
# --- Configuration Loading and Logging Setup ---
load_config() # Loads .env
# --- Create FastAPI App (Global Scope for SSE Mode) ---
# This 'app' object is targeted by 'mcp run doris_mcp_server/main.py:app --transport sse'
# And used when running directly with --sse
app = FastAPI(
title="Doris MCP Server (SSE Mode)",
# Lifespan will be added in start_sse_server
from mcp.types import (
Prompt,
Resource,
TextContent,
Tool,
)
# --- Removed StdioServerWrapper ---
from .tools.tools_manager import DorisToolsManager
from .tools.prompts_manager import DorisPromptsManager
from .tools.resources_manager import DorisResourcesManager
from .utils.config import DorisConfig
from .utils.db import DorisConnectionManager
from .utils.security import DorisSecurityManager
# --- Command Line Argument Parsing ---
def parse_args():
parser = argparse.ArgumentParser(description="Apache Doris MCP Server (SSE Mode Entry)")
# Only keep SSE related args here
parser.add_argument('--sse', action='store_true', help='Start SSE Web server mode (required)')
parser.add_argument('--host', type=str, default=os.getenv('SERVER_HOST', '0.0.0.0'), help='Host address')
parser.add_argument('--port', type=int, default=int(os.getenv('SERVER_PORT', os.getenv('MCP_PORT', '3000'))), help='Port number')
parser.add_argument('--debug', action='store_true', help='Enable debug mode')
parser.add_argument('--reload', action='store_true', help='Enable auto-reload')
return parser.parse_args()
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- SSE Mode Specific Code ---
@dataclass
class AppContext:
config: Dict[str, Any]
@asynccontextmanager
async def app_lifespan(app_instance: FastAPI) -> AsyncIterator[None]:
logger.info("SSE application lifecycle start...")
config = {
# Simplified config - maybe get from elsewhere?
"db_host": os.getenv("DB_HOST", "localhost"),
"db_port": int(os.getenv("DB_PORT", "9030")),
"db_user": os.getenv("DB_USER", "root"),
"db_password": os.getenv("DB_PASSWORD", ""),
"db_database": os.getenv("DB_DATABASE", "test"),
}
app_instance.state.config = config
try:
# Yield None implicitly or explicitly None
yield
finally:
logger.info("Cleaning up SSE application resources...")
class DorisServer:
"""Apache Doris MCP Server main class"""
async def start_sse_server(args):
"""Start SSE Web server mode (Configures the global 'app')"""
logger.info("Starting SSE Web server mode...")
global app
def __init__(self, config: DorisConfig):
self.config = config
self.server = Server("doris-mcp-server")
# --- Initialize MCP and Tools for SSE ---
# Create a *separate* MCP instance for SSE mode
sse_mcp = FastMCP(
name="doris-mcp-sse",
description="Apache Doris MCP Server (SSE)",
lifespan=None, # Managed by FastAPI
dependencies=["fastapi", "uvicorn", "openai", "sse_starlette"]
)
logger.info("Registering MCP tools for SSE mode...")
await register_mcp_tools(sse_mcp) # Register tools for the SSE instance
logger.info("MCP tools registered for SSE.")
# Initialize security manager
self.security_manager = DorisSecurityManager(config)
# --- Configure Lifespan and CORS for the global app ---
app.router.lifespan_context = app_lifespan
origins = os.getenv("ALLOWED_ORIGINS", "*").split(",")
allow_credentials = os.getenv("MCP_ALLOW_CREDENTIALS", "false").lower() == "true"
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=allow_credentials,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["Mcp-Session-Id"],
)
# Initialize connection manager, pass in security manager
self.connection_manager = DorisConnectionManager(config, self.security_manager)
# --- Initialize Handlers and Register Routes (Pass sse_mcp instance) ---
logger.info("Initializing SSE server handlers and registering routes...")
sse_server_handler = DorisMCPSseServer(sse_mcp, app)
streamable_server_handler = DorisMCPStreamableServer(sse_mcp, app)
logger.info("SSE Server handlers initialized and routes registered.")
# Initialize independent managers
self.resources_manager = DorisResourcesManager(self.connection_manager)
self.tools_manager = DorisToolsManager(self.connection_manager)
self.prompts_manager = DorisPromptsManager(self.connection_manager)
# --- Print Configuration and Endpoints ---
print("--- SSE Mode Configuration ---")
print(f"Server Host: {args.host}")
print(f"Server Port: {args.port}")
print(f"Allowed Origins: {origins}")
print(f"Allow Credentials: {allow_credentials}")
print(f"Log Level: {os.getenv('LOG_LEVEL', 'info')}")
print(f"Debug Mode: {args.debug}")
print(f"Reload Mode: {args.reload}")
print(f"DB Host: {os.getenv('DB_HOST')}")
print(f"DB Port: {os.getenv('DB_PORT')}")
print(f"DB User: {os.getenv('DB_USER')}")
print(f"DB Database: {os.getenv('DB_DATABASE')}")
print(f"Force Refresh Metadata: {os.getenv('FORCE_REFRESH_METADATA', 'false')}")
print("------------------------------")
base_url = f"http://{args.host}:{args.port}"
print(f"Service running at: {base_url}")
print(f" Health Check: GET {base_url}/health")
print(f" Status Check: GET {base_url}/status")
print(f" SSE Init: GET {base_url}/sse")
print(f" SSE/Legacy Messages: POST {base_url}/mcp/messages")
print(f" Streamable HTTP: GET/POST/DELETE/OPTIONS {base_url}/mcp")
print("------------------------------")
print("Use Ctrl+C to stop the service")
self.logger = logging.getLogger(f"{__name__}.DorisServer")
self._setup_handlers()
# --- Start Uvicorn Server ---
config = Config(
app=app,
host=args.host,
port=args.port,
log_level="debug" if args.debug else "info",
reload=args.reload
)
server = Server(config=config)
await server.serve()
def _setup_handlers(self):
"""Setup MCP protocol handlers"""
# --- Main Execution Logic (Simplified) ---
@self.server.list_resources()
async def handle_list_resources() -> list[Resource]:
"""Handle resource list request"""
try:
self.logger.info("Handling resource list request")
resources = await self.resources_manager.list_resources()
self.logger.info(f"Returning {len(resources)} resources")
return resources
except Exception as e:
self.logger.error(f"Failed to handle resource list request: {e}")
return []
def run_main_sync():
"""Synchronous wrapper, primarily for SSE mode now."""
sync_logger = logging.getLogger("run_main_sync")
sync_logger.info("Entering run_main_sync (SSE focus)...")
print("DEBUG: Entering run_main_sync (SSE focus)...", file=sys.stderr, flush=True)
args = parse_args()
@self.server.read_resource()
async def handle_read_resource(uri: str) -> str:
"""Handle resource read request"""
try:
self.logger.info(f"Handling resource read request: {uri}")
content = await self.resources_manager.read_resource(uri)
return content
except Exception as e:
self.logger.error(f"Failed to handle resource read request: {e}")
return json.dumps(
{"error": f"Failed to read resource: {str(e)}", "uri": uri},
ensure_ascii=False,
indent=2,
)
@self.server.list_tools()
async def handle_list_tools() -> list[Tool]:
"""Handle tool list request"""
try:
self.logger.info("Handling tool list request")
tools = await self.tools_manager.list_tools()
self.logger.info(f"Returning {len(tools)} tools")
return tools
except Exception as e:
self.logger.error(f"Failed to handle tool list request: {e}")
return []
@self.server.call_tool()
async def handle_call_tool(
name: str, arguments: dict[str, Any]
) -> list[TextContent]:
"""Handle tool call request"""
try:
self.logger.info(f"Handling tool call request: {name}")
result = await self.tools_manager.call_tool(name, arguments)
return [TextContent(type="text", text=result)]
except Exception as e:
self.logger.error(f"Failed to handle tool call request: {e}")
error_result = json.dumps(
{
"error": f"Tool call failed: {str(e)}",
"tool_name": name,
"arguments": arguments,
},
ensure_ascii=False,
indent=2,
)
return [TextContent(type="text", text=error_result)]
@self.server.list_prompts()
async def handle_list_prompts() -> list[Prompt]:
"""Handle prompt list request"""
try:
self.logger.info("Handling prompt list request")
prompts = await self.prompts_manager.list_prompts()
self.logger.info(f"Returning {len(prompts)} prompts")
return prompts
except Exception as e:
self.logger.error(f"Failed to handle prompt list request: {e}")
return []
@self.server.get_prompt()
async def handle_get_prompt(name: str, arguments: dict[str, Any]) -> str:
"""Handle prompt get request"""
try:
self.logger.info(f"Handling prompt get request: {name}")
result = await self.prompts_manager.get_prompt(name, arguments)
return result
except Exception as e:
self.logger.error(f"Failed to handle prompt get request: {e}")
error_result = json.dumps(
{
"error": f"Failed to get prompt: {str(e)}",
"prompt_name": name,
"arguments": arguments,
},
ensure_ascii=False,
indent=2,
)
return error_result
async def start_stdio(self):
"""Start stdio transport mode"""
self.logger.info("Starting Doris MCP Server (stdio mode)")
if args.sse:
try:
# Run the async SSE server setup and Uvicorn loop
asyncio.run(start_sse_server(args))
sync_logger.info("asyncio.run(start_sse_server) completed.")
print("DEBUG: asyncio.run(start_sse_server) completed.", file=sys.stderr, flush=True)
except KeyboardInterrupt:
sync_logger.info("SSE server stopped by KeyboardInterrupt.")
# Ensure connection manager is initialized
await self.connection_manager.initialize()
self.logger.info("Connection manager initialization completed")
# Start stdio server - using simpler approach
from mcp.server.stdio import stdio_server
self.logger.info("Creating stdio_server transport...")
# Try different startup approaches
try:
async with stdio_server() as streams:
read_stream, write_stream = streams
self.logger.info("stdio_server streams created successfully")
# Create initialization options
# MCP 1.8.0 requires parameters for get_capabilities
from mcp.server.lowlevel.server import NotificationOptions
capabilities = self.server.get_capabilities(
notification_options=NotificationOptions(
prompts_changed=True,
resources_changed=True,
tools_changed=True
),
experimental_capabilities={}
)
init_options = InitializationOptions(
server_name="doris-mcp-server",
server_version="1.0.0",
capabilities=capabilities,
)
self.logger.info("Initialization options created successfully")
# Run server
self.logger.info("Starting to run MCP server...")
await self.server.run(read_stream, write_stream, init_options)
except Exception as inner_e:
self.logger.error(f"stdio_server internal error: {inner_e}")
self.logger.error(f"Error type: {type(inner_e)}")
# Try to get more error information
import traceback
self.logger.error("Complete error stack:")
self.logger.error(traceback.format_exc())
# If it's ExceptionGroup, try to parse
if hasattr(inner_e, 'exceptions'):
self.logger.error(f"ExceptionGroup contains {len(inner_e.exceptions)} exceptions:")
for i, exc in enumerate(inner_e.exceptions):
self.logger.error(f" Exception {i+1}: {type(exc).__name__}: {exc}")
raise inner_e
except Exception as e:
sync_logger.critical(f"Error during asyncio.run(start_sse_server): {e}", exc_info=True)
print(f"DEBUG: Error during asyncio.run(start_sse_server): {e}", file=sys.stderr, flush=True)
self.logger.error(f"stdio server startup failed: {e}")
self.logger.error(f"Error type: {type(e)}")
raise
else:
# If run without --sse, print help/error
message = "Error: This entry point requires --sse flag. For stdio mode, use 'uv run mcp-doris' or the appropriate command for your stdio setup."
sync_logger.error(message)
print(message, file=sys.stderr)
sys.exit(1)
async def start_http(self, host: str = "localhost", port: int = 3000):
"""Start Streamable HTTP transport mode"""
self.logger.info(f"Starting Doris MCP Server (Streamable HTTP mode) - {host}:{port}")
try:
# Ensure connection manager is initialized
await self.connection_manager.initialize()
# Use Starlette and StreamableHTTPSessionManager according to official example
import uvicorn
import contextlib
from collections.abc import AsyncIterator
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from starlette.applications import Starlette
from starlette.routing import Mount, Route
from starlette.responses import JSONResponse, Response
from starlette.types import Receive, Scope, Send
# Create session manager
session_manager = StreamableHTTPSessionManager(
app=self.server,
json_response=True, # Enable JSON response
stateless=False # Maintain session state
)
self.logger.info(f"StreamableHTTP session manager created, will start at http://{host}:{port}")
# Health check endpoint
async def health_check(request):
return JSONResponse({"status": "healthy", "service": "doris-mcp-server"})
# Lifecycle manager - simplified since we manage session_manager externally
@contextlib.asynccontextmanager
async def lifespan(app: Starlette) -> AsyncIterator[None]:
"""Context manager for managing application lifecycle"""
self.logger.info("Application started!")
try:
yield
finally:
self.logger.info("Application is shutting down...")
# Create ASGI application - use direct session manager as ASGI app
starlette_app = Starlette(
debug=True,
routes=[
Route("/health", health_check, methods=["GET"]),
],
lifespan=lifespan,
)
# Custom ASGI app that handles both /mcp and /mcp/ without redirects
async def mcp_app(scope, receive, send):
# Handle lifespan events
if scope["type"] == "lifespan":
await starlette_app(scope, receive, send)
return
# Handle HTTP requests
if scope["type"] == "http":
path = scope.get("path", "")
self.logger.info(f"Received request for path: {path}")
try:
# Handle health check
if path.startswith("/health"):
await starlette_app(scope, receive, send)
return
# Handle MCP requests - both /mcp and /mcp/ go to session manager
if path == "/mcp" or path.startswith("/mcp/"):
self.logger.info(f"Handling MCP request for path: {path}")
# Log request details for debugging
method = scope.get("method", "UNKNOWN")
headers = dict(scope.get("headers", []))
self.logger.info(f"MCP Request - Method: {method}")
self.logger.info(f"MCP Request - Headers: {headers}")
# Handle Dify compatibility for GET requests
if method == "GET":
accept_header = headers.get(b'accept', b'').decode('utf-8')
user_agent = headers.get(b'user-agent', b'').decode('utf-8')
# For other GET requests, try to add application/json to Accept header
if 'text/event-stream' in accept_header and 'application/json' not in accept_header:
self.logger.info("Adding application/json to Accept header for GET request")
# Modify headers to include both content types
new_headers = []
for name, value in scope.get("headers", []):
if name == b'accept':
# Add application/json to the accept header
new_value = value.decode('utf-8') + ', application/json'
new_headers.append((name, new_value.encode('utf-8')))
else:
new_headers.append((name, value))
# Update scope with modified headers
scope = dict(scope)
scope["headers"] = new_headers
self.logger.info(f"Modified Accept header to: {new_value}")
await session_manager.handle_request(scope, receive, send)
return
# 404 for other paths
self.logger.info(f"Path not found: {path}")
response = Response("Not Found", status_code=404)
await response(scope, receive, send)
except Exception as e:
self.logger.error(f"Error handling request for {path}: {e}")
import traceback
self.logger.error(traceback.format_exc())
response = Response("Internal Server Error", status_code=500)
await response(scope, receive, send)
else:
# For other scope types, just return
self.logger.warning(f"Unsupported scope type: {scope['type']}")
return
# Start uvicorn server with session manager lifecycle
config = uvicorn.Config(
app=mcp_app,
host=host,
port=port,
log_level="info"
)
server = uvicorn.Server(config)
# Run session manager and server together
async with session_manager.run():
self.logger.info("Session manager started, now starting HTTP server")
await server.serve()
except Exception as e:
self.logger.error(f"Streamable HTTP server startup failed: {e}")
import traceback
self.logger.error("Complete error stack:")
self.logger.error(traceback.format_exc())
# If it's ExceptionGroup, try to parse
if hasattr(e, 'exceptions'):
self.logger.error(f"ExceptionGroup contains {len(e.exceptions)} exceptions:")
for i, exc in enumerate(e.exceptions):
self.logger.error(f" Exception {i+1}: {type(exc).__name__}: {exc}")
raise
async def shutdown(self):
"""Shutdown server"""
self.logger.info("Shutting down Doris MCP Server")
try:
await self.connection_manager.close()
self.logger.info("Doris MCP Server has been shut down")
except Exception as e:
self.logger.error(f"Error occurred while shutting down server: {e}")
def create_arg_parser():
"""Create command line argument parser"""
parser = argparse.ArgumentParser(
description="Apache Doris MCP Server - Enterprise Database Service",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Transport Modes:
stdio - Standard input/output (for local process communication)
http - Streamable HTTP mode (MCP 2025-03-26 protocol)
Examples:
python -m doris_mcp_server --transport stdio
python -m doris_mcp_server --transport http --host 0.0.0.0 --port 3000
"""
)
parser.add_argument(
"--transport",
type=str,
choices=["stdio", "http"],
default="stdio",
help="Transport protocol type: stdio (local), http (Streamable HTTP)",
)
parser.add_argument(
"--host",
type=str,
default="localhost",
help="Host address for HTTP mode (default: localhost)",
)
parser.add_argument(
"--port", type=int, default=3000, help="Port number for HTTP mode (default: 3000)"
)
parser.add_argument(
"--db-host",
type=str,
default="localhost",
help="Doris database host address (default: localhost)",
)
parser.add_argument(
"--db-port", type=int, default=9030, help="Doris database port number (default: 9030)"
)
parser.add_argument(
"--db-user", type=str, default="root", help="Doris database username (default: root)"
)
parser.add_argument("--db-password", type=str, default="", help="Doris database password")
parser.add_argument(
"--db-database",
type=str,
default="information_schema",
help="Doris database name (default: information_schema)",
)
parser.add_argument(
"--log-level",
type=str,
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
default="INFO",
help="Log level (default: INFO)",
)
return parser
async def main():
"""Main function"""
parser = create_arg_parser()
args = parser.parse_args()
# Set log level
logging.getLogger().setLevel(getattr(logging, args.log_level))
# Create configuration - priority: command line arguments > .env file > default values
config = DorisConfig.from_env() # First load from .env file and environment variables
# Command line arguments override configuration (if provided)
if args.db_host != "localhost": # If not default value, use command line argument
config.database.host = args.db_host
if args.db_port != 9030:
config.database.port = args.db_port
if args.db_user != "root":
config.database.user = args.db_user
if args.db_password: # Use password if provided
config.database.password = args.db_password
if args.db_database != "information_schema":
config.database.database = args.db_database
if args.log_level != "INFO":
config.logging.level = args.log_level
# Create server instance
server = DorisServer(config)
try:
if args.transport == "stdio":
await server.start_stdio()
elif args.transport == "http":
await server.start_http(args.host, args.port)
else:
logger.error(f"Unsupported transport protocol: {args.transport}")
await server.shutdown()
return 1
except KeyboardInterrupt:
logger.info("Received interrupt signal, shutting down server...")
except Exception as e:
logger.error(f"Server runtime error: {e}")
# Clean up resources even in case of exception
try:
await server.shutdown()
except Exception as shutdown_error:
logger.error(f"Error occurred while shutting down server: {shutdown_error}")
return 1
finally:
# Cleanup in case of normal shutdown
try:
await server.shutdown()
except Exception as shutdown_error:
logger.error(f"Error occurred while shutting down server: {shutdown_error}")
return 0
def main_sync():
"""Synchronous main function for entry point"""
exit_code = asyncio.run(main())
exit(exit_code)
if __name__ == "__main__":
run_main_sync()
main_sync()