0.3.0 Release Version
This commit is contained in:
@@ -1,100 +1,479 @@
|
||||
import os
|
||||
import json
|
||||
import pymysql
|
||||
import pandas as pd
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dotenv import load_dotenv
|
||||
import re
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Apache Doris Database Connection Management Module
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv(override=True)
|
||||
Provides high-performance database connection pool management, automatic reconnection mechanism and connection health check functionality
|
||||
Supports asynchronous operations and concurrent connection management, ensuring stability and performance for enterprise applications
|
||||
"""
|
||||
|
||||
# Database configuration
|
||||
DB_CONFIG = {
|
||||
"host": os.getenv("DB_HOST", "localhost"),
|
||||
"port": int(os.getenv("DB_PORT", "9030")),
|
||||
"user": os.getenv("DB_USER", "root"),
|
||||
"password": os.getenv("DB_PASSWORD", ""),
|
||||
"database": os.getenv("DB_DATABASE", ""),
|
||||
"charset": "utf8mb4",
|
||||
"cursorclass": pymysql.cursors.DictCursor
|
||||
}
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
|
||||
def get_db_connection(db_name: Optional[str] = None):
|
||||
import aiomysql
|
||||
from aiomysql import Connection, Pool
|
||||
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConnectionMetrics:
|
||||
"""Connection pool performance metrics"""
|
||||
|
||||
total_connections: int = 0
|
||||
active_connections: int = 0
|
||||
idle_connections: int = 0
|
||||
failed_connections: int = 0
|
||||
connection_errors: int = 0
|
||||
avg_connection_time: float = 0.0
|
||||
last_health_check: datetime | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryResult:
|
||||
"""Query result wrapper"""
|
||||
|
||||
data: list[dict[str, Any]]
|
||||
metadata: dict[str, Any]
|
||||
execution_time: float
|
||||
row_count: int
|
||||
|
||||
|
||||
class DorisConnection:
|
||||
"""Doris database connection wrapper class"""
|
||||
|
||||
def __init__(self, connection: Connection, session_id: str, security_manager=None):
|
||||
self.connection = connection
|
||||
self.session_id = session_id
|
||||
self.created_at = datetime.utcnow()
|
||||
self.last_used = datetime.utcnow()
|
||||
self.query_count = 0
|
||||
self.is_healthy = True
|
||||
self.security_manager = security_manager
|
||||
|
||||
async def execute(self, sql: str, params: tuple | None = None, auth_context=None) -> QueryResult:
|
||||
"""Execute SQL query"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# If security manager exists, perform SQL security check
|
||||
security_result = None
|
||||
if self.security_manager and auth_context:
|
||||
validation_result = await self.security_manager.validate_sql_security(sql, auth_context)
|
||||
if not validation_result.is_valid:
|
||||
raise ValueError(f"SQL security validation failed: {validation_result.error_message}")
|
||||
security_result = {
|
||||
"is_valid": validation_result.is_valid,
|
||||
"risk_level": validation_result.risk_level,
|
||||
"blocked_operations": validation_result.blocked_operations
|
||||
}
|
||||
|
||||
async with self.connection.cursor(aiomysql.DictCursor) as cursor:
|
||||
await cursor.execute(sql, params)
|
||||
|
||||
# Check if it's a query statement (statement that returns result set)
|
||||
sql_upper = sql.strip().upper()
|
||||
if (sql_upper.startswith("SELECT") or
|
||||
sql_upper.startswith("SHOW") or
|
||||
sql_upper.startswith("DESCRIBE") or
|
||||
sql_upper.startswith("DESC") or
|
||||
sql_upper.startswith("EXPLAIN")):
|
||||
data = await cursor.fetchall()
|
||||
row_count = len(data)
|
||||
else:
|
||||
data = []
|
||||
row_count = cursor.rowcount
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
self.last_used = datetime.utcnow()
|
||||
self.query_count += 1
|
||||
|
||||
# Get column information
|
||||
columns = []
|
||||
if cursor.description:
|
||||
columns = [desc[0] for desc in cursor.description]
|
||||
|
||||
# If security manager exists and has auth context, apply data masking
|
||||
final_data = list(data) if data else []
|
||||
if self.security_manager and auth_context and final_data:
|
||||
final_data = await self.security_manager.apply_data_masking(final_data, auth_context)
|
||||
|
||||
metadata = {"columns": columns, "query": sql, "params": params}
|
||||
if security_result:
|
||||
metadata["security_check"] = security_result
|
||||
|
||||
return QueryResult(
|
||||
data=final_data,
|
||||
metadata=metadata,
|
||||
execution_time=execution_time,
|
||||
row_count=row_count,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.is_healthy = False
|
||||
logging.error(f"Query execution failed: {e}")
|
||||
raise
|
||||
|
||||
async def ping(self) -> bool:
|
||||
"""Check connection health status"""
|
||||
try:
|
||||
await self.connection.ping()
|
||||
self.is_healthy = True
|
||||
return True
|
||||
except Exception:
|
||||
self.is_healthy = False
|
||||
return False
|
||||
|
||||
async def close(self):
|
||||
"""Close connection"""
|
||||
try:
|
||||
if self.connection and not self.connection.closed:
|
||||
await self.connection.ensure_closed()
|
||||
except Exception as e:
|
||||
logging.error(f"Error occurred while closing connection: {e}")
|
||||
|
||||
|
||||
class DorisConnectionManager:
|
||||
"""Doris database connection manager
|
||||
|
||||
Provides connection pool management, connection health monitoring, fault recovery and other functions
|
||||
Supports session-level connection reuse and intelligent load balancing
|
||||
Integrates security manager to provide unified security validation and data masking
|
||||
"""
|
||||
Get database connection
|
||||
|
||||
Args:
|
||||
db_name: Specify the database name to connect to, use default config if None
|
||||
|
||||
Returns:
|
||||
Database connection
|
||||
"""
|
||||
if db_name:
|
||||
# Use default config but override database name
|
||||
config = DB_CONFIG.copy()
|
||||
config["database"] = db_name
|
||||
return pymysql.connect(**config)
|
||||
else:
|
||||
# Use default config
|
||||
return pymysql.connect(**DB_CONFIG)
|
||||
|
||||
def get_db_name() -> str:
|
||||
"""Get the currently configured default database name"""
|
||||
return DB_CONFIG["database"] or os.getenv("DB_DATABASE", "")
|
||||
def __init__(self, config, security_manager=None):
|
||||
self.config = config
|
||||
self.pool: Pool | None = None
|
||||
self.session_connections: dict[str, DorisConnection] = {}
|
||||
self.metrics = ConnectionMetrics()
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.security_manager = security_manager
|
||||
|
||||
def execute_query(sql, db_name: Optional[str] = None):
|
||||
"""
|
||||
Execute SQL query and return results
|
||||
|
||||
Args:
|
||||
sql: SQL query statement
|
||||
db_name: Specify the database name to connect to, use default config if None
|
||||
|
||||
Returns:
|
||||
Query results
|
||||
"""
|
||||
conn = get_db_connection(db_name)
|
||||
try:
|
||||
with conn.cursor() as cursor:
|
||||
# Set connection character set to utf8 before executing query
|
||||
cursor.execute("SET NAMES utf8")
|
||||
# Health check configuration
|
||||
self.health_check_interval = config.database.health_check_interval or 60
|
||||
self.max_connection_age = config.database.max_connection_age or 3600
|
||||
self.connection_timeout = config.database.connection_timeout or 30
|
||||
|
||||
# Start background tasks
|
||||
self._health_check_task = None
|
||||
self._cleanup_task = None
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize connection manager"""
|
||||
try:
|
||||
# Create connection pool
|
||||
self.pool = await aiomysql.create_pool(
|
||||
host=self.config.database.host,
|
||||
port=self.config.database.port,
|
||||
user=self.config.database.user,
|
||||
password=self.config.database.password,
|
||||
db=self.config.database.database,
|
||||
charset="utf8",
|
||||
minsize=self.config.database.min_connections or 5,
|
||||
maxsize=self.config.database.max_connections or 20,
|
||||
autocommit=True,
|
||||
connect_timeout=self.connection_timeout,
|
||||
)
|
||||
|
||||
self.logger.info(
|
||||
f"Connection pool initialized successfully, min connections: {self.config.database.min_connections}, "
|
||||
f"max connections: {self.config.database.max_connections}"
|
||||
)
|
||||
|
||||
# Start background monitoring tasks
|
||||
self._health_check_task = asyncio.create_task(self._health_check_loop())
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Connection pool initialization failed: {e}")
|
||||
raise
|
||||
|
||||
async def get_connection(self, session_id: str) -> DorisConnection:
|
||||
"""Get database connection
|
||||
|
||||
Supports session-level connection reuse to improve performance and consistency
|
||||
"""
|
||||
# Check if there's an existing session connection
|
||||
if session_id in self.session_connections:
|
||||
conn = self.session_connections[session_id]
|
||||
# Check connection health
|
||||
if await conn.ping():
|
||||
return conn
|
||||
else:
|
||||
# Connection is unhealthy, clean up and create new one
|
||||
await self._cleanup_session_connection(session_id)
|
||||
|
||||
# Create new connection
|
||||
return await self._create_new_connection(session_id)
|
||||
|
||||
async def _create_new_connection(self, session_id: str) -> DorisConnection:
|
||||
"""Create new database connection"""
|
||||
try:
|
||||
if not self.pool:
|
||||
raise RuntimeError("Connection pool not initialized")
|
||||
|
||||
# Get connection from pool
|
||||
raw_connection = await self.pool.acquire()
|
||||
|
||||
# Execute the actual query
|
||||
cursor.execute(sql)
|
||||
result = cursor.fetchall()
|
||||
return result
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def execute_query_df(sql, db_name: Optional[str] = None):
|
||||
"""
|
||||
Execute SQL query and return pandas DataFrame
|
||||
|
||||
Args:
|
||||
sql: SQL query statement
|
||||
db_name: Specify the database name to connect to, use default config if None
|
||||
|
||||
Returns:
|
||||
pandas DataFrame
|
||||
"""
|
||||
conn = get_db_connection(db_name)
|
||||
try:
|
||||
# Use a temporary cursor to execute the query and get results
|
||||
with conn.cursor() as cursor:
|
||||
# Set connection character set to utf8 before executing query
|
||||
cursor.execute("SET NAMES utf8")
|
||||
# Create wrapped connection
|
||||
doris_conn = DorisConnection(raw_connection, session_id, self.security_manager)
|
||||
|
||||
# Execute the actual query
|
||||
cursor.execute(sql)
|
||||
result = cursor.fetchall()
|
||||
# Store in session connections
|
||||
self.session_connections[session_id] = doris_conn
|
||||
|
||||
self.metrics.total_connections += 1
|
||||
self.logger.debug(f"Created new connection for session: {session_id}")
|
||||
|
||||
return doris_conn
|
||||
|
||||
except Exception as e:
|
||||
self.metrics.connection_errors += 1
|
||||
self.logger.error(f"Failed to create connection for session {session_id}: {e}")
|
||||
raise
|
||||
|
||||
async def release_connection(self, session_id: str):
|
||||
"""Release session connection"""
|
||||
if session_id in self.session_connections:
|
||||
await self._cleanup_session_connection(session_id)
|
||||
|
||||
async def _cleanup_session_connection(self, session_id: str):
|
||||
"""Clean up session connection"""
|
||||
if session_id in self.session_connections:
|
||||
conn = self.session_connections[session_id]
|
||||
try:
|
||||
# Return connection to pool
|
||||
if self.pool and conn.connection and not conn.connection.closed:
|
||||
self.pool.release(conn.connection)
|
||||
|
||||
# Close connection wrapper
|
||||
await conn.close()
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error cleaning up connection for session {session_id}: {e}")
|
||||
finally:
|
||||
# Remove from session connections
|
||||
del self.session_connections[session_id]
|
||||
self.logger.debug(f"Cleaned up connection for session: {session_id}")
|
||||
|
||||
async def _health_check_loop(self):
|
||||
"""Background health check loop"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(self.health_check_interval)
|
||||
await self._perform_health_check()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error(f"Health check error: {e}")
|
||||
|
||||
async def _perform_health_check(self):
|
||||
"""Perform health check"""
|
||||
try:
|
||||
unhealthy_sessions = []
|
||||
|
||||
for session_id, conn in self.session_connections.items():
|
||||
if not await conn.ping():
|
||||
unhealthy_sessions.append(session_id)
|
||||
|
||||
# Clean up unhealthy connections
|
||||
for session_id in unhealthy_sessions:
|
||||
await self._cleanup_session_connection(session_id)
|
||||
self.metrics.failed_connections += 1
|
||||
|
||||
# Update metrics
|
||||
await self._update_connection_metrics()
|
||||
self.metrics.last_health_check = datetime.utcnow()
|
||||
|
||||
if unhealthy_sessions:
|
||||
self.logger.warning(f"Cleaned up {len(unhealthy_sessions)} unhealthy connections")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Health check failed: {e}")
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
"""Background cleanup loop"""
|
||||
while True:
|
||||
try:
|
||||
await asyncio.sleep(300) # Run every 5 minutes
|
||||
await self._cleanup_idle_connections()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error(f"Cleanup loop error: {e}")
|
||||
|
||||
async def _cleanup_idle_connections(self):
|
||||
"""Clean up idle connections"""
|
||||
current_time = datetime.utcnow()
|
||||
idle_sessions = []
|
||||
|
||||
# If no results, return empty DataFrame
|
||||
if not result:
|
||||
return pd.DataFrame()
|
||||
|
||||
# Manually convert dict results to DataFrame
|
||||
df = pd.DataFrame(result)
|
||||
return df
|
||||
finally:
|
||||
conn.close()
|
||||
for session_id, conn in self.session_connections.items():
|
||||
# Check if connection has exceeded maximum age
|
||||
age = (current_time - conn.created_at).total_seconds()
|
||||
if age > self.max_connection_age:
|
||||
idle_sessions.append(session_id)
|
||||
|
||||
# Clean up idle connections
|
||||
for session_id in idle_sessions:
|
||||
await self._cleanup_session_connection(session_id)
|
||||
|
||||
if idle_sessions:
|
||||
self.logger.info(f"Cleaned up {len(idle_sessions)} idle connections")
|
||||
|
||||
async def _update_connection_metrics(self):
|
||||
"""Update connection metrics"""
|
||||
self.metrics.active_connections = len(self.session_connections)
|
||||
if self.pool:
|
||||
self.metrics.idle_connections = self.pool.freesize
|
||||
|
||||
async def get_metrics(self) -> ConnectionMetrics:
|
||||
"""Get connection metrics"""
|
||||
await self._update_connection_metrics()
|
||||
return self.metrics
|
||||
|
||||
async def execute_query(
|
||||
self, session_id: str, sql: str, params: tuple | None = None, auth_context=None
|
||||
) -> QueryResult:
|
||||
"""Execute query"""
|
||||
conn = await self.get_connection(session_id)
|
||||
return await conn.execute(sql, params, auth_context)
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_connection_context(self, session_id: str):
|
||||
"""Get connection context manager"""
|
||||
conn = await self.get_connection(session_id)
|
||||
try:
|
||||
yield conn
|
||||
finally:
|
||||
# Connection will be reused, no need to close here
|
||||
pass
|
||||
|
||||
async def close(self):
|
||||
"""Close connection manager"""
|
||||
try:
|
||||
# Cancel background tasks
|
||||
if self._health_check_task:
|
||||
self._health_check_task.cancel()
|
||||
try:
|
||||
await self._health_check_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Clean up all session connections
|
||||
for session_id in list(self.session_connections.keys()):
|
||||
await self._cleanup_session_connection(session_id)
|
||||
|
||||
# Close connection pool
|
||||
if self.pool:
|
||||
self.pool.close()
|
||||
await self.pool.wait_closed()
|
||||
|
||||
self.logger.info("Connection manager closed successfully")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error closing connection manager: {e}")
|
||||
|
||||
async def test_connection(self) -> bool:
|
||||
"""Test database connection"""
|
||||
try:
|
||||
if not self.pool:
|
||||
return False
|
||||
|
||||
async with self.pool.acquire() as conn:
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("SELECT 1")
|
||||
result = await cursor.fetchone()
|
||||
return result is not None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Connection test failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
class ConnectionPoolMonitor:
|
||||
"""Connection pool monitor
|
||||
|
||||
Provides detailed monitoring and reporting capabilities for connection pool status
|
||||
"""
|
||||
|
||||
def __init__(self, connection_manager: DorisConnectionManager):
|
||||
self.connection_manager = connection_manager
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
async def get_pool_status(self) -> dict[str, Any]:
|
||||
"""Get connection pool status"""
|
||||
metrics = await self.connection_manager.get_metrics()
|
||||
|
||||
status = {
|
||||
"pool_size": self.connection_manager.pool.size if self.connection_manager.pool else 0,
|
||||
"free_connections": self.connection_manager.pool.freesize if self.connection_manager.pool else 0,
|
||||
"active_sessions": len(self.connection_manager.session_connections),
|
||||
"total_connections": metrics.total_connections,
|
||||
"failed_connections": metrics.failed_connections,
|
||||
"connection_errors": metrics.connection_errors,
|
||||
"avg_connection_time": metrics.avg_connection_time,
|
||||
"last_health_check": metrics.last_health_check.isoformat() if metrics.last_health_check else None,
|
||||
}
|
||||
|
||||
return status
|
||||
|
||||
async def get_session_details(self) -> list[dict[str, Any]]:
|
||||
"""Get session connection details"""
|
||||
sessions = []
|
||||
|
||||
for session_id, conn in self.connection_manager.session_connections.items():
|
||||
session_info = {
|
||||
"session_id": session_id,
|
||||
"created_at": conn.created_at.isoformat(),
|
||||
"last_used": conn.last_used.isoformat(),
|
||||
"query_count": conn.query_count,
|
||||
"is_healthy": conn.is_healthy,
|
||||
"connection_age": (datetime.utcnow() - conn.created_at).total_seconds(),
|
||||
}
|
||||
sessions.append(session_info)
|
||||
|
||||
return sessions
|
||||
|
||||
async def generate_health_report(self) -> dict[str, Any]:
|
||||
"""Generate connection health report"""
|
||||
pool_status = await self.get_pool_status()
|
||||
session_details = await self.get_session_details()
|
||||
|
||||
# Calculate health statistics
|
||||
healthy_sessions = sum(1 for s in session_details if s["is_healthy"])
|
||||
total_sessions = len(session_details)
|
||||
health_ratio = healthy_sessions / total_sessions if total_sessions > 0 else 1.0
|
||||
|
||||
report = {
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"pool_status": pool_status,
|
||||
"session_summary": {
|
||||
"total_sessions": total_sessions,
|
||||
"healthy_sessions": healthy_sessions,
|
||||
"health_ratio": health_ratio,
|
||||
},
|
||||
"session_details": session_details,
|
||||
"recommendations": [],
|
||||
}
|
||||
|
||||
# Add recommendations based on health status
|
||||
if health_ratio < 0.8:
|
||||
report["recommendations"].append("Consider checking database connectivity and network stability")
|
||||
|
||||
if pool_status["connection_errors"] > 10:
|
||||
report["recommendations"].append("High connection error rate detected, review connection configuration")
|
||||
|
||||
if pool_status["active_sessions"] > pool_status["pool_size"] * 0.9:
|
||||
report["recommendations"].append("Connection pool utilization is high, consider increasing pool size")
|
||||
|
||||
return report
|
||||
|
||||
Reference in New Issue
Block a user