fix some security issues (#68)
This commit is contained in:
@@ -29,6 +29,13 @@ from pathlib import Path
|
||||
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
from .sql_security_utils import (
|
||||
SQLSecurityError,
|
||||
validate_identifier,
|
||||
quote_identifier,
|
||||
build_table_reference,
|
||||
get_auth_context
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -46,10 +53,17 @@ class TableAnalyzer:
|
||||
sample_size: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""Get table summary information"""
|
||||
# SECURITY FIX: Validate table_name and get auth_context
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
except SQLSecurityError as e:
|
||||
raise ValueError(f"Invalid table name: {e}")
|
||||
|
||||
auth_context = get_auth_context()
|
||||
connection = await self.connection_manager.get_connection("query")
|
||||
|
||||
# Get table basic information
|
||||
table_info_sql = f"""
|
||||
# Get table basic information using parameterized query
|
||||
table_info_sql = """
|
||||
SELECT
|
||||
table_name,
|
||||
table_comment,
|
||||
@@ -58,17 +72,17 @@ class TableAnalyzer:
|
||||
engine
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = '{table_name}'
|
||||
AND table_name = %s
|
||||
"""
|
||||
|
||||
table_info_result = await connection.execute(table_info_sql)
|
||||
table_info_result = await connection.execute(table_info_sql, params=(table_name,), auth_context=auth_context)
|
||||
if not table_info_result.data:
|
||||
raise ValueError(f"Table {table_name} does not exist")
|
||||
|
||||
table_info = table_info_result.data[0]
|
||||
|
||||
# Get column information
|
||||
columns_sql = f"""
|
||||
# Get column information using parameterized query
|
||||
columns_sql = """
|
||||
SELECT
|
||||
column_name,
|
||||
data_type,
|
||||
@@ -76,11 +90,11 @@ class TableAnalyzer:
|
||||
column_comment
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = '{table_name}'
|
||||
AND table_name = %s
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
|
||||
columns_result = await connection.execute(columns_sql)
|
||||
columns_result = await connection.execute(columns_sql, params=(table_name,), auth_context=auth_context)
|
||||
|
||||
summary = {
|
||||
"table_name": table_info["table_name"],
|
||||
@@ -92,10 +106,11 @@ class TableAnalyzer:
|
||||
"columns": columns_result.data,
|
||||
}
|
||||
|
||||
# Get sample data
|
||||
# Get sample data using quoted identifier
|
||||
if include_sample and sample_size > 0:
|
||||
sample_sql = f"SELECT * FROM {table_name} LIMIT {sample_size}"
|
||||
sample_result = await connection.execute(sample_sql)
|
||||
quoted_table = quote_identifier(table_name, "table name")
|
||||
sample_sql = f"SELECT * FROM {quoted_table} LIMIT {sample_size}"
|
||||
sample_result = await connection.execute(sample_sql, auth_context=auth_context)
|
||||
summary["sample_data"] = sample_result.data
|
||||
|
||||
return summary
|
||||
@@ -120,7 +135,8 @@ class TableAnalyzer:
|
||||
FROM {table_name}
|
||||
"""
|
||||
|
||||
basic_result = await connection.execute(basic_stats_sql)
|
||||
auth_context = get_auth_context()
|
||||
basic_result = await connection.execute(basic_stats_sql, auth_context=auth_context)
|
||||
if not basic_result.data:
|
||||
return {
|
||||
"success": False,
|
||||
@@ -144,7 +160,7 @@ class TableAnalyzer:
|
||||
LIMIT 20
|
||||
"""
|
||||
|
||||
distribution_result = await connection.execute(distribution_sql)
|
||||
distribution_result = await connection.execute(distribution_sql, auth_context=auth_context)
|
||||
analysis["value_distribution"] = distribution_result.data
|
||||
|
||||
if analysis_type == "detailed":
|
||||
@@ -159,7 +175,7 @@ class TableAnalyzer:
|
||||
WHERE {column_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
numeric_result = await connection.execute(numeric_stats_sql)
|
||||
numeric_result = await connection.execute(numeric_stats_sql, auth_context=auth_context)
|
||||
if numeric_result.data:
|
||||
analysis.update(numeric_result.data[0])
|
||||
except Exception:
|
||||
@@ -196,7 +212,8 @@ class TableAnalyzer:
|
||||
AND table_name = '{table_name}'
|
||||
"""
|
||||
|
||||
table_result = await connection.execute(table_info_sql)
|
||||
auth_context = get_auth_context()
|
||||
table_result = await connection.execute(table_info_sql, auth_context=auth_context)
|
||||
if not table_result.data:
|
||||
raise ValueError(f"Table {table_name} does not exist")
|
||||
|
||||
@@ -211,7 +228,7 @@ class TableAnalyzer:
|
||||
AND table_name != %s
|
||||
"""
|
||||
|
||||
all_tables_result = await connection.execute(all_tables_sql, (table_name,))
|
||||
all_tables_result = await connection.execute(all_tables_sql, params=(table_name,), auth_context=auth_context)
|
||||
|
||||
return {
|
||||
"center_table": table_result.data[0],
|
||||
@@ -291,7 +308,8 @@ class PerformanceMonitor:
|
||||
LIMIT 20
|
||||
"""
|
||||
|
||||
tables_result = await connection.execute(tables_sql)
|
||||
auth_context = get_auth_context()
|
||||
tables_result = await connection.execute(tables_sql, auth_context=auth_context)
|
||||
stats = {
|
||||
"metric_type": "tables",
|
||||
"time_range": time_range,
|
||||
@@ -380,8 +398,14 @@ class SQLAnalyzer:
|
||||
logger.info(f"Generating SQL explain for query ID: {query_id}")
|
||||
|
||||
# Switch database if specified
|
||||
# SECURITY FIX: Validate and quote db_name
|
||||
if db_name:
|
||||
await self.connection_manager.execute_query("explain_session", f"USE {db_name}")
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
return {"success": False, "error": f"Invalid database name: {e}"}
|
||||
safe_db = quote_identifier(db_name, "database name")
|
||||
await self.connection_manager.execute_query("explain_session", f"USE {safe_db}")
|
||||
|
||||
# Construct EXPLAIN query
|
||||
explain_type = "EXPLAIN VERBOSE" if verbose else "EXPLAIN"
|
||||
@@ -515,24 +539,36 @@ class SQLAnalyzer:
|
||||
|
||||
try:
|
||||
# Switch to specified database/catalog if provided
|
||||
# SECURITY FIX: Validate identifiers before using in SQL
|
||||
if catalog_name:
|
||||
await connection.execute(f"SWITCH `{catalog_name}`")
|
||||
try:
|
||||
validate_identifier(catalog_name, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
return {"success": False, "error": f"Invalid catalog name: {e}"}
|
||||
safe_catalog = quote_identifier(catalog_name, "catalog name")
|
||||
auth_context = get_auth_context()
|
||||
await connection.execute(f"SWITCH {safe_catalog}", auth_context=auth_context)
|
||||
if db_name:
|
||||
await connection.execute(f"USE `{db_name}`")
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
return {"success": False, "error": f"Invalid database name: {e}"}
|
||||
safe_db = quote_identifier(db_name, "database name")
|
||||
await connection.execute(f"USE {safe_db}", auth_context=auth_context)
|
||||
|
||||
# Set trace ID for the session using session variable
|
||||
# According to official docs: set session_context="trace_id:your_trace_id"
|
||||
await connection.execute(f'set session_context="trace_id:{trace_id}"')
|
||||
await connection.execute(f'set session_context="trace_id:{trace_id}"', auth_context=auth_context)
|
||||
logger.info(f"Set trace ID: {trace_id}")
|
||||
|
||||
# Enable profile
|
||||
await connection.execute(f'set enable_profile=true')
|
||||
await connection.execute(f'set enable_profile=true', auth_context=auth_context)
|
||||
logger.info(f"Enabled profile")
|
||||
|
||||
# Execute the SQL statement
|
||||
logger.info(f"Executing SQL with trace ID: {sql}")
|
||||
start_time = time.time()
|
||||
sql_result = await connection.execute(sql)
|
||||
sql_result = await connection.execute(sql, auth_context=auth_context)
|
||||
execution_time = time.time() - start_time
|
||||
logger.info(f"SQL execution completed in {execution_time:.3f}s")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user