1. Fix DB Connection BUG

2. Modify the global default configuration items and obtain them from Config
This commit is contained in:
FreeOnePlus
2025-06-11 11:52:15 +08:00
parent 0a81d5693b
commit 5d46d153e1
3 changed files with 83 additions and 30 deletions

View File

@@ -44,11 +44,15 @@ from .tools.resources_manager import DorisResourcesManager
from .utils.config import DorisConfig from .utils.config import DorisConfig
from .utils.db import DorisConnectionManager from .utils.db import DorisConnectionManager
from .utils.security import DorisSecurityManager from .utils.security import DorisSecurityManager
import os
# Configure logging # Configure logging
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Create a default config instance for getting default values
_default_config = DorisConfig()
class DorisServer: class DorisServer:
"""Apache Doris MCP Server main class""" """Apache Doris MCP Server main class"""
@@ -204,7 +208,7 @@ class DorisServer:
init_options = InitializationOptions( init_options = InitializationOptions(
server_name="doris-mcp-server", server_name="doris-mcp-server",
server_version="1.0.0", server_version=os.getenv("SERVER_VERSION", _default_config.server_version),
capabilities=capabilities, capabilities=capabilities,
) )
self.logger.info("Initialization options created successfully") self.logger.info("Initialization options created successfully")
@@ -237,7 +241,7 @@ class DorisServer:
async def start_http(self, host: str = "localhost", port: int = 3000): async def start_http(self, host: str = os.getenv("SERVER_HOST", _default_config.database.host), port: int = os.getenv("SERVER_PORT", _default_config.server_port)):
"""Start Streamable HTTP transport mode""" """Start Streamable HTTP transport mode"""
self.logger.info(f"Starting Doris MCP Server (Streamable HTTP mode) - {host}:{port}") self.logger.info(f"Starting Doris MCP Server (Streamable HTTP mode) - {host}:{port}")
@@ -251,9 +255,9 @@ class DorisServer:
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
from starlette.applications import Starlette from starlette.applications import Starlette
from starlette.routing import Mount, Route from starlette.routing import Route
from starlette.responses import JSONResponse, Response from starlette.responses import JSONResponse, Response
from starlette.types import Receive, Scope, Send from starlette.types import Scope
# Create session manager # Create session manager
session_manager = StreamableHTTPSessionManager( session_manager = StreamableHTTPSessionManager(
@@ -413,34 +417,34 @@ Examples:
"--transport", "--transport",
type=str, type=str,
choices=["stdio", "http"], choices=["stdio", "http"],
default="stdio", default=os.getenv("TRANSPORT", _default_config.transport),
help="Transport protocol type: stdio (local), http (Streamable HTTP)", help=f"Transport protocol type: stdio (local), http (Streamable HTTP) (default: {_default_config.transport})",
) )
parser.add_argument( parser.add_argument(
"--host", "--host",
type=str, type=str,
default="localhost", default=os.getenv("SERVER_HOST", _default_config.database.host),
help="Host address for HTTP mode (default: localhost)", help=f"Host address for HTTP mode (default: {_default_config.database.host})",
) )
parser.add_argument( parser.add_argument(
"--port", type=int, default=3000, help="Port number for HTTP mode (default: 3000)" "--port", type=int, default=os.getenv("SERVER_PORT", _default_config.server_port), help=f"Port number for HTTP mode (default: {_default_config.server_port})"
) )
parser.add_argument( parser.add_argument(
"--db-host", "--db-host",
type=str, type=str,
default="localhost", default=os.getenv("DB_HOST", _default_config.database.host),
help="Doris database host address (default: localhost)", help=f"Doris database host address (default: {_default_config.database.host})",
) )
parser.add_argument( parser.add_argument(
"--db-port", type=int, default=9030, help="Doris database port number (default: 9030)" "--db-port", type=int, default=os.getenv("DB_PORT", _default_config.database.port), help=f"Doris database port number (default: {_default_config.database.port})"
) )
parser.add_argument( parser.add_argument(
"--db-user", type=str, default="root", help="Doris database username (default: root)" "--db-user", type=str, default=os.getenv("DB_USER", _default_config.database.user), help=f"Doris database username (default: {_default_config.database.user})"
) )
parser.add_argument("--db-password", type=str, default="", help="Doris database password") parser.add_argument("--db-password", type=str, default="", help="Doris database password")
@@ -448,16 +452,16 @@ Examples:
parser.add_argument( parser.add_argument(
"--db-database", "--db-database",
type=str, type=str,
default="information_schema", default=os.getenv("DB_DATABASE", _default_config.database.database),
help="Doris database name (default: information_schema)", help=f"Doris database name (default: {_default_config.database.database})",
) )
parser.add_argument( parser.add_argument(
"--log-level", "--log-level",
type=str, type=str,
choices=["DEBUG", "INFO", "WARNING", "ERROR"], choices=["DEBUG", "INFO", "WARNING", "ERROR"],
default="INFO", default=os.getenv("LOG_LEVEL", _default_config.logging.level),
help="Log level (default: INFO)", help=f"Log level (default: {_default_config.logging.level})",
) )
return parser return parser
@@ -475,17 +479,17 @@ async def main():
config = DorisConfig.from_env() # First load from .env file and environment variables config = DorisConfig.from_env() # First load from .env file and environment variables
# Command line arguments override configuration (if provided) # Command line arguments override configuration (if provided)
if args.db_host != "localhost": # If not default value, use command line argument if args.db_host != _default_config.database.host: # If not default value, use command line argument
config.database.host = args.db_host config.database.host = args.db_host
if args.db_port != 9030: if args.db_port != _default_config.database.port:
config.database.port = args.db_port config.database.port = args.db_port
if args.db_user != "root": if args.db_user != _default_config.database.user:
config.database.user = args.db_user config.database.user = args.db_user
if args.db_password: # Use password if provided if args.db_password: # Use password if provided
config.database.password = args.db_password config.database.password = args.db_password
if args.db_database != "information_schema": if args.db_database != _default_config.database.database:
config.database.database = args.db_database config.database.database = args.db_database
if args.log_level != "INFO": if args.log_level != _default_config.logging.level:
config.logging.level = args.log_level config.logging.level = args.log_level
# Create server instance # Create server instance

View File

@@ -41,8 +41,8 @@ class DatabaseConfig:
port: int = 9030 port: int = 9030
user: str = "root" user: str = "root"
password: str = "" password: str = ""
database: str = "test" database: str = "information_schema"
charset: str = "utf8mb4" charset: str = "UTF8"
# Connection pool configuration # Connection pool configuration
min_connections: int = 5 min_connections: int = 5
@@ -125,11 +125,11 @@ class MonitoringConfig:
# Metrics collection configuration # Metrics collection configuration
enable_metrics: bool = True enable_metrics: bool = True
metrics_port: int = 8081 metrics_port: int = 3001
metrics_path: str = "/metrics" metrics_path: str = "/metrics"
# Health check configuration # Health check configuration
health_check_port: int = 8082 health_check_port: int = 3002
health_check_path: str = "/health" health_check_path: str = "/health"
# Alert configuration # Alert configuration
@@ -143,8 +143,9 @@ class DorisConfig:
# Basic configuration # Basic configuration
server_name: str = "doris-mcp-server" server_name: str = "doris-mcp-server"
server_version: str = "1.0.0" server_version: str = "0.3.0"
server_port: int = 8080 server_port: int = 3000
transport: str = "stdio"
# Sub-configuration modules # Sub-configuration modules
database: DatabaseConfig = field(default_factory=DatabaseConfig) database: DatabaseConfig = field(default_factory=DatabaseConfig)

View File

@@ -137,10 +137,18 @@ class DorisConnection:
async def ping(self) -> bool: async def ping(self) -> bool:
"""Check connection health status""" """Check connection health status"""
try: try:
# Check if connection exists and is not closed
if not self.connection or self.connection.closed:
self.is_healthy = False
return False
# Try to ping the connection
await self.connection.ping() await self.connection.ping()
self.is_healthy = True self.is_healthy = True
return True return True
except Exception: except Exception as e:
# Log the specific error for debugging
logging.debug(f"Connection ping failed for session {self.session_id}: {e}")
self.is_healthy = False self.is_healthy = False
return False return False
@@ -181,7 +189,17 @@ class DorisConnectionManager:
async def initialize(self): async def initialize(self):
"""Initialize connection manager""" """Initialize connection manager"""
try: try:
# Create connection pool self.logger.info(f"Initializing connection pool to {self.config.database.host}:{self.config.database.port}")
# Validate configuration
if not self.config.database.host:
raise ValueError("Database host is required")
if not self.config.database.user:
raise ValueError("Database user is required")
if not self.config.database.password:
self.logger.warning("Database password is empty, this may cause connection issues")
# Create connection pool with additional parameters for stability
self.pool = await aiomysql.create_pool( self.pool = await aiomysql.create_pool(
host=self.config.database.host, host=self.config.database.host,
port=self.config.database.port, port=self.config.database.port,
@@ -193,8 +211,15 @@ class DorisConnectionManager:
maxsize=self.config.database.max_connections or 20, maxsize=self.config.database.max_connections or 20,
autocommit=True, autocommit=True,
connect_timeout=self.connection_timeout, connect_timeout=self.connection_timeout,
# Additional parameters for stability
pool_recycle=3600, # Recycle connections every hour
echo=False, # Don't echo SQL statements
) )
# Test the connection pool
if not await self.test_connection():
raise RuntimeError("Connection pool test failed")
self.logger.info( self.logger.info(
f"Connection pool initialized successfully, min connections: {self.config.database.min_connections}, " f"Connection pool initialized successfully, min connections: {self.config.database.min_connections}, "
f"max connections: {self.config.database.max_connections}" f"max connections: {self.config.database.max_connections}"
@@ -206,6 +231,14 @@ class DorisConnectionManager:
except Exception as e: except Exception as e:
self.logger.error(f"Connection pool initialization failed: {e}") self.logger.error(f"Connection pool initialization failed: {e}")
# Clean up partial initialization
if self.pool:
try:
self.pool.close()
await self.pool.wait_closed()
except Exception:
pass
self.pool = None
raise raise
async def get_connection(self, session_id: str) -> DorisConnection: async def get_connection(self, session_id: str) -> DorisConnection:
@@ -235,9 +268,24 @@ class DorisConnectionManager:
# Get connection from pool # Get connection from pool
raw_connection = await self.pool.acquire() raw_connection = await self.pool.acquire()
# Validate the raw connection
if not raw_connection:
raise RuntimeError(f"Failed to acquire connection from pool for session {session_id}")
# Verify the connection is not closed
if raw_connection.closed:
raise RuntimeError(f"Acquired connection is already closed for session {session_id}")
# Create wrapped connection # Create wrapped connection
doris_conn = DorisConnection(raw_connection, session_id, self.security_manager) doris_conn = DorisConnection(raw_connection, session_id, self.security_manager)
# Test the connection before storing it
if not await doris_conn.ping():
# If ping fails, release the connection and raise error
if self.pool and raw_connection and not raw_connection.closed:
self.pool.release(raw_connection)
raise RuntimeError(f"New connection failed ping test for session {session_id}")
# Store in session connections # Store in session connections
self.session_connections[session_id] = doris_conn self.session_connections[session_id] = doris_conn