fix some security issues (#68)

This commit is contained in:
bingquanzhao
2025-12-10 09:11:03 +08:00
committed by GitHub
parent a125a2f5f8
commit e58361e04b
17 changed files with 2520 additions and 214 deletions

View File

@@ -28,6 +28,7 @@ from typing import Any, Dict, List, Optional
from ..utils.logger import get_logger
from ..utils.db import DorisConnectionManager
from ..utils.sql_security_utils import get_auth_context
logger = get_logger(__name__)
@@ -277,7 +278,8 @@ class DorisADBCQueryTools:
# Get BE nodes via SHOW BACKENDS
logger.info("No BE hosts configured, getting BE node information via SHOW BACKENDS")
connection = await self.connection_manager.get_connection("query")
result = await connection.execute("SHOW BACKENDS")
auth_context = get_auth_context()
result = await connection.execute("SHOW BACKENDS", auth_context=auth_context)
be_hosts = []
for row in result.data:
@@ -383,6 +385,20 @@ class DorisADBCQueryTools:
"error_type": "no_connection"
}
# SECURITY FIX: Perform SQL security validation before executing
auth_context = get_auth_context()
if self.connection_manager.security_manager:
# Always perform security validation, even without auth_context
# Use a default context for basic SQL security checks
validation_result = await self.connection_manager.security_manager.validate_sql_security(sql, auth_context)
if not validation_result.is_valid:
return {
"success": False,
"error": f"SQL security validation failed: {validation_result.error_message}",
"error_type": "security_violation",
"risk_level": validation_result.risk_level
}
cursor = self.adbc_client.cursor()
start_time = time.time()

View File

@@ -29,6 +29,13 @@ from pathlib import Path
from .db import DorisConnectionManager
from .logger import get_logger
from .sql_security_utils import (
SQLSecurityError,
validate_identifier,
quote_identifier,
build_table_reference,
get_auth_context
)
logger = get_logger(__name__)
@@ -46,10 +53,17 @@ class TableAnalyzer:
sample_size: int = 10
) -> Dict[str, Any]:
"""Get table summary information"""
# SECURITY FIX: Validate table_name and get auth_context
try:
validate_identifier(table_name, "table name")
except SQLSecurityError as e:
raise ValueError(f"Invalid table name: {e}")
auth_context = get_auth_context()
connection = await self.connection_manager.get_connection("query")
# Get table basic information
table_info_sql = f"""
# Get table basic information using parameterized query
table_info_sql = """
SELECT
table_name,
table_comment,
@@ -58,17 +72,17 @@ class TableAnalyzer:
engine
FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_name = '{table_name}'
AND table_name = %s
"""
table_info_result = await connection.execute(table_info_sql)
table_info_result = await connection.execute(table_info_sql, params=(table_name,), auth_context=auth_context)
if not table_info_result.data:
raise ValueError(f"Table {table_name} does not exist")
table_info = table_info_result.data[0]
# Get column information
columns_sql = f"""
# Get column information using parameterized query
columns_sql = """
SELECT
column_name,
data_type,
@@ -76,11 +90,11 @@ class TableAnalyzer:
column_comment
FROM information_schema.columns
WHERE table_schema = DATABASE()
AND table_name = '{table_name}'
AND table_name = %s
ORDER BY ordinal_position
"""
columns_result = await connection.execute(columns_sql)
columns_result = await connection.execute(columns_sql, params=(table_name,), auth_context=auth_context)
summary = {
"table_name": table_info["table_name"],
@@ -92,10 +106,11 @@ class TableAnalyzer:
"columns": columns_result.data,
}
# Get sample data
# Get sample data using quoted identifier
if include_sample and sample_size > 0:
sample_sql = f"SELECT * FROM {table_name} LIMIT {sample_size}"
sample_result = await connection.execute(sample_sql)
quoted_table = quote_identifier(table_name, "table name")
sample_sql = f"SELECT * FROM {quoted_table} LIMIT {sample_size}"
sample_result = await connection.execute(sample_sql, auth_context=auth_context)
summary["sample_data"] = sample_result.data
return summary
@@ -120,7 +135,8 @@ class TableAnalyzer:
FROM {table_name}
"""
basic_result = await connection.execute(basic_stats_sql)
auth_context = get_auth_context()
basic_result = await connection.execute(basic_stats_sql, auth_context=auth_context)
if not basic_result.data:
return {
"success": False,
@@ -144,7 +160,7 @@ class TableAnalyzer:
LIMIT 20
"""
distribution_result = await connection.execute(distribution_sql)
distribution_result = await connection.execute(distribution_sql, auth_context=auth_context)
analysis["value_distribution"] = distribution_result.data
if analysis_type == "detailed":
@@ -159,7 +175,7 @@ class TableAnalyzer:
WHERE {column_name} IS NOT NULL
"""
numeric_result = await connection.execute(numeric_stats_sql)
numeric_result = await connection.execute(numeric_stats_sql, auth_context=auth_context)
if numeric_result.data:
analysis.update(numeric_result.data[0])
except Exception:
@@ -196,7 +212,8 @@ class TableAnalyzer:
AND table_name = '{table_name}'
"""
table_result = await connection.execute(table_info_sql)
auth_context = get_auth_context()
table_result = await connection.execute(table_info_sql, auth_context=auth_context)
if not table_result.data:
raise ValueError(f"Table {table_name} does not exist")
@@ -211,7 +228,7 @@ class TableAnalyzer:
AND table_name != %s
"""
all_tables_result = await connection.execute(all_tables_sql, (table_name,))
all_tables_result = await connection.execute(all_tables_sql, params=(table_name,), auth_context=auth_context)
return {
"center_table": table_result.data[0],
@@ -291,7 +308,8 @@ class PerformanceMonitor:
LIMIT 20
"""
tables_result = await connection.execute(tables_sql)
auth_context = get_auth_context()
tables_result = await connection.execute(tables_sql, auth_context=auth_context)
stats = {
"metric_type": "tables",
"time_range": time_range,
@@ -380,8 +398,14 @@ class SQLAnalyzer:
logger.info(f"Generating SQL explain for query ID: {query_id}")
# Switch database if specified
# SECURITY FIX: Validate and quote db_name
if db_name:
await self.connection_manager.execute_query("explain_session", f"USE {db_name}")
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
return {"success": False, "error": f"Invalid database name: {e}"}
safe_db = quote_identifier(db_name, "database name")
await self.connection_manager.execute_query("explain_session", f"USE {safe_db}")
# Construct EXPLAIN query
explain_type = "EXPLAIN VERBOSE" if verbose else "EXPLAIN"
@@ -515,24 +539,36 @@ class SQLAnalyzer:
try:
# Switch to specified database/catalog if provided
# SECURITY FIX: Validate identifiers before using in SQL
if catalog_name:
await connection.execute(f"SWITCH `{catalog_name}`")
try:
validate_identifier(catalog_name, "catalog name")
except SQLSecurityError as e:
return {"success": False, "error": f"Invalid catalog name: {e}"}
safe_catalog = quote_identifier(catalog_name, "catalog name")
auth_context = get_auth_context()
await connection.execute(f"SWITCH {safe_catalog}", auth_context=auth_context)
if db_name:
await connection.execute(f"USE `{db_name}`")
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
return {"success": False, "error": f"Invalid database name: {e}"}
safe_db = quote_identifier(db_name, "database name")
await connection.execute(f"USE {safe_db}", auth_context=auth_context)
# Set trace ID for the session using session variable
# According to official docs: set session_context="trace_id:your_trace_id"
await connection.execute(f'set session_context="trace_id:{trace_id}"')
await connection.execute(f'set session_context="trace_id:{trace_id}"', auth_context=auth_context)
logger.info(f"Set trace ID: {trace_id}")
# Enable profile
await connection.execute(f'set enable_profile=true')
await connection.execute(f'set enable_profile=true', auth_context=auth_context)
logger.info(f"Enabled profile")
# Execute the SQL statement
logger.info(f"Executing SQL with trace ID: {sql}")
start_time = time.time()
sql_result = await connection.execute(sql)
sql_result = await connection.execute(sql, auth_context=auth_context)
execution_time = time.time() - start_time
logger.info(f"SQL execution completed in {execution_time:.3f}s")

View File

@@ -26,6 +26,13 @@ from typing import Any, Dict, List, Optional, Union
from .db import DorisConnectionManager
from .logger import get_logger
from .sql_security_utils import (
SQLSecurityError,
validate_identifier,
quote_identifier,
build_table_reference,
get_auth_context
)
logger = get_logger(__name__)
@@ -43,24 +50,30 @@ class DataExplorationTools:
def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str:
"""Build full table name with catalog and database using three-part naming convention"""
# Default catalog for internal tables
# SECURITY FIX: Use build_table_reference for safe identifier handling
effective_catalog = catalog_name if catalog_name else "internal"
if db_name:
return f"{effective_catalog}.{db_name}.{table_name}"
return build_table_reference(table_name, db_name, effective_catalog)
else:
# If no db_name provided, need to determine the current database
return f"{effective_catalog}.{table_name}"
return build_table_reference(table_name, catalog_name=effective_catalog)
async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]:
"""Get basic table information including row count"""
try:
# SECURITY FIX: Get auth_context for security validation
# table_name should already be validated by _build_full_table_name
auth_context = get_auth_context()
count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}"
result = await connection.execute(count_sql)
result = await connection.execute(count_sql, auth_context=auth_context)
if result.data:
return {"row_count": result.data[0]["row_count"]}
return None
except SQLSecurityError as e:
logger.warning(f"Security validation failed for table {table_name}: {str(e)}")
return {"row_count": 0}
except Exception as e:
logger.warning(f"Failed to get basic info for table {table_name}: {str(e)}")
return {"row_count": 0}
@@ -68,10 +81,24 @@ class DataExplorationTools:
async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]:
"""Get detailed column information"""
try:
where_conditions = [f"table_name = '{table_name}'"]
# SECURITY FIX: Validate identifiers and use parameterized query
auth_context = get_auth_context()
try:
validate_identifier(table_name, "table name")
if db_name:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return []
# Build parameterized query
params = [table_name]
where_conditions = ["table_name = %s"]
if db_name:
where_conditions.append(f"table_schema = '{db_name}'")
where_conditions.append("table_schema = %s")
params.append(db_name)
else:
where_conditions.append("table_schema = DATABASE()")
@@ -87,9 +114,12 @@ class DataExplorationTools:
ORDER BY ordinal_position
"""
result = await connection.execute(columns_sql)
result = await connection.execute(columns_sql, params=tuple(params), auth_context=auth_context)
return result.data if result.data else []
except SQLSecurityError as e:
logger.warning(f"Security validation failed: {str(e)}")
return []
except Exception as e:
logger.warning(f"Failed to get columns info for table {table_name}: {str(e)}")
return []
@@ -177,7 +207,8 @@ class DataExplorationTools:
WHERE {col_name} IS NOT NULL
"""
stats_result = await connection.execute(stats_sql)
auth_context = get_auth_context()
stats_result = await connection.execute(stats_sql, auth_context=auth_context)
if stats_result.data and stats_result.data[0]["count"] > 0:
stats = stats_result.data[0]
@@ -229,7 +260,8 @@ class DataExplorationTools:
WHERE {col_name} IS NOT NULL
"""
result = await connection.execute(percentile_sql)
auth_context = get_auth_context()
result = await connection.execute(percentile_sql, auth_context=auth_context)
if result.data:
data = result.data[0]
@@ -268,7 +300,8 @@ class DataExplorationTools:
WHERE {col_name} IS NOT NULL
"""
result = await connection.execute(outlier_sql)
auth_context = get_auth_context()
result = await connection.execute(outlier_sql, auth_context=auth_context)
if result.data:
data = result.data[0]
@@ -359,7 +392,8 @@ class DataExplorationTools:
{sampling_info.get('sample_query_suffix', '')}
"""
cardinality_result = await connection.execute(cardinality_sql)
auth_context = get_auth_context()
cardinality_result = await connection.execute(cardinality_sql, auth_context=auth_context)
if cardinality_result.data:
cardinality_data = cardinality_result.data[0]
@@ -408,7 +442,8 @@ class DataExplorationTools:
LIMIT 20
"""
result = await connection.execute(distribution_sql)
auth_context = get_auth_context()
result = await connection.execute(distribution_sql, auth_context=auth_context)
if result.data:
distribution = []
@@ -458,7 +493,8 @@ class DataExplorationTools:
WHERE {col_name} IS NOT NULL
"""
range_result = await connection.execute(range_sql)
auth_context = get_auth_context()
range_result = await connection.execute(range_sql, auth_context=auth_context)
if range_result.data and range_result.data[0]["non_null_count"] > 0:
range_data = range_result.data[0]
@@ -539,7 +575,8 @@ class DataExplorationTools:
ORDER BY day_of_week
"""
weekly_result = await connection.execute(weekly_pattern_sql)
auth_context = get_auth_context()
weekly_result = await connection.execute(weekly_pattern_sql, auth_context=auth_context)
weekly_pattern = []
if weekly_result.data:
@@ -561,7 +598,7 @@ class DataExplorationTools:
LIMIT 12
"""
monthly_result = await connection.execute(monthly_trend_sql)
monthly_result = await connection.execute(monthly_trend_sql, auth_context=auth_context)
monthly_trend = "stable" # Simplified trend analysis
if monthly_result.data and len(monthly_result.data) > 3:
@@ -646,7 +683,8 @@ class DataExplorationTools:
FROM {table_expr}
"""
result = await connection.execute(null_sql)
auth_context = get_auth_context()
result = await connection.execute(null_sql, auth_context=auth_context)
if result.data:
data = result.data[0]
total_count = data["total_count"]

View File

@@ -26,6 +26,13 @@ from typing import Any, Dict, List, Optional
from .db import DorisConnectionManager
from .logger import get_logger
from .sql_security_utils import (
SQLSecurityError,
validate_identifier,
quote_identifier,
build_table_reference,
get_auth_context
)
logger = get_logger(__name__)
@@ -216,26 +223,34 @@ class DataGovernanceTools:
# ==================== Private Helper Methods ====================
def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str:
"""Build full table name - use three-level naming convention"""
"""Build full table name - use three-level naming convention with security validation"""
# SECURITY FIX: Use build_table_reference for safe identifier handling
# Default catalog is internal for internal tables
effective_catalog = catalog_name if catalog_name else "internal"
if db_name:
return f"{effective_catalog}.{db_name}.{table_name}"
return build_table_reference(table_name, db_name, effective_catalog)
else:
# If db_name is not provided, need to determine current database
return f"{effective_catalog}.{table_name}"
return build_table_reference(table_name, catalog_name=effective_catalog)
async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]:
"""Get table basic information"""
try:
# SECURITY FIX: Get auth_context for security validation
# table_name should already be validated by _build_full_table_name
auth_context = get_auth_context()
# Try to get table row count
count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}"
result = await connection.execute(count_sql)
result = await connection.execute(count_sql, auth_context=auth_context)
if result.data:
return {"row_count": result.data[0]["row_count"]}
return None
except SQLSecurityError as e:
logger.warning(f"Security validation failed for table {table_name}: {str(e)}")
return {"row_count": 0}
except Exception as e:
logger.warning(f"Failed to get basic info for table {table_name}: {str(e)}")
return {"row_count": 0}
@@ -243,11 +258,24 @@ class DataGovernanceTools:
async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]:
"""Get table column information"""
try:
# Build query conditions
where_conditions = [f"table_name = '{table_name}'"]
# SECURITY FIX: Validate identifiers and use parameterized query
auth_context = get_auth_context()
try:
validate_identifier(table_name, "table name")
if db_name:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return []
# Build parameterized query conditions
params = [table_name]
where_conditions = ["table_name = %s"]
if db_name:
where_conditions.append(f"table_schema = '{db_name}'")
where_conditions.append("table_schema = %s")
params.append(db_name)
else:
where_conditions.append("table_schema = DATABASE()")
@@ -263,30 +291,49 @@ class DataGovernanceTools:
ORDER BY ordinal_position
"""
result = await connection.execute(columns_sql)
result = await connection.execute(columns_sql, params=tuple(params), auth_context=auth_context)
return result.data if result.data else []
except SQLSecurityError as e:
logger.warning(f"Security validation failed: {str(e)}")
return []
except Exception as e:
logger.warning(f"Failed to get columns info for table {table_name}: {str(e)}")
return []
async def _analyze_column_completeness(self, connection, table_name: str, columns_info: List[Dict]) -> Dict[str, Any]:
"""Analyze column completeness"""
# SECURITY FIX: Get auth_context for security validation
auth_context = get_auth_context()
column_completeness = {}
for column in columns_info:
column_name = column["column_name"]
try:
# SECURITY FIX: Validate column name before using in SQL
try:
validate_identifier(column_name, "column name")
except SQLSecurityError as e:
logger.warning(f"Invalid column name rejected: {e}")
column_completeness[column_name] = {
"error": f"Invalid column name: {e}",
"completeness_score": 0.0
}
continue
# Use quoted identifier for column name
quoted_column = quote_identifier(column_name, "column name")
# Calculate null value statistics
null_sql = f"""
SELECT
COUNT(*) as total_count,
COUNT({column_name}) as non_null_count,
COUNT(*) - COUNT({column_name}) as null_count
COUNT({quoted_column}) as non_null_count,
COUNT(*) - COUNT({quoted_column}) as null_count
FROM {table_name}
"""
result = await connection.execute(null_sql)
result = await connection.execute(null_sql, auth_context=auth_context)
if result.data:
stats = result.data[0]
total_count = stats["total_count"]
@@ -304,6 +351,12 @@ class DataGovernanceTools:
"completeness_score": round(completeness_score, 4)
}
except SQLSecurityError as e:
logger.warning(f"Security validation failed for column {column_name}: {str(e)}")
column_completeness[column_name] = {
"error": str(e),
"completeness_score": 0.0
}
except Exception as e:
logger.warning(f"Failed to analyze completeness for column {column_name}: {str(e)}")
column_completeness[column_name] = {
@@ -333,7 +386,8 @@ class DataGovernanceTools:
FROM {table_name}
"""
result = await connection.execute(compliance_sql)
auth_context = get_auth_context()
result = await connection.execute(compliance_sql, auth_context=auth_context)
if result.data:
stats = result.data[0]
pass_count = stats["pass_count"] or 0
@@ -378,7 +432,8 @@ class DataGovernanceTools:
) t
"""
result = await connection.execute(duplicate_sql)
auth_context = get_auth_context()
result = await connection.execute(duplicate_sql, auth_context=auth_context)
if result.data and result.data[0]["duplicate_count"] > 0:
issues.append({
"type": "duplicate_primary_keys",
@@ -456,10 +511,21 @@ class DataGovernanceTools:
async def _verify_column_exists(self, connection, table_name: str, column_name: str) -> bool:
"""Verify if column exists"""
try:
# Simple verification method: try to query the column
verify_sql = f"SELECT {column_name} FROM {table_name} LIMIT 1"
await connection.execute(verify_sql)
# SECURITY FIX: Validate and quote column name
try:
validate_identifier(column_name, "column name")
except SQLSecurityError as e:
logger.warning(f"Invalid column name rejected: {e}")
return False
safe_column = quote_identifier(column_name, "column name")
# table_name is already safe (from _build_full_table_name)
verify_sql = f"SELECT {safe_column} FROM {table_name} LIMIT 1"
auth_context = get_auth_context()
await connection.execute(verify_sql, auth_context=auth_context)
return True
except SQLSecurityError:
return False
except Exception:
return False
@@ -469,21 +535,34 @@ class DataGovernanceTools:
source_chain = []
try:
# SECURITY FIX: Validate table name and use parameterized-like approach
table_name_part = table_name.split('.')[-1]
try:
validate_identifier(table_name_part, "table name")
except SQLSecurityError as e:
logger.warning(f"Invalid table name rejected: {e}")
return []
# Escape special characters for LIKE pattern
safe_pattern = table_name_part.replace('%', r'\%').replace('_', r'\_')
like_pattern = f"%{safe_pattern}%"
# Try to find related INSERT/CREATE TABLE AS SELECT statements from audit logs (one year range)
auth_context = get_auth_context()
audit_sql = """
SELECT
stmt as sql_statement,
`time` as execution_time,
`user` as user_name
FROM internal.__internal_schema.audit_log
WHERE stmt LIKE '%{}%'
WHERE stmt LIKE %s
AND (stmt LIKE '%INSERT%' OR stmt LIKE '%CREATE%' OR stmt LIKE '%SELECT%')
AND `time` >= DATE_SUB(NOW(), INTERVAL 1 YEAR)
ORDER BY `time` DESC
LIMIT 50
""".format(table_name.split('.')[-1]) # Use the last part of table name
"""
result = await connection.execute(audit_sql)
result = await connection.execute(audit_sql, params=(like_pattern,), auth_context=auth_context)
if result.data:
for i, log_entry in enumerate(result.data[:depth]):
@@ -556,19 +635,33 @@ class DataGovernanceTools:
downstream_usage = []
try:
# SECURITY FIX: Validate inputs and use parameterized-like approach
table_name_part = table_name.split('.')[-1]
try:
validate_identifier(table_name_part, "table name")
validate_identifier(column_name, "column name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return []
# Escape special characters for LIKE pattern
safe_table_pattern = f"%{table_name_part.replace('%', r'\\%').replace('_', r'\\_')}%"
safe_column_pattern = f"%{column_name.replace('%', r'\\%').replace('_', r'\\_')}%"
# Find other tables that might use this field (through audit logs, one year range)
auth_context = get_auth_context()
usage_sql = """
SELECT DISTINCT
stmt as sql_statement
FROM internal.__internal_schema.audit_log
WHERE stmt LIKE '%{}%'
AND stmt LIKE '%{}%'
WHERE stmt LIKE %s
AND stmt LIKE %s
AND stmt LIKE '%SELECT%'
AND `time` >= DATE_SUB(NOW(), INTERVAL 1 YEAR)
LIMIT 20
""".format(table_name.split('.')[-1], column_name)
"""
result = await connection.execute(usage_sql)
result = await connection.execute(usage_sql, params=(safe_table_pattern, safe_column_pattern), auth_context=auth_context)
if result.data:
for entry in result.data:
@@ -634,14 +727,20 @@ class DataGovernanceTools:
async def _get_all_tables(self, connection, catalog_name: Optional[str], db_name: Optional[str]) -> List[str]:
"""Get list of all tables"""
try:
where_conditions = []
auth_context = get_auth_context()
params = []
# SECURITY FIX: Use parameterized query
if db_name:
where_conditions.append(f"table_schema = '{db_name}'")
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
logger.warning(f"Invalid database name rejected: {e}")
return []
where_clause = "table_schema = %s"
params.append(db_name)
else:
where_conditions.append("table_schema = DATABASE()")
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
where_clause = "table_schema = DATABASE()"
tables_sql = f"""
SELECT table_name
@@ -651,7 +750,7 @@ class DataGovernanceTools:
ORDER BY table_name
"""
result = await connection.execute(tables_sql)
result = await connection.execute(tables_sql, params=tuple(params) if params else None, auth_context=auth_context)
return [row["table_name"] for row in result.data] if result.data else []
except Exception as e:
@@ -728,15 +827,23 @@ class DataGovernanceTools:
async def _get_freshness_from_partition_info(self, connection, table_name: str) -> Optional[Dict]:
"""Get freshness from partition information"""
try:
# Query partition information (if table has partitions)
partition_sql = f"""
# SECURITY FIX: Validate and use parameterized query
table_name_part = table_name.split('.')[-1]
try:
validate_identifier(table_name_part, "table name")
except SQLSecurityError as e:
logger.warning(f"Invalid table name rejected: {e}")
return None
auth_context = get_auth_context()
partition_sql = """
SELECT MAX(CREATE_TIME) as last_update
FROM information_schema.partitions
WHERE table_name = '{table_name.split('.')[-1]}'
WHERE table_name = %s
AND CREATE_TIME IS NOT NULL
"""
result = await connection.execute(partition_sql)
result = await connection.execute(partition_sql, params=(table_name_part,), auth_context=auth_context)
if result.data and result.data[0]["last_update"]:
return {
"last_update": result.data[0]["last_update"],
@@ -759,7 +866,8 @@ class DataGovernanceTools:
FROM {table_name}
"""
result = await connection.execute(max_time_sql)
auth_context = get_auth_context()
result = await connection.execute(max_time_sql, auth_context=auth_context)
if result.data and result.data[0]["last_update"]:
return {
"last_update": result.data[0]["last_update"],
@@ -773,15 +881,23 @@ class DataGovernanceTools:
async def _get_freshness_from_table_metadata(self, connection, table_name: str) -> Optional[Dict]:
"""Get freshness from table metadata"""
try:
# Query table's update time
metadata_sql = f"""
# SECURITY FIX: Validate and use parameterized query
table_name_part = table_name.split('.')[-1]
try:
validate_identifier(table_name_part, "table name")
except SQLSecurityError as e:
logger.warning(f"Invalid table name rejected: {e}")
return None
auth_context = get_auth_context()
metadata_sql = """
SELECT UPDATE_TIME as last_update
FROM information_schema.tables
WHERE table_name = '{table_name.split('.')[-1]}'
WHERE table_name = %s
AND UPDATE_TIME IS NOT NULL
"""
result = await connection.execute(metadata_sql)
result = await connection.execute(metadata_sql, params=(table_name_part,), auth_context=auth_context)
if result.data and result.data[0]["last_update"]:
return {
"last_update": result.data[0]["last_update"],
@@ -795,10 +911,19 @@ class DataGovernanceTools:
async def _find_timestamp_columns(self, connection, table_name: str) -> List[str]:
"""Find possible timestamp fields"""
try:
timestamp_sql = f"""
# SECURITY FIX: Validate and use parameterized query
table_name_part = table_name.split('.')[-1]
try:
validate_identifier(table_name_part, "table name")
except SQLSecurityError as e:
logger.warning(f"Invalid table name rejected: {e}")
return []
auth_context = get_auth_context()
timestamp_sql = """
SELECT column_name
FROM information_schema.columns
WHERE table_name = '{table_name.split('.')[-1]}'
WHERE table_name = %s
AND (
data_type IN ('datetime', 'timestamp', 'date')
OR column_name LIKE '%time%'
@@ -815,7 +940,7 @@ class DataGovernanceTools:
END
"""
result = await connection.execute(timestamp_sql)
result = await connection.execute(timestamp_sql, params=(table_name_part,), auth_context=auth_context)
return [row["column_name"] for row in result.data] if result.data else []
except Exception:

View File

@@ -31,6 +31,12 @@ from collections import Counter, defaultdict
from .db import DorisConnectionManager
from .logger import get_logger
from .config import DorisConfig
from .sql_security_utils import (
SQLSecurityError,
validate_identifier,
build_table_reference,
get_auth_context
)
logger = get_logger(__name__)
@@ -299,23 +305,26 @@ class DataQualityTools:
# ===========================================
def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str:
"""Build full table name"""
if catalog_name and db_name:
return f"{catalog_name}.{db_name}.{table_name}"
elif db_name:
return f"{db_name}.{table_name}"
else:
return table_name
"""Build full table name with security validation"""
# SECURITY FIX: Use build_table_reference for safe identifier handling
return build_table_reference(table_name, db_name, catalog_name)
async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]:
"""Get basic table information"""
try:
# SECURITY FIX: table_name should already be validated by _build_full_table_name
# But we add auth_context for security validation
auth_context = get_auth_context()
# Try to get row count
count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}"
result = await connection.execute(count_sql)
result = await connection.execute(count_sql, auth_context=auth_context)
if result.data:
return {"row_count": result.data[0]["row_count"]}
return None
except SQLSecurityError as e:
logger.warning(f"Security validation failed: {str(e)}")
return None
except Exception as e:
logger.warning(f"Failed to get table basic info: {str(e)}")
return None
@@ -323,9 +332,13 @@ class DataQualityTools:
async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]:
"""Get table column information"""
try:
# Build DESCRIBE query
describe_sql = f"DESCRIBE {self._build_full_table_name(table_name, catalog_name, db_name)}"
result = await connection.execute(describe_sql)
# SECURITY FIX: Build safe table reference and pass auth_context
auth_context = get_auth_context()
# Build DESCRIBE query with safe table reference
safe_table_ref = self._build_full_table_name(table_name, catalog_name, db_name)
describe_sql = f"DESCRIBE {safe_table_ref}"
result = await connection.execute(describe_sql, auth_context=auth_context)
columns_info = []
if result.data:
@@ -339,6 +352,9 @@ class DataQualityTools:
})
return columns_info
except SQLSecurityError as e:
logger.warning(f"Security validation failed: {str(e)}")
return []
except Exception as e:
logger.warning(f"Failed to get table columns info: {str(e)}")
return []
@@ -346,7 +362,32 @@ class DataQualityTools:
async def _get_table_partitions(self, connection, table_name: str, db_name: Optional[str] = None) -> List[Dict]:
"""Get table partition information"""
try:
# Query partition information
# SECURITY FIX: Validate identifiers and use parameterized query
auth_context = get_auth_context()
# Validate table_name
try:
validate_identifier(table_name, "table name")
if db_name:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return []
# Build parameterized query
params = []
where_conditions = []
if db_name:
where_conditions.append("TABLE_SCHEMA = %s")
params.append(db_name)
else:
where_conditions.append("TABLE_SCHEMA = ''")
where_conditions.append("TABLE_NAME = %s")
params.append(table_name)
where_conditions.append("PARTITION_NAME IS NOT NULL")
partition_sql = f"""
SELECT
PARTITION_NAME,
@@ -355,12 +396,10 @@ class DataQualityTools:
DATA_LENGTH,
INDEX_LENGTH
FROM information_schema.PARTITIONS
WHERE TABLE_SCHEMA = '{db_name or ""}'
AND TABLE_NAME = '{table_name}'
AND PARTITION_NAME IS NOT NULL
WHERE {' AND '.join(where_conditions)}
"""
result = await connection.execute(partition_sql)
result = await connection.execute(partition_sql, params=tuple(params), auth_context=auth_context)
partitions = []
if result.data:
for row in result.data:
@@ -373,6 +412,9 @@ class DataQualityTools:
})
return partitions
except SQLSecurityError as e:
logger.warning(f"Security validation failed: {str(e)}")
return []
except Exception as e:
logger.warning(f"Failed to get table partitions: {str(e)}")
return []
@@ -417,7 +459,8 @@ class DataQualityTools:
if db_name
else f"SHOW CREATE TABLE {table_name}"
)
result = await connection.execute(query)
auth_context = get_auth_context()
result = await connection.execute(query, auth_context=auth_context)
if result.data:
return result.data[0].get("Create Table")
return None
@@ -428,8 +471,16 @@ class DataQualityTools:
async def _get_table_size_info(self, connection, table_name: str) -> Dict[str, Any]:
"""Get table size information"""
try:
# Query table size information
size_sql = f"""
# SECURITY FIX: Validate and use parameterized query
table_name_part = table_name.split('.')[-1]
try:
validate_identifier(table_name_part, "table name")
except SQLSecurityError as e:
logger.warning(f"Invalid table name rejected: {e}")
return {"engine": "Unknown", "estimated_rows": 0, "data_length": 0, "index_length": 0, "total_size": 0}
auth_context = get_auth_context()
size_sql = """
SELECT
table_name,
engine,
@@ -438,10 +489,10 @@ class DataQualityTools:
index_length,
(data_length + index_length) as total_size
FROM information_schema.tables
WHERE table_name = '{table_name.split('.')[-1]}'
WHERE table_name = %s
"""
result = await connection.execute(size_sql)
result = await connection.execute(size_sql, params=(table_name_part,), auth_context=auth_context)
if result.data and result.data[0]:
row = result.data[0]
return {
@@ -582,7 +633,8 @@ class DataQualityTools:
batch_sql = f"SELECT {', '.join(select_clauses)} FROM {table_expr}"
result = await connection.execute(batch_sql)
auth_context = get_auth_context()
result = await connection.execute(batch_sql, auth_context=auth_context)
if not result.data:
return {"error": "No data returned from batch completeness query"}
@@ -664,7 +716,8 @@ class DataQualityTools:
batch_sql = f"SELECT {', '.join(select_clauses)} FROM {table_expr}"
result = await connection.execute(batch_sql)
auth_context = get_auth_context()
result = await connection.execute(batch_sql, auth_context=auth_context)
if not result.data:
return {}
@@ -705,7 +758,8 @@ class DataQualityTools:
LIMIT 10
"""
result = await connection.execute(freq_sql)
auth_context = get_auth_context()
result = await connection.execute(freq_sql, auth_context=auth_context)
frequencies = result.data if result.data else []
categorical_results[col_name] = {
@@ -738,7 +792,8 @@ class DataQualityTools:
batch_sql = f"SELECT {', '.join(select_clauses)} FROM {table_expr}"
result = await connection.execute(batch_sql)
auth_context = get_auth_context()
result = await connection.execute(batch_sql, auth_context=auth_context)
if not result.data:
return {}
@@ -780,7 +835,8 @@ class DataQualityTools:
FROM {table_expr}
"""
result = await connection.execute(completeness_sql)
auth_context = get_auth_context()
result = await connection.execute(completeness_sql, auth_context=auth_context)
if result.data:
stats = result.data[0]
total_count = stats["total_count"]
@@ -906,7 +962,8 @@ class DataQualityTools:
WHERE {col_name} IS NOT NULL
"""
result = await connection.execute(stats_sql)
auth_context = get_auth_context()
result = await connection.execute(stats_sql, auth_context=auth_context)
if result.data and result.data[0]["non_null_count"] > 0:
stats = result.data[0]
numeric_analysis[col_name] = {
@@ -945,7 +1002,8 @@ class DataQualityTools:
WHERE {col_name} IS NOT NULL
"""
cardinality_result = await connection.execute(cardinality_sql)
auth_context = get_auth_context()
cardinality_result = await connection.execute(cardinality_sql, auth_context=auth_context)
if cardinality_result.data:
stats = cardinality_result.data[0]
@@ -969,7 +1027,7 @@ class DataQualityTools:
LIMIT 10
"""
top_values_result = await connection.execute(top_values_sql)
top_values_result = await connection.execute(top_values_sql, auth_context=auth_context)
if top_values_result.data:
categorical_analysis[col_name]["top_values"] = [
{"value": row[col_name], "count": row["count"]}
@@ -998,7 +1056,8 @@ class DataQualityTools:
WHERE {col_name} IS NOT NULL
"""
result = await connection.execute(stats_sql)
auth_context = get_auth_context()
result = await connection.execute(stats_sql, auth_context=auth_context)
if result.data and result.data[0]["non_null_count"] > 0:
stats = result.data[0]
temporal_analysis[col_name] = {

View File

@@ -27,6 +27,13 @@ from collections import defaultdict, deque
from .db import DorisConnectionManager
from .logger import get_logger
from .sql_security_utils import (
SQLSecurityError,
validate_identifier,
quote_identifier,
build_table_reference,
get_auth_context
)
logger = get_logger(__name__)
@@ -122,10 +129,19 @@ class DependencyAnalysisTools:
async def _get_tables_metadata(self, connection, catalog_name: Optional[str], db_name: Optional[str], include_views: bool) -> List[Dict]:
"""Get metadata for all tables and views"""
try:
# Build conditions for query
# Build conditions for query with parameterized values
where_conditions = []
params = []
if db_name:
where_conditions.append(f"table_schema = '{db_name}'")
# SECURITY FIX: Validate identifier and use parameterized query
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
logger.warning(f"Invalid database name rejected: {e}")
return []
where_conditions.append("table_schema = %s")
params.append(db_name)
else:
where_conditions.append("table_schema = DATABASE()")
@@ -148,9 +164,18 @@ class DependencyAnalysisTools:
ORDER BY table_schema, table_name
"""
result = await connection.execute(metadata_sql)
# SECURITY FIX: Get auth_context and pass to execute for security validation
auth_context = get_auth_context()
result = await connection.execute(
metadata_sql,
params=tuple(params) if params else None,
auth_context=auth_context
)
return result.data if result.data else []
except SQLSecurityError as e:
logger.warning(f"Security validation failed in _get_tables_metadata: {str(e)}")
return []
except Exception as e:
logger.warning(f"Failed to get tables metadata: {str(e)}")
return []
@@ -186,17 +211,31 @@ class DependencyAnalysisTools:
async def _analyze_view_dependencies(self, connection, dependency_graph: Dict, tables_metadata: List[Dict]) -> None:
"""Analyze view definitions to extract table dependencies"""
# Get auth_context once for all operations in this method
auth_context = get_auth_context()
try:
for table in tables_metadata:
if table["table_type"] == "VIEW":
table_name = table["table_name"]
schema_name = table.get("schema_name", "")
# Get view definition
view_def_sql = f"SHOW CREATE VIEW {schema_name}.{table_name}" if schema_name else f"SHOW CREATE VIEW {table_name}"
# SECURITY FIX: Validate identifiers before using in SQL
try:
validate_identifier(table_name, "table name")
if schema_name:
validate_identifier(schema_name, "schema name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected in view analysis: {e}")
continue
# Build safe view reference using quoted identifiers
view_ref = build_table_reference(table_name, schema_name) if schema_name else quote_identifier(table_name, "table name")
view_def_sql = f"SHOW CREATE VIEW {view_ref}"
try:
result = await connection.execute(view_def_sql)
# SECURITY FIX: Pass auth_context to execute
result = await connection.execute(view_def_sql, auth_context=auth_context)
if result.data and len(result.data) > 0:
# Extract view definition from result
view_definition = ""
@@ -235,6 +274,9 @@ class DependencyAnalysisTools:
async def _analyze_runtime_dependencies(self, connection, dependency_graph: Dict, analysis_depth: int) -> None:
"""Analyze audit logs to discover runtime table dependencies"""
# Get auth_context for security validation
auth_context = get_auth_context()
try:
# Get recent SQL statements from audit logs
audit_sql = """
@@ -252,7 +294,8 @@ class DependencyAnalysisTools:
LIMIT 1000
"""
result = await connection.execute(audit_sql)
# SECURITY FIX: Pass auth_context to execute
result = await connection.execute(audit_sql, auth_context=auth_context)
if result.data:
for row in result.data:
@@ -274,6 +317,9 @@ class DependencyAnalysisTools:
async def _analyze_foreign_key_dependencies(self, connection, dependency_graph: Dict, tables_metadata: List[Dict]) -> None:
"""Analyze foreign key constraints for explicit dependencies"""
# Get auth_context for security validation
auth_context = get_auth_context()
try:
# Get foreign key information
fk_sql = """
@@ -288,7 +334,8 @@ class DependencyAnalysisTools:
WHERE REFERENCED_TABLE_NAME IS NOT NULL
"""
result = await connection.execute(fk_sql)
# SECURITY FIX: Pass auth_context to execute
result = await connection.execute(fk_sql, auth_context=auth_context)
if result.data:
for row in result.data:

View File

@@ -28,6 +28,7 @@ from datetime import datetime
from .db import DorisConnectionManager
from .logger import get_logger
from .sql_security_utils import get_auth_context
logger = get_logger(__name__)
@@ -713,7 +714,8 @@ class DorisMonitoringTools:
# Fallback to SHOW BACKENDS if no BE hosts configured
logger.info("No BE hosts configured, using SHOW BACKENDS to discover BE nodes")
connection = await self.connection_manager.get_connection("query")
result = await connection.execute("SHOW BACKENDS")
auth_context = get_auth_context()
result = await connection.execute("SHOW BACKENDS", auth_context=auth_context)
be_nodes = []
for row in result.data:

View File

@@ -27,6 +27,13 @@ from collections import defaultdict, Counter
from .db import DorisConnectionManager
from .logger import get_logger
from .sql_security_utils import (
SQLSecurityError,
validate_identifier,
quote_identifier,
build_table_reference,
get_auth_context
)
logger = get_logger(__name__)
@@ -229,7 +236,8 @@ class PerformanceAnalyticsTools:
ORDER BY query_date
"""
result = await connection.execute(query_volume_sql)
auth_context = get_auth_context()
result = await connection.execute(query_volume_sql, auth_context=auth_context)
daily_data = result.data if result.data else []
if not daily_data:
@@ -304,7 +312,8 @@ class PerformanceAnalyticsTools:
ORDER BY activity_date
"""
result = await connection.execute(user_activity_sql)
auth_context = get_auth_context()
result = await connection.execute(user_activity_sql, auth_context=auth_context)
daily_data = result.data if result.data else []
if not daily_data:
@@ -383,7 +392,8 @@ class PerformanceAnalyticsTools:
LIMIT 5000
"""
result = await connection.execute(slow_query_sql)
auth_context = get_auth_context()
result = await connection.execute(slow_query_sql, auth_context=auth_context)
return result.data if result.data else []
except Exception as e:
@@ -705,7 +715,8 @@ class PerformanceAnalyticsTools:
ORDER BY size_mb DESC
"""
db_result = await connection.execute(db_sizes_sql)
auth_context = get_auth_context()
db_result = await connection.execute(db_sizes_sql, auth_context=auth_context)
if not db_result.data:
logger.warning("No database size information available")
@@ -805,7 +816,16 @@ class PerformanceAnalyticsTools:
async def _get_database_table_details_from_schema(self, connection, db_name: str) -> List[Dict]:
"""Get table details for a specific database using information_schema"""
try:
table_details_sql = f"""
# SECURITY FIX: Validate db_name and use parameterized query
auth_context = get_auth_context()
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
logger.warning(f"Invalid database name rejected: {e}")
return []
table_details_sql = """
SELECT
TABLE_SCHEMA as schema_name,
TABLE_NAME as table_name,
@@ -814,13 +834,13 @@ class PerformanceAnalyticsTools:
CREATE_TIME as create_time,
UPDATE_TIME as update_time
FROM information_schema.tables
WHERE TABLE_SCHEMA = '{db_name}'
WHERE TABLE_SCHEMA = %s
AND TABLE_TYPE = 'BASE TABLE'
AND (COALESCE(DATA_LENGTH, 0) + COALESCE(INDEX_LENGTH, 0)) > 0
ORDER BY size_mb DESC
"""
result = await connection.execute(table_details_sql)
result = await connection.execute(table_details_sql, params=(db_name,), auth_context=auth_context)
if not result.data:
logger.warning(f"No table details found for database {db_name}")
@@ -867,6 +887,13 @@ class PerformanceAnalyticsTools:
async def _get_database_table_details(self, connection, db_name: str) -> List[Dict]:
"""Get table details for a specific database using session-consistent queries"""
try:
# SECURITY FIX: Validate db_name before using in SQL
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
logger.warning(f"Invalid database name rejected: {e}")
return []
# Method 1: Try to use session-consistent approach with raw connection
# This requires accessing the underlying connection to maintain session state
@@ -877,8 +904,9 @@ class PerformanceAnalyticsTools:
# Use raw connection to maintain session state
cursor = await raw_conn.cursor()
try:
# Execute USE and SHOW DATA in the same session
await cursor.execute(f"USE {db_name}")
# SECURITY FIX: Use quoted identifier for USE statement
quoted_db = quote_identifier(db_name, "database name")
await cursor.execute(f"USE {quoted_db}")
await cursor.execute("SHOW DATA")
result = await cursor.fetchall()
@@ -922,9 +950,19 @@ class PerformanceAnalyticsTools:
async def _get_database_table_details_fallback(self, connection, db_name: str) -> List[Dict]:
"""Fallback method to get table details using individual queries"""
try:
# Get all tables in the database
tables_sql = f"SHOW TABLES FROM {db_name}"
tables_result = await connection.execute(tables_sql)
# SECURITY FIX: Validate db_name and get auth_context
auth_context = get_auth_context()
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
logger.warning(f"Invalid database name rejected: {e}")
return []
# Get all tables in the database using quoted identifier
quoted_db = quote_identifier(db_name, "database name")
tables_sql = f"SHOW TABLES FROM {quoted_db}"
tables_result = await connection.execute(tables_sql, auth_context=auth_context)
if not tables_result.data:
return []
@@ -934,9 +972,11 @@ class PerformanceAnalyticsTools:
table_name = table_row.get(f"Tables_in_{db_name}", "") or table_row.get("table_name", "")
if table_name:
try:
# Use SHOW DATA FROM db.table for each table
data_sql = f"SHOW DATA FROM {db_name}.{table_name}"
data_result = await connection.execute(data_sql)
# SECURITY FIX: Validate table_name and use safe reference
validate_identifier(table_name, "table name")
safe_table_ref = build_table_reference(table_name, db_name)
data_sql = f"SHOW DATA FROM {safe_table_ref}"
data_result = await connection.execute(data_sql, auth_context=auth_context)
if data_result.data:
for row in data_result.data:
@@ -1036,6 +1076,7 @@ class PerformanceAnalyticsTools:
async def _get_all_tables_info(self, connection) -> List[Dict]:
"""Get basic information for all tables (fallback method)"""
try:
auth_context = get_auth_context()
tables_sql = """
SELECT
table_schema,
@@ -1053,7 +1094,7 @@ class PerformanceAnalyticsTools:
ORDER BY (data_length + index_length) DESC
"""
result = await connection.execute(tables_sql)
result = await connection.execute(tables_sql, auth_context=auth_context)
return result.data if result.data else []
except Exception as e:
@@ -1120,23 +1161,37 @@ class PerformanceAnalyticsTools:
async def _get_current_table_size(self, connection, full_table_name: str) -> Optional[Dict]:
"""Get current table size"""
try:
# Try to query table size directly
size_sql = f"""
# SECURITY FIX: Get auth_context and use parameterized query
auth_context = get_auth_context()
# Extract table name for parameterized query
table_name_only = full_table_name.split('.')[-1] if '.' in full_table_name else full_table_name
# Validate identifiers
try:
validate_identifier(table_name_only, "table name")
except SQLSecurityError as e:
logger.warning(f"Invalid table name rejected: {e}")
return None
# Use parameterized query for safety
size_sql = """
SELECT
COALESCE(ROUND((COALESCE(data_length, 0) + COALESCE(index_length, 0)) / 1024 / 1024, 2), 0) as size_mb,
COALESCE(table_rows, 0) as `rows`
FROM information_schema.tables
WHERE CONCAT(table_schema, '.', table_name) = '{full_table_name}'
OR table_name = '{full_table_name.split('.')[-1]}'
WHERE CONCAT(table_schema, '.', table_name) = %s
OR table_name = %s
"""
result = await connection.execute(size_sql)
result = await connection.execute(size_sql, params=(full_table_name, table_name_only), auth_context=auth_context)
if result.data and result.data[0]:
return result.data[0]
# If information_schema has no data, try COUNT query
# full_table_name should already be validated by caller using build_table_reference
count_sql = f"SELECT COUNT(*) as rows FROM {full_table_name}"
count_result = await connection.execute(count_sql)
count_result = await connection.execute(count_sql, auth_context=auth_context)
if count_result.data:
return {
"size_mb": 0, # Cannot get exact size
@@ -1145,6 +1200,9 @@ class PerformanceAnalyticsTools:
return None
except SQLSecurityError as e:
logger.warning(f"Security validation failed for {full_table_name}: {str(e)}")
return None
except Exception as e:
logger.warning(f"Failed to get current size for {full_table_name}: {str(e)}")
return None
@@ -1154,8 +1212,19 @@ class PerformanceAnalyticsTools:
) -> List[Dict]:
"""Get historical growth data based on partitions"""
try:
# Query partition information
partition_sql = f"""
# SECURITY FIX: Validate identifiers and use parameterized query
auth_context = get_auth_context()
try:
validate_identifier(table_name, "table name")
if schema_name:
validate_identifier(schema_name, "schema name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return []
# Use parameterized query for safety
partition_sql = """
SELECT
partition_name,
partition_description,
@@ -1163,15 +1232,19 @@ class PerformanceAnalyticsTools:
data_length,
create_time
FROM information_schema.partitions
WHERE table_schema = '{schema_name or ""}'
AND table_name = '{table_name}'
WHERE table_schema = %s
AND table_name = %s
AND partition_name IS NOT NULL
AND create_time IS NOT NULL
AND create_time >= DATE_SUB(NOW(), INTERVAL {days} DAY)
AND create_time >= DATE_SUB(NOW(), INTERVAL %s DAY)
ORDER BY create_time DESC
"""
result = await connection.execute(partition_sql)
result = await connection.execute(
partition_sql,
params=(schema_name or "", table_name, days),
auth_context=auth_context
)
if not result.data:
return []
@@ -1210,6 +1283,9 @@ class PerformanceAnalyticsTools:
) -> List[Dict]:
"""Get historical growth data based on timestamp fields"""
try:
# SECURITY FIX: Get auth_context
auth_context = get_auth_context()
# Find possible timestamp fields
timestamp_columns = await self._find_timestamp_columns(connection, table_name, schema_name)
if not timestamp_columns:
@@ -1218,20 +1294,29 @@ class PerformanceAnalyticsTools:
# Use best timestamp field for analysis
time_column = timestamp_columns[0]
# Aggregate data by date
# SECURITY FIX: Validate time_column before using in SQL
try:
validate_identifier(time_column, "column name")
except SQLSecurityError as e:
logger.warning(f"Invalid column name rejected: {e}")
return []
quoted_time_column = quote_identifier(time_column, "column name")
# Aggregate data by date (full_table_name should be validated by caller)
growth_sql = f"""
SELECT
DATE({time_column}) as date,
DATE({quoted_time_column}) as date,
COUNT(*) as daily_records,
COUNT(*) / SUM(COUNT(*)) OVER() * 100 as percentage
FROM {full_table_name}
WHERE {time_column} >= DATE_SUB(NOW(), INTERVAL {days} DAY)
AND {time_column} IS NOT NULL
GROUP BY DATE({time_column})
WHERE {quoted_time_column} >= DATE_SUB(NOW(), INTERVAL %s DAY)
AND {quoted_time_column} IS NOT NULL
GROUP BY DATE({quoted_time_column})
ORDER BY date DESC
"""
result = await connection.execute(growth_sql)
result = await connection.execute(growth_sql, params=(days,), auth_context=auth_context)
if not result.data:
return []
@@ -1257,11 +1342,22 @@ class PerformanceAnalyticsTools:
async def _find_timestamp_columns(self, connection, table_name: str, schema_name: str) -> List[str]:
"""Find timestamp fields in table"""
try:
timestamp_sql = f"""
# SECURITY FIX: Validate identifiers and use parameterized query
auth_context = get_auth_context()
try:
validate_identifier(table_name, "table name")
if schema_name:
validate_identifier(schema_name, "schema name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return []
timestamp_sql = """
SELECT column_name, data_type
FROM information_schema.columns
WHERE table_schema = '{schema_name or ""}'
AND table_name = '{table_name}'
WHERE table_schema = %s
AND table_name = %s
AND (
data_type IN ('datetime', 'timestamp', 'date')
OR column_name REGEXP '(create|insert|update|modify).*time'
@@ -1278,9 +1374,16 @@ class PerformanceAnalyticsTools:
END
"""
result = await connection.execute(timestamp_sql)
result = await connection.execute(
timestamp_sql,
params=(schema_name or "", table_name),
auth_context=auth_context
)
return [row["column_name"] for row in result.data] if result.data else []
except SQLSecurityError as e:
logger.warning(f"Security validation failed: {str(e)}")
return []
except Exception as e:
logger.warning(f"Failed to find timestamp columns: {str(e)}")
return []
@@ -1290,8 +1393,22 @@ class PerformanceAnalyticsTools:
) -> List[Dict]:
"""Estimate growth data based on audit logs"""
try:
# SECURITY FIX: Validate table_name and use parameterized query
auth_context = get_auth_context()
try:
validate_identifier(table_name.split(".")[-1], "table name")
except SQLSecurityError as e:
logger.warning(f"Invalid table name rejected: {e}")
return []
# Extract just the table name for LIKE pattern
table_name_only = table_name.split(".")[-1]
like_pattern_full = f"%{table_name}%"
like_pattern_short = f"%{table_name_only}%"
# Analyze operation history for this table
audit_sql = f"""
audit_sql = """
SELECT
DATE(`time`) as operation_date,
COUNT(*) as operation_count,
@@ -1299,17 +1416,21 @@ class PerformanceAnalyticsTools:
SUM(CASE WHEN stmt LIKE 'UPDATE%' THEN 1 ELSE 0 END) as update_count,
SUM(CASE WHEN stmt LIKE 'DELETE%' THEN 1 ELSE 0 END) as delete_count
FROM internal.__internal_schema.audit_log
WHERE `time` >= DATE_SUB(NOW(), INTERVAL {days} DAY)
WHERE `time` >= DATE_SUB(NOW(), INTERVAL %s DAY)
AND stmt IS NOT NULL
AND (
stmt LIKE '%{table_name}%'
OR stmt LIKE '%{table_name.split(".")[-1]}%'
stmt LIKE %s
OR stmt LIKE %s
)
GROUP BY DATE(`time`)
ORDER BY operation_date DESC
"""
result = await connection.execute(audit_sql)
result = await connection.execute(
audit_sql,
params=(days, like_pattern_full, like_pattern_short),
auth_context=auth_context
)
if not result.data:
return []

View File

@@ -35,6 +35,7 @@ from decimal import Decimal
from .db import DorisConnectionManager, QueryResult
from .logger import get_logger
from .sql_security_utils import get_auth_context
@dataclass
@@ -497,7 +498,8 @@ class DorisQueryExecutor:
explain_sql = f"EXPLAIN {sql}"
connection = await self.connection_manager.get_connection(session_id)
result = await connection.execute(explain_sql)
auth_context = get_auth_context()
result = await connection.execute(explain_sql, auth_context=auth_context)
return {
"query": sql,

View File

@@ -32,6 +32,11 @@ from datetime import datetime, timedelta
# Import unified logging configuration
from .logger import get_logger
from .sql_security_utils import (
SQLSecurityError,
validate_identifier,
quote_identifier
)
# Configure logging
logger = get_logger(__name__)
@@ -431,6 +436,16 @@ class MetadataExtractor:
logger.warning("Database name not specified")
return {}
# SECURITY FIX: Validate identifiers to prevent SQL injection
try:
validate_identifier(table_name, "table name")
validate_identifier(db_name, "database name")
if effective_catalog:
validate_identifier(effective_catalog, "catalog name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected in get_table_schema: {e}")
return {}
cache_key = f"schema_{effective_catalog or 'default'}_{db_name}_{table_name}"
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
return self.metadata_cache[cache_key]
@@ -536,6 +551,16 @@ class MetadataExtractor:
logger.warning("Database name not specified")
return ""
# SECURITY FIX: Validate identifiers
try:
validate_identifier(table_name, "table name")
validate_identifier(db_name, "database name")
if effective_catalog:
validate_identifier(effective_catalog, "catalog name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return ""
cache_key = f"table_comment_{effective_catalog or 'default'}_{db_name}_{table_name}"
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
return self.metadata_cache[cache_key]
@@ -587,6 +612,16 @@ class MetadataExtractor:
logger.warning("Database name not specified")
return {}
# SECURITY FIX: Validate identifiers
try:
validate_identifier(table_name, "table name")
validate_identifier(db_name, "database name")
if effective_catalog:
validate_identifier(effective_catalog, "catalog name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return {}
cache_key = f"column_comments_{effective_catalog or 'default'}_{db_name}_{table_name}"
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
return self.metadata_cache[cache_key]
@@ -643,17 +678,30 @@ class MetadataExtractor:
logger.error("Database name not specified")
return []
# SECURITY FIX: Validate identifiers
try:
validate_identifier(table_name, "table name")
validate_identifier(db_name, "database name")
if effective_catalog:
validate_identifier(effective_catalog, "catalog name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return []
cache_key = f"indexes_{effective_catalog or 'default'}_{db_name}_{table_name}"
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
return self.metadata_cache[cache_key]
try:
# Build query with catalog prefix if specified
# Build query with catalog prefix if specified (identifiers already validated)
safe_table = quote_identifier(table_name, "table name")
safe_db = quote_identifier(db_name, "database name")
if effective_catalog:
query = f"SHOW INDEX FROM `{effective_catalog}`.`{db_name}`.`{table_name}`"
safe_catalog = quote_identifier(effective_catalog, "catalog name")
query = f"SHOW INDEX FROM {safe_catalog}.{safe_db}.{safe_table}"
logger.info(f"Using three-part naming for index query: {query}")
else:
query = f"SHOW INDEX FROM `{db_name}`.`{table_name}`"
query = f"SHOW INDEX FROM {safe_db}.{safe_table}"
try:
# NOTE: Deprecated sync path retained for compatibility; use async variant instead.
@@ -1188,12 +1236,28 @@ class MetadataExtractor:
try:
# Use async query method
effective_catalog = catalog_name or self.catalog_name
effective_db = db_name or self.db_name
# SECURITY FIX: Validate identifiers
try:
validate_identifier(table_name, "table name")
if effective_db:
validate_identifier(effective_db, "database name")
if effective_catalog and effective_catalog != "internal":
validate_identifier(effective_catalog, "catalog name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return []
# Build query statement using safe identifiers
safe_table = quote_identifier(table_name, "table name")
safe_db = quote_identifier(effective_db, "database name") if effective_db else None
# Build query statement
if effective_catalog and effective_catalog != "internal":
query = f"DESCRIBE `{effective_catalog}`.`{db_name or self.db_name}`.`{table_name}`"
safe_catalog = quote_identifier(effective_catalog, "catalog name")
query = f"DESCRIBE {safe_catalog}.{safe_db}.{safe_table}"
else:
query = f"DESCRIBE `{db_name or self.db_name}`.`{table_name}`"
query = f"DESCRIBE {safe_db}.{safe_table}"
# Execute async query
result = await self._execute_query_async(query, db_name)
@@ -1226,8 +1290,15 @@ class MetadataExtractor:
try:
effective_catalog = catalog_name or self.catalog_name
# SECURITY FIX: Validate catalog name if provided
if effective_catalog and effective_catalog != "internal":
query = f"SHOW DATABASES FROM `{effective_catalog}`"
try:
validate_identifier(effective_catalog, "catalog name")
except SQLSecurityError as e:
logger.warning(f"Invalid catalog name rejected: {e}")
return []
safe_catalog = quote_identifier(effective_catalog, "catalog name")
query = f"SHOW DATABASES FROM {safe_catalog}"
else:
query = "SHOW DATABASES"
@@ -1257,10 +1328,23 @@ class MetadataExtractor:
effective_catalog = catalog_name or self.catalog_name
effective_db = db_name or self.db_name
# SECURITY FIX: Validate identifiers
try:
if effective_db:
validate_identifier(effective_db, "database name")
if effective_catalog and effective_catalog != "internal":
validate_identifier(effective_catalog, "catalog name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return []
safe_db = quote_identifier(effective_db, "database name") if effective_db else None
if effective_catalog and effective_catalog != "internal":
query = f"SHOW TABLES FROM `{effective_catalog}`.`{effective_db}`"
safe_catalog = quote_identifier(effective_catalog, "catalog name")
query = f"SHOW TABLES FROM {safe_catalog}.{safe_db}"
else:
query = f"SHOW TABLES FROM `{effective_db}`"
query = f"SHOW TABLES FROM {safe_db}"
result = await self._execute_query_async(query, effective_db)
@@ -1319,6 +1403,15 @@ class MetadataExtractor:
effective_db = db_name or self.db_name
effective_catalog = catalog_name or self.catalog_name
# SECURITY FIX: Validate identifiers
try:
validate_identifier(table_name, "table name")
if effective_db:
validate_identifier(effective_db, "database name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return ""
query = f"""
SELECT
TABLE_COMMENT
@@ -1343,6 +1436,15 @@ class MetadataExtractor:
effective_db = db_name or self.db_name
effective_catalog = catalog_name or self.catalog_name
# SECURITY FIX: Validate identifiers
try:
validate_identifier(table_name, "table name")
if effective_db:
validate_identifier(effective_db, "database name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return {}
query = f"""
SELECT
COLUMN_NAME,
@@ -1373,12 +1475,27 @@ class MetadataExtractor:
effective_db = db_name or self.db_name
effective_catalog = catalog_name or self.catalog_name
# Build query with catalog prefix if specified
# SECURITY FIX: Validate identifiers
try:
validate_identifier(table_name, "table name")
if effective_db:
validate_identifier(effective_db, "database name")
if effective_catalog:
validate_identifier(effective_catalog, "catalog name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return []
# Build query with catalog prefix if specified (using safe identifiers)
safe_table = quote_identifier(table_name, "table name")
safe_db = quote_identifier(effective_db, "database name") if effective_db else None
if effective_catalog:
query = f"SHOW INDEX FROM `{effective_catalog}`.`{effective_db}`.`{table_name}`"
safe_catalog = quote_identifier(effective_catalog, "catalog name")
query = f"SHOW INDEX FROM {safe_catalog}.{safe_db}.{safe_table}"
logger.info(f"Using three-part naming for async index query: {query}")
else:
query = f"SHOW INDEX FROM `{effective_db}`.`{table_name}`"
query = f"SHOW INDEX FROM {safe_db}.{safe_table}"
rows = await self._execute_query_async(query, effective_db)
indexes: List[Dict[str, Any]] = []
@@ -1475,21 +1592,45 @@ class MetadataExtractor:
return self._format_response(success=False, error="No SQL statement provided", message="Please provide SQL statement to execute")
# FIX for Issue #62 Bug 3: Build context switching SQL if db_name or catalog_name is specified
# SECURITY FIX: Validate catalog_name and db_name to prevent SQL injection
final_sql = sql
if catalog_name or db_name:
context_statements = []
# Validate and sanitize catalog_name
if catalog_name:
# Switch to specified catalog
context_statements.append(f"USE CATALOG `{catalog_name}`")
try:
validate_identifier(catalog_name, "catalog name")
except SQLSecurityError as e:
logger.warning(f"Invalid catalog name rejected: {e}")
return self._format_response(
success=False,
error=f"Invalid catalog name: {catalog_name}",
message="Catalog name contains invalid characters"
)
# Use quote_identifier to safely escape the catalog name
safe_catalog = quote_identifier(catalog_name, "catalog name")
context_statements.append(f"USE CATALOG {safe_catalog}")
logger.debug(f"Switching to catalog: {catalog_name}")
# Validate and sanitize db_name
if db_name:
# Switch to specified database
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
logger.warning(f"Invalid database name rejected: {e}")
return self._format_response(
success=False,
error=f"Invalid database name: {db_name}",
message="Database name contains invalid characters"
)
# Use quote_identifier to safely escape the database name
safe_db = quote_identifier(db_name, "database name")
if catalog_name:
context_statements.append(f"USE `{catalog_name}`.`{db_name}`")
safe_catalog = quote_identifier(catalog_name, "catalog name")
context_statements.append(f"USE {safe_catalog}.{safe_db}")
else:
context_statements.append(f"USE `{db_name}`")
context_statements.append(f"USE {safe_db}")
logger.debug(f"Switching to database: {db_name}")
# Combine context switching with original SQL
@@ -1551,6 +1692,36 @@ class MetadataExtractor:
if not table_name:
return self._format_response(success=False, error="Missing table_name parameter")
# SECURITY: Validate identifiers before processing
try:
validate_identifier(table_name, "table name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid table name: {table_name}",
message="Table name contains invalid characters"
)
if db_name:
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid database name: {db_name}",
message="Database name contains invalid characters"
)
if catalog_name and catalog_name != "internal":
try:
validate_identifier(catalog_name, "catalog name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid catalog name: {catalog_name}",
message="Catalog name contains invalid characters"
)
try:
schema = await self.get_table_schema_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
@@ -1574,6 +1745,27 @@ class MetadataExtractor:
"""Get list of all table names in specified database - MCP interface"""
logger.info(f"Getting database table list: DB: {db_name}, Catalog: {catalog_name}")
# SECURITY: Validate identifiers
if db_name:
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid database name: {db_name}",
message="Database name contains invalid characters"
)
if catalog_name and catalog_name != "internal":
try:
validate_identifier(catalog_name, "catalog name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid catalog name: {catalog_name}",
message="Catalog name contains invalid characters"
)
try:
tables = await self.get_database_tables_async(db_name=db_name, catalog_name=catalog_name)
return self._format_response(success=True, result=tables)
@@ -1604,6 +1796,36 @@ class MetadataExtractor:
if not table_name:
return self._format_response(success=False, error="Missing table_name parameter")
# SECURITY: Validate identifiers
try:
validate_identifier(table_name, "table name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid table name: {table_name}",
message="Table name contains invalid characters"
)
if db_name:
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid database name: {db_name}",
message="Database name contains invalid characters"
)
if catalog_name and catalog_name != "internal":
try:
validate_identifier(catalog_name, "catalog name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid catalog name: {catalog_name}",
message="Catalog name contains invalid characters"
)
try:
comment = await self.get_table_comment_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
return self._format_response(success=True, result=comment)
@@ -1623,6 +1845,36 @@ class MetadataExtractor:
if not table_name:
return self._format_response(success=False, error="Missing table_name parameter")
# SECURITY: Validate identifiers
try:
validate_identifier(table_name, "table name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid table name: {table_name}",
message="Table name contains invalid characters"
)
if db_name:
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid database name: {db_name}",
message="Database name contains invalid characters"
)
if catalog_name and catalog_name != "internal":
try:
validate_identifier(catalog_name, "catalog name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid catalog name: {catalog_name}",
message="Catalog name contains invalid characters"
)
try:
comments = await self.get_column_comments_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
return self._format_response(success=True, result=comments)
@@ -1642,6 +1894,36 @@ class MetadataExtractor:
if not table_name:
return self._format_response(success=False, error="Missing table_name parameter")
# SECURITY: Validate identifiers
try:
validate_identifier(table_name, "table name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid table name: {table_name}",
message="Table name contains invalid characters"
)
if db_name:
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid database name: {db_name}",
message="Database name contains invalid characters"
)
if catalog_name and catalog_name != "internal":
try:
validate_identifier(catalog_name, "catalog name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid catalog name: {catalog_name}",
message="Catalog name contains invalid characters"
)
try:
indexes = await self.get_table_indexes_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
return self._format_response(success=True, result=indexes)

View File

@@ -901,30 +901,50 @@ class SQLSecurityValidator:
if not self.enable_security_check:
self.logger.debug("SQL security check is disabled, allowing all queries")
return ValidationResult(is_valid=True)
try:
# Parse SQL statement
parsed = sqlparse.parse(sql)[0]
# SECURITY FIX: Parse ALL SQL statements, not just the first one
# This prevents bypassing security checks by injecting additional statements
all_statements = sqlparse.parse(sql)
# Check blocked operations first (more specific)
keyword_result = await self._check_blocked_keywords(parsed)
if not keyword_result.is_valid:
return keyword_result
if not all_statements:
return ValidationResult(
is_valid=False,
error_message="Empty or invalid SQL statement",
risk_level="medium"
)
# Check SQL injection risks
injection_result = await self._check_sql_injection(sql, parsed)
if not injection_result.is_valid:
return injection_result
# SECURITY FIX: Validate each statement individually
for idx, parsed in enumerate(all_statements):
# Skip empty statements (e.g., from trailing semicolons)
if not parsed.tokens or str(parsed).strip() == '':
continue
# Check query complexity
complexity_result = await self._check_query_complexity(parsed)
if not complexity_result.is_valid:
return complexity_result
self.logger.debug(f"Validating SQL statement {idx + 1}/{len(all_statements)}: {str(parsed)[:100]}...")
# Check table access permissions
table_result = await self._check_table_access(parsed, auth_context)
if not table_result.is_valid:
return table_result
# Check blocked operations first (more specific)
keyword_result = await self._check_blocked_keywords(parsed)
if not keyword_result.is_valid:
keyword_result.error_message = f"Statement {idx + 1}: {keyword_result.error_message}"
return keyword_result
# Check SQL injection risks
injection_result = await self._check_sql_injection(sql, parsed)
if not injection_result.is_valid:
injection_result.error_message = f"Statement {idx + 1}: {injection_result.error_message}"
return injection_result
# Check query complexity
complexity_result = await self._check_query_complexity(parsed)
if not complexity_result.is_valid:
complexity_result.error_message = f"Statement {idx + 1}: {complexity_result.error_message}"
return complexity_result
# Check table access permissions
table_result = await self._check_table_access(parsed, auth_context)
if not table_result.is_valid:
table_result.error_message = f"Statement {idx + 1}: {table_result.error_message}"
return table_result
return ValidationResult(is_valid=True)
@@ -1134,6 +1154,10 @@ class SQLSecurityValidator:
self, parsed: Statement, auth_context: AuthContext
) -> ValidationResult:
"""Check table access permissions"""
# If no auth_context, skip table access checks (rely on other security checks)
if auth_context is None:
return ValidationResult(is_valid=True)
# Extract table names from query
tables = self._extract_table_names(parsed)

View File

@@ -26,6 +26,7 @@ from collections import Counter, defaultdict
from .db import DorisConnectionManager
from .logger import get_logger
from .sql_security_utils import get_auth_context
logger = get_logger(__name__)
@@ -192,7 +193,9 @@ class SecurityAnalyticsTools:
LIMIT 10000
"""
result = await connection.execute(audit_sql)
# SECURITY FIX: Pass auth_context to execute
auth_context = get_auth_context()
result = await connection.execute(audit_sql, auth_context=auth_context)
return result.data if result.data else []
except Exception as e:
@@ -215,7 +218,8 @@ class SecurityAnalyticsTools:
LIMIT 10000
"""
result = await connection.execute(simple_audit_sql)
auth_context = get_auth_context()
result = await connection.execute(simple_audit_sql, auth_context=auth_context)
return result.data if result.data else []
except Exception as e2:
@@ -498,7 +502,8 @@ class SecurityAnalyticsTools:
FROM mysql.user
"""
result = await connection.execute(roles_sql)
auth_context = get_auth_context()
result = await connection.execute(roles_sql, auth_context=auth_context)
user_roles = defaultdict(list)
if result.data:

View File

@@ -0,0 +1,301 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
SQL Security Utilities Module
Provides SQL identifier validation, escaping, and safe query building utilities
to prevent SQL injection attacks.
"""
import re
from contextvars import ContextVar
from typing import Optional, Tuple, List, Any
from .logger import get_logger
logger = get_logger(__name__)
# Context variable for auth_context (set by HTTP middleware)
auth_context_var: ContextVar = ContextVar('mcp_auth_context', default=None)
class SQLSecurityError(Exception):
"""Exception raised for SQL security validation failures"""
pass
class SQLSecurityUtils:
"""
SQL Security Utilities for preventing SQL injection attacks.
Provides:
- Identifier validation (database names, table names, column names)
- Safe identifier quoting with backticks
- Safe table reference building
- Auth context retrieval from context variables
"""
# Valid SQL identifier pattern: letters, numbers, underscores
# Must start with letter or underscore, not a number
# Supports Unicode letters for international database/table names
IDENTIFIER_PATTERN = re.compile(r'^[a-zA-Z_\u4e00-\u9fff][a-zA-Z0-9_\u4e00-\u9fff]*$')
# Maximum identifier length (MySQL/Doris standard)
MAX_IDENTIFIER_LENGTH = 64
# SQL reserved keywords that should be quoted
SQL_KEYWORDS = {
'SELECT', 'FROM', 'WHERE', 'INSERT', 'UPDATE', 'DELETE', 'DROP',
'CREATE', 'ALTER', 'TABLE', 'DATABASE', 'INDEX', 'VIEW', 'AND',
'OR', 'NOT', 'NULL', 'TRUE', 'FALSE', 'IN', 'LIKE', 'BETWEEN',
'JOIN', 'LEFT', 'RIGHT', 'INNER', 'OUTER', 'ON', 'AS', 'ORDER',
'BY', 'GROUP', 'HAVING', 'LIMIT', 'OFFSET', 'UNION', 'ALL',
'DISTINCT', 'INTO', 'VALUES', 'SET', 'DEFAULT', 'PRIMARY', 'KEY',
'FOREIGN', 'REFERENCES', 'CHECK', 'UNIQUE', 'CONSTRAINT'
}
@classmethod
def validate_identifier(cls, name: str, identifier_type: str = "identifier") -> str:
"""
Validate a SQL identifier (database name, table name, column name, etc.)
Args:
name: The identifier to validate
identifier_type: Type description for error messages (e.g., "database name", "table name")
Returns:
The validated identifier (unchanged if valid)
Raises:
SQLSecurityError: If the identifier is invalid
"""
if not name:
raise SQLSecurityError(f"Empty {identifier_type} is not allowed")
if not isinstance(name, str):
raise SQLSecurityError(f"Invalid {identifier_type}: must be a string, got {type(name).__name__}")
# Strip whitespace
name = name.strip()
if not name:
raise SQLSecurityError(f"Empty {identifier_type} is not allowed")
# Check length
if len(name) > cls.MAX_IDENTIFIER_LENGTH:
raise SQLSecurityError(
f"Invalid {identifier_type}: '{name[:20]}...' exceeds maximum length of {cls.MAX_IDENTIFIER_LENGTH} characters"
)
# Check for dangerous characters that could be SQL injection
dangerous_chars = ["'", '"', ';', '--', '/*', '*/', '\\', '\x00']
for char in dangerous_chars:
if char in name:
raise SQLSecurityError(
f"Invalid {identifier_type}: '{name}' contains forbidden character '{char}'"
)
# Validate pattern
if not cls.IDENTIFIER_PATTERN.match(name):
raise SQLSecurityError(
f"Invalid {identifier_type}: '{name}' contains invalid characters. "
f"Only letters, numbers, and underscores are allowed, and must start with a letter or underscore."
)
logger.debug(f"Validated {identifier_type}: {name}")
return name
@classmethod
def quote_identifier(cls, name: str, identifier_type: str = "identifier") -> str:
"""
Safely quote a SQL identifier using backticks.
Args:
name: The identifier to quote
identifier_type: Type description for error messages
Returns:
The quoted identifier (e.g., `table_name`)
Raises:
SQLSecurityError: If the identifier is invalid
"""
# First validate the identifier
validated_name = cls.validate_identifier(name, identifier_type)
# Escape any backticks within the name (double them)
escaped_name = validated_name.replace('`', '``')
return f"`{escaped_name}`"
@classmethod
def build_table_reference(
cls,
table_name: str,
db_name: Optional[str] = None,
catalog_name: Optional[str] = None,
quote: bool = True
) -> str:
"""
Build a safe, fully-qualified table reference.
Args:
table_name: The table name (required)
db_name: The database name (optional)
catalog_name: The catalog name (optional)
quote: Whether to quote identifiers with backticks (default: True)
Returns:
A safe table reference string (e.g., `catalog`.`db`.`table`)
Raises:
SQLSecurityError: If any identifier is invalid
"""
parts = []
if catalog_name:
if quote:
parts.append(cls.quote_identifier(catalog_name, "catalog name"))
else:
parts.append(cls.validate_identifier(catalog_name, "catalog name"))
if db_name:
if quote:
parts.append(cls.quote_identifier(db_name, "database name"))
else:
parts.append(cls.validate_identifier(db_name, "database name"))
if quote:
parts.append(cls.quote_identifier(table_name, "table name"))
else:
parts.append(cls.validate_identifier(table_name, "table name"))
return '.'.join(parts)
@classmethod
def build_column_reference(
cls,
column_name: str,
table_name: Optional[str] = None,
quote: bool = True
) -> str:
"""
Build a safe column reference.
Args:
column_name: The column name (required)
table_name: The table name (optional, for qualified references)
quote: Whether to quote identifiers with backticks (default: True)
Returns:
A safe column reference string (e.g., `table`.`column`)
Raises:
SQLSecurityError: If any identifier is invalid
"""
parts = []
if table_name:
if quote:
parts.append(cls.quote_identifier(table_name, "table name"))
else:
parts.append(cls.validate_identifier(table_name, "table name"))
if quote:
parts.append(cls.quote_identifier(column_name, "column name"))
else:
parts.append(cls.validate_identifier(column_name, "column name"))
return '.'.join(parts)
@classmethod
def validate_and_build_where_condition(
cls,
column_name: str,
operator: str = "=",
use_param: bool = True
) -> Tuple[str, bool]:
"""
Build a safe WHERE condition for a column.
Args:
column_name: The column name
operator: The comparison operator (=, !=, <, >, <=, >=, LIKE, IN)
use_param: Whether to use parameterized placeholder (%s)
Returns:
Tuple of (condition_string, needs_param)
e.g., ("`column` = %s", True) or ("`column` = DATABASE()", False)
Raises:
SQLSecurityError: If column name is invalid or operator is not allowed
"""
# Validate column name
quoted_column = cls.quote_identifier(column_name, "column name")
# Validate operator
allowed_operators = {'=', '!=', '<>', '<', '>', '<=', '>=', 'LIKE', 'IN', 'IS'}
if operator.upper() not in allowed_operators:
raise SQLSecurityError(f"Invalid operator: '{operator}'. Allowed: {allowed_operators}")
if use_param:
return f"{quoted_column} {operator} %s", True
else:
return f"{quoted_column} {operator}", False
@staticmethod
def get_auth_context():
"""
Get auth_context from the context variable.
This retrieves the auth_context that was set by the HTTP middleware
during request processing.
Returns:
The auth_context object, or None if not available
"""
try:
auth_context = auth_context_var.get()
if auth_context:
logger.debug(f"Retrieved auth_context from context variable")
return auth_context
except Exception as e:
logger.debug(f"Could not retrieve auth_context: {e}")
return None
@staticmethod
def set_auth_context(auth_context):
"""
Set auth_context in the context variable.
This is typically called by the HTTP middleware during request processing.
Args:
auth_context: The auth_context object to set
"""
auth_context_var.set(auth_context)
logger.debug("Set auth_context in context variable")
# Convenience functions for direct use
validate_identifier = SQLSecurityUtils.validate_identifier
quote_identifier = SQLSecurityUtils.quote_identifier
build_table_reference = SQLSecurityUtils.build_table_reference
build_column_reference = SQLSecurityUtils.build_column_reference
get_auth_context = SQLSecurityUtils.get_auth_context
set_auth_context = SQLSecurityUtils.set_auth_context