diff --git a/doris_mcp_server/main.py b/doris_mcp_server/main.py index 3bb62fd..b28cce7 100644 --- a/doris_mcp_server/main.py +++ b/doris_mcp_server/main.py @@ -432,9 +432,9 @@ class DorisServer: await self.security_manager.initialize() self.logger.info("Security manager initialization completed") - # Ensure connection manager is initialized - await self.connection_manager.initialize() - self.logger.info("Connection manager initialization completed") + # For stdio mode, we must establish a working database connection + # Use the dedicated stdio mode initialization method + await self.connection_manager.initialize_for_stdio_mode() # Start stdio server - using compatible import approach try: @@ -502,8 +502,12 @@ class DorisServer: await self.security_manager.initialize() self.logger.info("Security manager initialization completed") - # Ensure connection manager is initialized - await self.connection_manager.initialize() + # For HTTP mode, try to initialize global connection pool with graceful degradation + global_pool_created = await self.connection_manager.initialize_for_http_mode() + if global_pool_created: + self.logger.info("Global database connection pool available for HTTP mode") + else: + self.logger.info("HTTP mode running without global database pool, will use token-bound configurations") # Use Starlette and StreamableHTTPSessionManager according to official example import uvicorn diff --git a/doris_mcp_server/utils/db.py b/doris_mcp_server/utils/db.py index 8d5c58a..69c769b 100644 --- a/doris_mcp_server/utils/db.py +++ b/doris_mcp_server/utils/db.py @@ -626,6 +626,213 @@ class DorisConnectionManager: self.logger.error(f"Failed to initialize connection pool: {e}") raise + async def initialize_for_stdio_mode(self, timeout: float = 30.0) -> None: + """ + Initialize connection pool for stdio mode with strict validation + + stdio mode requires a working database connection because: + - No HTTP authentication mechanism to support token-bound configs + - All database operations depend on the global connection pool + + Args: + timeout: Maximum time to wait for connection establishment + + Raises: + RuntimeError: If configuration is invalid or connection fails + """ + try: + # Validate that we have valid global configuration + if not self._has_valid_global_config(): + error_msg = ( + "stdio mode requires valid global database configuration. " + "Please set DORIS_HOST and DORIS_USER in environment variables or .env file. " + f"Current config: host='{self.host}', user='{self.user}'" + ) + self.logger.error(error_msg) + raise RuntimeError(error_msg) + + self.logger.info(f"stdio mode database config validated: {self.host}:{self.port}") + + # Validate configuration format + is_valid, error_message = self.validate_database_configuration() + if not is_valid: + error_msg = f"Database configuration validation failed: {error_message}" + self.logger.error(error_msg) + raise RuntimeError(error_msg) + + # Test connectivity with timeout + self.logger.info("Testing database connectivity for stdio mode...") + if not await self._test_connectivity_with_timeout(timeout): + error_msg = ( + f"Failed to connect to Doris database within {timeout} seconds. " + f"Please check if Doris is running at {self.host}:{self.port} " + f"and verify network connectivity." + ) + self.logger.error(error_msg) + raise RuntimeError(error_msg) + + # Initialize the connection pool + await self._create_connection_pool() + + # Verify that we have a working connection pool + if not self.pool: + error_msg = "Database connection pool was not created successfully." + self.logger.error(error_msg) + raise RuntimeError(error_msg) + + # Start background monitoring tasks + self.pool_health_check_task = asyncio.create_task(self._pool_health_monitor()) + self.pool_cleanup_task = asyncio.create_task(self._pool_cleanup_monitor()) + + # Perform initial pool warmup + await self._warmup_pool() + + self.logger.info("Database connection established successfully for stdio mode") + + except Exception as e: + self.logger.error(f"stdio mode database initialization failed: {e}") + raise + + async def initialize_for_http_mode(self) -> bool: + """ + Initialize connection pool for HTTP mode with graceful degradation + + HTTP mode can work without global database configuration because: + - Supports token-bound database configurations + - Can handle authentication and use per-request database configs + - Has fallback mechanisms for database operations + + Returns: + bool: True if global database pool was created, False if gracefully degraded + """ + try: + # First validate configuration format if we have one + if self._has_valid_global_config(): + is_valid, error_message = self.validate_database_configuration() + if not is_valid: + self.logger.warning(f"Global database configuration invalid: {error_message}") + self.logger.info("HTTP mode will rely on token-bound database configurations") + return False + + # Try to establish global connection pool + self.logger.info(f"Attempting to create global connection pool: {self.host}:{self.port}") + + try: + # Test connectivity with shorter timeout for HTTP mode + if await self._test_connectivity_with_timeout(10.0): + await self._create_connection_pool() + + if self.pool: + # Start background monitoring tasks + self.pool_health_check_task = asyncio.create_task(self._pool_health_monitor()) + self.pool_cleanup_task = asyncio.create_task(self._pool_cleanup_monitor()) + + # Perform initial pool warmup + await self._warmup_pool() + + self.logger.info("Global database connection pool created successfully for HTTP mode") + return True + else: + self.logger.warning("Global database connection test failed, will use token-bound configs") + return False + + except Exception as pool_error: + self.logger.warning(f"Failed to create global connection pool: {pool_error}") + self.logger.info("HTTP mode will rely on token-bound database configurations") + return False + else: + self.logger.info("No valid global database config found, HTTP mode will use token-bound configurations") + return False + + except Exception as e: + self.logger.warning(f"HTTP mode database initialization encountered error: {e}") + self.logger.info("HTTP mode will rely on token-bound database configurations") + return False + + async def _test_connectivity_with_timeout(self, timeout: float) -> bool: + """ + Test database connectivity with timeout + + Args: + timeout: Maximum time to wait for connection test + + Returns: + bool: True if connection successful, False otherwise + """ + try: + await asyncio.wait_for(self._test_basic_connectivity(), timeout=timeout) + return True + except asyncio.TimeoutError: + self.logger.error(f"Database connectivity test timed out after {timeout} seconds") + return False + except Exception as e: + self.logger.error(f"Database connectivity test failed: {e}") + return False + + async def _test_basic_connectivity(self) -> None: + """ + Test basic database connectivity without connection pool + + Raises: + Exception: If connection fails + """ + import aiomysql + + conn = None + try: + conn = await aiomysql.connect( + host=self.host, + port=self.port, + user=self.user, + password=self.password, + db=self.database, + charset=self.charset, + connect_timeout=self.connect_timeout, + autocommit=True + ) + + async with conn.cursor() as cursor: + await cursor.execute("SELECT 1") + result = await cursor.fetchone() + if not result or result[0] != 1: + raise RuntimeError("Database connectivity test query failed") + + except Exception as e: + raise RuntimeError(f"Database connectivity test failed: {e}") + finally: + if conn: + conn.close() + + async def _create_connection_pool(self) -> None: + """ + Create the connection pool + + Raises: + Exception: If pool creation fails + """ + self.pool = await aiomysql.create_pool( + host=self.host, + port=self.port, + user=self.user, + password=self.password, + db=self.database, + charset=self.charset, + minsize=self.minsize, + maxsize=self.maxsize, + pool_recycle=self.pool_recycle, + connect_timeout=self.connect_timeout, + autocommit=True + ) + + # Test pool health + if not await self._test_pool_health(): + # Clean up the pool if health test fails + if self.pool: + self.pool.close() + await self.pool.wait_closed() + self.pool = None + raise RuntimeError("Connection pool health check failed") + async def _test_pool_health(self) -> bool: """Test connection pool health""" try: diff --git a/pyproject.toml b/pyproject.toml index a54a707..a9ae3da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ build-backend = "hatchling.build" [project] name = "doris-mcp-server" -version = "0.6.0" +version = "0.6.1" description = "Enterprise-grade Model Context Protocol (MCP) server implementation for Apache Doris" authors = [ {name = "Yijia Su", email = "freeoneplus@apache.org"} diff --git a/uv.lock b/uv.lock index e6d4c34..93b1e12 100644 --- a/uv.lock +++ b/uv.lock @@ -562,7 +562,7 @@ wheels = [ [[package]] name = "doris-mcp-server" -version = "0.5.1" +version = "0.6.1" source = { editable = "." } dependencies = [ { name = "adbc-driver-flightsql" },