Files
doris-mcp-server/doris_mcp_server/utils/db.py
2025-06-08 18:44:40 +08:00

480 lines
18 KiB
Python

#!/usr/bin/env python3
"""
Apache Doris Database Connection Management Module
Provides high-performance database connection pool management, automatic reconnection mechanism and connection health check functionality
Supports asynchronous operations and concurrent connection management, ensuring stability and performance for enterprise applications
"""
import asyncio
import logging
import time
from contextlib import asynccontextmanager
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List
import aiomysql
from aiomysql import Connection, Pool
@dataclass
class ConnectionMetrics:
"""Connection pool performance metrics"""
total_connections: int = 0
active_connections: int = 0
idle_connections: int = 0
failed_connections: int = 0
connection_errors: int = 0
avg_connection_time: float = 0.0
last_health_check: datetime | None = None
@dataclass
class QueryResult:
"""Query result wrapper"""
data: list[dict[str, Any]]
metadata: dict[str, Any]
execution_time: float
row_count: int
class DorisConnection:
"""Doris database connection wrapper class"""
def __init__(self, connection: Connection, session_id: str, security_manager=None):
self.connection = connection
self.session_id = session_id
self.created_at = datetime.utcnow()
self.last_used = datetime.utcnow()
self.query_count = 0
self.is_healthy = True
self.security_manager = security_manager
async def execute(self, sql: str, params: tuple | None = None, auth_context=None) -> QueryResult:
"""Execute SQL query"""
start_time = time.time()
try:
# If security manager exists, perform SQL security check
security_result = None
if self.security_manager and auth_context:
validation_result = await self.security_manager.validate_sql_security(sql, auth_context)
if not validation_result.is_valid:
raise ValueError(f"SQL security validation failed: {validation_result.error_message}")
security_result = {
"is_valid": validation_result.is_valid,
"risk_level": validation_result.risk_level,
"blocked_operations": validation_result.blocked_operations
}
async with self.connection.cursor(aiomysql.DictCursor) as cursor:
await cursor.execute(sql, params)
# Check if it's a query statement (statement that returns result set)
sql_upper = sql.strip().upper()
if (sql_upper.startswith("SELECT") or
sql_upper.startswith("SHOW") or
sql_upper.startswith("DESCRIBE") or
sql_upper.startswith("DESC") or
sql_upper.startswith("EXPLAIN")):
data = await cursor.fetchall()
row_count = len(data)
else:
data = []
row_count = cursor.rowcount
execution_time = time.time() - start_time
self.last_used = datetime.utcnow()
self.query_count += 1
# Get column information
columns = []
if cursor.description:
columns = [desc[0] for desc in cursor.description]
# If security manager exists and has auth context, apply data masking
final_data = list(data) if data else []
if self.security_manager and auth_context and final_data:
final_data = await self.security_manager.apply_data_masking(final_data, auth_context)
metadata = {"columns": columns, "query": sql, "params": params}
if security_result:
metadata["security_check"] = security_result
return QueryResult(
data=final_data,
metadata=metadata,
execution_time=execution_time,
row_count=row_count,
)
except Exception as e:
self.is_healthy = False
logging.error(f"Query execution failed: {e}")
raise
async def ping(self) -> bool:
"""Check connection health status"""
try:
await self.connection.ping()
self.is_healthy = True
return True
except Exception:
self.is_healthy = False
return False
async def close(self):
"""Close connection"""
try:
if self.connection and not self.connection.closed:
await self.connection.ensure_closed()
except Exception as e:
logging.error(f"Error occurred while closing connection: {e}")
class DorisConnectionManager:
"""Doris database connection manager
Provides connection pool management, connection health monitoring, fault recovery and other functions
Supports session-level connection reuse and intelligent load balancing
Integrates security manager to provide unified security validation and data masking
"""
def __init__(self, config, security_manager=None):
self.config = config
self.pool: Pool | None = None
self.session_connections: dict[str, DorisConnection] = {}
self.metrics = ConnectionMetrics()
self.logger = logging.getLogger(__name__)
self.security_manager = security_manager
# Health check configuration
self.health_check_interval = config.database.health_check_interval or 60
self.max_connection_age = config.database.max_connection_age or 3600
self.connection_timeout = config.database.connection_timeout or 30
# Start background tasks
self._health_check_task = None
self._cleanup_task = None
async def initialize(self):
"""Initialize connection manager"""
try:
# Create connection pool
self.pool = await aiomysql.create_pool(
host=self.config.database.host,
port=self.config.database.port,
user=self.config.database.user,
password=self.config.database.password,
db=self.config.database.database,
charset="utf8",
minsize=self.config.database.min_connections or 5,
maxsize=self.config.database.max_connections or 20,
autocommit=True,
connect_timeout=self.connection_timeout,
)
self.logger.info(
f"Connection pool initialized successfully, min connections: {self.config.database.min_connections}, "
f"max connections: {self.config.database.max_connections}"
)
# Start background monitoring tasks
self._health_check_task = asyncio.create_task(self._health_check_loop())
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
except Exception as e:
self.logger.error(f"Connection pool initialization failed: {e}")
raise
async def get_connection(self, session_id: str) -> DorisConnection:
"""Get database connection
Supports session-level connection reuse to improve performance and consistency
"""
# Check if there's an existing session connection
if session_id in self.session_connections:
conn = self.session_connections[session_id]
# Check connection health
if await conn.ping():
return conn
else:
# Connection is unhealthy, clean up and create new one
await self._cleanup_session_connection(session_id)
# Create new connection
return await self._create_new_connection(session_id)
async def _create_new_connection(self, session_id: str) -> DorisConnection:
"""Create new database connection"""
try:
if not self.pool:
raise RuntimeError("Connection pool not initialized")
# Get connection from pool
raw_connection = await self.pool.acquire()
# Create wrapped connection
doris_conn = DorisConnection(raw_connection, session_id, self.security_manager)
# Store in session connections
self.session_connections[session_id] = doris_conn
self.metrics.total_connections += 1
self.logger.debug(f"Created new connection for session: {session_id}")
return doris_conn
except Exception as e:
self.metrics.connection_errors += 1
self.logger.error(f"Failed to create connection for session {session_id}: {e}")
raise
async def release_connection(self, session_id: str):
"""Release session connection"""
if session_id in self.session_connections:
await self._cleanup_session_connection(session_id)
async def _cleanup_session_connection(self, session_id: str):
"""Clean up session connection"""
if session_id in self.session_connections:
conn = self.session_connections[session_id]
try:
# Return connection to pool
if self.pool and conn.connection and not conn.connection.closed:
self.pool.release(conn.connection)
# Close connection wrapper
await conn.close()
except Exception as e:
self.logger.error(f"Error cleaning up connection for session {session_id}: {e}")
finally:
# Remove from session connections
del self.session_connections[session_id]
self.logger.debug(f"Cleaned up connection for session: {session_id}")
async def _health_check_loop(self):
"""Background health check loop"""
while True:
try:
await asyncio.sleep(self.health_check_interval)
await self._perform_health_check()
except asyncio.CancelledError:
break
except Exception as e:
self.logger.error(f"Health check error: {e}")
async def _perform_health_check(self):
"""Perform health check"""
try:
unhealthy_sessions = []
for session_id, conn in self.session_connections.items():
if not await conn.ping():
unhealthy_sessions.append(session_id)
# Clean up unhealthy connections
for session_id in unhealthy_sessions:
await self._cleanup_session_connection(session_id)
self.metrics.failed_connections += 1
# Update metrics
await self._update_connection_metrics()
self.metrics.last_health_check = datetime.utcnow()
if unhealthy_sessions:
self.logger.warning(f"Cleaned up {len(unhealthy_sessions)} unhealthy connections")
except Exception as e:
self.logger.error(f"Health check failed: {e}")
async def _cleanup_loop(self):
"""Background cleanup loop"""
while True:
try:
await asyncio.sleep(300) # Run every 5 minutes
await self._cleanup_idle_connections()
except asyncio.CancelledError:
break
except Exception as e:
self.logger.error(f"Cleanup loop error: {e}")
async def _cleanup_idle_connections(self):
"""Clean up idle connections"""
current_time = datetime.utcnow()
idle_sessions = []
for session_id, conn in self.session_connections.items():
# Check if connection has exceeded maximum age
age = (current_time - conn.created_at).total_seconds()
if age > self.max_connection_age:
idle_sessions.append(session_id)
# Clean up idle connections
for session_id in idle_sessions:
await self._cleanup_session_connection(session_id)
if idle_sessions:
self.logger.info(f"Cleaned up {len(idle_sessions)} idle connections")
async def _update_connection_metrics(self):
"""Update connection metrics"""
self.metrics.active_connections = len(self.session_connections)
if self.pool:
self.metrics.idle_connections = self.pool.freesize
async def get_metrics(self) -> ConnectionMetrics:
"""Get connection metrics"""
await self._update_connection_metrics()
return self.metrics
async def execute_query(
self, session_id: str, sql: str, params: tuple | None = None, auth_context=None
) -> QueryResult:
"""Execute query"""
conn = await self.get_connection(session_id)
return await conn.execute(sql, params, auth_context)
@asynccontextmanager
async def get_connection_context(self, session_id: str):
"""Get connection context manager"""
conn = await self.get_connection(session_id)
try:
yield conn
finally:
# Connection will be reused, no need to close here
pass
async def close(self):
"""Close connection manager"""
try:
# Cancel background tasks
if self._health_check_task:
self._health_check_task.cancel()
try:
await self._health_check_task
except asyncio.CancelledError:
pass
if self._cleanup_task:
self._cleanup_task.cancel()
try:
await self._cleanup_task
except asyncio.CancelledError:
pass
# Clean up all session connections
for session_id in list(self.session_connections.keys()):
await self._cleanup_session_connection(session_id)
# Close connection pool
if self.pool:
self.pool.close()
await self.pool.wait_closed()
self.logger.info("Connection manager closed successfully")
except Exception as e:
self.logger.error(f"Error closing connection manager: {e}")
async def test_connection(self) -> bool:
"""Test database connection"""
try:
if not self.pool:
return False
async with self.pool.acquire() as conn:
async with conn.cursor() as cursor:
await cursor.execute("SELECT 1")
result = await cursor.fetchone()
return result is not None
except Exception as e:
self.logger.error(f"Connection test failed: {e}")
return False
class ConnectionPoolMonitor:
"""Connection pool monitor
Provides detailed monitoring and reporting capabilities for connection pool status
"""
def __init__(self, connection_manager: DorisConnectionManager):
self.connection_manager = connection_manager
self.logger = logging.getLogger(__name__)
async def get_pool_status(self) -> dict[str, Any]:
"""Get connection pool status"""
metrics = await self.connection_manager.get_metrics()
status = {
"pool_size": self.connection_manager.pool.size if self.connection_manager.pool else 0,
"free_connections": self.connection_manager.pool.freesize if self.connection_manager.pool else 0,
"active_sessions": len(self.connection_manager.session_connections),
"total_connections": metrics.total_connections,
"failed_connections": metrics.failed_connections,
"connection_errors": metrics.connection_errors,
"avg_connection_time": metrics.avg_connection_time,
"last_health_check": metrics.last_health_check.isoformat() if metrics.last_health_check else None,
}
return status
async def get_session_details(self) -> list[dict[str, Any]]:
"""Get session connection details"""
sessions = []
for session_id, conn in self.connection_manager.session_connections.items():
session_info = {
"session_id": session_id,
"created_at": conn.created_at.isoformat(),
"last_used": conn.last_used.isoformat(),
"query_count": conn.query_count,
"is_healthy": conn.is_healthy,
"connection_age": (datetime.utcnow() - conn.created_at).total_seconds(),
}
sessions.append(session_info)
return sessions
async def generate_health_report(self) -> dict[str, Any]:
"""Generate connection health report"""
pool_status = await self.get_pool_status()
session_details = await self.get_session_details()
# Calculate health statistics
healthy_sessions = sum(1 for s in session_details if s["is_healthy"])
total_sessions = len(session_details)
health_ratio = healthy_sessions / total_sessions if total_sessions > 0 else 1.0
report = {
"timestamp": datetime.utcnow().isoformat(),
"pool_status": pool_status,
"session_summary": {
"total_sessions": total_sessions,
"healthy_sessions": healthy_sessions,
"health_ratio": health_ratio,
},
"session_details": session_details,
"recommendations": [],
}
# Add recommendations based on health status
if health_ratio < 0.8:
report["recommendations"].append("Consider checking database connectivity and network stability")
if pool_status["connection_errors"] > 10:
report["recommendations"].append("High connection error rate detected, review connection configuration")
if pool_status["active_sessions"] > pool_status["pool_size"] * 0.9:
report["recommendations"].append("Connection pool utilization is high, consider increasing pool size")
return report