* fix tocken auth * Further fixes to the token overwriting issue and restoration of hot reloading of tokens.json.
972 lines
42 KiB
Python
972 lines
42 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
"""
|
|
Apache Doris MCP Server - Enterprise Database Service Implementation
|
|
|
|
Based on Apache Doris official MCP Server architecture design, providing complete MCP protocol support
|
|
Supports independent encapsulation implementation of Resources, Tools, and Prompts
|
|
Supports both stdio and streamable HTTP startup modes
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from typing import Any
|
|
|
|
# MCP version compatibility handling
|
|
MCP_VERSION = 'unknown'
|
|
Server = None
|
|
InitializationOptions = None
|
|
Prompt = None
|
|
Resource = None
|
|
TextContent = None
|
|
Tool = None
|
|
|
|
def _import_mcp_with_compatibility():
|
|
"""Import MCP components with multi-version compatibility"""
|
|
global MCP_VERSION, Server, InitializationOptions, Prompt, Resource, TextContent, Tool
|
|
|
|
try:
|
|
# Strategy 1: Try direct server-only imports to avoid client-side issues
|
|
from mcp.server import Server as _Server
|
|
from mcp.server.models import InitializationOptions as _InitOptions
|
|
from mcp.types import (
|
|
Prompt as _Prompt,
|
|
Resource as _Resource,
|
|
TextContent as _TextContent,
|
|
Tool as _Tool,
|
|
)
|
|
|
|
# Assign to globals
|
|
Server = _Server
|
|
InitializationOptions = _InitOptions
|
|
Prompt = _Prompt
|
|
Resource = _Resource
|
|
TextContent = _TextContent
|
|
Tool = _Tool
|
|
|
|
# Try to get version safely
|
|
try:
|
|
import mcp
|
|
MCP_VERSION = getattr(mcp, '__version__', None)
|
|
if not MCP_VERSION:
|
|
# Fallback: try to get version from package metadata
|
|
try:
|
|
import importlib.metadata
|
|
MCP_VERSION = importlib.metadata.version('mcp')
|
|
except Exception:
|
|
# Second fallback: try pkg_resources
|
|
try:
|
|
import pkg_resources
|
|
MCP_VERSION = pkg_resources.get_distribution('mcp').version
|
|
except Exception:
|
|
MCP_VERSION = 'detected-but-version-unknown'
|
|
except Exception:
|
|
# Version detection failed, but imports worked
|
|
try:
|
|
import importlib.metadata
|
|
MCP_VERSION = importlib.metadata.version('mcp')
|
|
except Exception:
|
|
try:
|
|
import pkg_resources
|
|
MCP_VERSION = pkg_resources.get_distribution('mcp').version
|
|
except Exception:
|
|
MCP_VERSION = 'imported-successfully'
|
|
|
|
logger = logging.getLogger(__name__)
|
|
logger.info(f"MCP components imported successfully, version: {MCP_VERSION}")
|
|
return True
|
|
|
|
except Exception as import_error:
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Strategy 2: Handle RequestContext compatibility issues in 1.9.x versions
|
|
error_str = str(import_error).lower()
|
|
if 'requestcontext' in error_str and 'too few arguments' in error_str:
|
|
logger.warning(f"Detected MCP RequestContext compatibility issue: {import_error}")
|
|
logger.info("Attempting comprehensive workaround for MCP 1.9.x RequestContext issue...")
|
|
|
|
try:
|
|
# Comprehensive monkey patch approach
|
|
import sys
|
|
import types
|
|
|
|
# Create and install mock modules before any MCP imports
|
|
if 'mcp.shared.context' not in sys.modules:
|
|
mock_context_module = types.ModuleType('mcp.shared.context')
|
|
|
|
class FlexibleRequestContext:
|
|
"""Flexible RequestContext that accepts variable arguments"""
|
|
def __init__(self, *args, **kwargs):
|
|
self.args = args
|
|
self.kwargs = kwargs
|
|
|
|
def __class_getitem__(cls, params):
|
|
# Accept any number of parameters and return cls
|
|
return cls
|
|
|
|
# Add other methods that might be called
|
|
def __getattr__(self, name):
|
|
return lambda *args, **kwargs: None
|
|
|
|
mock_context_module.RequestContext = FlexibleRequestContext
|
|
sys.modules['mcp.shared.context'] = mock_context_module
|
|
|
|
# Also patch the typing system to be more permissive
|
|
original_check_generic = None
|
|
try:
|
|
import typing
|
|
if hasattr(typing, '_check_generic'):
|
|
original_check_generic = typing._check_generic
|
|
def permissive_check_generic(cls, params, elen):
|
|
# Don't enforce strict parameter count checking
|
|
return
|
|
typing._check_generic = permissive_check_generic
|
|
except Exception:
|
|
pass
|
|
|
|
# Clear any cached imports that might have failed
|
|
modules_to_clear = [k for k in sys.modules.keys() if k.startswith('mcp.')]
|
|
for module in modules_to_clear:
|
|
if module in sys.modules:
|
|
del sys.modules[module]
|
|
|
|
# Now try importing again with the patches in place
|
|
from mcp.server import Server as _Server
|
|
from mcp.server.models import InitializationOptions as _InitOptions
|
|
from mcp.types import (
|
|
Prompt as _Prompt,
|
|
Resource as _Resource,
|
|
TextContent as _TextContent,
|
|
Tool as _Tool,
|
|
)
|
|
|
|
# Assign to globals
|
|
Server = _Server
|
|
InitializationOptions = _InitOptions
|
|
Prompt = _Prompt
|
|
Resource = _Resource
|
|
TextContent = _TextContent
|
|
Tool = _Tool
|
|
|
|
# Try to detect actual version even in compatibility mode
|
|
try:
|
|
import importlib.metadata
|
|
actual_version = importlib.metadata.version('mcp')
|
|
MCP_VERSION = f'compatibility-mode-{actual_version}'
|
|
except Exception:
|
|
try:
|
|
import pkg_resources
|
|
actual_version = pkg_resources.get_distribution('mcp').version
|
|
MCP_VERSION = f'compatibility-mode-{actual_version}'
|
|
except Exception:
|
|
MCP_VERSION = 'compatibility-mode-1.9.x'
|
|
|
|
logger.info("MCP 1.9.x compatibility workaround successful!")
|
|
|
|
# Restore original typing function if we patched it
|
|
if original_check_generic:
|
|
typing._check_generic = original_check_generic
|
|
|
|
return True
|
|
|
|
except Exception as workaround_error:
|
|
logger.error(f"MCP compatibility workaround failed: {workaround_error}")
|
|
|
|
# Restore original typing function if we patched it
|
|
if original_check_generic:
|
|
try:
|
|
import typing
|
|
typing._check_generic = original_check_generic
|
|
except Exception:
|
|
pass
|
|
|
|
logger.error(f"Failed to import MCP components: {import_error}")
|
|
return False
|
|
|
|
# Perform MCP import with compatibility handling
|
|
if not _import_mcp_with_compatibility():
|
|
raise ImportError(
|
|
"Failed to import MCP components. Please ensure MCP is properly installed. "
|
|
"Supported versions: 1.8.x, 1.9.x"
|
|
)
|
|
|
|
from .tools.tools_manager import DorisToolsManager
|
|
from .tools.prompts_manager import DorisPromptsManager
|
|
from .tools.resources_manager import DorisResourcesManager
|
|
from .utils.config import DorisConfig
|
|
from .utils.db import DorisConnectionManager
|
|
from .utils.security import DorisSecurityManager
|
|
import os
|
|
|
|
# Configure logging - will be properly initialized later
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Create a default config instance for getting default values
|
|
_default_config = DorisConfig()
|
|
|
|
|
|
|
|
|
|
class DorisServer:
|
|
"""Apache Doris MCP Server main class"""
|
|
|
|
def __init__(self, config: DorisConfig):
|
|
self.config = config
|
|
self.server = Server("doris-mcp-server")
|
|
|
|
# Initialize security manager (without connection_manager initially)
|
|
self.security_manager = DorisSecurityManager(config)
|
|
|
|
# Initialize connection manager, pass in security manager and token manager for token-bound DB config
|
|
token_manager = self.security_manager.auth_provider.token_manager if hasattr(self.security_manager, 'auth_provider') and hasattr(self.security_manager.auth_provider, 'token_manager') else None
|
|
self.connection_manager = DorisConnectionManager(config, self.security_manager, token_manager)
|
|
|
|
# Set connection manager reference in security manager for database validation
|
|
self.security_manager.connection_manager = self.connection_manager
|
|
|
|
# Initialize independent managers
|
|
self.resources_manager = DorisResourcesManager(self.connection_manager)
|
|
self.tools_manager = DorisToolsManager(self.connection_manager)
|
|
self.prompts_manager = DorisPromptsManager(self.connection_manager)
|
|
|
|
# Import here to avoid circular imports
|
|
from .utils.logger import get_logger
|
|
self.logger = get_logger(f"{__name__}.DorisServer")
|
|
self._setup_handlers()
|
|
|
|
async def _extract_auth_info_from_scope(self, scope, headers):
|
|
"""Extract authentication information from ASGI scope and headers"""
|
|
auth_info = {}
|
|
|
|
# Extract client IP
|
|
client = scope.get("client")
|
|
if client:
|
|
auth_info["client_ip"] = client[0]
|
|
else:
|
|
auth_info["client_ip"] = "unknown"
|
|
|
|
# Extract token from Authorization header
|
|
authorization = headers.get(b'authorization', b'').decode('utf-8')
|
|
if authorization:
|
|
if authorization.startswith('Bearer '):
|
|
auth_info["token"] = authorization[7:]
|
|
auth_info["authorization"] = authorization
|
|
elif authorization.startswith('Token '):
|
|
auth_info["token"] = authorization[6:]
|
|
auth_info["authorization"] = authorization
|
|
|
|
# Extract token from query parameters (for compatibility)
|
|
query_string = scope.get("query_string", b"").decode('utf-8')
|
|
if query_string and "token=" in query_string:
|
|
import urllib.parse
|
|
query_params = urllib.parse.parse_qs(query_string)
|
|
if "token" in query_params:
|
|
auth_info["token"] = query_params["token"][0]
|
|
|
|
# If no token found, this will be handled by the authentication system
|
|
# (either return anonymous context if auth disabled, or raise error if auth enabled)
|
|
|
|
return auth_info
|
|
|
|
def _get_mcp_capabilities(self):
|
|
"""Get MCP capabilities with version compatibility"""
|
|
try:
|
|
# For MCP 1.9.x and newer
|
|
from mcp.server.lowlevel.server import NotificationOptions
|
|
|
|
return self.server.get_capabilities(
|
|
notification_options=NotificationOptions(
|
|
prompts_changed=True,
|
|
resources_changed=True,
|
|
tools_changed=True
|
|
),
|
|
experimental_capabilities={}
|
|
)
|
|
except TypeError:
|
|
try:
|
|
# For MCP 1.8.x
|
|
from mcp.server.lowlevel.server import NotificationOptions
|
|
|
|
return self.server.get_capabilities(
|
|
notification_options=NotificationOptions(
|
|
prompts_changed=True,
|
|
resources_changed=True,
|
|
tools_changed=True
|
|
),
|
|
experimental_capabilities={}
|
|
)
|
|
except Exception as e:
|
|
self.logger.warning(f"Could not get capabilities with NotificationOptions: {e}")
|
|
# Fallback for older versions
|
|
try:
|
|
return self.server.get_capabilities()
|
|
except Exception as fallback_e:
|
|
self.logger.error(f"Failed to get capabilities: {fallback_e}")
|
|
# Return minimal capabilities
|
|
return {
|
|
"resources": {},
|
|
"tools": {},
|
|
"prompts": {}
|
|
}
|
|
|
|
def _setup_handlers(self):
|
|
"""Setup MCP protocol handlers"""
|
|
|
|
@self.server.list_resources()
|
|
async def handle_list_resources() -> list[Resource]:
|
|
"""Handle resource list request"""
|
|
try:
|
|
self.logger.info("Handling resource list request")
|
|
resources = await self.resources_manager.list_resources()
|
|
self.logger.info(f"Returning {len(resources)} resources")
|
|
return resources
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to handle resource list request: {e}")
|
|
return []
|
|
|
|
@self.server.read_resource()
|
|
async def handle_read_resource(uri: str) -> str:
|
|
"""Handle resource read request"""
|
|
try:
|
|
self.logger.info(f"Handling resource read request: {uri}")
|
|
content = await self.resources_manager.read_resource(uri)
|
|
return content
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to handle resource read request: {e}")
|
|
return json.dumps(
|
|
{"error": f"Failed to read resource: {str(e)}", "uri": uri},
|
|
ensure_ascii=False,
|
|
indent=2,
|
|
)
|
|
|
|
@self.server.list_tools()
|
|
async def handle_list_tools() -> list[Tool]:
|
|
"""Handle tool list request"""
|
|
try:
|
|
self.logger.info("Handling tool list request")
|
|
tools = await self.tools_manager.list_tools()
|
|
self.logger.info(f"Returning {len(tools)} tools")
|
|
return tools
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to handle tool list request: {e}")
|
|
return []
|
|
|
|
@self.server.call_tool()
|
|
async def handle_call_tool(
|
|
name: str, arguments: dict[str, Any]
|
|
) -> list[TextContent]:
|
|
"""Handle tool call request"""
|
|
try:
|
|
self.logger.info(f"Handling tool call request: {name}")
|
|
result = await self.tools_manager.call_tool(name, arguments)
|
|
|
|
return [TextContent(type="text", text=result)]
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to handle tool call request: {e}")
|
|
error_result = json.dumps(
|
|
{
|
|
"error": f"Tool call failed: {str(e)}",
|
|
"tool_name": name,
|
|
"arguments": arguments,
|
|
},
|
|
ensure_ascii=False,
|
|
indent=2,
|
|
)
|
|
|
|
return [TextContent(type="text", text=error_result)]
|
|
|
|
@self.server.list_prompts()
|
|
async def handle_list_prompts() -> list[Prompt]:
|
|
"""Handle prompt list request"""
|
|
try:
|
|
self.logger.info("Handling prompt list request")
|
|
prompts = await self.prompts_manager.list_prompts()
|
|
self.logger.info(f"Returning {len(prompts)} prompts")
|
|
return prompts
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to handle prompt list request: {e}")
|
|
return []
|
|
|
|
@self.server.get_prompt()
|
|
async def handle_get_prompt(name: str, arguments: dict[str, Any]) -> str:
|
|
"""Handle prompt get request"""
|
|
try:
|
|
self.logger.info(f"Handling prompt get request: {name}")
|
|
result = await self.prompts_manager.get_prompt(name, arguments)
|
|
return result
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to handle prompt get request: {e}")
|
|
error_result = json.dumps(
|
|
{
|
|
"error": f"Failed to get prompt: {str(e)}",
|
|
"prompt_name": name,
|
|
"arguments": arguments,
|
|
},
|
|
ensure_ascii=False,
|
|
indent=2,
|
|
)
|
|
return error_result
|
|
|
|
async def start_stdio(self):
|
|
"""Start stdio transport mode"""
|
|
self.logger.info("Starting Doris MCP Server (stdio mode)")
|
|
|
|
try:
|
|
# Initialize security manager first (includes JWT setup if enabled)
|
|
await self.security_manager.initialize()
|
|
self.logger.info("Security manager initialization completed")
|
|
|
|
# For stdio mode, we must establish a working database connection
|
|
# Use the dedicated stdio mode initialization method
|
|
await self.connection_manager.initialize_for_stdio_mode()
|
|
|
|
# Start stdio server - using compatible import approach
|
|
try:
|
|
from mcp.server.stdio import stdio_server
|
|
except ImportError:
|
|
# Fallback for different MCP versions
|
|
try:
|
|
from mcp.server import stdio_server
|
|
except ImportError as stdio_import_error:
|
|
self.logger.error(f"Failed to import stdio_server: {stdio_import_error}")
|
|
raise RuntimeError("stdio_server module not available in this MCP version")
|
|
|
|
self.logger.info("Creating stdio_server transport...")
|
|
|
|
# Try different startup approaches
|
|
try:
|
|
async with stdio_server() as streams:
|
|
read_stream, write_stream = streams
|
|
self.logger.info("stdio_server streams created successfully")
|
|
|
|
# Create initialization options with version compatibility
|
|
capabilities = self._get_mcp_capabilities()
|
|
|
|
init_options = InitializationOptions(
|
|
server_name="doris-mcp-server",
|
|
server_version=os.getenv("SERVER_VERSION", _default_config.server_version),
|
|
capabilities=capabilities,
|
|
)
|
|
self.logger.info("Initialization options created successfully")
|
|
|
|
# Run server
|
|
self.logger.info("Starting to run MCP server...")
|
|
await self.server.run(read_stream, write_stream, init_options)
|
|
|
|
except Exception as inner_e:
|
|
self.logger.error(f"stdio_server internal error: {inner_e}")
|
|
self.logger.error(f"Error type: {type(inner_e)}")
|
|
|
|
# Try to get more error information
|
|
import traceback
|
|
self.logger.error("Complete error stack:")
|
|
self.logger.error(traceback.format_exc())
|
|
|
|
# If it's ExceptionGroup, try to parse
|
|
if hasattr(inner_e, 'exceptions'):
|
|
self.logger.error(f"ExceptionGroup contains {len(inner_e.exceptions)} exceptions:")
|
|
for i, exc in enumerate(inner_e.exceptions):
|
|
self.logger.error(f" Exception {i+1}: {type(exc).__name__}: {exc}")
|
|
|
|
raise inner_e
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"stdio server startup failed: {e}")
|
|
self.logger.error(f"Error type: {type(e)}")
|
|
raise
|
|
|
|
|
|
|
|
async def start_http(self, host: str = os.getenv("SERVER_HOST", _default_config.database.host), port: int = os.getenv("SERVER_PORT", _default_config.server_port), workers: int = 1):
|
|
"""Start Streamable HTTP transport mode with workers support"""
|
|
self.logger.info(f"Starting Doris MCP Server (Streamable HTTP mode) - {host}:{port}, workers: {workers}")
|
|
|
|
try:
|
|
# Initialize security manager first (includes JWT setup if enabled)
|
|
await self.security_manager.initialize()
|
|
self.logger.info("Security manager initialization completed")
|
|
|
|
# For HTTP mode, try to initialize global connection pool with graceful degradation
|
|
global_pool_created = await self.connection_manager.initialize_for_http_mode()
|
|
if global_pool_created:
|
|
self.logger.info("Global database connection pool available for HTTP mode")
|
|
else:
|
|
self.logger.info("HTTP mode running without global database pool, will use token-bound configurations")
|
|
|
|
# Use Starlette and StreamableHTTPSessionManager according to official example
|
|
import uvicorn
|
|
import contextlib
|
|
from collections.abc import AsyncIterator
|
|
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
|
from starlette.applications import Starlette
|
|
from starlette.routing import Route
|
|
from starlette.responses import JSONResponse, Response
|
|
from starlette.types import Scope
|
|
|
|
# Create session manager
|
|
session_manager = StreamableHTTPSessionManager(
|
|
app=self.server,
|
|
json_response=True, # Enable JSON response
|
|
stateless=False # Maintain session state
|
|
)
|
|
|
|
self.logger.info(f"StreamableHTTP session manager created, will start at http://{host}:{port}")
|
|
|
|
# Health check endpoint
|
|
async def health_check(request):
|
|
return JSONResponse({"status": "healthy", "service": "doris-mcp-server"})
|
|
|
|
# OAuth endpoints
|
|
from .auth.oauth_handlers import OAuthHandlers
|
|
oauth_handlers = OAuthHandlers(self.security_manager)
|
|
|
|
async def oauth_login(request):
|
|
return await oauth_handlers.handle_login(request)
|
|
|
|
async def oauth_callback(request):
|
|
return await oauth_handlers.handle_callback(request)
|
|
|
|
async def oauth_provider_info(request):
|
|
return await oauth_handlers.handle_provider_info(request)
|
|
|
|
async def oauth_demo(request):
|
|
return await oauth_handlers.handle_demo_page(request)
|
|
|
|
# Token management endpoints
|
|
from .auth.token_handlers import TokenHandlers
|
|
token_handlers = TokenHandlers(self.security_manager, self.config)
|
|
|
|
async def token_create(request):
|
|
return await token_handlers.handle_create_token(request)
|
|
|
|
async def token_revoke(request):
|
|
return await token_handlers.handle_revoke_token(request)
|
|
|
|
async def token_list(request):
|
|
return await token_handlers.handle_list_tokens(request)
|
|
|
|
async def token_stats(request):
|
|
return await token_handlers.handle_token_stats(request)
|
|
|
|
async def token_cleanup(request):
|
|
return await token_handlers.handle_cleanup_tokens(request)
|
|
|
|
async def token_management(request):
|
|
return await token_handlers.handle_management_page(request)
|
|
|
|
# Lifecycle manager - simplified since we manage session_manager externally
|
|
@contextlib.asynccontextmanager
|
|
async def lifespan(app: Starlette) -> AsyncIterator[None]:
|
|
"""Context manager for managing application lifecycle"""
|
|
self.logger.info("Application started!")
|
|
try:
|
|
yield
|
|
finally:
|
|
self.logger.info("Application is shutting down...")
|
|
|
|
# Create ASGI application - use direct session manager as ASGI app
|
|
starlette_app = Starlette(
|
|
debug=True,
|
|
routes=[
|
|
Route("/health", health_check, methods=["GET"]),
|
|
# OAuth endpoints
|
|
Route("/auth/login", oauth_login, methods=["GET"]),
|
|
Route("/auth/callback", oauth_callback, methods=["GET"]),
|
|
Route("/auth/provider", oauth_provider_info, methods=["GET"]),
|
|
Route("/auth/demo", oauth_demo, methods=["GET"]),
|
|
# Token management endpoints
|
|
Route("/token/create", token_create, methods=["GET", "POST"]),
|
|
Route("/token/revoke", token_revoke, methods=["GET", "DELETE"]),
|
|
Route("/token/list", token_list, methods=["GET"]),
|
|
Route("/token/stats", token_stats, methods=["GET"]),
|
|
Route("/token/cleanup", token_cleanup, methods=["GET", "POST"]),
|
|
Route("/token/management", token_management, methods=["GET"]),
|
|
],
|
|
lifespan=lifespan,
|
|
)
|
|
|
|
# Custom ASGI app that handles both /mcp and /mcp/ without redirects
|
|
async def mcp_app(scope, receive, send):
|
|
# Handle lifespan events
|
|
if scope["type"] == "lifespan":
|
|
await starlette_app(scope, receive, send)
|
|
return
|
|
|
|
# Handle HTTP requests
|
|
if scope["type"] == "http":
|
|
path = scope.get("path", "")
|
|
self.logger.info(f"Received request for path: {path}")
|
|
|
|
try:
|
|
# Handle health check, auth, and token management endpoints
|
|
if (path.startswith("/health") or
|
|
path.startswith("/auth/") or
|
|
path.startswith("/token/")):
|
|
await starlette_app(scope, receive, send)
|
|
return
|
|
|
|
# Handle MCP requests - both /mcp and /mcp/ go to session manager
|
|
if path == "/mcp" or path.startswith("/mcp/"):
|
|
self.logger.info(f"Handling MCP request for path: {path}")
|
|
# Log request details for debugging
|
|
method = scope.get("method", "UNKNOWN")
|
|
headers = dict(scope.get("headers", []))
|
|
self.logger.info(f"MCP Request - Method: {method}")
|
|
self.logger.info(f"MCP Request - Headers: {headers}")
|
|
|
|
# Authentication check for MCP requests
|
|
try:
|
|
# Extract authentication information
|
|
auth_info = await self._extract_auth_info_from_scope(scope, headers)
|
|
|
|
# Authenticate the request
|
|
auth_context = await self.security_manager.authenticate_request(auth_info)
|
|
self.logger.info(f"MCP request authenticated: token_id={auth_context.token_id}, client_ip={auth_context.client_ip}")
|
|
|
|
# Store auth context in scope for potential use by tools/resources
|
|
scope["auth_context"] = auth_context
|
|
|
|
# FIX for Issue #62 Bug 1: Set auth_context in context variable
|
|
# This allows tools to access token information for token-bound database configuration
|
|
# CRITICAL: Use the global ContextVar from security.py to ensure same instance is used everywhere
|
|
try:
|
|
from .utils.security import mcp_auth_context_var
|
|
mcp_auth_context_var.set(auth_context)
|
|
self.logger.debug(f"Set auth_context in context variable with token: {bool(hasattr(auth_context, 'token') and auth_context.token)}")
|
|
except Exception as ctx_error:
|
|
self.logger.warning(f"Failed to set auth_context in context variable: {ctx_error}")
|
|
|
|
except Exception as auth_error:
|
|
self.logger.error(f"MCP authentication failed: {auth_error}")
|
|
# Return 401 Unauthorized
|
|
from starlette.responses import JSONResponse
|
|
response = JSONResponse(
|
|
{"error": "Authentication required", "message": str(auth_error)},
|
|
status_code=401
|
|
)
|
|
await response(scope, receive, send)
|
|
return
|
|
|
|
# Handle Dify compatibility for GET requests
|
|
if method == "GET":
|
|
accept_header = headers.get(b'accept', b'').decode('utf-8')
|
|
user_agent = headers.get(b'user-agent', b'').decode('utf-8')
|
|
|
|
|
|
|
|
# For other GET requests, try to add application/json to Accept header
|
|
if 'text/event-stream' in accept_header and 'application/json' not in accept_header:
|
|
self.logger.info("Adding application/json to Accept header for GET request")
|
|
# Modify headers to include both content types
|
|
new_headers = []
|
|
for name, value in scope.get("headers", []):
|
|
if name == b'accept':
|
|
# Add application/json to the accept header
|
|
new_value = value.decode('utf-8') + ', application/json'
|
|
new_headers.append((name, new_value.encode('utf-8')))
|
|
else:
|
|
new_headers.append((name, value))
|
|
# Update scope with modified headers
|
|
scope = dict(scope)
|
|
scope["headers"] = new_headers
|
|
self.logger.info(f"Modified Accept header to: {new_value}")
|
|
|
|
await session_manager.handle_request(scope, receive, send)
|
|
return
|
|
|
|
# 404 for other paths
|
|
self.logger.info(f"Path not found: {path}")
|
|
response = Response("Not Found", status_code=404)
|
|
await response(scope, receive, send)
|
|
except Exception as e:
|
|
self.logger.error(f"Error handling request for {path}: {e}")
|
|
import traceback
|
|
self.logger.error(traceback.format_exc())
|
|
response = Response("Internal Server Error", status_code=500)
|
|
await response(scope, receive, send)
|
|
else:
|
|
# For other scope types, just return
|
|
self.logger.warning(f"Unsupported scope type: {scope['type']}")
|
|
return
|
|
|
|
# Choose startup method based on worker count
|
|
if workers > 1:
|
|
self.logger.info(f"Using multi-process mode with {workers} workers")
|
|
self.logger.info("Note: Multi-worker mode provides full MCP functionality with independent worker processes")
|
|
|
|
# Use the dedicated multiworker app module with full MCP support
|
|
uvicorn.run(
|
|
"doris_mcp_server.multiworker_app:app",
|
|
host=host,
|
|
port=port,
|
|
workers=workers,
|
|
log_level="info"
|
|
)
|
|
|
|
else:
|
|
self.logger.info("Using single-process mode")
|
|
# Single worker mode, use original logic with session manager lifecycle
|
|
config = uvicorn.Config(
|
|
app=mcp_app,
|
|
host=host,
|
|
port=port,
|
|
log_level="info"
|
|
)
|
|
server = uvicorn.Server(config)
|
|
|
|
# Run session manager and server together
|
|
async with session_manager.run():
|
|
self.logger.info("Session manager started, now starting HTTP server")
|
|
await server.serve()
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Streamable HTTP server startup failed: {e}")
|
|
import traceback
|
|
self.logger.error("Complete error stack:")
|
|
self.logger.error(traceback.format_exc())
|
|
|
|
# If it's ExceptionGroup, try to parse
|
|
if hasattr(e, 'exceptions'):
|
|
self.logger.error(f"ExceptionGroup contains {len(e.exceptions)} exceptions:")
|
|
for i, exc in enumerate(e.exceptions):
|
|
self.logger.error(f" Exception {i+1}: {type(exc).__name__}: {exc}")
|
|
raise
|
|
|
|
|
|
|
|
async def shutdown(self):
|
|
"""Shutdown server"""
|
|
self.logger.info("Shutting down Doris MCP Server")
|
|
try:
|
|
# Shutdown security manager first (includes JWT cleanup)
|
|
await self.security_manager.shutdown()
|
|
self.logger.info("Security manager shutdown completed")
|
|
|
|
await self.connection_manager.close()
|
|
self.logger.info("Doris MCP Server has been shut down")
|
|
except Exception as e:
|
|
self.logger.error(f"Error occurred while shutting down server: {e}")
|
|
|
|
|
|
def create_arg_parser():
|
|
"""Create command line argument parser"""
|
|
parser = argparse.ArgumentParser(
|
|
description="Apache Doris MCP Server - Enterprise Database Service",
|
|
formatter_class=argparse.RawDescriptionHelpFormatter,
|
|
epilog="""
|
|
Transport Modes:
|
|
stdio - Standard input/output (for local process communication)
|
|
http - Streamable HTTP mode (MCP 2025-03-26 protocol)
|
|
|
|
Examples:
|
|
python -m doris_mcp_server --transport stdio
|
|
python -m doris_mcp_server --transport http --host 0.0.0.0 --port 3000
|
|
python -m doris_mcp_server --transport stdio --doris-host localhost --doris-port 9030
|
|
python -m doris_mcp_server --transport http --doris-user admin --doris-database test_db
|
|
|
|
# Backward compatibility: --db-* parameters are also supported
|
|
python -m doris_mcp_server --transport stdio --db-host localhost --db-port 9030
|
|
"""
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--transport",
|
|
type=str,
|
|
choices=["stdio", "http"],
|
|
default=os.getenv("TRANSPORT", _default_config.transport),
|
|
help=f"Transport protocol type: stdio (local), http (Streamable HTTP) (default: {_default_config.transport})",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--host",
|
|
type=str,
|
|
default=os.getenv("SERVER_HOST", _default_config.server_host),
|
|
help=f"Host address for HTTP mode (default: {_default_config.server_host})",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--port",
|
|
type=int,
|
|
default=3000,
|
|
help="Port number for HTTP mode (default: 3000)"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--workers",
|
|
type=int,
|
|
default=1,
|
|
help="Number of worker processes for HTTP mode (default: 1, use 0 for auto-detect CPU cores)"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--doris-host", "--db-host",
|
|
type=str,
|
|
default=os.getenv("DORIS_HOST", _default_config.database.host),
|
|
help=f"Doris database host address (default: {_default_config.database.host})",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--doris-port", "--db-port", type=int, default=9030, help="Doris database port number (default: 9030)"
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--doris-user", "--db-user", type=str, default=os.getenv("DORIS_USER", _default_config.database.user), help=f"Doris database username (default: {_default_config.database.user})"
|
|
)
|
|
|
|
parser.add_argument("--doris-password", "--db-password", type=str, default=os.getenv("DORIS_PASSWORD", ""), help="Doris database password")
|
|
|
|
parser.add_argument(
|
|
"--doris-database", "--db-database",
|
|
type=str,
|
|
default=os.getenv("DORIS_DATABASE", _default_config.database.database),
|
|
help=f"Doris database name (default: {_default_config.database.database})",
|
|
)
|
|
|
|
parser.add_argument(
|
|
"--log-level",
|
|
type=str,
|
|
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
|
default=os.getenv("LOG_LEVEL", _default_config.logging.level),
|
|
help=f"Log level (default: {_default_config.logging.level})",
|
|
)
|
|
|
|
return parser
|
|
|
|
|
|
def update_configuration(config: DorisConfig):
|
|
"""Update doris configuration object"""
|
|
# For some arguments, if not specified, environment variables or default configurations will be used as default values
|
|
parser = create_arg_parser()
|
|
args = parser.parse_args()
|
|
|
|
# Update config values
|
|
# Command line arguments override configuration (if provided)
|
|
# basic
|
|
if args.transport != _default_config.transport:
|
|
config.transport = args.transport
|
|
if args.host != _default_config.server_host:
|
|
config.server_host = args.host
|
|
if args.port != _default_config.server_port:
|
|
config.server_port = args.port
|
|
server_name = os.getenv("SERVER_NAME")
|
|
if server_name:
|
|
config.server_name = server_name
|
|
server_version = os.getenv("SERVER_VERSION")
|
|
if server_version:
|
|
config.server_version = server_version
|
|
|
|
# database
|
|
if args.doris_host != _default_config.database.host: # If not default value, use command line argument
|
|
config.database.host = args.doris_host
|
|
if args.doris_port != _default_config.database.port:
|
|
config.database.port = args.doris_port
|
|
if args.doris_user != _default_config.database.user:
|
|
config.database.user = args.doris_user
|
|
if args.doris_password: # Use password if provided
|
|
config.database.password = args.doris_password
|
|
if args.doris_database != _default_config.database.database:
|
|
config.database.database = args.doris_database
|
|
|
|
# logging
|
|
if args.log_level != _default_config.logging.level:
|
|
config.logging.level = args.log_level
|
|
|
|
# workers (add to config for HTTP mode)
|
|
if hasattr(args, 'workers'):
|
|
config.workers = args.workers
|
|
|
|
|
|
async def main():
|
|
"""Main function"""
|
|
# Create configuration - priority: command line arguments > env variables > .env file > default values
|
|
# First load from .env file and environment variables
|
|
config = DorisConfig.from_env()
|
|
|
|
# Then parse the command line arguments, and update the config object.
|
|
update_configuration(config)
|
|
|
|
# Initialize enhanced logging system
|
|
from .utils.config import ConfigManager
|
|
config_manager = ConfigManager(config)
|
|
config_manager.setup_logging()
|
|
|
|
# Get logger with proper configuration
|
|
from .utils.logger import get_logger, log_system_info
|
|
logger = get_logger(__name__)
|
|
|
|
# Log system information for debugging
|
|
log_system_info()
|
|
|
|
logger.info("Starting Doris MCP Server...")
|
|
logger.info(f"Transport: {config.transport}")
|
|
logger.info(f"Log Level: {config.logging.level}")
|
|
|
|
# Create server instance
|
|
server = DorisServer(config)
|
|
|
|
try:
|
|
if config.transport == "stdio":
|
|
await server.start_stdio()
|
|
elif config.transport == "http":
|
|
# Get workers configuration with auto-detection support
|
|
workers = getattr(config, 'workers', 1)
|
|
if workers == 0:
|
|
import multiprocessing
|
|
workers = multiprocessing.cpu_count()
|
|
logger.info(f"Auto-detected {workers} CPU cores for worker processes")
|
|
|
|
await server.start_http(config.server_host, config.server_port, workers)
|
|
else:
|
|
logger.error(f"Unsupported transport protocol: {config.transport}")
|
|
await server.shutdown()
|
|
return 1
|
|
|
|
except KeyboardInterrupt:
|
|
logger.info("Received interrupt signal, shutting down server...")
|
|
except Exception as e:
|
|
logger.error(f"Server runtime error: {e}")
|
|
# Clean up resources even in case of exception
|
|
try:
|
|
await server.shutdown()
|
|
except Exception as shutdown_error:
|
|
logger.error(f"Error occurred while shutting down server: {shutdown_error}")
|
|
return 1
|
|
finally:
|
|
# Cleanup in case of normal shutdown
|
|
try:
|
|
await server.shutdown()
|
|
except Exception as shutdown_error:
|
|
logger.error(f"Error occurred while shutting down server: {shutdown_error}")
|
|
|
|
# Shutdown logging system
|
|
from .utils.logger import shutdown_logging
|
|
shutdown_logging()
|
|
|
|
return 0
|
|
|
|
|
|
def main_sync():
|
|
"""Synchronous main function for entry point"""
|
|
exit_code = asyncio.run(main())
|
|
exit(exit_code)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main_sync()
|