#!/usr/bin/env python3 # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """ Apache Doris Database Connection Management Module Provides high-performance database connection pool management, automatic reconnection mechanism and connection health check functionality Supports asynchronous operations and concurrent connection management, ensuring stability and performance for enterprise applications """ import asyncio import logging import time from contextlib import asynccontextmanager from dataclasses import dataclass from datetime import datetime from typing import Any, Dict, List import aiomysql from aiomysql import Connection, Pool @dataclass class ConnectionMetrics: """Connection pool performance metrics""" total_connections: int = 0 active_connections: int = 0 idle_connections: int = 0 failed_connections: int = 0 connection_errors: int = 0 avg_connection_time: float = 0.0 last_health_check: datetime | None = None @dataclass class QueryResult: """Query result wrapper""" data: list[dict[str, Any]] metadata: dict[str, Any] execution_time: float row_count: int class DorisConnection: """Doris database connection wrapper class""" def __init__(self, connection: Connection, session_id: str, security_manager=None): self.connection = connection self.session_id = session_id self.created_at = datetime.utcnow() self.last_used = datetime.utcnow() self.query_count = 0 self.is_healthy = True self.security_manager = security_manager async def execute(self, sql: str, params: tuple | None = None, auth_context=None) -> QueryResult: """Execute SQL query""" start_time = time.time() try: # If security manager exists, perform SQL security check security_result = None if self.security_manager and auth_context: validation_result = await self.security_manager.validate_sql_security(sql, auth_context) if not validation_result.is_valid: raise ValueError(f"SQL security validation failed: {validation_result.error_message}") security_result = { "is_valid": validation_result.is_valid, "risk_level": validation_result.risk_level, "blocked_operations": validation_result.blocked_operations } async with self.connection.cursor(aiomysql.DictCursor) as cursor: await cursor.execute(sql, params) # Check if it's a query statement (statement that returns result set) sql_upper = sql.strip().upper() if (sql_upper.startswith("SELECT") or sql_upper.startswith("SHOW") or sql_upper.startswith("DESCRIBE") or sql_upper.startswith("DESC") or sql_upper.startswith("EXPLAIN")): data = await cursor.fetchall() row_count = len(data) else: data = [] row_count = cursor.rowcount execution_time = time.time() - start_time self.last_used = datetime.utcnow() self.query_count += 1 # Get column information columns = [] if cursor.description: columns = [desc[0] for desc in cursor.description] # If security manager exists and has auth context, apply data masking final_data = list(data) if data else [] if self.security_manager and auth_context and final_data: final_data = await self.security_manager.apply_data_masking(final_data, auth_context) metadata = {"columns": columns, "query": sql, "params": params} if security_result: metadata["security_check"] = security_result return QueryResult( data=final_data, metadata=metadata, execution_time=execution_time, row_count=row_count, ) except Exception as e: self.is_healthy = False logging.error(f"Query execution failed: {e}") raise async def ping(self) -> bool: """Check connection health status""" try: await self.connection.ping() self.is_healthy = True return True except Exception: self.is_healthy = False return False async def close(self): """Close connection""" try: if self.connection and not self.connection.closed: await self.connection.ensure_closed() except Exception as e: logging.error(f"Error occurred while closing connection: {e}") class DorisConnectionManager: """Doris database connection manager Provides connection pool management, connection health monitoring, fault recovery and other functions Supports session-level connection reuse and intelligent load balancing Integrates security manager to provide unified security validation and data masking """ def __init__(self, config, security_manager=None): self.config = config self.pool: Pool | None = None self.session_connections: dict[str, DorisConnection] = {} self.metrics = ConnectionMetrics() self.logger = logging.getLogger(__name__) self.security_manager = security_manager # Health check configuration self.health_check_interval = config.database.health_check_interval or 60 self.max_connection_age = config.database.max_connection_age or 3600 self.connection_timeout = config.database.connection_timeout or 30 # Start background tasks self._health_check_task = None self._cleanup_task = None async def initialize(self): """Initialize connection manager""" try: # Create connection pool self.pool = await aiomysql.create_pool( host=self.config.database.host, port=self.config.database.port, user=self.config.database.user, password=self.config.database.password, db=self.config.database.database, charset="utf8", minsize=self.config.database.min_connections or 5, maxsize=self.config.database.max_connections or 20, autocommit=True, connect_timeout=self.connection_timeout, ) self.logger.info( f"Connection pool initialized successfully, min connections: {self.config.database.min_connections}, " f"max connections: {self.config.database.max_connections}" ) # Start background monitoring tasks self._health_check_task = asyncio.create_task(self._health_check_loop()) self._cleanup_task = asyncio.create_task(self._cleanup_loop()) except Exception as e: self.logger.error(f"Connection pool initialization failed: {e}") raise async def get_connection(self, session_id: str) -> DorisConnection: """Get database connection Supports session-level connection reuse to improve performance and consistency """ # Check if there's an existing session connection if session_id in self.session_connections: conn = self.session_connections[session_id] # Check connection health if await conn.ping(): return conn else: # Connection is unhealthy, clean up and create new one await self._cleanup_session_connection(session_id) # Create new connection return await self._create_new_connection(session_id) async def _create_new_connection(self, session_id: str) -> DorisConnection: """Create new database connection""" try: if not self.pool: raise RuntimeError("Connection pool not initialized") # Get connection from pool raw_connection = await self.pool.acquire() # Create wrapped connection doris_conn = DorisConnection(raw_connection, session_id, self.security_manager) # Store in session connections self.session_connections[session_id] = doris_conn self.metrics.total_connections += 1 self.logger.debug(f"Created new connection for session: {session_id}") return doris_conn except Exception as e: self.metrics.connection_errors += 1 self.logger.error(f"Failed to create connection for session {session_id}: {e}") raise async def release_connection(self, session_id: str): """Release session connection""" if session_id in self.session_connections: await self._cleanup_session_connection(session_id) async def _cleanup_session_connection(self, session_id: str): """Clean up session connection""" if session_id in self.session_connections: conn = self.session_connections[session_id] try: # Return connection to pool if self.pool and conn.connection and not conn.connection.closed: self.pool.release(conn.connection) # Close connection wrapper await conn.close() except Exception as e: self.logger.error(f"Error cleaning up connection for session {session_id}: {e}") finally: # Remove from session connections del self.session_connections[session_id] self.logger.debug(f"Cleaned up connection for session: {session_id}") async def _health_check_loop(self): """Background health check loop""" while True: try: await asyncio.sleep(self.health_check_interval) await self._perform_health_check() except asyncio.CancelledError: break except Exception as e: self.logger.error(f"Health check error: {e}") async def _perform_health_check(self): """Perform health check""" try: unhealthy_sessions = [] for session_id, conn in self.session_connections.items(): if not await conn.ping(): unhealthy_sessions.append(session_id) # Clean up unhealthy connections for session_id in unhealthy_sessions: await self._cleanup_session_connection(session_id) self.metrics.failed_connections += 1 # Update metrics await self._update_connection_metrics() self.metrics.last_health_check = datetime.utcnow() if unhealthy_sessions: self.logger.warning(f"Cleaned up {len(unhealthy_sessions)} unhealthy connections") except Exception as e: self.logger.error(f"Health check failed: {e}") async def _cleanup_loop(self): """Background cleanup loop""" while True: try: await asyncio.sleep(300) # Run every 5 minutes await self._cleanup_idle_connections() except asyncio.CancelledError: break except Exception as e: self.logger.error(f"Cleanup loop error: {e}") async def _cleanup_idle_connections(self): """Clean up idle connections""" current_time = datetime.utcnow() idle_sessions = [] for session_id, conn in self.session_connections.items(): # Check if connection has exceeded maximum age age = (current_time - conn.created_at).total_seconds() if age > self.max_connection_age: idle_sessions.append(session_id) # Clean up idle connections for session_id in idle_sessions: await self._cleanup_session_connection(session_id) if idle_sessions: self.logger.info(f"Cleaned up {len(idle_sessions)} idle connections") async def _update_connection_metrics(self): """Update connection metrics""" self.metrics.active_connections = len(self.session_connections) if self.pool: self.metrics.idle_connections = self.pool.freesize async def get_metrics(self) -> ConnectionMetrics: """Get connection metrics""" await self._update_connection_metrics() return self.metrics async def execute_query( self, session_id: str, sql: str, params: tuple | None = None, auth_context=None ) -> QueryResult: """Execute query""" conn = await self.get_connection(session_id) return await conn.execute(sql, params, auth_context) @asynccontextmanager async def get_connection_context(self, session_id: str): """Get connection context manager""" conn = await self.get_connection(session_id) try: yield conn finally: # Connection will be reused, no need to close here pass async def close(self): """Close connection manager""" try: # Cancel background tasks if self._health_check_task: self._health_check_task.cancel() try: await self._health_check_task except asyncio.CancelledError: pass if self._cleanup_task: self._cleanup_task.cancel() try: await self._cleanup_task except asyncio.CancelledError: pass # Clean up all session connections for session_id in list(self.session_connections.keys()): await self._cleanup_session_connection(session_id) # Close connection pool if self.pool: self.pool.close() await self.pool.wait_closed() self.logger.info("Connection manager closed successfully") except Exception as e: self.logger.error(f"Error closing connection manager: {e}") async def test_connection(self) -> bool: """Test database connection""" try: if not self.pool: return False async with self.pool.acquire() as conn: async with conn.cursor() as cursor: await cursor.execute("SELECT 1") result = await cursor.fetchone() return result is not None except Exception as e: self.logger.error(f"Connection test failed: {e}") return False class ConnectionPoolMonitor: """Connection pool monitor Provides detailed monitoring and reporting capabilities for connection pool status """ def __init__(self, connection_manager: DorisConnectionManager): self.connection_manager = connection_manager self.logger = logging.getLogger(__name__) async def get_pool_status(self) -> dict[str, Any]: """Get connection pool status""" metrics = await self.connection_manager.get_metrics() status = { "pool_size": self.connection_manager.pool.size if self.connection_manager.pool else 0, "free_connections": self.connection_manager.pool.freesize if self.connection_manager.pool else 0, "active_sessions": len(self.connection_manager.session_connections), "total_connections": metrics.total_connections, "failed_connections": metrics.failed_connections, "connection_errors": metrics.connection_errors, "avg_connection_time": metrics.avg_connection_time, "last_health_check": metrics.last_health_check.isoformat() if metrics.last_health_check else None, } return status async def get_session_details(self) -> list[dict[str, Any]]: """Get session connection details""" sessions = [] for session_id, conn in self.connection_manager.session_connections.items(): session_info = { "session_id": session_id, "created_at": conn.created_at.isoformat(), "last_used": conn.last_used.isoformat(), "query_count": conn.query_count, "is_healthy": conn.is_healthy, "connection_age": (datetime.utcnow() - conn.created_at).total_seconds(), } sessions.append(session_info) return sessions async def generate_health_report(self) -> dict[str, Any]: """Generate connection health report""" pool_status = await self.get_pool_status() session_details = await self.get_session_details() # Calculate health statistics healthy_sessions = sum(1 for s in session_details if s["is_healthy"]) total_sessions = len(session_details) health_ratio = healthy_sessions / total_sessions if total_sessions > 0 else 1.0 report = { "timestamp": datetime.utcnow().isoformat(), "pool_status": pool_status, "session_summary": { "total_sessions": total_sessions, "healthy_sessions": healthy_sessions, "health_ratio": health_ratio, }, "session_details": session_details, "recommendations": [], } # Add recommendations based on health status if health_ratio < 0.8: report["recommendations"].append("Consider checking database connectivity and network stability") if pool_status["connection_errors"] > 10: report["recommendations"].append("High connection error rate detected, review connection configuration") if pool_status["active_sessions"] > pool_status["pool_size"] * 0.9: report["recommendations"].append("Connection pool utilization is high, consider increasing pool size") return report