1. Fix DB Connection BUG
2. Modify the global default configuration items and obtain them from Config
This commit is contained in:
@@ -44,11 +44,15 @@ from .tools.resources_manager import DorisResourcesManager
|
|||||||
from .utils.config import DorisConfig
|
from .utils.config import DorisConfig
|
||||||
from .utils.db import DorisConnectionManager
|
from .utils.db import DorisConnectionManager
|
||||||
from .utils.security import DorisSecurityManager
|
from .utils.security import DorisSecurityManager
|
||||||
|
import os
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Create a default config instance for getting default values
|
||||||
|
_default_config = DorisConfig()
|
||||||
|
|
||||||
|
|
||||||
class DorisServer:
|
class DorisServer:
|
||||||
"""Apache Doris MCP Server main class"""
|
"""Apache Doris MCP Server main class"""
|
||||||
@@ -204,7 +208,7 @@ class DorisServer:
|
|||||||
|
|
||||||
init_options = InitializationOptions(
|
init_options = InitializationOptions(
|
||||||
server_name="doris-mcp-server",
|
server_name="doris-mcp-server",
|
||||||
server_version="1.0.0",
|
server_version=os.getenv("SERVER_VERSION", _default_config.server_version),
|
||||||
capabilities=capabilities,
|
capabilities=capabilities,
|
||||||
)
|
)
|
||||||
self.logger.info("Initialization options created successfully")
|
self.logger.info("Initialization options created successfully")
|
||||||
@@ -237,7 +241,7 @@ class DorisServer:
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def start_http(self, host: str = "localhost", port: int = 3000):
|
async def start_http(self, host: str = os.getenv("SERVER_HOST", _default_config.database.host), port: int = os.getenv("SERVER_PORT", _default_config.server_port)):
|
||||||
"""Start Streamable HTTP transport mode"""
|
"""Start Streamable HTTP transport mode"""
|
||||||
self.logger.info(f"Starting Doris MCP Server (Streamable HTTP mode) - {host}:{port}")
|
self.logger.info(f"Starting Doris MCP Server (Streamable HTTP mode) - {host}:{port}")
|
||||||
|
|
||||||
@@ -251,9 +255,9 @@ class DorisServer:
|
|||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
|
||||||
from starlette.applications import Starlette
|
from starlette.applications import Starlette
|
||||||
from starlette.routing import Mount, Route
|
from starlette.routing import Route
|
||||||
from starlette.responses import JSONResponse, Response
|
from starlette.responses import JSONResponse, Response
|
||||||
from starlette.types import Receive, Scope, Send
|
from starlette.types import Scope
|
||||||
|
|
||||||
# Create session manager
|
# Create session manager
|
||||||
session_manager = StreamableHTTPSessionManager(
|
session_manager = StreamableHTTPSessionManager(
|
||||||
@@ -413,34 +417,34 @@ Examples:
|
|||||||
"--transport",
|
"--transport",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["stdio", "http"],
|
choices=["stdio", "http"],
|
||||||
default="stdio",
|
default=os.getenv("TRANSPORT", _default_config.transport),
|
||||||
help="Transport protocol type: stdio (local), http (Streamable HTTP)",
|
help=f"Transport protocol type: stdio (local), http (Streamable HTTP) (default: {_default_config.transport})",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--host",
|
"--host",
|
||||||
type=str,
|
type=str,
|
||||||
default="localhost",
|
default=os.getenv("SERVER_HOST", _default_config.database.host),
|
||||||
help="Host address for HTTP mode (default: localhost)",
|
help=f"Host address for HTTP mode (default: {_default_config.database.host})",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--port", type=int, default=3000, help="Port number for HTTP mode (default: 3000)"
|
"--port", type=int, default=os.getenv("SERVER_PORT", _default_config.server_port), help=f"Port number for HTTP mode (default: {_default_config.server_port})"
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--db-host",
|
"--db-host",
|
||||||
type=str,
|
type=str,
|
||||||
default="localhost",
|
default=os.getenv("DB_HOST", _default_config.database.host),
|
||||||
help="Doris database host address (default: localhost)",
|
help=f"Doris database host address (default: {_default_config.database.host})",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--db-port", type=int, default=9030, help="Doris database port number (default: 9030)"
|
"--db-port", type=int, default=os.getenv("DB_PORT", _default_config.database.port), help=f"Doris database port number (default: {_default_config.database.port})"
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--db-user", type=str, default="root", help="Doris database username (default: root)"
|
"--db-user", type=str, default=os.getenv("DB_USER", _default_config.database.user), help=f"Doris database username (default: {_default_config.database.user})"
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument("--db-password", type=str, default="", help="Doris database password")
|
parser.add_argument("--db-password", type=str, default="", help="Doris database password")
|
||||||
@@ -448,16 +452,16 @@ Examples:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--db-database",
|
"--db-database",
|
||||||
type=str,
|
type=str,
|
||||||
default="information_schema",
|
default=os.getenv("DB_DATABASE", _default_config.database.database),
|
||||||
help="Doris database name (default: information_schema)",
|
help=f"Doris database name (default: {_default_config.database.database})",
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--log-level",
|
"--log-level",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
|
||||||
default="INFO",
|
default=os.getenv("LOG_LEVEL", _default_config.logging.level),
|
||||||
help="Log level (default: INFO)",
|
help=f"Log level (default: {_default_config.logging.level})",
|
||||||
)
|
)
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
@@ -475,17 +479,17 @@ async def main():
|
|||||||
config = DorisConfig.from_env() # First load from .env file and environment variables
|
config = DorisConfig.from_env() # First load from .env file and environment variables
|
||||||
|
|
||||||
# Command line arguments override configuration (if provided)
|
# Command line arguments override configuration (if provided)
|
||||||
if args.db_host != "localhost": # If not default value, use command line argument
|
if args.db_host != _default_config.database.host: # If not default value, use command line argument
|
||||||
config.database.host = args.db_host
|
config.database.host = args.db_host
|
||||||
if args.db_port != 9030:
|
if args.db_port != _default_config.database.port:
|
||||||
config.database.port = args.db_port
|
config.database.port = args.db_port
|
||||||
if args.db_user != "root":
|
if args.db_user != _default_config.database.user:
|
||||||
config.database.user = args.db_user
|
config.database.user = args.db_user
|
||||||
if args.db_password: # Use password if provided
|
if args.db_password: # Use password if provided
|
||||||
config.database.password = args.db_password
|
config.database.password = args.db_password
|
||||||
if args.db_database != "information_schema":
|
if args.db_database != _default_config.database.database:
|
||||||
config.database.database = args.db_database
|
config.database.database = args.db_database
|
||||||
if args.log_level != "INFO":
|
if args.log_level != _default_config.logging.level:
|
||||||
config.logging.level = args.log_level
|
config.logging.level = args.log_level
|
||||||
|
|
||||||
# Create server instance
|
# Create server instance
|
||||||
|
|||||||
@@ -41,8 +41,8 @@ class DatabaseConfig:
|
|||||||
port: int = 9030
|
port: int = 9030
|
||||||
user: str = "root"
|
user: str = "root"
|
||||||
password: str = ""
|
password: str = ""
|
||||||
database: str = "test"
|
database: str = "information_schema"
|
||||||
charset: str = "utf8mb4"
|
charset: str = "UTF8"
|
||||||
|
|
||||||
# Connection pool configuration
|
# Connection pool configuration
|
||||||
min_connections: int = 5
|
min_connections: int = 5
|
||||||
@@ -125,11 +125,11 @@ class MonitoringConfig:
|
|||||||
|
|
||||||
# Metrics collection configuration
|
# Metrics collection configuration
|
||||||
enable_metrics: bool = True
|
enable_metrics: bool = True
|
||||||
metrics_port: int = 8081
|
metrics_port: int = 3001
|
||||||
metrics_path: str = "/metrics"
|
metrics_path: str = "/metrics"
|
||||||
|
|
||||||
# Health check configuration
|
# Health check configuration
|
||||||
health_check_port: int = 8082
|
health_check_port: int = 3002
|
||||||
health_check_path: str = "/health"
|
health_check_path: str = "/health"
|
||||||
|
|
||||||
# Alert configuration
|
# Alert configuration
|
||||||
@@ -143,8 +143,9 @@ class DorisConfig:
|
|||||||
|
|
||||||
# Basic configuration
|
# Basic configuration
|
||||||
server_name: str = "doris-mcp-server"
|
server_name: str = "doris-mcp-server"
|
||||||
server_version: str = "1.0.0"
|
server_version: str = "0.3.0"
|
||||||
server_port: int = 8080
|
server_port: int = 3000
|
||||||
|
transport: str = "stdio"
|
||||||
|
|
||||||
# Sub-configuration modules
|
# Sub-configuration modules
|
||||||
database: DatabaseConfig = field(default_factory=DatabaseConfig)
|
database: DatabaseConfig = field(default_factory=DatabaseConfig)
|
||||||
|
|||||||
@@ -137,10 +137,18 @@ class DorisConnection:
|
|||||||
async def ping(self) -> bool:
|
async def ping(self) -> bool:
|
||||||
"""Check connection health status"""
|
"""Check connection health status"""
|
||||||
try:
|
try:
|
||||||
|
# Check if connection exists and is not closed
|
||||||
|
if not self.connection or self.connection.closed:
|
||||||
|
self.is_healthy = False
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Try to ping the connection
|
||||||
await self.connection.ping()
|
await self.connection.ping()
|
||||||
self.is_healthy = True
|
self.is_healthy = True
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception as e:
|
||||||
|
# Log the specific error for debugging
|
||||||
|
logging.debug(f"Connection ping failed for session {self.session_id}: {e}")
|
||||||
self.is_healthy = False
|
self.is_healthy = False
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -181,7 +189,17 @@ class DorisConnectionManager:
|
|||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""Initialize connection manager"""
|
"""Initialize connection manager"""
|
||||||
try:
|
try:
|
||||||
# Create connection pool
|
self.logger.info(f"Initializing connection pool to {self.config.database.host}:{self.config.database.port}")
|
||||||
|
|
||||||
|
# Validate configuration
|
||||||
|
if not self.config.database.host:
|
||||||
|
raise ValueError("Database host is required")
|
||||||
|
if not self.config.database.user:
|
||||||
|
raise ValueError("Database user is required")
|
||||||
|
if not self.config.database.password:
|
||||||
|
self.logger.warning("Database password is empty, this may cause connection issues")
|
||||||
|
|
||||||
|
# Create connection pool with additional parameters for stability
|
||||||
self.pool = await aiomysql.create_pool(
|
self.pool = await aiomysql.create_pool(
|
||||||
host=self.config.database.host,
|
host=self.config.database.host,
|
||||||
port=self.config.database.port,
|
port=self.config.database.port,
|
||||||
@@ -193,8 +211,15 @@ class DorisConnectionManager:
|
|||||||
maxsize=self.config.database.max_connections or 20,
|
maxsize=self.config.database.max_connections or 20,
|
||||||
autocommit=True,
|
autocommit=True,
|
||||||
connect_timeout=self.connection_timeout,
|
connect_timeout=self.connection_timeout,
|
||||||
|
# Additional parameters for stability
|
||||||
|
pool_recycle=3600, # Recycle connections every hour
|
||||||
|
echo=False, # Don't echo SQL statements
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Test the connection pool
|
||||||
|
if not await self.test_connection():
|
||||||
|
raise RuntimeError("Connection pool test failed")
|
||||||
|
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
f"Connection pool initialized successfully, min connections: {self.config.database.min_connections}, "
|
f"Connection pool initialized successfully, min connections: {self.config.database.min_connections}, "
|
||||||
f"max connections: {self.config.database.max_connections}"
|
f"max connections: {self.config.database.max_connections}"
|
||||||
@@ -206,6 +231,14 @@ class DorisConnectionManager:
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"Connection pool initialization failed: {e}")
|
self.logger.error(f"Connection pool initialization failed: {e}")
|
||||||
|
# Clean up partial initialization
|
||||||
|
if self.pool:
|
||||||
|
try:
|
||||||
|
self.pool.close()
|
||||||
|
await self.pool.wait_closed()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self.pool = None
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def get_connection(self, session_id: str) -> DorisConnection:
|
async def get_connection(self, session_id: str) -> DorisConnection:
|
||||||
@@ -235,9 +268,24 @@ class DorisConnectionManager:
|
|||||||
# Get connection from pool
|
# Get connection from pool
|
||||||
raw_connection = await self.pool.acquire()
|
raw_connection = await self.pool.acquire()
|
||||||
|
|
||||||
|
# Validate the raw connection
|
||||||
|
if not raw_connection:
|
||||||
|
raise RuntimeError(f"Failed to acquire connection from pool for session {session_id}")
|
||||||
|
|
||||||
|
# Verify the connection is not closed
|
||||||
|
if raw_connection.closed:
|
||||||
|
raise RuntimeError(f"Acquired connection is already closed for session {session_id}")
|
||||||
|
|
||||||
# Create wrapped connection
|
# Create wrapped connection
|
||||||
doris_conn = DorisConnection(raw_connection, session_id, self.security_manager)
|
doris_conn = DorisConnection(raw_connection, session_id, self.security_manager)
|
||||||
|
|
||||||
|
# Test the connection before storing it
|
||||||
|
if not await doris_conn.ping():
|
||||||
|
# If ping fails, release the connection and raise error
|
||||||
|
if self.pool and raw_connection and not raw_connection.closed:
|
||||||
|
self.pool.release(raw_connection)
|
||||||
|
raise RuntimeError(f"New connection failed ping test for session {session_id}")
|
||||||
|
|
||||||
# Store in session connections
|
# Store in session connections
|
||||||
self.session_connections[session_id] = doris_conn
|
self.session_connections[session_id] = doris_conn
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user