From 5d46d153e13d5d9c952ffa454a22449f3f781a24 Mon Sep 17 00:00:00 2001 From: FreeOnePlus Date: Wed, 11 Jun 2025 11:52:15 +0800 Subject: [PATCH] 1. Fix DB Connection BUG 2. Modify the global default configuration items and obtain them from Config --- doris_mcp_server/main.py | 48 +++++++++++++++-------------- doris_mcp_server/utils/config.py | 13 ++++---- doris_mcp_server/utils/db.py | 52 ++++++++++++++++++++++++++++++-- 3 files changed, 83 insertions(+), 30 deletions(-) diff --git a/doris_mcp_server/main.py b/doris_mcp_server/main.py index b6c5536..a26e223 100644 --- a/doris_mcp_server/main.py +++ b/doris_mcp_server/main.py @@ -44,11 +44,15 @@ from .tools.resources_manager import DorisResourcesManager from .utils.config import DorisConfig from .utils.db import DorisConnectionManager from .utils.security import DorisSecurityManager +import os # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) +# Create a default config instance for getting default values +_default_config = DorisConfig() + class DorisServer: """Apache Doris MCP Server main class""" @@ -204,7 +208,7 @@ class DorisServer: init_options = InitializationOptions( server_name="doris-mcp-server", - server_version="1.0.0", + server_version=os.getenv("SERVER_VERSION", _default_config.server_version), capabilities=capabilities, ) self.logger.info("Initialization options created successfully") @@ -237,7 +241,7 @@ class DorisServer: - async def start_http(self, host: str = "localhost", port: int = 3000): + async def start_http(self, host: str = os.getenv("SERVER_HOST", _default_config.database.host), port: int = os.getenv("SERVER_PORT", _default_config.server_port)): """Start Streamable HTTP transport mode""" self.logger.info(f"Starting Doris MCP Server (Streamable HTTP mode) - {host}:{port}") @@ -251,9 +255,9 @@ class DorisServer: from collections.abc import AsyncIterator from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from starlette.applications import Starlette - from starlette.routing import Mount, Route + from starlette.routing import Route from starlette.responses import JSONResponse, Response - from starlette.types import Receive, Scope, Send + from starlette.types import Scope # Create session manager session_manager = StreamableHTTPSessionManager( @@ -413,34 +417,34 @@ Examples: "--transport", type=str, choices=["stdio", "http"], - default="stdio", - help="Transport protocol type: stdio (local), http (Streamable HTTP)", + default=os.getenv("TRANSPORT", _default_config.transport), + help=f"Transport protocol type: stdio (local), http (Streamable HTTP) (default: {_default_config.transport})", ) parser.add_argument( "--host", type=str, - default="localhost", - help="Host address for HTTP mode (default: localhost)", + default=os.getenv("SERVER_HOST", _default_config.database.host), + help=f"Host address for HTTP mode (default: {_default_config.database.host})", ) parser.add_argument( - "--port", type=int, default=3000, help="Port number for HTTP mode (default: 3000)" + "--port", type=int, default=os.getenv("SERVER_PORT", _default_config.server_port), help=f"Port number for HTTP mode (default: {_default_config.server_port})" ) parser.add_argument( "--db-host", type=str, - default="localhost", - help="Doris database host address (default: localhost)", + default=os.getenv("DB_HOST", _default_config.database.host), + help=f"Doris database host address (default: {_default_config.database.host})", ) parser.add_argument( - "--db-port", type=int, default=9030, help="Doris database port number (default: 9030)" + "--db-port", type=int, default=os.getenv("DB_PORT", _default_config.database.port), help=f"Doris database port number (default: {_default_config.database.port})" ) parser.add_argument( - "--db-user", type=str, default="root", help="Doris database username (default: root)" + "--db-user", type=str, default=os.getenv("DB_USER", _default_config.database.user), help=f"Doris database username (default: {_default_config.database.user})" ) parser.add_argument("--db-password", type=str, default="", help="Doris database password") @@ -448,16 +452,16 @@ Examples: parser.add_argument( "--db-database", type=str, - default="information_schema", - help="Doris database name (default: information_schema)", + default=os.getenv("DB_DATABASE", _default_config.database.database), + help=f"Doris database name (default: {_default_config.database.database})", ) parser.add_argument( "--log-level", type=str, choices=["DEBUG", "INFO", "WARNING", "ERROR"], - default="INFO", - help="Log level (default: INFO)", + default=os.getenv("LOG_LEVEL", _default_config.logging.level), + help=f"Log level (default: {_default_config.logging.level})", ) return parser @@ -475,17 +479,17 @@ async def main(): config = DorisConfig.from_env() # First load from .env file and environment variables # Command line arguments override configuration (if provided) - if args.db_host != "localhost": # If not default value, use command line argument + if args.db_host != _default_config.database.host: # If not default value, use command line argument config.database.host = args.db_host - if args.db_port != 9030: + if args.db_port != _default_config.database.port: config.database.port = args.db_port - if args.db_user != "root": + if args.db_user != _default_config.database.user: config.database.user = args.db_user if args.db_password: # Use password if provided config.database.password = args.db_password - if args.db_database != "information_schema": + if args.db_database != _default_config.database.database: config.database.database = args.db_database - if args.log_level != "INFO": + if args.log_level != _default_config.logging.level: config.logging.level = args.log_level # Create server instance diff --git a/doris_mcp_server/utils/config.py b/doris_mcp_server/utils/config.py index 5de6e15..50c196f 100644 --- a/doris_mcp_server/utils/config.py +++ b/doris_mcp_server/utils/config.py @@ -41,8 +41,8 @@ class DatabaseConfig: port: int = 9030 user: str = "root" password: str = "" - database: str = "test" - charset: str = "utf8mb4" + database: str = "information_schema" + charset: str = "UTF8" # Connection pool configuration min_connections: int = 5 @@ -125,11 +125,11 @@ class MonitoringConfig: # Metrics collection configuration enable_metrics: bool = True - metrics_port: int = 8081 + metrics_port: int = 3001 metrics_path: str = "/metrics" # Health check configuration - health_check_port: int = 8082 + health_check_port: int = 3002 health_check_path: str = "/health" # Alert configuration @@ -143,8 +143,9 @@ class DorisConfig: # Basic configuration server_name: str = "doris-mcp-server" - server_version: str = "1.0.0" - server_port: int = 8080 + server_version: str = "0.3.0" + server_port: int = 3000 + transport: str = "stdio" # Sub-configuration modules database: DatabaseConfig = field(default_factory=DatabaseConfig) diff --git a/doris_mcp_server/utils/db.py b/doris_mcp_server/utils/db.py index 4d67a0a..190aa1c 100644 --- a/doris_mcp_server/utils/db.py +++ b/doris_mcp_server/utils/db.py @@ -137,10 +137,18 @@ class DorisConnection: async def ping(self) -> bool: """Check connection health status""" try: + # Check if connection exists and is not closed + if not self.connection or self.connection.closed: + self.is_healthy = False + return False + + # Try to ping the connection await self.connection.ping() self.is_healthy = True return True - except Exception: + except Exception as e: + # Log the specific error for debugging + logging.debug(f"Connection ping failed for session {self.session_id}: {e}") self.is_healthy = False return False @@ -181,7 +189,17 @@ class DorisConnectionManager: async def initialize(self): """Initialize connection manager""" try: - # Create connection pool + self.logger.info(f"Initializing connection pool to {self.config.database.host}:{self.config.database.port}") + + # Validate configuration + if not self.config.database.host: + raise ValueError("Database host is required") + if not self.config.database.user: + raise ValueError("Database user is required") + if not self.config.database.password: + self.logger.warning("Database password is empty, this may cause connection issues") + + # Create connection pool with additional parameters for stability self.pool = await aiomysql.create_pool( host=self.config.database.host, port=self.config.database.port, @@ -193,8 +211,15 @@ class DorisConnectionManager: maxsize=self.config.database.max_connections or 20, autocommit=True, connect_timeout=self.connection_timeout, + # Additional parameters for stability + pool_recycle=3600, # Recycle connections every hour + echo=False, # Don't echo SQL statements ) + # Test the connection pool + if not await self.test_connection(): + raise RuntimeError("Connection pool test failed") + self.logger.info( f"Connection pool initialized successfully, min connections: {self.config.database.min_connections}, " f"max connections: {self.config.database.max_connections}" @@ -206,6 +231,14 @@ class DorisConnectionManager: except Exception as e: self.logger.error(f"Connection pool initialization failed: {e}") + # Clean up partial initialization + if self.pool: + try: + self.pool.close() + await self.pool.wait_closed() + except Exception: + pass + self.pool = None raise async def get_connection(self, session_id: str) -> DorisConnection: @@ -235,9 +268,24 @@ class DorisConnectionManager: # Get connection from pool raw_connection = await self.pool.acquire() + # Validate the raw connection + if not raw_connection: + raise RuntimeError(f"Failed to acquire connection from pool for session {session_id}") + + # Verify the connection is not closed + if raw_connection.closed: + raise RuntimeError(f"Acquired connection is already closed for session {session_id}") + # Create wrapped connection doris_conn = DorisConnection(raw_connection, session_id, self.security_manager) + # Test the connection before storing it + if not await doris_conn.ping(): + # If ping fails, release the connection and raise error + if self.pool and raw_connection and not raw_connection.closed: + self.pool.release(raw_connection) + raise RuntimeError(f"New connection failed ping test for session {session_id}") + # Store in session connections self.session_connections[session_id] = doris_conn