fix some security issues (#68)
This commit is contained in:
@@ -31,6 +31,7 @@ from mcp.types import (
|
||||
)
|
||||
|
||||
from ..utils.db import DorisConnectionManager
|
||||
from ..utils.sql_security_utils import get_auth_context
|
||||
|
||||
|
||||
class PromptTemplate:
|
||||
@@ -422,7 +423,8 @@ Please generate accurate and efficient SQL queries based on the above requiremen
|
||||
AND table_type = 'BASE TABLE'
|
||||
"""
|
||||
|
||||
db_result = await connection.execute(db_info_sql)
|
||||
auth_context = get_auth_context()
|
||||
db_result = await connection.execute(db_info_sql, auth_context=auth_context)
|
||||
db_info = db_result.data[0] if db_result.data else {}
|
||||
|
||||
# Get main table list
|
||||
@@ -438,7 +440,7 @@ Please generate accurate and efficient SQL queries based on the above requiremen
|
||||
LIMIT 10
|
||||
"""
|
||||
|
||||
tables_result = await connection.execute(tables_sql)
|
||||
tables_result = await connection.execute(tables_sql, auth_context=auth_context)
|
||||
|
||||
context = f"""Current database statistics:
|
||||
- Total number of tables: {db_info.get("table_count", 0)}
|
||||
|
||||
@@ -26,6 +26,7 @@ from typing import Any
|
||||
from mcp.types import Resource
|
||||
|
||||
from ..utils.db import DorisConnectionManager
|
||||
from ..utils.sql_security_utils import get_auth_context
|
||||
|
||||
|
||||
class TableMetadata:
|
||||
@@ -169,7 +170,8 @@ class DorisResourcesManager:
|
||||
ORDER BY table_name
|
||||
"""
|
||||
|
||||
result = await connection.execute(tables_query)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(tables_query, auth_context=auth_context)
|
||||
tables = []
|
||||
|
||||
for row in result.data:
|
||||
@@ -204,7 +206,8 @@ class DorisResourcesManager:
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
|
||||
result = await connection.execute(columns_query, (table_name,))
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(columns_query, params=(table_name,), auth_context=auth_context)
|
||||
return [dict(row) for row in result.data]
|
||||
|
||||
async def _get_view_metadata(self) -> list[ViewMetadata]:
|
||||
@@ -226,7 +229,8 @@ class DorisResourcesManager:
|
||||
ORDER BY table_name
|
||||
"""
|
||||
|
||||
result = await connection.execute(views_query)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(views_query, auth_context=auth_context)
|
||||
views = []
|
||||
|
||||
for row in result.data:
|
||||
@@ -257,7 +261,8 @@ class DorisResourcesManager:
|
||||
AND table_name = %s
|
||||
"""
|
||||
|
||||
table_result = await connection.execute(table_info_query, (table_name,))
|
||||
auth_context = get_auth_context()
|
||||
table_result = await connection.execute(table_info_query, params=(table_name,), auth_context=auth_context)
|
||||
if not table_result.data:
|
||||
raise ValueError(f"Table {table_name} does not exist")
|
||||
|
||||
@@ -295,7 +300,8 @@ class DorisResourcesManager:
|
||||
ORDER BY index_name, seq_in_index
|
||||
"""
|
||||
|
||||
result = await connection.execute(indexes_query, (table_name,))
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(indexes_query, params=(table_name,), auth_context=auth_context)
|
||||
return [dict(row) for row in result.data]
|
||||
|
||||
async def _get_view_definition(self, view_name: str) -> str:
|
||||
@@ -312,7 +318,8 @@ class DorisResourcesManager:
|
||||
AND table_name = %s
|
||||
"""
|
||||
|
||||
result = await connection.execute(view_query, (view_name,))
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(view_query, params=(view_name,), auth_context=auth_context)
|
||||
if not result.data:
|
||||
raise ValueError(f"View {view_name} does not exist")
|
||||
|
||||
@@ -340,7 +347,8 @@ class DorisResourcesManager:
|
||||
AND table_type = 'BASE TABLE'
|
||||
"""
|
||||
|
||||
table_result = await connection.execute(table_stats_query)
|
||||
auth_context = get_auth_context()
|
||||
table_result = await connection.execute(table_stats_query, auth_context=auth_context)
|
||||
table_stats = table_result.data[0] if table_result.data else {}
|
||||
|
||||
# Get view statistics
|
||||
@@ -350,7 +358,7 @@ class DorisResourcesManager:
|
||||
WHERE table_schema = DATABASE()
|
||||
"""
|
||||
|
||||
view_result = await connection.execute(view_stats_query)
|
||||
view_result = await connection.execute(view_stats_query, auth_context=auth_context)
|
||||
view_stats = view_result.data[0] if view_result.data else {}
|
||||
|
||||
stats_info = {
|
||||
|
||||
@@ -28,6 +28,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.db import DorisConnectionManager
|
||||
from ..utils.sql_security_utils import get_auth_context
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -277,7 +278,8 @@ class DorisADBCQueryTools:
|
||||
# Get BE nodes via SHOW BACKENDS
|
||||
logger.info("No BE hosts configured, getting BE node information via SHOW BACKENDS")
|
||||
connection = await self.connection_manager.get_connection("query")
|
||||
result = await connection.execute("SHOW BACKENDS")
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute("SHOW BACKENDS", auth_context=auth_context)
|
||||
|
||||
be_hosts = []
|
||||
for row in result.data:
|
||||
@@ -383,6 +385,20 @@ class DorisADBCQueryTools:
|
||||
"error_type": "no_connection"
|
||||
}
|
||||
|
||||
# SECURITY FIX: Perform SQL security validation before executing
|
||||
auth_context = get_auth_context()
|
||||
if self.connection_manager.security_manager:
|
||||
# Always perform security validation, even without auth_context
|
||||
# Use a default context for basic SQL security checks
|
||||
validation_result = await self.connection_manager.security_manager.validate_sql_security(sql, auth_context)
|
||||
if not validation_result.is_valid:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"SQL security validation failed: {validation_result.error_message}",
|
||||
"error_type": "security_violation",
|
||||
"risk_level": validation_result.risk_level
|
||||
}
|
||||
|
||||
cursor = self.adbc_client.cursor()
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
@@ -29,6 +29,13 @@ from pathlib import Path
|
||||
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
from .sql_security_utils import (
|
||||
SQLSecurityError,
|
||||
validate_identifier,
|
||||
quote_identifier,
|
||||
build_table_reference,
|
||||
get_auth_context
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -46,10 +53,17 @@ class TableAnalyzer:
|
||||
sample_size: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""Get table summary information"""
|
||||
# SECURITY FIX: Validate table_name and get auth_context
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
except SQLSecurityError as e:
|
||||
raise ValueError(f"Invalid table name: {e}")
|
||||
|
||||
auth_context = get_auth_context()
|
||||
connection = await self.connection_manager.get_connection("query")
|
||||
|
||||
# Get table basic information
|
||||
table_info_sql = f"""
|
||||
# Get table basic information using parameterized query
|
||||
table_info_sql = """
|
||||
SELECT
|
||||
table_name,
|
||||
table_comment,
|
||||
@@ -58,17 +72,17 @@ class TableAnalyzer:
|
||||
engine
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = '{table_name}'
|
||||
AND table_name = %s
|
||||
"""
|
||||
|
||||
table_info_result = await connection.execute(table_info_sql)
|
||||
table_info_result = await connection.execute(table_info_sql, params=(table_name,), auth_context=auth_context)
|
||||
if not table_info_result.data:
|
||||
raise ValueError(f"Table {table_name} does not exist")
|
||||
|
||||
table_info = table_info_result.data[0]
|
||||
|
||||
# Get column information
|
||||
columns_sql = f"""
|
||||
# Get column information using parameterized query
|
||||
columns_sql = """
|
||||
SELECT
|
||||
column_name,
|
||||
data_type,
|
||||
@@ -76,11 +90,11 @@ class TableAnalyzer:
|
||||
column_comment
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = '{table_name}'
|
||||
AND table_name = %s
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
|
||||
columns_result = await connection.execute(columns_sql)
|
||||
columns_result = await connection.execute(columns_sql, params=(table_name,), auth_context=auth_context)
|
||||
|
||||
summary = {
|
||||
"table_name": table_info["table_name"],
|
||||
@@ -92,10 +106,11 @@ class TableAnalyzer:
|
||||
"columns": columns_result.data,
|
||||
}
|
||||
|
||||
# Get sample data
|
||||
# Get sample data using quoted identifier
|
||||
if include_sample and sample_size > 0:
|
||||
sample_sql = f"SELECT * FROM {table_name} LIMIT {sample_size}"
|
||||
sample_result = await connection.execute(sample_sql)
|
||||
quoted_table = quote_identifier(table_name, "table name")
|
||||
sample_sql = f"SELECT * FROM {quoted_table} LIMIT {sample_size}"
|
||||
sample_result = await connection.execute(sample_sql, auth_context=auth_context)
|
||||
summary["sample_data"] = sample_result.data
|
||||
|
||||
return summary
|
||||
@@ -120,7 +135,8 @@ class TableAnalyzer:
|
||||
FROM {table_name}
|
||||
"""
|
||||
|
||||
basic_result = await connection.execute(basic_stats_sql)
|
||||
auth_context = get_auth_context()
|
||||
basic_result = await connection.execute(basic_stats_sql, auth_context=auth_context)
|
||||
if not basic_result.data:
|
||||
return {
|
||||
"success": False,
|
||||
@@ -144,7 +160,7 @@ class TableAnalyzer:
|
||||
LIMIT 20
|
||||
"""
|
||||
|
||||
distribution_result = await connection.execute(distribution_sql)
|
||||
distribution_result = await connection.execute(distribution_sql, auth_context=auth_context)
|
||||
analysis["value_distribution"] = distribution_result.data
|
||||
|
||||
if analysis_type == "detailed":
|
||||
@@ -159,7 +175,7 @@ class TableAnalyzer:
|
||||
WHERE {column_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
numeric_result = await connection.execute(numeric_stats_sql)
|
||||
numeric_result = await connection.execute(numeric_stats_sql, auth_context=auth_context)
|
||||
if numeric_result.data:
|
||||
analysis.update(numeric_result.data[0])
|
||||
except Exception:
|
||||
@@ -196,7 +212,8 @@ class TableAnalyzer:
|
||||
AND table_name = '{table_name}'
|
||||
"""
|
||||
|
||||
table_result = await connection.execute(table_info_sql)
|
||||
auth_context = get_auth_context()
|
||||
table_result = await connection.execute(table_info_sql, auth_context=auth_context)
|
||||
if not table_result.data:
|
||||
raise ValueError(f"Table {table_name} does not exist")
|
||||
|
||||
@@ -211,7 +228,7 @@ class TableAnalyzer:
|
||||
AND table_name != %s
|
||||
"""
|
||||
|
||||
all_tables_result = await connection.execute(all_tables_sql, (table_name,))
|
||||
all_tables_result = await connection.execute(all_tables_sql, params=(table_name,), auth_context=auth_context)
|
||||
|
||||
return {
|
||||
"center_table": table_result.data[0],
|
||||
@@ -291,7 +308,8 @@ class PerformanceMonitor:
|
||||
LIMIT 20
|
||||
"""
|
||||
|
||||
tables_result = await connection.execute(tables_sql)
|
||||
auth_context = get_auth_context()
|
||||
tables_result = await connection.execute(tables_sql, auth_context=auth_context)
|
||||
stats = {
|
||||
"metric_type": "tables",
|
||||
"time_range": time_range,
|
||||
@@ -380,8 +398,14 @@ class SQLAnalyzer:
|
||||
logger.info(f"Generating SQL explain for query ID: {query_id}")
|
||||
|
||||
# Switch database if specified
|
||||
# SECURITY FIX: Validate and quote db_name
|
||||
if db_name:
|
||||
await self.connection_manager.execute_query("explain_session", f"USE {db_name}")
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
return {"success": False, "error": f"Invalid database name: {e}"}
|
||||
safe_db = quote_identifier(db_name, "database name")
|
||||
await self.connection_manager.execute_query("explain_session", f"USE {safe_db}")
|
||||
|
||||
# Construct EXPLAIN query
|
||||
explain_type = "EXPLAIN VERBOSE" if verbose else "EXPLAIN"
|
||||
@@ -515,24 +539,36 @@ class SQLAnalyzer:
|
||||
|
||||
try:
|
||||
# Switch to specified database/catalog if provided
|
||||
# SECURITY FIX: Validate identifiers before using in SQL
|
||||
if catalog_name:
|
||||
await connection.execute(f"SWITCH `{catalog_name}`")
|
||||
try:
|
||||
validate_identifier(catalog_name, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
return {"success": False, "error": f"Invalid catalog name: {e}"}
|
||||
safe_catalog = quote_identifier(catalog_name, "catalog name")
|
||||
auth_context = get_auth_context()
|
||||
await connection.execute(f"SWITCH {safe_catalog}", auth_context=auth_context)
|
||||
if db_name:
|
||||
await connection.execute(f"USE `{db_name}`")
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
return {"success": False, "error": f"Invalid database name: {e}"}
|
||||
safe_db = quote_identifier(db_name, "database name")
|
||||
await connection.execute(f"USE {safe_db}", auth_context=auth_context)
|
||||
|
||||
# Set trace ID for the session using session variable
|
||||
# According to official docs: set session_context="trace_id:your_trace_id"
|
||||
await connection.execute(f'set session_context="trace_id:{trace_id}"')
|
||||
await connection.execute(f'set session_context="trace_id:{trace_id}"', auth_context=auth_context)
|
||||
logger.info(f"Set trace ID: {trace_id}")
|
||||
|
||||
# Enable profile
|
||||
await connection.execute(f'set enable_profile=true')
|
||||
await connection.execute(f'set enable_profile=true', auth_context=auth_context)
|
||||
logger.info(f"Enabled profile")
|
||||
|
||||
# Execute the SQL statement
|
||||
logger.info(f"Executing SQL with trace ID: {sql}")
|
||||
start_time = time.time()
|
||||
sql_result = await connection.execute(sql)
|
||||
sql_result = await connection.execute(sql, auth_context=auth_context)
|
||||
execution_time = time.time() - start_time
|
||||
logger.info(f"SQL execution completed in {execution_time:.3f}s")
|
||||
|
||||
|
||||
@@ -26,6 +26,13 @@ from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
from .sql_security_utils import (
|
||||
SQLSecurityError,
|
||||
validate_identifier,
|
||||
quote_identifier,
|
||||
build_table_reference,
|
||||
get_auth_context
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -43,24 +50,30 @@ class DataExplorationTools:
|
||||
|
||||
def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str:
|
||||
"""Build full table name with catalog and database using three-part naming convention"""
|
||||
# Default catalog for internal tables
|
||||
# SECURITY FIX: Use build_table_reference for safe identifier handling
|
||||
effective_catalog = catalog_name if catalog_name else "internal"
|
||||
|
||||
if db_name:
|
||||
return f"{effective_catalog}.{db_name}.{table_name}"
|
||||
return build_table_reference(table_name, db_name, effective_catalog)
|
||||
else:
|
||||
# If no db_name provided, need to determine the current database
|
||||
return f"{effective_catalog}.{table_name}"
|
||||
return build_table_reference(table_name, catalog_name=effective_catalog)
|
||||
|
||||
async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]:
|
||||
"""Get basic table information including row count"""
|
||||
try:
|
||||
# SECURITY FIX: Get auth_context for security validation
|
||||
# table_name should already be validated by _build_full_table_name
|
||||
auth_context = get_auth_context()
|
||||
|
||||
count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}"
|
||||
result = await connection.execute(count_sql)
|
||||
result = await connection.execute(count_sql, auth_context=auth_context)
|
||||
|
||||
if result.data:
|
||||
return {"row_count": result.data[0]["row_count"]}
|
||||
return None
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Security validation failed for table {table_name}: {str(e)}")
|
||||
return {"row_count": 0}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get basic info for table {table_name}: {str(e)}")
|
||||
return {"row_count": 0}
|
||||
@@ -68,10 +81,24 @@ class DataExplorationTools:
|
||||
async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]:
|
||||
"""Get detailed column information"""
|
||||
try:
|
||||
where_conditions = [f"table_name = '{table_name}'"]
|
||||
# SECURITY FIX: Validate identifiers and use parameterized query
|
||||
auth_context = get_auth_context()
|
||||
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
if db_name:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return []
|
||||
|
||||
# Build parameterized query
|
||||
params = [table_name]
|
||||
where_conditions = ["table_name = %s"]
|
||||
|
||||
if db_name:
|
||||
where_conditions.append(f"table_schema = '{db_name}'")
|
||||
where_conditions.append("table_schema = %s")
|
||||
params.append(db_name)
|
||||
else:
|
||||
where_conditions.append("table_schema = DATABASE()")
|
||||
|
||||
@@ -87,9 +114,12 @@ class DataExplorationTools:
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
|
||||
result = await connection.execute(columns_sql)
|
||||
result = await connection.execute(columns_sql, params=tuple(params), auth_context=auth_context)
|
||||
return result.data if result.data else []
|
||||
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Security validation failed: {str(e)}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get columns info for table {table_name}: {str(e)}")
|
||||
return []
|
||||
@@ -177,7 +207,8 @@ class DataExplorationTools:
|
||||
WHERE {col_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
stats_result = await connection.execute(stats_sql)
|
||||
auth_context = get_auth_context()
|
||||
stats_result = await connection.execute(stats_sql, auth_context=auth_context)
|
||||
|
||||
if stats_result.data and stats_result.data[0]["count"] > 0:
|
||||
stats = stats_result.data[0]
|
||||
@@ -229,7 +260,8 @@ class DataExplorationTools:
|
||||
WHERE {col_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
result = await connection.execute(percentile_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(percentile_sql, auth_context=auth_context)
|
||||
|
||||
if result.data:
|
||||
data = result.data[0]
|
||||
@@ -268,7 +300,8 @@ class DataExplorationTools:
|
||||
WHERE {col_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
result = await connection.execute(outlier_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(outlier_sql, auth_context=auth_context)
|
||||
|
||||
if result.data:
|
||||
data = result.data[0]
|
||||
@@ -359,7 +392,8 @@ class DataExplorationTools:
|
||||
{sampling_info.get('sample_query_suffix', '')}
|
||||
"""
|
||||
|
||||
cardinality_result = await connection.execute(cardinality_sql)
|
||||
auth_context = get_auth_context()
|
||||
cardinality_result = await connection.execute(cardinality_sql, auth_context=auth_context)
|
||||
|
||||
if cardinality_result.data:
|
||||
cardinality_data = cardinality_result.data[0]
|
||||
@@ -408,7 +442,8 @@ class DataExplorationTools:
|
||||
LIMIT 20
|
||||
"""
|
||||
|
||||
result = await connection.execute(distribution_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(distribution_sql, auth_context=auth_context)
|
||||
|
||||
if result.data:
|
||||
distribution = []
|
||||
@@ -458,7 +493,8 @@ class DataExplorationTools:
|
||||
WHERE {col_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
range_result = await connection.execute(range_sql)
|
||||
auth_context = get_auth_context()
|
||||
range_result = await connection.execute(range_sql, auth_context=auth_context)
|
||||
|
||||
if range_result.data and range_result.data[0]["non_null_count"] > 0:
|
||||
range_data = range_result.data[0]
|
||||
@@ -539,7 +575,8 @@ class DataExplorationTools:
|
||||
ORDER BY day_of_week
|
||||
"""
|
||||
|
||||
weekly_result = await connection.execute(weekly_pattern_sql)
|
||||
auth_context = get_auth_context()
|
||||
weekly_result = await connection.execute(weekly_pattern_sql, auth_context=auth_context)
|
||||
|
||||
weekly_pattern = []
|
||||
if weekly_result.data:
|
||||
@@ -561,7 +598,7 @@ class DataExplorationTools:
|
||||
LIMIT 12
|
||||
"""
|
||||
|
||||
monthly_result = await connection.execute(monthly_trend_sql)
|
||||
monthly_result = await connection.execute(monthly_trend_sql, auth_context=auth_context)
|
||||
monthly_trend = "stable" # Simplified trend analysis
|
||||
|
||||
if monthly_result.data and len(monthly_result.data) > 3:
|
||||
@@ -646,7 +683,8 @@ class DataExplorationTools:
|
||||
FROM {table_expr}
|
||||
"""
|
||||
|
||||
result = await connection.execute(null_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(null_sql, auth_context=auth_context)
|
||||
if result.data:
|
||||
data = result.data[0]
|
||||
total_count = data["total_count"]
|
||||
|
||||
@@ -26,6 +26,13 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
from .sql_security_utils import (
|
||||
SQLSecurityError,
|
||||
validate_identifier,
|
||||
quote_identifier,
|
||||
build_table_reference,
|
||||
get_auth_context
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -216,26 +223,34 @@ class DataGovernanceTools:
|
||||
# ==================== Private Helper Methods ====================
|
||||
|
||||
def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str:
|
||||
"""Build full table name - use three-level naming convention"""
|
||||
"""Build full table name - use three-level naming convention with security validation"""
|
||||
# SECURITY FIX: Use build_table_reference for safe identifier handling
|
||||
# Default catalog is internal for internal tables
|
||||
effective_catalog = catalog_name if catalog_name else "internal"
|
||||
|
||||
if db_name:
|
||||
return f"{effective_catalog}.{db_name}.{table_name}"
|
||||
return build_table_reference(table_name, db_name, effective_catalog)
|
||||
else:
|
||||
# If db_name is not provided, need to determine current database
|
||||
return f"{effective_catalog}.{table_name}"
|
||||
return build_table_reference(table_name, catalog_name=effective_catalog)
|
||||
|
||||
async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]:
|
||||
"""Get table basic information"""
|
||||
try:
|
||||
# SECURITY FIX: Get auth_context for security validation
|
||||
# table_name should already be validated by _build_full_table_name
|
||||
auth_context = get_auth_context()
|
||||
|
||||
# Try to get table row count
|
||||
count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}"
|
||||
result = await connection.execute(count_sql)
|
||||
result = await connection.execute(count_sql, auth_context=auth_context)
|
||||
|
||||
if result.data:
|
||||
return {"row_count": result.data[0]["row_count"]}
|
||||
return None
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Security validation failed for table {table_name}: {str(e)}")
|
||||
return {"row_count": 0}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get basic info for table {table_name}: {str(e)}")
|
||||
return {"row_count": 0}
|
||||
@@ -243,11 +258,24 @@ class DataGovernanceTools:
|
||||
async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]:
|
||||
"""Get table column information"""
|
||||
try:
|
||||
# Build query conditions
|
||||
where_conditions = [f"table_name = '{table_name}'"]
|
||||
# SECURITY FIX: Validate identifiers and use parameterized query
|
||||
auth_context = get_auth_context()
|
||||
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
if db_name:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return []
|
||||
|
||||
# Build parameterized query conditions
|
||||
params = [table_name]
|
||||
where_conditions = ["table_name = %s"]
|
||||
|
||||
if db_name:
|
||||
where_conditions.append(f"table_schema = '{db_name}'")
|
||||
where_conditions.append("table_schema = %s")
|
||||
params.append(db_name)
|
||||
else:
|
||||
where_conditions.append("table_schema = DATABASE()")
|
||||
|
||||
@@ -263,30 +291,49 @@ class DataGovernanceTools:
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
|
||||
result = await connection.execute(columns_sql)
|
||||
result = await connection.execute(columns_sql, params=tuple(params), auth_context=auth_context)
|
||||
return result.data if result.data else []
|
||||
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Security validation failed: {str(e)}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get columns info for table {table_name}: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _analyze_column_completeness(self, connection, table_name: str, columns_info: List[Dict]) -> Dict[str, Any]:
|
||||
"""Analyze column completeness"""
|
||||
# SECURITY FIX: Get auth_context for security validation
|
||||
auth_context = get_auth_context()
|
||||
column_completeness = {}
|
||||
|
||||
for column in columns_info:
|
||||
column_name = column["column_name"]
|
||||
try:
|
||||
# SECURITY FIX: Validate column name before using in SQL
|
||||
try:
|
||||
validate_identifier(column_name, "column name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid column name rejected: {e}")
|
||||
column_completeness[column_name] = {
|
||||
"error": f"Invalid column name: {e}",
|
||||
"completeness_score": 0.0
|
||||
}
|
||||
continue
|
||||
|
||||
# Use quoted identifier for column name
|
||||
quoted_column = quote_identifier(column_name, "column name")
|
||||
|
||||
# Calculate null value statistics
|
||||
null_sql = f"""
|
||||
SELECT
|
||||
COUNT(*) as total_count,
|
||||
COUNT({column_name}) as non_null_count,
|
||||
COUNT(*) - COUNT({column_name}) as null_count
|
||||
COUNT({quoted_column}) as non_null_count,
|
||||
COUNT(*) - COUNT({quoted_column}) as null_count
|
||||
FROM {table_name}
|
||||
"""
|
||||
|
||||
result = await connection.execute(null_sql)
|
||||
result = await connection.execute(null_sql, auth_context=auth_context)
|
||||
if result.data:
|
||||
stats = result.data[0]
|
||||
total_count = stats["total_count"]
|
||||
@@ -304,6 +351,12 @@ class DataGovernanceTools:
|
||||
"completeness_score": round(completeness_score, 4)
|
||||
}
|
||||
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Security validation failed for column {column_name}: {str(e)}")
|
||||
column_completeness[column_name] = {
|
||||
"error": str(e),
|
||||
"completeness_score": 0.0
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze completeness for column {column_name}: {str(e)}")
|
||||
column_completeness[column_name] = {
|
||||
@@ -333,7 +386,8 @@ class DataGovernanceTools:
|
||||
FROM {table_name}
|
||||
"""
|
||||
|
||||
result = await connection.execute(compliance_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(compliance_sql, auth_context=auth_context)
|
||||
if result.data:
|
||||
stats = result.data[0]
|
||||
pass_count = stats["pass_count"] or 0
|
||||
@@ -378,7 +432,8 @@ class DataGovernanceTools:
|
||||
) t
|
||||
"""
|
||||
|
||||
result = await connection.execute(duplicate_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(duplicate_sql, auth_context=auth_context)
|
||||
if result.data and result.data[0]["duplicate_count"] > 0:
|
||||
issues.append({
|
||||
"type": "duplicate_primary_keys",
|
||||
@@ -456,10 +511,21 @@ class DataGovernanceTools:
|
||||
async def _verify_column_exists(self, connection, table_name: str, column_name: str) -> bool:
|
||||
"""Verify if column exists"""
|
||||
try:
|
||||
# Simple verification method: try to query the column
|
||||
verify_sql = f"SELECT {column_name} FROM {table_name} LIMIT 1"
|
||||
await connection.execute(verify_sql)
|
||||
# SECURITY FIX: Validate and quote column name
|
||||
try:
|
||||
validate_identifier(column_name, "column name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid column name rejected: {e}")
|
||||
return False
|
||||
|
||||
safe_column = quote_identifier(column_name, "column name")
|
||||
# table_name is already safe (from _build_full_table_name)
|
||||
verify_sql = f"SELECT {safe_column} FROM {table_name} LIMIT 1"
|
||||
auth_context = get_auth_context()
|
||||
await connection.execute(verify_sql, auth_context=auth_context)
|
||||
return True
|
||||
except SQLSecurityError:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@@ -469,21 +535,34 @@ class DataGovernanceTools:
|
||||
source_chain = []
|
||||
|
||||
try:
|
||||
# SECURITY FIX: Validate table name and use parameterized-like approach
|
||||
table_name_part = table_name.split('.')[-1]
|
||||
try:
|
||||
validate_identifier(table_name_part, "table name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid table name rejected: {e}")
|
||||
return []
|
||||
|
||||
# Escape special characters for LIKE pattern
|
||||
safe_pattern = table_name_part.replace('%', r'\%').replace('_', r'\_')
|
||||
like_pattern = f"%{safe_pattern}%"
|
||||
|
||||
# Try to find related INSERT/CREATE TABLE AS SELECT statements from audit logs (one year range)
|
||||
auth_context = get_auth_context()
|
||||
audit_sql = """
|
||||
SELECT
|
||||
stmt as sql_statement,
|
||||
`time` as execution_time,
|
||||
`user` as user_name
|
||||
FROM internal.__internal_schema.audit_log
|
||||
WHERE stmt LIKE '%{}%'
|
||||
WHERE stmt LIKE %s
|
||||
AND (stmt LIKE '%INSERT%' OR stmt LIKE '%CREATE%' OR stmt LIKE '%SELECT%')
|
||||
AND `time` >= DATE_SUB(NOW(), INTERVAL 1 YEAR)
|
||||
ORDER BY `time` DESC
|
||||
LIMIT 50
|
||||
""".format(table_name.split('.')[-1]) # Use the last part of table name
|
||||
"""
|
||||
|
||||
result = await connection.execute(audit_sql)
|
||||
result = await connection.execute(audit_sql, params=(like_pattern,), auth_context=auth_context)
|
||||
|
||||
if result.data:
|
||||
for i, log_entry in enumerate(result.data[:depth]):
|
||||
@@ -556,19 +635,33 @@ class DataGovernanceTools:
|
||||
downstream_usage = []
|
||||
|
||||
try:
|
||||
# SECURITY FIX: Validate inputs and use parameterized-like approach
|
||||
table_name_part = table_name.split('.')[-1]
|
||||
try:
|
||||
validate_identifier(table_name_part, "table name")
|
||||
validate_identifier(column_name, "column name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return []
|
||||
|
||||
# Escape special characters for LIKE pattern
|
||||
safe_table_pattern = f"%{table_name_part.replace('%', r'\\%').replace('_', r'\\_')}%"
|
||||
safe_column_pattern = f"%{column_name.replace('%', r'\\%').replace('_', r'\\_')}%"
|
||||
|
||||
# Find other tables that might use this field (through audit logs, one year range)
|
||||
auth_context = get_auth_context()
|
||||
usage_sql = """
|
||||
SELECT DISTINCT
|
||||
stmt as sql_statement
|
||||
FROM internal.__internal_schema.audit_log
|
||||
WHERE stmt LIKE '%{}%'
|
||||
AND stmt LIKE '%{}%'
|
||||
WHERE stmt LIKE %s
|
||||
AND stmt LIKE %s
|
||||
AND stmt LIKE '%SELECT%'
|
||||
AND `time` >= DATE_SUB(NOW(), INTERVAL 1 YEAR)
|
||||
LIMIT 20
|
||||
""".format(table_name.split('.')[-1], column_name)
|
||||
"""
|
||||
|
||||
result = await connection.execute(usage_sql)
|
||||
result = await connection.execute(usage_sql, params=(safe_table_pattern, safe_column_pattern), auth_context=auth_context)
|
||||
|
||||
if result.data:
|
||||
for entry in result.data:
|
||||
@@ -634,14 +727,20 @@ class DataGovernanceTools:
|
||||
async def _get_all_tables(self, connection, catalog_name: Optional[str], db_name: Optional[str]) -> List[str]:
|
||||
"""Get list of all tables"""
|
||||
try:
|
||||
where_conditions = []
|
||||
auth_context = get_auth_context()
|
||||
params = []
|
||||
|
||||
# SECURITY FIX: Use parameterized query
|
||||
if db_name:
|
||||
where_conditions.append(f"table_schema = '{db_name}'")
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid database name rejected: {e}")
|
||||
return []
|
||||
where_clause = "table_schema = %s"
|
||||
params.append(db_name)
|
||||
else:
|
||||
where_conditions.append("table_schema = DATABASE()")
|
||||
|
||||
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
|
||||
where_clause = "table_schema = DATABASE()"
|
||||
|
||||
tables_sql = f"""
|
||||
SELECT table_name
|
||||
@@ -651,7 +750,7 @@ class DataGovernanceTools:
|
||||
ORDER BY table_name
|
||||
"""
|
||||
|
||||
result = await connection.execute(tables_sql)
|
||||
result = await connection.execute(tables_sql, params=tuple(params) if params else None, auth_context=auth_context)
|
||||
return [row["table_name"] for row in result.data] if result.data else []
|
||||
|
||||
except Exception as e:
|
||||
@@ -728,15 +827,23 @@ class DataGovernanceTools:
|
||||
async def _get_freshness_from_partition_info(self, connection, table_name: str) -> Optional[Dict]:
|
||||
"""Get freshness from partition information"""
|
||||
try:
|
||||
# Query partition information (if table has partitions)
|
||||
partition_sql = f"""
|
||||
# SECURITY FIX: Validate and use parameterized query
|
||||
table_name_part = table_name.split('.')[-1]
|
||||
try:
|
||||
validate_identifier(table_name_part, "table name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid table name rejected: {e}")
|
||||
return None
|
||||
|
||||
auth_context = get_auth_context()
|
||||
partition_sql = """
|
||||
SELECT MAX(CREATE_TIME) as last_update
|
||||
FROM information_schema.partitions
|
||||
WHERE table_name = '{table_name.split('.')[-1]}'
|
||||
WHERE table_name = %s
|
||||
AND CREATE_TIME IS NOT NULL
|
||||
"""
|
||||
|
||||
result = await connection.execute(partition_sql)
|
||||
result = await connection.execute(partition_sql, params=(table_name_part,), auth_context=auth_context)
|
||||
if result.data and result.data[0]["last_update"]:
|
||||
return {
|
||||
"last_update": result.data[0]["last_update"],
|
||||
@@ -759,7 +866,8 @@ class DataGovernanceTools:
|
||||
FROM {table_name}
|
||||
"""
|
||||
|
||||
result = await connection.execute(max_time_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(max_time_sql, auth_context=auth_context)
|
||||
if result.data and result.data[0]["last_update"]:
|
||||
return {
|
||||
"last_update": result.data[0]["last_update"],
|
||||
@@ -773,15 +881,23 @@ class DataGovernanceTools:
|
||||
async def _get_freshness_from_table_metadata(self, connection, table_name: str) -> Optional[Dict]:
|
||||
"""Get freshness from table metadata"""
|
||||
try:
|
||||
# Query table's update time
|
||||
metadata_sql = f"""
|
||||
# SECURITY FIX: Validate and use parameterized query
|
||||
table_name_part = table_name.split('.')[-1]
|
||||
try:
|
||||
validate_identifier(table_name_part, "table name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid table name rejected: {e}")
|
||||
return None
|
||||
|
||||
auth_context = get_auth_context()
|
||||
metadata_sql = """
|
||||
SELECT UPDATE_TIME as last_update
|
||||
FROM information_schema.tables
|
||||
WHERE table_name = '{table_name.split('.')[-1]}'
|
||||
WHERE table_name = %s
|
||||
AND UPDATE_TIME IS NOT NULL
|
||||
"""
|
||||
|
||||
result = await connection.execute(metadata_sql)
|
||||
result = await connection.execute(metadata_sql, params=(table_name_part,), auth_context=auth_context)
|
||||
if result.data and result.data[0]["last_update"]:
|
||||
return {
|
||||
"last_update": result.data[0]["last_update"],
|
||||
@@ -795,10 +911,19 @@ class DataGovernanceTools:
|
||||
async def _find_timestamp_columns(self, connection, table_name: str) -> List[str]:
|
||||
"""Find possible timestamp fields"""
|
||||
try:
|
||||
timestamp_sql = f"""
|
||||
# SECURITY FIX: Validate and use parameterized query
|
||||
table_name_part = table_name.split('.')[-1]
|
||||
try:
|
||||
validate_identifier(table_name_part, "table name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid table name rejected: {e}")
|
||||
return []
|
||||
|
||||
auth_context = get_auth_context()
|
||||
timestamp_sql = """
|
||||
SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = '{table_name.split('.')[-1]}'
|
||||
WHERE table_name = %s
|
||||
AND (
|
||||
data_type IN ('datetime', 'timestamp', 'date')
|
||||
OR column_name LIKE '%time%'
|
||||
@@ -815,7 +940,7 @@ class DataGovernanceTools:
|
||||
END
|
||||
"""
|
||||
|
||||
result = await connection.execute(timestamp_sql)
|
||||
result = await connection.execute(timestamp_sql, params=(table_name_part,), auth_context=auth_context)
|
||||
return [row["column_name"] for row in result.data] if result.data else []
|
||||
|
||||
except Exception:
|
||||
|
||||
@@ -31,6 +31,12 @@ from collections import Counter, defaultdict
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
from .config import DorisConfig
|
||||
from .sql_security_utils import (
|
||||
SQLSecurityError,
|
||||
validate_identifier,
|
||||
build_table_reference,
|
||||
get_auth_context
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -299,23 +305,26 @@ class DataQualityTools:
|
||||
# ===========================================
|
||||
|
||||
def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str:
|
||||
"""Build full table name"""
|
||||
if catalog_name and db_name:
|
||||
return f"{catalog_name}.{db_name}.{table_name}"
|
||||
elif db_name:
|
||||
return f"{db_name}.{table_name}"
|
||||
else:
|
||||
return table_name
|
||||
"""Build full table name with security validation"""
|
||||
# SECURITY FIX: Use build_table_reference for safe identifier handling
|
||||
return build_table_reference(table_name, db_name, catalog_name)
|
||||
|
||||
async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]:
|
||||
"""Get basic table information"""
|
||||
try:
|
||||
# SECURITY FIX: table_name should already be validated by _build_full_table_name
|
||||
# But we add auth_context for security validation
|
||||
auth_context = get_auth_context()
|
||||
|
||||
# Try to get row count
|
||||
count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}"
|
||||
result = await connection.execute(count_sql)
|
||||
result = await connection.execute(count_sql, auth_context=auth_context)
|
||||
if result.data:
|
||||
return {"row_count": result.data[0]["row_count"]}
|
||||
return None
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Security validation failed: {str(e)}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get table basic info: {str(e)}")
|
||||
return None
|
||||
@@ -323,9 +332,13 @@ class DataQualityTools:
|
||||
async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]:
|
||||
"""Get table column information"""
|
||||
try:
|
||||
# Build DESCRIBE query
|
||||
describe_sql = f"DESCRIBE {self._build_full_table_name(table_name, catalog_name, db_name)}"
|
||||
result = await connection.execute(describe_sql)
|
||||
# SECURITY FIX: Build safe table reference and pass auth_context
|
||||
auth_context = get_auth_context()
|
||||
|
||||
# Build DESCRIBE query with safe table reference
|
||||
safe_table_ref = self._build_full_table_name(table_name, catalog_name, db_name)
|
||||
describe_sql = f"DESCRIBE {safe_table_ref}"
|
||||
result = await connection.execute(describe_sql, auth_context=auth_context)
|
||||
|
||||
columns_info = []
|
||||
if result.data:
|
||||
@@ -339,6 +352,9 @@ class DataQualityTools:
|
||||
})
|
||||
|
||||
return columns_info
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Security validation failed: {str(e)}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get table columns info: {str(e)}")
|
||||
return []
|
||||
@@ -346,7 +362,32 @@ class DataQualityTools:
|
||||
async def _get_table_partitions(self, connection, table_name: str, db_name: Optional[str] = None) -> List[Dict]:
|
||||
"""Get table partition information"""
|
||||
try:
|
||||
# Query partition information
|
||||
# SECURITY FIX: Validate identifiers and use parameterized query
|
||||
auth_context = get_auth_context()
|
||||
|
||||
# Validate table_name
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
if db_name:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return []
|
||||
|
||||
# Build parameterized query
|
||||
params = []
|
||||
where_conditions = []
|
||||
|
||||
if db_name:
|
||||
where_conditions.append("TABLE_SCHEMA = %s")
|
||||
params.append(db_name)
|
||||
else:
|
||||
where_conditions.append("TABLE_SCHEMA = ''")
|
||||
|
||||
where_conditions.append("TABLE_NAME = %s")
|
||||
params.append(table_name)
|
||||
where_conditions.append("PARTITION_NAME IS NOT NULL")
|
||||
|
||||
partition_sql = f"""
|
||||
SELECT
|
||||
PARTITION_NAME,
|
||||
@@ -355,12 +396,10 @@ class DataQualityTools:
|
||||
DATA_LENGTH,
|
||||
INDEX_LENGTH
|
||||
FROM information_schema.PARTITIONS
|
||||
WHERE TABLE_SCHEMA = '{db_name or ""}'
|
||||
AND TABLE_NAME = '{table_name}'
|
||||
AND PARTITION_NAME IS NOT NULL
|
||||
WHERE {' AND '.join(where_conditions)}
|
||||
"""
|
||||
|
||||
result = await connection.execute(partition_sql)
|
||||
result = await connection.execute(partition_sql, params=tuple(params), auth_context=auth_context)
|
||||
partitions = []
|
||||
if result.data:
|
||||
for row in result.data:
|
||||
@@ -373,6 +412,9 @@ class DataQualityTools:
|
||||
})
|
||||
|
||||
return partitions
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Security validation failed: {str(e)}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get table partitions: {str(e)}")
|
||||
return []
|
||||
@@ -417,7 +459,8 @@ class DataQualityTools:
|
||||
if db_name
|
||||
else f"SHOW CREATE TABLE {table_name}"
|
||||
)
|
||||
result = await connection.execute(query)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(query, auth_context=auth_context)
|
||||
if result.data:
|
||||
return result.data[0].get("Create Table")
|
||||
return None
|
||||
@@ -428,8 +471,16 @@ class DataQualityTools:
|
||||
async def _get_table_size_info(self, connection, table_name: str) -> Dict[str, Any]:
|
||||
"""Get table size information"""
|
||||
try:
|
||||
# Query table size information
|
||||
size_sql = f"""
|
||||
# SECURITY FIX: Validate and use parameterized query
|
||||
table_name_part = table_name.split('.')[-1]
|
||||
try:
|
||||
validate_identifier(table_name_part, "table name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid table name rejected: {e}")
|
||||
return {"engine": "Unknown", "estimated_rows": 0, "data_length": 0, "index_length": 0, "total_size": 0}
|
||||
|
||||
auth_context = get_auth_context()
|
||||
size_sql = """
|
||||
SELECT
|
||||
table_name,
|
||||
engine,
|
||||
@@ -438,10 +489,10 @@ class DataQualityTools:
|
||||
index_length,
|
||||
(data_length + index_length) as total_size
|
||||
FROM information_schema.tables
|
||||
WHERE table_name = '{table_name.split('.')[-1]}'
|
||||
WHERE table_name = %s
|
||||
"""
|
||||
|
||||
result = await connection.execute(size_sql)
|
||||
result = await connection.execute(size_sql, params=(table_name_part,), auth_context=auth_context)
|
||||
if result.data and result.data[0]:
|
||||
row = result.data[0]
|
||||
return {
|
||||
@@ -582,7 +633,8 @@ class DataQualityTools:
|
||||
|
||||
batch_sql = f"SELECT {', '.join(select_clauses)} FROM {table_expr}"
|
||||
|
||||
result = await connection.execute(batch_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(batch_sql, auth_context=auth_context)
|
||||
if not result.data:
|
||||
return {"error": "No data returned from batch completeness query"}
|
||||
|
||||
@@ -664,7 +716,8 @@ class DataQualityTools:
|
||||
|
||||
batch_sql = f"SELECT {', '.join(select_clauses)} FROM {table_expr}"
|
||||
|
||||
result = await connection.execute(batch_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(batch_sql, auth_context=auth_context)
|
||||
if not result.data:
|
||||
return {}
|
||||
|
||||
@@ -705,7 +758,8 @@ class DataQualityTools:
|
||||
LIMIT 10
|
||||
"""
|
||||
|
||||
result = await connection.execute(freq_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(freq_sql, auth_context=auth_context)
|
||||
frequencies = result.data if result.data else []
|
||||
|
||||
categorical_results[col_name] = {
|
||||
@@ -738,7 +792,8 @@ class DataQualityTools:
|
||||
|
||||
batch_sql = f"SELECT {', '.join(select_clauses)} FROM {table_expr}"
|
||||
|
||||
result = await connection.execute(batch_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(batch_sql, auth_context=auth_context)
|
||||
if not result.data:
|
||||
return {}
|
||||
|
||||
@@ -780,7 +835,8 @@ class DataQualityTools:
|
||||
FROM {table_expr}
|
||||
"""
|
||||
|
||||
result = await connection.execute(completeness_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(completeness_sql, auth_context=auth_context)
|
||||
if result.data:
|
||||
stats = result.data[0]
|
||||
total_count = stats["total_count"]
|
||||
@@ -906,7 +962,8 @@ class DataQualityTools:
|
||||
WHERE {col_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
result = await connection.execute(stats_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(stats_sql, auth_context=auth_context)
|
||||
if result.data and result.data[0]["non_null_count"] > 0:
|
||||
stats = result.data[0]
|
||||
numeric_analysis[col_name] = {
|
||||
@@ -945,7 +1002,8 @@ class DataQualityTools:
|
||||
WHERE {col_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
cardinality_result = await connection.execute(cardinality_sql)
|
||||
auth_context = get_auth_context()
|
||||
cardinality_result = await connection.execute(cardinality_sql, auth_context=auth_context)
|
||||
|
||||
if cardinality_result.data:
|
||||
stats = cardinality_result.data[0]
|
||||
@@ -969,7 +1027,7 @@ class DataQualityTools:
|
||||
LIMIT 10
|
||||
"""
|
||||
|
||||
top_values_result = await connection.execute(top_values_sql)
|
||||
top_values_result = await connection.execute(top_values_sql, auth_context=auth_context)
|
||||
if top_values_result.data:
|
||||
categorical_analysis[col_name]["top_values"] = [
|
||||
{"value": row[col_name], "count": row["count"]}
|
||||
@@ -998,7 +1056,8 @@ class DataQualityTools:
|
||||
WHERE {col_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
result = await connection.execute(stats_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(stats_sql, auth_context=auth_context)
|
||||
if result.data and result.data[0]["non_null_count"] > 0:
|
||||
stats = result.data[0]
|
||||
temporal_analysis[col_name] = {
|
||||
|
||||
@@ -27,6 +27,13 @@ from collections import defaultdict, deque
|
||||
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
from .sql_security_utils import (
|
||||
SQLSecurityError,
|
||||
validate_identifier,
|
||||
quote_identifier,
|
||||
build_table_reference,
|
||||
get_auth_context
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -122,10 +129,19 @@ class DependencyAnalysisTools:
|
||||
async def _get_tables_metadata(self, connection, catalog_name: Optional[str], db_name: Optional[str], include_views: bool) -> List[Dict]:
|
||||
"""Get metadata for all tables and views"""
|
||||
try:
|
||||
# Build conditions for query
|
||||
# Build conditions for query with parameterized values
|
||||
where_conditions = []
|
||||
params = []
|
||||
|
||||
if db_name:
|
||||
where_conditions.append(f"table_schema = '{db_name}'")
|
||||
# SECURITY FIX: Validate identifier and use parameterized query
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid database name rejected: {e}")
|
||||
return []
|
||||
where_conditions.append("table_schema = %s")
|
||||
params.append(db_name)
|
||||
else:
|
||||
where_conditions.append("table_schema = DATABASE()")
|
||||
|
||||
@@ -148,9 +164,18 @@ class DependencyAnalysisTools:
|
||||
ORDER BY table_schema, table_name
|
||||
"""
|
||||
|
||||
result = await connection.execute(metadata_sql)
|
||||
# SECURITY FIX: Get auth_context and pass to execute for security validation
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(
|
||||
metadata_sql,
|
||||
params=tuple(params) if params else None,
|
||||
auth_context=auth_context
|
||||
)
|
||||
return result.data if result.data else []
|
||||
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Security validation failed in _get_tables_metadata: {str(e)}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get tables metadata: {str(e)}")
|
||||
return []
|
||||
@@ -186,17 +211,31 @@ class DependencyAnalysisTools:
|
||||
|
||||
async def _analyze_view_dependencies(self, connection, dependency_graph: Dict, tables_metadata: List[Dict]) -> None:
|
||||
"""Analyze view definitions to extract table dependencies"""
|
||||
# Get auth_context once for all operations in this method
|
||||
auth_context = get_auth_context()
|
||||
|
||||
try:
|
||||
for table in tables_metadata:
|
||||
if table["table_type"] == "VIEW":
|
||||
table_name = table["table_name"]
|
||||
schema_name = table.get("schema_name", "")
|
||||
|
||||
# Get view definition
|
||||
view_def_sql = f"SHOW CREATE VIEW {schema_name}.{table_name}" if schema_name else f"SHOW CREATE VIEW {table_name}"
|
||||
# SECURITY FIX: Validate identifiers before using in SQL
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
if schema_name:
|
||||
validate_identifier(schema_name, "schema name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected in view analysis: {e}")
|
||||
continue
|
||||
|
||||
# Build safe view reference using quoted identifiers
|
||||
view_ref = build_table_reference(table_name, schema_name) if schema_name else quote_identifier(table_name, "table name")
|
||||
view_def_sql = f"SHOW CREATE VIEW {view_ref}"
|
||||
|
||||
try:
|
||||
result = await connection.execute(view_def_sql)
|
||||
# SECURITY FIX: Pass auth_context to execute
|
||||
result = await connection.execute(view_def_sql, auth_context=auth_context)
|
||||
if result.data and len(result.data) > 0:
|
||||
# Extract view definition from result
|
||||
view_definition = ""
|
||||
@@ -235,6 +274,9 @@ class DependencyAnalysisTools:
|
||||
|
||||
async def _analyze_runtime_dependencies(self, connection, dependency_graph: Dict, analysis_depth: int) -> None:
|
||||
"""Analyze audit logs to discover runtime table dependencies"""
|
||||
# Get auth_context for security validation
|
||||
auth_context = get_auth_context()
|
||||
|
||||
try:
|
||||
# Get recent SQL statements from audit logs
|
||||
audit_sql = """
|
||||
@@ -252,7 +294,8 @@ class DependencyAnalysisTools:
|
||||
LIMIT 1000
|
||||
"""
|
||||
|
||||
result = await connection.execute(audit_sql)
|
||||
# SECURITY FIX: Pass auth_context to execute
|
||||
result = await connection.execute(audit_sql, auth_context=auth_context)
|
||||
|
||||
if result.data:
|
||||
for row in result.data:
|
||||
@@ -274,6 +317,9 @@ class DependencyAnalysisTools:
|
||||
|
||||
async def _analyze_foreign_key_dependencies(self, connection, dependency_graph: Dict, tables_metadata: List[Dict]) -> None:
|
||||
"""Analyze foreign key constraints for explicit dependencies"""
|
||||
# Get auth_context for security validation
|
||||
auth_context = get_auth_context()
|
||||
|
||||
try:
|
||||
# Get foreign key information
|
||||
fk_sql = """
|
||||
@@ -288,7 +334,8 @@ class DependencyAnalysisTools:
|
||||
WHERE REFERENCED_TABLE_NAME IS NOT NULL
|
||||
"""
|
||||
|
||||
result = await connection.execute(fk_sql)
|
||||
# SECURITY FIX: Pass auth_context to execute
|
||||
result = await connection.execute(fk_sql, auth_context=auth_context)
|
||||
|
||||
if result.data:
|
||||
for row in result.data:
|
||||
|
||||
@@ -28,6 +28,7 @@ from datetime import datetime
|
||||
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
from .sql_security_utils import get_auth_context
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -713,7 +714,8 @@ class DorisMonitoringTools:
|
||||
# Fallback to SHOW BACKENDS if no BE hosts configured
|
||||
logger.info("No BE hosts configured, using SHOW BACKENDS to discover BE nodes")
|
||||
connection = await self.connection_manager.get_connection("query")
|
||||
result = await connection.execute("SHOW BACKENDS")
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute("SHOW BACKENDS", auth_context=auth_context)
|
||||
|
||||
be_nodes = []
|
||||
for row in result.data:
|
||||
|
||||
@@ -27,6 +27,13 @@ from collections import defaultdict, Counter
|
||||
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
from .sql_security_utils import (
|
||||
SQLSecurityError,
|
||||
validate_identifier,
|
||||
quote_identifier,
|
||||
build_table_reference,
|
||||
get_auth_context
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -229,7 +236,8 @@ class PerformanceAnalyticsTools:
|
||||
ORDER BY query_date
|
||||
"""
|
||||
|
||||
result = await connection.execute(query_volume_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(query_volume_sql, auth_context=auth_context)
|
||||
daily_data = result.data if result.data else []
|
||||
|
||||
if not daily_data:
|
||||
@@ -304,7 +312,8 @@ class PerformanceAnalyticsTools:
|
||||
ORDER BY activity_date
|
||||
"""
|
||||
|
||||
result = await connection.execute(user_activity_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(user_activity_sql, auth_context=auth_context)
|
||||
daily_data = result.data if result.data else []
|
||||
|
||||
if not daily_data:
|
||||
@@ -383,7 +392,8 @@ class PerformanceAnalyticsTools:
|
||||
LIMIT 5000
|
||||
"""
|
||||
|
||||
result = await connection.execute(slow_query_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(slow_query_sql, auth_context=auth_context)
|
||||
return result.data if result.data else []
|
||||
|
||||
except Exception as e:
|
||||
@@ -705,7 +715,8 @@ class PerformanceAnalyticsTools:
|
||||
ORDER BY size_mb DESC
|
||||
"""
|
||||
|
||||
db_result = await connection.execute(db_sizes_sql)
|
||||
auth_context = get_auth_context()
|
||||
db_result = await connection.execute(db_sizes_sql, auth_context=auth_context)
|
||||
|
||||
if not db_result.data:
|
||||
logger.warning("No database size information available")
|
||||
@@ -805,7 +816,16 @@ class PerformanceAnalyticsTools:
|
||||
async def _get_database_table_details_from_schema(self, connection, db_name: str) -> List[Dict]:
|
||||
"""Get table details for a specific database using information_schema"""
|
||||
try:
|
||||
table_details_sql = f"""
|
||||
# SECURITY FIX: Validate db_name and use parameterized query
|
||||
auth_context = get_auth_context()
|
||||
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid database name rejected: {e}")
|
||||
return []
|
||||
|
||||
table_details_sql = """
|
||||
SELECT
|
||||
TABLE_SCHEMA as schema_name,
|
||||
TABLE_NAME as table_name,
|
||||
@@ -814,13 +834,13 @@ class PerformanceAnalyticsTools:
|
||||
CREATE_TIME as create_time,
|
||||
UPDATE_TIME as update_time
|
||||
FROM information_schema.tables
|
||||
WHERE TABLE_SCHEMA = '{db_name}'
|
||||
WHERE TABLE_SCHEMA = %s
|
||||
AND TABLE_TYPE = 'BASE TABLE'
|
||||
AND (COALESCE(DATA_LENGTH, 0) + COALESCE(INDEX_LENGTH, 0)) > 0
|
||||
ORDER BY size_mb DESC
|
||||
"""
|
||||
|
||||
result = await connection.execute(table_details_sql)
|
||||
result = await connection.execute(table_details_sql, params=(db_name,), auth_context=auth_context)
|
||||
|
||||
if not result.data:
|
||||
logger.warning(f"No table details found for database {db_name}")
|
||||
@@ -867,6 +887,13 @@ class PerformanceAnalyticsTools:
|
||||
async def _get_database_table_details(self, connection, db_name: str) -> List[Dict]:
|
||||
"""Get table details for a specific database using session-consistent queries"""
|
||||
try:
|
||||
# SECURITY FIX: Validate db_name before using in SQL
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid database name rejected: {e}")
|
||||
return []
|
||||
|
||||
# Method 1: Try to use session-consistent approach with raw connection
|
||||
# This requires accessing the underlying connection to maintain session state
|
||||
|
||||
@@ -877,8 +904,9 @@ class PerformanceAnalyticsTools:
|
||||
# Use raw connection to maintain session state
|
||||
cursor = await raw_conn.cursor()
|
||||
try:
|
||||
# Execute USE and SHOW DATA in the same session
|
||||
await cursor.execute(f"USE {db_name}")
|
||||
# SECURITY FIX: Use quoted identifier for USE statement
|
||||
quoted_db = quote_identifier(db_name, "database name")
|
||||
await cursor.execute(f"USE {quoted_db}")
|
||||
await cursor.execute("SHOW DATA")
|
||||
|
||||
result = await cursor.fetchall()
|
||||
@@ -922,9 +950,19 @@ class PerformanceAnalyticsTools:
|
||||
async def _get_database_table_details_fallback(self, connection, db_name: str) -> List[Dict]:
|
||||
"""Fallback method to get table details using individual queries"""
|
||||
try:
|
||||
# Get all tables in the database
|
||||
tables_sql = f"SHOW TABLES FROM {db_name}"
|
||||
tables_result = await connection.execute(tables_sql)
|
||||
# SECURITY FIX: Validate db_name and get auth_context
|
||||
auth_context = get_auth_context()
|
||||
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid database name rejected: {e}")
|
||||
return []
|
||||
|
||||
# Get all tables in the database using quoted identifier
|
||||
quoted_db = quote_identifier(db_name, "database name")
|
||||
tables_sql = f"SHOW TABLES FROM {quoted_db}"
|
||||
tables_result = await connection.execute(tables_sql, auth_context=auth_context)
|
||||
|
||||
if not tables_result.data:
|
||||
return []
|
||||
@@ -934,9 +972,11 @@ class PerformanceAnalyticsTools:
|
||||
table_name = table_row.get(f"Tables_in_{db_name}", "") or table_row.get("table_name", "")
|
||||
if table_name:
|
||||
try:
|
||||
# Use SHOW DATA FROM db.table for each table
|
||||
data_sql = f"SHOW DATA FROM {db_name}.{table_name}"
|
||||
data_result = await connection.execute(data_sql)
|
||||
# SECURITY FIX: Validate table_name and use safe reference
|
||||
validate_identifier(table_name, "table name")
|
||||
safe_table_ref = build_table_reference(table_name, db_name)
|
||||
data_sql = f"SHOW DATA FROM {safe_table_ref}"
|
||||
data_result = await connection.execute(data_sql, auth_context=auth_context)
|
||||
|
||||
if data_result.data:
|
||||
for row in data_result.data:
|
||||
@@ -1036,6 +1076,7 @@ class PerformanceAnalyticsTools:
|
||||
async def _get_all_tables_info(self, connection) -> List[Dict]:
|
||||
"""Get basic information for all tables (fallback method)"""
|
||||
try:
|
||||
auth_context = get_auth_context()
|
||||
tables_sql = """
|
||||
SELECT
|
||||
table_schema,
|
||||
@@ -1053,7 +1094,7 @@ class PerformanceAnalyticsTools:
|
||||
ORDER BY (data_length + index_length) DESC
|
||||
"""
|
||||
|
||||
result = await connection.execute(tables_sql)
|
||||
result = await connection.execute(tables_sql, auth_context=auth_context)
|
||||
return result.data if result.data else []
|
||||
|
||||
except Exception as e:
|
||||
@@ -1120,23 +1161,37 @@ class PerformanceAnalyticsTools:
|
||||
async def _get_current_table_size(self, connection, full_table_name: str) -> Optional[Dict]:
|
||||
"""Get current table size"""
|
||||
try:
|
||||
# Try to query table size directly
|
||||
size_sql = f"""
|
||||
# SECURITY FIX: Get auth_context and use parameterized query
|
||||
auth_context = get_auth_context()
|
||||
|
||||
# Extract table name for parameterized query
|
||||
table_name_only = full_table_name.split('.')[-1] if '.' in full_table_name else full_table_name
|
||||
|
||||
# Validate identifiers
|
||||
try:
|
||||
validate_identifier(table_name_only, "table name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid table name rejected: {e}")
|
||||
return None
|
||||
|
||||
# Use parameterized query for safety
|
||||
size_sql = """
|
||||
SELECT
|
||||
COALESCE(ROUND((COALESCE(data_length, 0) + COALESCE(index_length, 0)) / 1024 / 1024, 2), 0) as size_mb,
|
||||
COALESCE(table_rows, 0) as `rows`
|
||||
FROM information_schema.tables
|
||||
WHERE CONCAT(table_schema, '.', table_name) = '{full_table_name}'
|
||||
OR table_name = '{full_table_name.split('.')[-1]}'
|
||||
WHERE CONCAT(table_schema, '.', table_name) = %s
|
||||
OR table_name = %s
|
||||
"""
|
||||
|
||||
result = await connection.execute(size_sql)
|
||||
result = await connection.execute(size_sql, params=(full_table_name, table_name_only), auth_context=auth_context)
|
||||
if result.data and result.data[0]:
|
||||
return result.data[0]
|
||||
|
||||
# If information_schema has no data, try COUNT query
|
||||
# full_table_name should already be validated by caller using build_table_reference
|
||||
count_sql = f"SELECT COUNT(*) as rows FROM {full_table_name}"
|
||||
count_result = await connection.execute(count_sql)
|
||||
count_result = await connection.execute(count_sql, auth_context=auth_context)
|
||||
if count_result.data:
|
||||
return {
|
||||
"size_mb": 0, # Cannot get exact size
|
||||
@@ -1145,6 +1200,9 @@ class PerformanceAnalyticsTools:
|
||||
|
||||
return None
|
||||
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Security validation failed for {full_table_name}: {str(e)}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get current size for {full_table_name}: {str(e)}")
|
||||
return None
|
||||
@@ -1154,8 +1212,19 @@ class PerformanceAnalyticsTools:
|
||||
) -> List[Dict]:
|
||||
"""Get historical growth data based on partitions"""
|
||||
try:
|
||||
# Query partition information
|
||||
partition_sql = f"""
|
||||
# SECURITY FIX: Validate identifiers and use parameterized query
|
||||
auth_context = get_auth_context()
|
||||
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
if schema_name:
|
||||
validate_identifier(schema_name, "schema name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return []
|
||||
|
||||
# Use parameterized query for safety
|
||||
partition_sql = """
|
||||
SELECT
|
||||
partition_name,
|
||||
partition_description,
|
||||
@@ -1163,15 +1232,19 @@ class PerformanceAnalyticsTools:
|
||||
data_length,
|
||||
create_time
|
||||
FROM information_schema.partitions
|
||||
WHERE table_schema = '{schema_name or ""}'
|
||||
AND table_name = '{table_name}'
|
||||
WHERE table_schema = %s
|
||||
AND table_name = %s
|
||||
AND partition_name IS NOT NULL
|
||||
AND create_time IS NOT NULL
|
||||
AND create_time >= DATE_SUB(NOW(), INTERVAL {days} DAY)
|
||||
AND create_time >= DATE_SUB(NOW(), INTERVAL %s DAY)
|
||||
ORDER BY create_time DESC
|
||||
"""
|
||||
|
||||
result = await connection.execute(partition_sql)
|
||||
result = await connection.execute(
|
||||
partition_sql,
|
||||
params=(schema_name or "", table_name, days),
|
||||
auth_context=auth_context
|
||||
)
|
||||
if not result.data:
|
||||
return []
|
||||
|
||||
@@ -1210,6 +1283,9 @@ class PerformanceAnalyticsTools:
|
||||
) -> List[Dict]:
|
||||
"""Get historical growth data based on timestamp fields"""
|
||||
try:
|
||||
# SECURITY FIX: Get auth_context
|
||||
auth_context = get_auth_context()
|
||||
|
||||
# Find possible timestamp fields
|
||||
timestamp_columns = await self._find_timestamp_columns(connection, table_name, schema_name)
|
||||
if not timestamp_columns:
|
||||
@@ -1218,20 +1294,29 @@ class PerformanceAnalyticsTools:
|
||||
# Use best timestamp field for analysis
|
||||
time_column = timestamp_columns[0]
|
||||
|
||||
# Aggregate data by date
|
||||
# SECURITY FIX: Validate time_column before using in SQL
|
||||
try:
|
||||
validate_identifier(time_column, "column name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid column name rejected: {e}")
|
||||
return []
|
||||
|
||||
quoted_time_column = quote_identifier(time_column, "column name")
|
||||
|
||||
# Aggregate data by date (full_table_name should be validated by caller)
|
||||
growth_sql = f"""
|
||||
SELECT
|
||||
DATE({time_column}) as date,
|
||||
DATE({quoted_time_column}) as date,
|
||||
COUNT(*) as daily_records,
|
||||
COUNT(*) / SUM(COUNT(*)) OVER() * 100 as percentage
|
||||
FROM {full_table_name}
|
||||
WHERE {time_column} >= DATE_SUB(NOW(), INTERVAL {days} DAY)
|
||||
AND {time_column} IS NOT NULL
|
||||
GROUP BY DATE({time_column})
|
||||
WHERE {quoted_time_column} >= DATE_SUB(NOW(), INTERVAL %s DAY)
|
||||
AND {quoted_time_column} IS NOT NULL
|
||||
GROUP BY DATE({quoted_time_column})
|
||||
ORDER BY date DESC
|
||||
"""
|
||||
|
||||
result = await connection.execute(growth_sql)
|
||||
result = await connection.execute(growth_sql, params=(days,), auth_context=auth_context)
|
||||
if not result.data:
|
||||
return []
|
||||
|
||||
@@ -1257,11 +1342,22 @@ class PerformanceAnalyticsTools:
|
||||
async def _find_timestamp_columns(self, connection, table_name: str, schema_name: str) -> List[str]:
|
||||
"""Find timestamp fields in table"""
|
||||
try:
|
||||
timestamp_sql = f"""
|
||||
# SECURITY FIX: Validate identifiers and use parameterized query
|
||||
auth_context = get_auth_context()
|
||||
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
if schema_name:
|
||||
validate_identifier(schema_name, "schema name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return []
|
||||
|
||||
timestamp_sql = """
|
||||
SELECT column_name, data_type
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = '{schema_name or ""}'
|
||||
AND table_name = '{table_name}'
|
||||
WHERE table_schema = %s
|
||||
AND table_name = %s
|
||||
AND (
|
||||
data_type IN ('datetime', 'timestamp', 'date')
|
||||
OR column_name REGEXP '(create|insert|update|modify).*time'
|
||||
@@ -1278,9 +1374,16 @@ class PerformanceAnalyticsTools:
|
||||
END
|
||||
"""
|
||||
|
||||
result = await connection.execute(timestamp_sql)
|
||||
result = await connection.execute(
|
||||
timestamp_sql,
|
||||
params=(schema_name or "", table_name),
|
||||
auth_context=auth_context
|
||||
)
|
||||
return [row["column_name"] for row in result.data] if result.data else []
|
||||
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Security validation failed: {str(e)}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to find timestamp columns: {str(e)}")
|
||||
return []
|
||||
@@ -1290,8 +1393,22 @@ class PerformanceAnalyticsTools:
|
||||
) -> List[Dict]:
|
||||
"""Estimate growth data based on audit logs"""
|
||||
try:
|
||||
# SECURITY FIX: Validate table_name and use parameterized query
|
||||
auth_context = get_auth_context()
|
||||
|
||||
try:
|
||||
validate_identifier(table_name.split(".")[-1], "table name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid table name rejected: {e}")
|
||||
return []
|
||||
|
||||
# Extract just the table name for LIKE pattern
|
||||
table_name_only = table_name.split(".")[-1]
|
||||
like_pattern_full = f"%{table_name}%"
|
||||
like_pattern_short = f"%{table_name_only}%"
|
||||
|
||||
# Analyze operation history for this table
|
||||
audit_sql = f"""
|
||||
audit_sql = """
|
||||
SELECT
|
||||
DATE(`time`) as operation_date,
|
||||
COUNT(*) as operation_count,
|
||||
@@ -1299,17 +1416,21 @@ class PerformanceAnalyticsTools:
|
||||
SUM(CASE WHEN stmt LIKE 'UPDATE%' THEN 1 ELSE 0 END) as update_count,
|
||||
SUM(CASE WHEN stmt LIKE 'DELETE%' THEN 1 ELSE 0 END) as delete_count
|
||||
FROM internal.__internal_schema.audit_log
|
||||
WHERE `time` >= DATE_SUB(NOW(), INTERVAL {days} DAY)
|
||||
WHERE `time` >= DATE_SUB(NOW(), INTERVAL %s DAY)
|
||||
AND stmt IS NOT NULL
|
||||
AND (
|
||||
stmt LIKE '%{table_name}%'
|
||||
OR stmt LIKE '%{table_name.split(".")[-1]}%'
|
||||
stmt LIKE %s
|
||||
OR stmt LIKE %s
|
||||
)
|
||||
GROUP BY DATE(`time`)
|
||||
ORDER BY operation_date DESC
|
||||
"""
|
||||
|
||||
result = await connection.execute(audit_sql)
|
||||
result = await connection.execute(
|
||||
audit_sql,
|
||||
params=(days, like_pattern_full, like_pattern_short),
|
||||
auth_context=auth_context
|
||||
)
|
||||
if not result.data:
|
||||
return []
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ from decimal import Decimal
|
||||
|
||||
from .db import DorisConnectionManager, QueryResult
|
||||
from .logger import get_logger
|
||||
from .sql_security_utils import get_auth_context
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -497,7 +498,8 @@ class DorisQueryExecutor:
|
||||
explain_sql = f"EXPLAIN {sql}"
|
||||
|
||||
connection = await self.connection_manager.get_connection(session_id)
|
||||
result = await connection.execute(explain_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(explain_sql, auth_context=auth_context)
|
||||
|
||||
return {
|
||||
"query": sql,
|
||||
|
||||
@@ -32,6 +32,11 @@ from datetime import datetime, timedelta
|
||||
|
||||
# Import unified logging configuration
|
||||
from .logger import get_logger
|
||||
from .sql_security_utils import (
|
||||
SQLSecurityError,
|
||||
validate_identifier,
|
||||
quote_identifier
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logger = get_logger(__name__)
|
||||
@@ -431,6 +436,16 @@ class MetadataExtractor:
|
||||
logger.warning("Database name not specified")
|
||||
return {}
|
||||
|
||||
# SECURITY FIX: Validate identifiers to prevent SQL injection
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
validate_identifier(db_name, "database name")
|
||||
if effective_catalog:
|
||||
validate_identifier(effective_catalog, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected in get_table_schema: {e}")
|
||||
return {}
|
||||
|
||||
cache_key = f"schema_{effective_catalog or 'default'}_{db_name}_{table_name}"
|
||||
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
|
||||
return self.metadata_cache[cache_key]
|
||||
@@ -536,6 +551,16 @@ class MetadataExtractor:
|
||||
logger.warning("Database name not specified")
|
||||
return ""
|
||||
|
||||
# SECURITY FIX: Validate identifiers
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
validate_identifier(db_name, "database name")
|
||||
if effective_catalog:
|
||||
validate_identifier(effective_catalog, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return ""
|
||||
|
||||
cache_key = f"table_comment_{effective_catalog or 'default'}_{db_name}_{table_name}"
|
||||
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
|
||||
return self.metadata_cache[cache_key]
|
||||
@@ -587,6 +612,16 @@ class MetadataExtractor:
|
||||
logger.warning("Database name not specified")
|
||||
return {}
|
||||
|
||||
# SECURITY FIX: Validate identifiers
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
validate_identifier(db_name, "database name")
|
||||
if effective_catalog:
|
||||
validate_identifier(effective_catalog, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return {}
|
||||
|
||||
cache_key = f"column_comments_{effective_catalog or 'default'}_{db_name}_{table_name}"
|
||||
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
|
||||
return self.metadata_cache[cache_key]
|
||||
@@ -643,17 +678,30 @@ class MetadataExtractor:
|
||||
logger.error("Database name not specified")
|
||||
return []
|
||||
|
||||
# SECURITY FIX: Validate identifiers
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
validate_identifier(db_name, "database name")
|
||||
if effective_catalog:
|
||||
validate_identifier(effective_catalog, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return []
|
||||
|
||||
cache_key = f"indexes_{effective_catalog or 'default'}_{db_name}_{table_name}"
|
||||
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
|
||||
return self.metadata_cache[cache_key]
|
||||
|
||||
try:
|
||||
# Build query with catalog prefix if specified
|
||||
# Build query with catalog prefix if specified (identifiers already validated)
|
||||
safe_table = quote_identifier(table_name, "table name")
|
||||
safe_db = quote_identifier(db_name, "database name")
|
||||
if effective_catalog:
|
||||
query = f"SHOW INDEX FROM `{effective_catalog}`.`{db_name}`.`{table_name}`"
|
||||
safe_catalog = quote_identifier(effective_catalog, "catalog name")
|
||||
query = f"SHOW INDEX FROM {safe_catalog}.{safe_db}.{safe_table}"
|
||||
logger.info(f"Using three-part naming for index query: {query}")
|
||||
else:
|
||||
query = f"SHOW INDEX FROM `{db_name}`.`{table_name}`"
|
||||
query = f"SHOW INDEX FROM {safe_db}.{safe_table}"
|
||||
|
||||
try:
|
||||
# NOTE: Deprecated sync path retained for compatibility; use async variant instead.
|
||||
@@ -1188,12 +1236,28 @@ class MetadataExtractor:
|
||||
try:
|
||||
# Use async query method
|
||||
effective_catalog = catalog_name or self.catalog_name
|
||||
effective_db = db_name or self.db_name
|
||||
|
||||
# SECURITY FIX: Validate identifiers
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
if effective_db:
|
||||
validate_identifier(effective_db, "database name")
|
||||
if effective_catalog and effective_catalog != "internal":
|
||||
validate_identifier(effective_catalog, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return []
|
||||
|
||||
# Build query statement using safe identifiers
|
||||
safe_table = quote_identifier(table_name, "table name")
|
||||
safe_db = quote_identifier(effective_db, "database name") if effective_db else None
|
||||
|
||||
# Build query statement
|
||||
if effective_catalog and effective_catalog != "internal":
|
||||
query = f"DESCRIBE `{effective_catalog}`.`{db_name or self.db_name}`.`{table_name}`"
|
||||
safe_catalog = quote_identifier(effective_catalog, "catalog name")
|
||||
query = f"DESCRIBE {safe_catalog}.{safe_db}.{safe_table}"
|
||||
else:
|
||||
query = f"DESCRIBE `{db_name or self.db_name}`.`{table_name}`"
|
||||
query = f"DESCRIBE {safe_db}.{safe_table}"
|
||||
|
||||
# Execute async query
|
||||
result = await self._execute_query_async(query, db_name)
|
||||
@@ -1226,8 +1290,15 @@ class MetadataExtractor:
|
||||
try:
|
||||
effective_catalog = catalog_name or self.catalog_name
|
||||
|
||||
# SECURITY FIX: Validate catalog name if provided
|
||||
if effective_catalog and effective_catalog != "internal":
|
||||
query = f"SHOW DATABASES FROM `{effective_catalog}`"
|
||||
try:
|
||||
validate_identifier(effective_catalog, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid catalog name rejected: {e}")
|
||||
return []
|
||||
safe_catalog = quote_identifier(effective_catalog, "catalog name")
|
||||
query = f"SHOW DATABASES FROM {safe_catalog}"
|
||||
else:
|
||||
query = "SHOW DATABASES"
|
||||
|
||||
@@ -1257,10 +1328,23 @@ class MetadataExtractor:
|
||||
effective_catalog = catalog_name or self.catalog_name
|
||||
effective_db = db_name or self.db_name
|
||||
|
||||
# SECURITY FIX: Validate identifiers
|
||||
try:
|
||||
if effective_db:
|
||||
validate_identifier(effective_db, "database name")
|
||||
if effective_catalog and effective_catalog != "internal":
|
||||
validate_identifier(effective_catalog, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return []
|
||||
|
||||
safe_db = quote_identifier(effective_db, "database name") if effective_db else None
|
||||
|
||||
if effective_catalog and effective_catalog != "internal":
|
||||
query = f"SHOW TABLES FROM `{effective_catalog}`.`{effective_db}`"
|
||||
safe_catalog = quote_identifier(effective_catalog, "catalog name")
|
||||
query = f"SHOW TABLES FROM {safe_catalog}.{safe_db}"
|
||||
else:
|
||||
query = f"SHOW TABLES FROM `{effective_db}`"
|
||||
query = f"SHOW TABLES FROM {safe_db}"
|
||||
|
||||
result = await self._execute_query_async(query, effective_db)
|
||||
|
||||
@@ -1319,6 +1403,15 @@ class MetadataExtractor:
|
||||
effective_db = db_name or self.db_name
|
||||
effective_catalog = catalog_name or self.catalog_name
|
||||
|
||||
# SECURITY FIX: Validate identifiers
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
if effective_db:
|
||||
validate_identifier(effective_db, "database name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return ""
|
||||
|
||||
query = f"""
|
||||
SELECT
|
||||
TABLE_COMMENT
|
||||
@@ -1343,6 +1436,15 @@ class MetadataExtractor:
|
||||
effective_db = db_name or self.db_name
|
||||
effective_catalog = catalog_name or self.catalog_name
|
||||
|
||||
# SECURITY FIX: Validate identifiers
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
if effective_db:
|
||||
validate_identifier(effective_db, "database name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return {}
|
||||
|
||||
query = f"""
|
||||
SELECT
|
||||
COLUMN_NAME,
|
||||
@@ -1373,12 +1475,27 @@ class MetadataExtractor:
|
||||
effective_db = db_name or self.db_name
|
||||
effective_catalog = catalog_name or self.catalog_name
|
||||
|
||||
# Build query with catalog prefix if specified
|
||||
# SECURITY FIX: Validate identifiers
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
if effective_db:
|
||||
validate_identifier(effective_db, "database name")
|
||||
if effective_catalog:
|
||||
validate_identifier(effective_catalog, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return []
|
||||
|
||||
# Build query with catalog prefix if specified (using safe identifiers)
|
||||
safe_table = quote_identifier(table_name, "table name")
|
||||
safe_db = quote_identifier(effective_db, "database name") if effective_db else None
|
||||
|
||||
if effective_catalog:
|
||||
query = f"SHOW INDEX FROM `{effective_catalog}`.`{effective_db}`.`{table_name}`"
|
||||
safe_catalog = quote_identifier(effective_catalog, "catalog name")
|
||||
query = f"SHOW INDEX FROM {safe_catalog}.{safe_db}.{safe_table}"
|
||||
logger.info(f"Using three-part naming for async index query: {query}")
|
||||
else:
|
||||
query = f"SHOW INDEX FROM `{effective_db}`.`{table_name}`"
|
||||
query = f"SHOW INDEX FROM {safe_db}.{safe_table}"
|
||||
|
||||
rows = await self._execute_query_async(query, effective_db)
|
||||
indexes: List[Dict[str, Any]] = []
|
||||
@@ -1475,21 +1592,45 @@ class MetadataExtractor:
|
||||
return self._format_response(success=False, error="No SQL statement provided", message="Please provide SQL statement to execute")
|
||||
|
||||
# FIX for Issue #62 Bug 3: Build context switching SQL if db_name or catalog_name is specified
|
||||
# SECURITY FIX: Validate catalog_name and db_name to prevent SQL injection
|
||||
final_sql = sql
|
||||
if catalog_name or db_name:
|
||||
context_statements = []
|
||||
|
||||
# Validate and sanitize catalog_name
|
||||
if catalog_name:
|
||||
# Switch to specified catalog
|
||||
context_statements.append(f"USE CATALOG `{catalog_name}`")
|
||||
try:
|
||||
validate_identifier(catalog_name, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid catalog name rejected: {e}")
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid catalog name: {catalog_name}",
|
||||
message="Catalog name contains invalid characters"
|
||||
)
|
||||
# Use quote_identifier to safely escape the catalog name
|
||||
safe_catalog = quote_identifier(catalog_name, "catalog name")
|
||||
context_statements.append(f"USE CATALOG {safe_catalog}")
|
||||
logger.debug(f"Switching to catalog: {catalog_name}")
|
||||
|
||||
# Validate and sanitize db_name
|
||||
if db_name:
|
||||
# Switch to specified database
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid database name rejected: {e}")
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid database name: {db_name}",
|
||||
message="Database name contains invalid characters"
|
||||
)
|
||||
# Use quote_identifier to safely escape the database name
|
||||
safe_db = quote_identifier(db_name, "database name")
|
||||
if catalog_name:
|
||||
context_statements.append(f"USE `{catalog_name}`.`{db_name}`")
|
||||
safe_catalog = quote_identifier(catalog_name, "catalog name")
|
||||
context_statements.append(f"USE {safe_catalog}.{safe_db}")
|
||||
else:
|
||||
context_statements.append(f"USE `{db_name}`")
|
||||
context_statements.append(f"USE {safe_db}")
|
||||
logger.debug(f"Switching to database: {db_name}")
|
||||
|
||||
# Combine context switching with original SQL
|
||||
@@ -1551,6 +1692,36 @@ class MetadataExtractor:
|
||||
if not table_name:
|
||||
return self._format_response(success=False, error="Missing table_name parameter")
|
||||
|
||||
# SECURITY: Validate identifiers before processing
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid table name: {table_name}",
|
||||
message="Table name contains invalid characters"
|
||||
)
|
||||
|
||||
if db_name:
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid database name: {db_name}",
|
||||
message="Database name contains invalid characters"
|
||||
)
|
||||
|
||||
if catalog_name and catalog_name != "internal":
|
||||
try:
|
||||
validate_identifier(catalog_name, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid catalog name: {catalog_name}",
|
||||
message="Catalog name contains invalid characters"
|
||||
)
|
||||
|
||||
try:
|
||||
schema = await self.get_table_schema_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
|
||||
@@ -1574,6 +1745,27 @@ class MetadataExtractor:
|
||||
"""Get list of all table names in specified database - MCP interface"""
|
||||
logger.info(f"Getting database table list: DB: {db_name}, Catalog: {catalog_name}")
|
||||
|
||||
# SECURITY: Validate identifiers
|
||||
if db_name:
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid database name: {db_name}",
|
||||
message="Database name contains invalid characters"
|
||||
)
|
||||
|
||||
if catalog_name and catalog_name != "internal":
|
||||
try:
|
||||
validate_identifier(catalog_name, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid catalog name: {catalog_name}",
|
||||
message="Catalog name contains invalid characters"
|
||||
)
|
||||
|
||||
try:
|
||||
tables = await self.get_database_tables_async(db_name=db_name, catalog_name=catalog_name)
|
||||
return self._format_response(success=True, result=tables)
|
||||
@@ -1604,6 +1796,36 @@ class MetadataExtractor:
|
||||
if not table_name:
|
||||
return self._format_response(success=False, error="Missing table_name parameter")
|
||||
|
||||
# SECURITY: Validate identifiers
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid table name: {table_name}",
|
||||
message="Table name contains invalid characters"
|
||||
)
|
||||
|
||||
if db_name:
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid database name: {db_name}",
|
||||
message="Database name contains invalid characters"
|
||||
)
|
||||
|
||||
if catalog_name and catalog_name != "internal":
|
||||
try:
|
||||
validate_identifier(catalog_name, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid catalog name: {catalog_name}",
|
||||
message="Catalog name contains invalid characters"
|
||||
)
|
||||
|
||||
try:
|
||||
comment = await self.get_table_comment_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
return self._format_response(success=True, result=comment)
|
||||
@@ -1623,6 +1845,36 @@ class MetadataExtractor:
|
||||
if not table_name:
|
||||
return self._format_response(success=False, error="Missing table_name parameter")
|
||||
|
||||
# SECURITY: Validate identifiers
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid table name: {table_name}",
|
||||
message="Table name contains invalid characters"
|
||||
)
|
||||
|
||||
if db_name:
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid database name: {db_name}",
|
||||
message="Database name contains invalid characters"
|
||||
)
|
||||
|
||||
if catalog_name and catalog_name != "internal":
|
||||
try:
|
||||
validate_identifier(catalog_name, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid catalog name: {catalog_name}",
|
||||
message="Catalog name contains invalid characters"
|
||||
)
|
||||
|
||||
try:
|
||||
comments = await self.get_column_comments_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
return self._format_response(success=True, result=comments)
|
||||
@@ -1642,6 +1894,36 @@ class MetadataExtractor:
|
||||
if not table_name:
|
||||
return self._format_response(success=False, error="Missing table_name parameter")
|
||||
|
||||
# SECURITY: Validate identifiers
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid table name: {table_name}",
|
||||
message="Table name contains invalid characters"
|
||||
)
|
||||
|
||||
if db_name:
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid database name: {db_name}",
|
||||
message="Database name contains invalid characters"
|
||||
)
|
||||
|
||||
if catalog_name and catalog_name != "internal":
|
||||
try:
|
||||
validate_identifier(catalog_name, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid catalog name: {catalog_name}",
|
||||
message="Catalog name contains invalid characters"
|
||||
)
|
||||
|
||||
try:
|
||||
indexes = await self.get_table_indexes_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
return self._format_response(success=True, result=indexes)
|
||||
|
||||
@@ -901,30 +901,50 @@ class SQLSecurityValidator:
|
||||
if not self.enable_security_check:
|
||||
self.logger.debug("SQL security check is disabled, allowing all queries")
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
|
||||
try:
|
||||
# Parse SQL statement
|
||||
parsed = sqlparse.parse(sql)[0]
|
||||
# SECURITY FIX: Parse ALL SQL statements, not just the first one
|
||||
# This prevents bypassing security checks by injecting additional statements
|
||||
all_statements = sqlparse.parse(sql)
|
||||
|
||||
# Check blocked operations first (more specific)
|
||||
keyword_result = await self._check_blocked_keywords(parsed)
|
||||
if not keyword_result.is_valid:
|
||||
return keyword_result
|
||||
if not all_statements:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Empty or invalid SQL statement",
|
||||
risk_level="medium"
|
||||
)
|
||||
|
||||
# Check SQL injection risks
|
||||
injection_result = await self._check_sql_injection(sql, parsed)
|
||||
if not injection_result.is_valid:
|
||||
return injection_result
|
||||
# SECURITY FIX: Validate each statement individually
|
||||
for idx, parsed in enumerate(all_statements):
|
||||
# Skip empty statements (e.g., from trailing semicolons)
|
||||
if not parsed.tokens or str(parsed).strip() == '':
|
||||
continue
|
||||
|
||||
# Check query complexity
|
||||
complexity_result = await self._check_query_complexity(parsed)
|
||||
if not complexity_result.is_valid:
|
||||
return complexity_result
|
||||
self.logger.debug(f"Validating SQL statement {idx + 1}/{len(all_statements)}: {str(parsed)[:100]}...")
|
||||
|
||||
# Check table access permissions
|
||||
table_result = await self._check_table_access(parsed, auth_context)
|
||||
if not table_result.is_valid:
|
||||
return table_result
|
||||
# Check blocked operations first (more specific)
|
||||
keyword_result = await self._check_blocked_keywords(parsed)
|
||||
if not keyword_result.is_valid:
|
||||
keyword_result.error_message = f"Statement {idx + 1}: {keyword_result.error_message}"
|
||||
return keyword_result
|
||||
|
||||
# Check SQL injection risks
|
||||
injection_result = await self._check_sql_injection(sql, parsed)
|
||||
if not injection_result.is_valid:
|
||||
injection_result.error_message = f"Statement {idx + 1}: {injection_result.error_message}"
|
||||
return injection_result
|
||||
|
||||
# Check query complexity
|
||||
complexity_result = await self._check_query_complexity(parsed)
|
||||
if not complexity_result.is_valid:
|
||||
complexity_result.error_message = f"Statement {idx + 1}: {complexity_result.error_message}"
|
||||
return complexity_result
|
||||
|
||||
# Check table access permissions
|
||||
table_result = await self._check_table_access(parsed, auth_context)
|
||||
if not table_result.is_valid:
|
||||
table_result.error_message = f"Statement {idx + 1}: {table_result.error_message}"
|
||||
return table_result
|
||||
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
@@ -1134,6 +1154,10 @@ class SQLSecurityValidator:
|
||||
self, parsed: Statement, auth_context: AuthContext
|
||||
) -> ValidationResult:
|
||||
"""Check table access permissions"""
|
||||
# If no auth_context, skip table access checks (rely on other security checks)
|
||||
if auth_context is None:
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
# Extract table names from query
|
||||
tables = self._extract_table_names(parsed)
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from collections import Counter, defaultdict
|
||||
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
from .sql_security_utils import get_auth_context
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -192,7 +193,9 @@ class SecurityAnalyticsTools:
|
||||
LIMIT 10000
|
||||
"""
|
||||
|
||||
result = await connection.execute(audit_sql)
|
||||
# SECURITY FIX: Pass auth_context to execute
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(audit_sql, auth_context=auth_context)
|
||||
return result.data if result.data else []
|
||||
|
||||
except Exception as e:
|
||||
@@ -215,7 +218,8 @@ class SecurityAnalyticsTools:
|
||||
LIMIT 10000
|
||||
"""
|
||||
|
||||
result = await connection.execute(simple_audit_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(simple_audit_sql, auth_context=auth_context)
|
||||
return result.data if result.data else []
|
||||
|
||||
except Exception as e2:
|
||||
@@ -498,7 +502,8 @@ class SecurityAnalyticsTools:
|
||||
FROM mysql.user
|
||||
"""
|
||||
|
||||
result = await connection.execute(roles_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(roles_sql, auth_context=auth_context)
|
||||
|
||||
user_roles = defaultdict(list)
|
||||
if result.data:
|
||||
|
||||
301
doris_mcp_server/utils/sql_security_utils.py
Normal file
301
doris_mcp_server/utils/sql_security_utils.py
Normal file
@@ -0,0 +1,301 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
SQL Security Utilities Module
|
||||
|
||||
Provides SQL identifier validation, escaping, and safe query building utilities
|
||||
to prevent SQL injection attacks.
|
||||
"""
|
||||
|
||||
import re
|
||||
from contextvars import ContextVar
|
||||
from typing import Optional, Tuple, List, Any
|
||||
|
||||
from .logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Context variable for auth_context (set by HTTP middleware)
|
||||
auth_context_var: ContextVar = ContextVar('mcp_auth_context', default=None)
|
||||
|
||||
|
||||
class SQLSecurityError(Exception):
|
||||
"""Exception raised for SQL security validation failures"""
|
||||
pass
|
||||
|
||||
|
||||
class SQLSecurityUtils:
|
||||
"""
|
||||
SQL Security Utilities for preventing SQL injection attacks.
|
||||
|
||||
Provides:
|
||||
- Identifier validation (database names, table names, column names)
|
||||
- Safe identifier quoting with backticks
|
||||
- Safe table reference building
|
||||
- Auth context retrieval from context variables
|
||||
"""
|
||||
|
||||
# Valid SQL identifier pattern: letters, numbers, underscores
|
||||
# Must start with letter or underscore, not a number
|
||||
# Supports Unicode letters for international database/table names
|
||||
IDENTIFIER_PATTERN = re.compile(r'^[a-zA-Z_\u4e00-\u9fff][a-zA-Z0-9_\u4e00-\u9fff]*$')
|
||||
|
||||
# Maximum identifier length (MySQL/Doris standard)
|
||||
MAX_IDENTIFIER_LENGTH = 64
|
||||
|
||||
# SQL reserved keywords that should be quoted
|
||||
SQL_KEYWORDS = {
|
||||
'SELECT', 'FROM', 'WHERE', 'INSERT', 'UPDATE', 'DELETE', 'DROP',
|
||||
'CREATE', 'ALTER', 'TABLE', 'DATABASE', 'INDEX', 'VIEW', 'AND',
|
||||
'OR', 'NOT', 'NULL', 'TRUE', 'FALSE', 'IN', 'LIKE', 'BETWEEN',
|
||||
'JOIN', 'LEFT', 'RIGHT', 'INNER', 'OUTER', 'ON', 'AS', 'ORDER',
|
||||
'BY', 'GROUP', 'HAVING', 'LIMIT', 'OFFSET', 'UNION', 'ALL',
|
||||
'DISTINCT', 'INTO', 'VALUES', 'SET', 'DEFAULT', 'PRIMARY', 'KEY',
|
||||
'FOREIGN', 'REFERENCES', 'CHECK', 'UNIQUE', 'CONSTRAINT'
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def validate_identifier(cls, name: str, identifier_type: str = "identifier") -> str:
|
||||
"""
|
||||
Validate a SQL identifier (database name, table name, column name, etc.)
|
||||
|
||||
Args:
|
||||
name: The identifier to validate
|
||||
identifier_type: Type description for error messages (e.g., "database name", "table name")
|
||||
|
||||
Returns:
|
||||
The validated identifier (unchanged if valid)
|
||||
|
||||
Raises:
|
||||
SQLSecurityError: If the identifier is invalid
|
||||
"""
|
||||
if not name:
|
||||
raise SQLSecurityError(f"Empty {identifier_type} is not allowed")
|
||||
|
||||
if not isinstance(name, str):
|
||||
raise SQLSecurityError(f"Invalid {identifier_type}: must be a string, got {type(name).__name__}")
|
||||
|
||||
# Strip whitespace
|
||||
name = name.strip()
|
||||
|
||||
if not name:
|
||||
raise SQLSecurityError(f"Empty {identifier_type} is not allowed")
|
||||
|
||||
# Check length
|
||||
if len(name) > cls.MAX_IDENTIFIER_LENGTH:
|
||||
raise SQLSecurityError(
|
||||
f"Invalid {identifier_type}: '{name[:20]}...' exceeds maximum length of {cls.MAX_IDENTIFIER_LENGTH} characters"
|
||||
)
|
||||
|
||||
# Check for dangerous characters that could be SQL injection
|
||||
dangerous_chars = ["'", '"', ';', '--', '/*', '*/', '\\', '\x00']
|
||||
for char in dangerous_chars:
|
||||
if char in name:
|
||||
raise SQLSecurityError(
|
||||
f"Invalid {identifier_type}: '{name}' contains forbidden character '{char}'"
|
||||
)
|
||||
|
||||
# Validate pattern
|
||||
if not cls.IDENTIFIER_PATTERN.match(name):
|
||||
raise SQLSecurityError(
|
||||
f"Invalid {identifier_type}: '{name}' contains invalid characters. "
|
||||
f"Only letters, numbers, and underscores are allowed, and must start with a letter or underscore."
|
||||
)
|
||||
|
||||
logger.debug(f"Validated {identifier_type}: {name}")
|
||||
return name
|
||||
|
||||
@classmethod
|
||||
def quote_identifier(cls, name: str, identifier_type: str = "identifier") -> str:
|
||||
"""
|
||||
Safely quote a SQL identifier using backticks.
|
||||
|
||||
Args:
|
||||
name: The identifier to quote
|
||||
identifier_type: Type description for error messages
|
||||
|
||||
Returns:
|
||||
The quoted identifier (e.g., `table_name`)
|
||||
|
||||
Raises:
|
||||
SQLSecurityError: If the identifier is invalid
|
||||
"""
|
||||
# First validate the identifier
|
||||
validated_name = cls.validate_identifier(name, identifier_type)
|
||||
|
||||
# Escape any backticks within the name (double them)
|
||||
escaped_name = validated_name.replace('`', '``')
|
||||
|
||||
return f"`{escaped_name}`"
|
||||
|
||||
@classmethod
|
||||
def build_table_reference(
|
||||
cls,
|
||||
table_name: str,
|
||||
db_name: Optional[str] = None,
|
||||
catalog_name: Optional[str] = None,
|
||||
quote: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Build a safe, fully-qualified table reference.
|
||||
|
||||
Args:
|
||||
table_name: The table name (required)
|
||||
db_name: The database name (optional)
|
||||
catalog_name: The catalog name (optional)
|
||||
quote: Whether to quote identifiers with backticks (default: True)
|
||||
|
||||
Returns:
|
||||
A safe table reference string (e.g., `catalog`.`db`.`table`)
|
||||
|
||||
Raises:
|
||||
SQLSecurityError: If any identifier is invalid
|
||||
"""
|
||||
parts = []
|
||||
|
||||
if catalog_name:
|
||||
if quote:
|
||||
parts.append(cls.quote_identifier(catalog_name, "catalog name"))
|
||||
else:
|
||||
parts.append(cls.validate_identifier(catalog_name, "catalog name"))
|
||||
|
||||
if db_name:
|
||||
if quote:
|
||||
parts.append(cls.quote_identifier(db_name, "database name"))
|
||||
else:
|
||||
parts.append(cls.validate_identifier(db_name, "database name"))
|
||||
|
||||
if quote:
|
||||
parts.append(cls.quote_identifier(table_name, "table name"))
|
||||
else:
|
||||
parts.append(cls.validate_identifier(table_name, "table name"))
|
||||
|
||||
return '.'.join(parts)
|
||||
|
||||
@classmethod
|
||||
def build_column_reference(
|
||||
cls,
|
||||
column_name: str,
|
||||
table_name: Optional[str] = None,
|
||||
quote: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Build a safe column reference.
|
||||
|
||||
Args:
|
||||
column_name: The column name (required)
|
||||
table_name: The table name (optional, for qualified references)
|
||||
quote: Whether to quote identifiers with backticks (default: True)
|
||||
|
||||
Returns:
|
||||
A safe column reference string (e.g., `table`.`column`)
|
||||
|
||||
Raises:
|
||||
SQLSecurityError: If any identifier is invalid
|
||||
"""
|
||||
parts = []
|
||||
|
||||
if table_name:
|
||||
if quote:
|
||||
parts.append(cls.quote_identifier(table_name, "table name"))
|
||||
else:
|
||||
parts.append(cls.validate_identifier(table_name, "table name"))
|
||||
|
||||
if quote:
|
||||
parts.append(cls.quote_identifier(column_name, "column name"))
|
||||
else:
|
||||
parts.append(cls.validate_identifier(column_name, "column name"))
|
||||
|
||||
return '.'.join(parts)
|
||||
|
||||
@classmethod
|
||||
def validate_and_build_where_condition(
|
||||
cls,
|
||||
column_name: str,
|
||||
operator: str = "=",
|
||||
use_param: bool = True
|
||||
) -> Tuple[str, bool]:
|
||||
"""
|
||||
Build a safe WHERE condition for a column.
|
||||
|
||||
Args:
|
||||
column_name: The column name
|
||||
operator: The comparison operator (=, !=, <, >, <=, >=, LIKE, IN)
|
||||
use_param: Whether to use parameterized placeholder (%s)
|
||||
|
||||
Returns:
|
||||
Tuple of (condition_string, needs_param)
|
||||
e.g., ("`column` = %s", True) or ("`column` = DATABASE()", False)
|
||||
|
||||
Raises:
|
||||
SQLSecurityError: If column name is invalid or operator is not allowed
|
||||
"""
|
||||
# Validate column name
|
||||
quoted_column = cls.quote_identifier(column_name, "column name")
|
||||
|
||||
# Validate operator
|
||||
allowed_operators = {'=', '!=', '<>', '<', '>', '<=', '>=', 'LIKE', 'IN', 'IS'}
|
||||
if operator.upper() not in allowed_operators:
|
||||
raise SQLSecurityError(f"Invalid operator: '{operator}'. Allowed: {allowed_operators}")
|
||||
|
||||
if use_param:
|
||||
return f"{quoted_column} {operator} %s", True
|
||||
else:
|
||||
return f"{quoted_column} {operator}", False
|
||||
|
||||
@staticmethod
|
||||
def get_auth_context():
|
||||
"""
|
||||
Get auth_context from the context variable.
|
||||
|
||||
This retrieves the auth_context that was set by the HTTP middleware
|
||||
during request processing.
|
||||
|
||||
Returns:
|
||||
The auth_context object, or None if not available
|
||||
"""
|
||||
try:
|
||||
auth_context = auth_context_var.get()
|
||||
if auth_context:
|
||||
logger.debug(f"Retrieved auth_context from context variable")
|
||||
return auth_context
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not retrieve auth_context: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def set_auth_context(auth_context):
|
||||
"""
|
||||
Set auth_context in the context variable.
|
||||
|
||||
This is typically called by the HTTP middleware during request processing.
|
||||
|
||||
Args:
|
||||
auth_context: The auth_context object to set
|
||||
"""
|
||||
auth_context_var.set(auth_context)
|
||||
logger.debug("Set auth_context in context variable")
|
||||
|
||||
|
||||
# Convenience functions for direct use
|
||||
validate_identifier = SQLSecurityUtils.validate_identifier
|
||||
quote_identifier = SQLSecurityUtils.quote_identifier
|
||||
build_table_reference = SQLSecurityUtils.build_table_reference
|
||||
build_column_reference = SQLSecurityUtils.build_column_reference
|
||||
get_auth_context = SQLSecurityUtils.get_auth_context
|
||||
set_auth_context = SQLSecurityUtils.set_auth_context
|
||||
|
||||
Reference in New Issue
Block a user