v0.4.1 preview

This commit is contained in:
FreeOnePlus
2025-06-26 18:55:30 +08:00
parent 72865654e2
commit 97006a756d
6 changed files with 505 additions and 270 deletions

View File

@@ -70,17 +70,26 @@ class SecurityConfig:
token_expiry: int = 3600
# SQL security configuration
enable_security_check: bool = True # Main switch: whether to enable SQL security check
blocked_keywords: list[str] = field(
default_factory=lambda: [
# DDL Operations (Data Definition Language)
"DROP",
"DELETE",
"TRUNCATE",
"CREATE",
"ALTER",
"CREATE",
"TRUNCATE",
# DML Operations (Data Manipulation Language)
"DELETE",
"INSERT",
"UPDATE",
# DCL Operations (Data Control Language)
"GRANT",
"REVOKE",
# System Operations
"EXEC",
"EXECUTE",
"SHUTDOWN",
"KILL",
]
)
max_query_complexity: int = 100
@@ -154,7 +163,7 @@ class DorisConfig:
# Basic configuration
server_name: str = "doris-mcp-server"
server_version: str = "0.4.0"
server_version: str = "0.4.1"
server_port: int = 3000
transport: str = "stdio"
@@ -267,6 +276,22 @@ class DorisConfig:
config.security.max_query_complexity = int(
os.getenv("MAX_QUERY_COMPLEXITY", str(config.security.max_query_complexity))
)
config.security.enable_security_check = (
os.getenv("ENABLE_SECURITY_CHECK", str(config.security.enable_security_check).lower()).lower() == "true"
)
# Handle blocked keywords environment variable configuration
# Format: BLOCKED_KEYWORDS="DROP,DELETE,TRUNCATE,ALTER,CREATE,INSERT,UPDATE,GRANT,REVOKE"
blocked_keywords_env = os.getenv("BLOCKED_KEYWORDS", "")
if blocked_keywords_env:
# If environment variable is provided, use keywords list from environment variable
config.security.blocked_keywords = [
keyword.strip().upper()
for keyword in blocked_keywords_env.split(",")
if keyword.strip()
]
# If environment variable is empty, keep default configuration unchanged
config.security.enable_masking = (
os.getenv("ENABLE_MASKING", str(config.security.enable_masking).lower()).lower() == "true"
)
@@ -399,6 +424,7 @@ class DorisConfig:
"auth_type": self.security.auth_type,
"token_secret": "***", # Hide secret key
"token_expiry": self.security.token_expiry,
"enable_security_check": self.security.enable_security_check,
"blocked_keywords": self.security.blocked_keywords,
"max_query_complexity": self.security.max_query_complexity,
"max_result_rows": self.security.max_result_rows,

View File

@@ -142,11 +142,22 @@ class DorisConnection:
self.is_healthy = False
return False
# Check if connection has _reader (aiomysql internal state)
# This prevents the 'NoneType' object has no attribute 'at_eof' error
if not hasattr(self.connection, '_reader') or self.connection._reader is None:
self.is_healthy = False
return False
# Additional check for reader's state
if hasattr(self.connection._reader, '_transport') and self.connection._reader._transport is None:
self.is_healthy = False
return False
# Try to ping the connection
await self.connection.ping()
self.is_healthy = True
return True
except Exception as e:
except (AttributeError, OSError, ConnectionError, Exception) as e:
# Log the specific error for debugging
logging.debug(f"Connection ping failed for session {self.session_id}: {e}")
self.is_healthy = False
@@ -309,15 +320,34 @@ class DorisConnectionManager:
if session_id in self.session_connections:
conn = self.session_connections[session_id]
try:
# Return connection to pool
if self.pool and conn.connection and not conn.connection.closed:
self.pool.release(conn.connection)
# Return connection to pool only if it's valid and not closed
if (self.pool and
conn.connection and
not conn.connection.closed and
hasattr(conn.connection, '_reader') and
conn.connection._reader is not None):
try:
# Try to gracefully return to pool
self.pool.release(conn.connection)
except Exception as pool_error:
self.logger.debug(f"Failed to return connection to pool for session {session_id}: {pool_error}")
# If pool release fails, try to close the connection directly
try:
await conn.connection.ensure_closed()
except Exception:
pass # Ignore errors during forced close
# Close connection wrapper
await conn.close()
except Exception as e:
self.logger.error(f"Error cleaning up connection for session {session_id}: {e}")
# Force close if normal cleanup fails
try:
if conn.connection and not conn.connection.closed:
await conn.connection.ensure_closed()
except Exception:
pass # Ignore errors during forced close
finally:
# Remove from session connections
del self.session_connections[session_id]
@@ -339,12 +369,26 @@ class DorisConnectionManager:
try:
unhealthy_sessions = []
# First pass: check basic connectivity
for session_id, conn in self.session_connections.items():
if not await conn.ping():
unhealthy_sessions.append(session_id)
# Clean up unhealthy connections
for session_id in unhealthy_sessions:
# Second pass: check for stale connections (over 30 minutes old)
current_time = datetime.utcnow()
stale_sessions = []
for session_id, conn in self.session_connections.items():
if session_id not in unhealthy_sessions: # Don't double-check
last_used_delta = (current_time - conn.last_used).total_seconds()
if last_used_delta > 1800: # 30 minutes
# Force a ping check for stale connections
if not await conn.ping():
stale_sessions.append(session_id)
all_problematic_sessions = list(set(unhealthy_sessions + stale_sessions))
# Clean up problematic connections
for session_id in all_problematic_sessions:
await self._cleanup_session_connection(session_id)
self.metrics.failed_connections += 1
@@ -352,11 +396,19 @@ class DorisConnectionManager:
await self._update_connection_metrics()
self.metrics.last_health_check = datetime.utcnow()
if unhealthy_sessions:
self.logger.warning(f"Cleaned up {len(unhealthy_sessions)} unhealthy connections")
if all_problematic_sessions:
self.logger.warning(f"Health check: cleaned up {len(unhealthy_sessions)} unhealthy and {len(stale_sessions)} stale connections")
else:
self.logger.debug(f"Health check: all {len(self.session_connections)} connections healthy")
except Exception as e:
self.logger.error(f"Health check failed: {e}")
# If health check fails, try to diagnose the issue
try:
diagnosis = await self.diagnose_connection_health()
self.logger.error(f"Connection diagnosis: {diagnosis}")
except Exception:
pass # Don't let diagnosis failure crash health check
async def _cleanup_loop(self):
"""Background cleanup loop"""
@@ -463,6 +515,93 @@ class DorisConnectionManager:
self.logger.error(f"Connection test failed: {e}")
return False
async def diagnose_connection_health(self) -> Dict[str, Any]:
"""Diagnose connection pool and session health"""
diagnosis = {
"timestamp": datetime.utcnow().isoformat(),
"pool_status": "unknown",
"session_connections": {},
"problematic_connections": [],
"recommendations": []
}
try:
# Check pool status
if not self.pool:
diagnosis["pool_status"] = "not_initialized"
diagnosis["recommendations"].append("Initialize connection pool")
return diagnosis
if self.pool.closed:
diagnosis["pool_status"] = "closed"
diagnosis["recommendations"].append("Recreate connection pool")
return diagnosis
diagnosis["pool_status"] = "healthy"
diagnosis["pool_info"] = {
"size": self.pool.size,
"free_size": self.pool.freesize,
"min_size": self.pool.minsize,
"max_size": self.pool.maxsize
}
# Check session connections
problematic_sessions = []
for session_id, conn in self.session_connections.items():
conn_status = {
"session_id": session_id,
"created_at": conn.created_at.isoformat(),
"last_used": conn.last_used.isoformat(),
"query_count": conn.query_count,
"is_healthy": conn.is_healthy
}
# Detailed connection checks
if conn.connection:
conn_status["connection_closed"] = conn.connection.closed
conn_status["has_reader"] = hasattr(conn.connection, '_reader') and conn.connection._reader is not None
if hasattr(conn.connection, '_reader') and conn.connection._reader:
conn_status["reader_transport"] = conn.connection._reader._transport is not None
else:
conn_status["reader_transport"] = False
else:
conn_status["connection_closed"] = True
conn_status["has_reader"] = False
conn_status["reader_transport"] = False
# Check if connection is problematic
if (not conn.is_healthy or
conn_status["connection_closed"] or
not conn_status["has_reader"] or
not conn_status["reader_transport"]):
problematic_sessions.append(session_id)
diagnosis["problematic_connections"].append(conn_status)
diagnosis["session_connections"][session_id] = conn_status
# Generate recommendations
if problematic_sessions:
diagnosis["recommendations"].append(f"Clean up {len(problematic_sessions)} problematic connections")
if self.pool.freesize == 0 and self.pool.size >= self.pool.maxsize:
diagnosis["recommendations"].append("Connection pool exhausted - consider increasing max_connections")
# Auto-cleanup problematic connections
for session_id in problematic_sessions:
try:
await self._cleanup_session_connection(session_id)
self.logger.info(f"Auto-cleaned problematic connection for session: {session_id}")
except Exception as e:
self.logger.error(f"Failed to auto-clean session {session_id}: {e}")
return diagnosis
except Exception as e:
diagnosis["error"] = str(e)
diagnosis["recommendations"].append("Manual intervention required")
return diagnosis
class ConnectionPoolMonitor:
"""Connection pool monitor

View File

@@ -548,79 +548,127 @@ class DorisQueryExecutor:
user_id: str = "mcp_user"
) -> Dict[str, Any]:
"""Execute SQL query for MCP interface - unified method"""
try:
if not sql:
return {
"success": False,
"error": "SQL query is required",
"data": None
}
max_retries = 2
retry_count = 0
while retry_count <= max_retries:
try:
if not sql:
return {
"success": False,
"error": "SQL query is required",
"data": None
}
# Add LIMIT if not present and it's a SELECT query
if sql.upper().startswith("SELECT") and "LIMIT" not in sql.upper():
if sql.endswith(";"):
sql = sql[:-1]
sql = f"{sql} LIMIT {limit}"
# Add LIMIT if not present and it's a SELECT query
if sql.upper().startswith("SELECT") and "LIMIT" not in sql.upper():
if sql.endswith(";"):
sql = sql[:-1]
sql = f"{sql} LIMIT {limit}"
# Create auth context for MCP calls
class MockAuthContext:
def __init__(self):
self.user_id = user_id
self.roles = ["data_analyst"]
self.permissions = ["read_data", "execute_query"]
self.session_id = session_id
self.security_level = "internal"
# Create auth context for MCP calls
class MockAuthContext:
def __init__(self):
self.user_id = user_id
self.roles = ["data_analyst"]
self.permissions = ["read_data", "execute_query"]
self.session_id = session_id
self.security_level = "internal"
auth_context = MockAuthContext()
# Create query request
query_request = QueryRequest(
sql=sql,
session_id=session_id,
user_id=user_id,
timeout=timeout,
cache_enabled=True
)
# Execute query
result = await self.execute_query(query_request, auth_context)
# Process results
processed_data = []
if result.data:
for row in result.data:
processed_row = self._serialize_row_data(row)
processed_data.append(processed_row)
auth_context = MockAuthContext()
# Create query request
query_request = QueryRequest(
sql=sql,
session_id=session_id,
user_id=user_id,
timeout=timeout,
cache_enabled=False # Disable cache for MCP calls to ensure fresh data
)
return {
"success": True,
"data": processed_data,
"metadata": {
"row_count": result.row_count,
"execution_time": result.execution_time,
"columns": result.metadata.get("columns", []),
"query": sql
},
"error": None
}
# Execute query with retry logic
try:
result = await self.execute_query(query_request, auth_context)
# Serialize data for JSON response
serialized_data = []
for row in result.data:
serialized_data.append(self._serialize_row_data(row))
except Exception as e:
error_msg = str(e)
self.logger.error(f"SQL execution error: {error_msg}")
# Analyze error for better user feedback
error_analysis = self._analyze_error(error_msg)
return {
"success": False,
"error": error_analysis.get("user_message", error_msg),
"error_type": error_analysis.get("error_type", "execution_error"),
"data": None,
"metadata": {
"query": sql,
"error_details": error_msg
}
}
return {
"success": True,
"data": serialized_data,
"row_count": result.row_count,
"execution_time": result.execution_time,
"metadata": {
"columns": result.metadata.get("columns", []),
"query": sql
}
}
except Exception as query_error:
# Check if it's a connection-related error that we should retry
error_str = str(query_error).lower()
connection_errors = [
"at_eof", "connection", "closed", "nonetype",
"transport", "reader", "broken pipe", "connection reset"
]
is_connection_error = any(err in error_str for err in connection_errors)
if is_connection_error and retry_count < max_retries:
retry_count += 1
self.logger.warning(f"Connection error detected, retrying ({retry_count}/{max_retries}): {query_error}")
# Release the problematic connection
try:
await self.connection_manager.release_connection(session_id)
except Exception:
pass # Ignore cleanup errors
# Wait a bit before retry
await asyncio.sleep(0.5 * retry_count)
continue
else:
# Re-raise if not a connection error or max retries exceeded
raise query_error
except Exception as e:
error_msg = str(e)
# If we've exhausted retries or it's not a connection error, return error
if retry_count >= max_retries or "at_eof" not in error_msg.lower():
error_analysis = self._analyze_error(error_msg)
return {
"success": False,
"error": error_analysis.get("user_message", error_msg),
"error_type": error_analysis.get("error_type", "general_error"),
"data": None,
"metadata": {
"query": sql,
"error_details": error_msg,
"retry_count": retry_count
}
}
else:
# Try one more time for connection errors
retry_count += 1
if retry_count <= max_retries:
self.logger.warning(f"Retrying query due to connection error ({retry_count}/{max_retries}): {e}")
await asyncio.sleep(0.5 * retry_count)
continue
else:
return {
"success": False,
"error": f"Query failed after {max_retries} retries: {error_msg}",
"data": None,
"metadata": {
"query": sql,
"error_details": error_msg,
"retry_count": retry_count
}
}
def _serialize_row_data(self, row_data: Dict[str, Any]) -> Dict[str, Any]:
"""Serialize row data for JSON response"""
@@ -649,7 +697,12 @@ class DorisQueryExecutor:
"""Analyze error message and provide user-friendly feedback"""
error_msg_lower = error_message.lower()
if "table" in error_msg_lower and "doesn't exist" in error_msg_lower:
if "at_eof" in error_msg_lower or "nonetype" in error_msg_lower and "at_eof" in error_msg_lower:
return {
"error_type": "connection_lost",
"user_message": "Database connection was lost. The query has been automatically retried. If this persists, please restart the server."
}
elif "table" in error_msg_lower and "doesn't exist" in error_msg_lower:
return {
"error_type": "table_not_found",
"user_message": "The specified table does not exist. Please check the table name and database."
@@ -674,6 +727,11 @@ class DorisQueryExecutor:
"error_type": "timeout",
"user_message": "Query execution timed out. Try simplifying your query or adding more specific filters."
}
elif "connection" in error_msg_lower and ("closed" in error_msg_lower or "reset" in error_msg_lower):
return {
"error_type": "connection_error",
"user_message": "Database connection was interrupted. The query has been automatically retried."
}
else:
return {
"error_type": "general_error",

View File

@@ -20,7 +20,6 @@ Doris Security Management Module
Implements enterprise-level authentication, authorization, SQL security validation and data masking functionality
"""
import hashlib
import logging
import re
from dataclasses import dataclass
@@ -101,30 +100,24 @@ class DorisSecurityManager:
self.masking_rules = self._load_masking_rules()
def _load_blocked_keywords(self) -> set[str]:
"""Load blocked SQL keywords"""
default_blocked = {
"DROP",
"DELETE",
"TRUNCATE",
"ALTER",
"CREATE",
"INSERT",
"UPDATE",
"GRANT",
"REVOKE",
"EXEC",
"EXECUTE",
"SHUTDOWN",
"KILL",
}
# Load custom rules from configuration file
"""Load blocked SQL keywords from configuration"""
# Load keywords from configuration, unified source of truth
if hasattr(self.config, 'get'):
custom_blocked = set(self.config.get("blocked_keywords", []))
# Dictionary-style configuration
blocked_keywords = self.config.get("blocked_keywords", [])
elif hasattr(self.config, 'security') and hasattr(self.config.security, 'blocked_keywords'):
# DorisConfig object, get through security.blocked_keywords
blocked_keywords = self.config.security.blocked_keywords
else:
custom_blocked = set()
# Fallback to default if no configuration available
blocked_keywords = [
"DROP", "CREATE", "ALTER", "TRUNCATE",
"DELETE", "INSERT", "UPDATE",
"GRANT", "REVOKE",
"EXEC", "EXECUTE", "SHUTDOWN", "KILL"
]
return default_blocked.union(custom_blocked)
return set(blocked_keywords)
def _load_sensitive_tables(self) -> dict[str, SecurityLevel]:
"""Load sensitive table configuration"""
@@ -478,13 +471,30 @@ class SQLSecurityValidator:
# Dictionary configuration
self.blocked_keywords = set(config.get("blocked_keywords", []))
self.max_query_complexity = config.get("max_query_complexity", 100)
self.enable_security_check = config.get("enable_security_check", True)
elif hasattr(config, 'security'):
# DorisConfig object with security attribute - unified source from config
self.blocked_keywords = set(config.security.blocked_keywords)
self.max_query_complexity = config.security.max_query_complexity
self.enable_security_check = getattr(config.security, 'enable_security_check', True)
else:
# DorisConfig object, use default values
self.blocked_keywords = set(["DROP", "DELETE", "TRUNCATE", "ALTER", "CREATE", "INSERT", "UPDATE"])
# Fallback to default if no configuration available
self.blocked_keywords = set([
"DROP", "CREATE", "ALTER", "TRUNCATE",
"DELETE", "INSERT", "UPDATE",
"GRANT", "REVOKE",
"EXEC", "EXECUTE", "SHUTDOWN", "KILL"
])
self.max_query_complexity = 100
self.enable_security_check = True
async def validate(self, sql: str, auth_context: AuthContext) -> ValidationResult:
"""Validate SQL query security"""
# If security check is disabled, always return valid
if not self.enable_security_check:
self.logger.debug("SQL security check is disabled, allowing all queries")
return ValidationResult(is_valid=True)
try:
# Parse SQL statement
parsed = sqlparse.parse(sql)[0]