fix some security issues (#68)

This commit is contained in:
bingquanzhao
2025-12-10 09:11:03 +08:00
committed by GitHub
parent a125a2f5f8
commit e58361e04b
17 changed files with 2520 additions and 214 deletions

View File

@@ -29,6 +29,13 @@ from pathlib import Path
from .db import DorisConnectionManager
from .logger import get_logger
from .sql_security_utils import (
SQLSecurityError,
validate_identifier,
quote_identifier,
build_table_reference,
get_auth_context
)
logger = get_logger(__name__)
@@ -46,10 +53,17 @@ class TableAnalyzer:
sample_size: int = 10
) -> Dict[str, Any]:
"""Get table summary information"""
# SECURITY FIX: Validate table_name and get auth_context
try:
validate_identifier(table_name, "table name")
except SQLSecurityError as e:
raise ValueError(f"Invalid table name: {e}")
auth_context = get_auth_context()
connection = await self.connection_manager.get_connection("query")
# Get table basic information
table_info_sql = f"""
# Get table basic information using parameterized query
table_info_sql = """
SELECT
table_name,
table_comment,
@@ -58,17 +72,17 @@ class TableAnalyzer:
engine
FROM information_schema.tables
WHERE table_schema = DATABASE()
AND table_name = '{table_name}'
AND table_name = %s
"""
table_info_result = await connection.execute(table_info_sql)
table_info_result = await connection.execute(table_info_sql, params=(table_name,), auth_context=auth_context)
if not table_info_result.data:
raise ValueError(f"Table {table_name} does not exist")
table_info = table_info_result.data[0]
# Get column information
columns_sql = f"""
# Get column information using parameterized query
columns_sql = """
SELECT
column_name,
data_type,
@@ -76,11 +90,11 @@ class TableAnalyzer:
column_comment
FROM information_schema.columns
WHERE table_schema = DATABASE()
AND table_name = '{table_name}'
AND table_name = %s
ORDER BY ordinal_position
"""
columns_result = await connection.execute(columns_sql)
columns_result = await connection.execute(columns_sql, params=(table_name,), auth_context=auth_context)
summary = {
"table_name": table_info["table_name"],
@@ -92,10 +106,11 @@ class TableAnalyzer:
"columns": columns_result.data,
}
# Get sample data
# Get sample data using quoted identifier
if include_sample and sample_size > 0:
sample_sql = f"SELECT * FROM {table_name} LIMIT {sample_size}"
sample_result = await connection.execute(sample_sql)
quoted_table = quote_identifier(table_name, "table name")
sample_sql = f"SELECT * FROM {quoted_table} LIMIT {sample_size}"
sample_result = await connection.execute(sample_sql, auth_context=auth_context)
summary["sample_data"] = sample_result.data
return summary
@@ -120,7 +135,8 @@ class TableAnalyzer:
FROM {table_name}
"""
basic_result = await connection.execute(basic_stats_sql)
auth_context = get_auth_context()
basic_result = await connection.execute(basic_stats_sql, auth_context=auth_context)
if not basic_result.data:
return {
"success": False,
@@ -144,7 +160,7 @@ class TableAnalyzer:
LIMIT 20
"""
distribution_result = await connection.execute(distribution_sql)
distribution_result = await connection.execute(distribution_sql, auth_context=auth_context)
analysis["value_distribution"] = distribution_result.data
if analysis_type == "detailed":
@@ -159,7 +175,7 @@ class TableAnalyzer:
WHERE {column_name} IS NOT NULL
"""
numeric_result = await connection.execute(numeric_stats_sql)
numeric_result = await connection.execute(numeric_stats_sql, auth_context=auth_context)
if numeric_result.data:
analysis.update(numeric_result.data[0])
except Exception:
@@ -196,7 +212,8 @@ class TableAnalyzer:
AND table_name = '{table_name}'
"""
table_result = await connection.execute(table_info_sql)
auth_context = get_auth_context()
table_result = await connection.execute(table_info_sql, auth_context=auth_context)
if not table_result.data:
raise ValueError(f"Table {table_name} does not exist")
@@ -211,7 +228,7 @@ class TableAnalyzer:
AND table_name != %s
"""
all_tables_result = await connection.execute(all_tables_sql, (table_name,))
all_tables_result = await connection.execute(all_tables_sql, params=(table_name,), auth_context=auth_context)
return {
"center_table": table_result.data[0],
@@ -291,7 +308,8 @@ class PerformanceMonitor:
LIMIT 20
"""
tables_result = await connection.execute(tables_sql)
auth_context = get_auth_context()
tables_result = await connection.execute(tables_sql, auth_context=auth_context)
stats = {
"metric_type": "tables",
"time_range": time_range,
@@ -380,8 +398,14 @@ class SQLAnalyzer:
logger.info(f"Generating SQL explain for query ID: {query_id}")
# Switch database if specified
# SECURITY FIX: Validate and quote db_name
if db_name:
await self.connection_manager.execute_query("explain_session", f"USE {db_name}")
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
return {"success": False, "error": f"Invalid database name: {e}"}
safe_db = quote_identifier(db_name, "database name")
await self.connection_manager.execute_query("explain_session", f"USE {safe_db}")
# Construct EXPLAIN query
explain_type = "EXPLAIN VERBOSE" if verbose else "EXPLAIN"
@@ -515,24 +539,36 @@ class SQLAnalyzer:
try:
# Switch to specified database/catalog if provided
# SECURITY FIX: Validate identifiers before using in SQL
if catalog_name:
await connection.execute(f"SWITCH `{catalog_name}`")
try:
validate_identifier(catalog_name, "catalog name")
except SQLSecurityError as e:
return {"success": False, "error": f"Invalid catalog name: {e}"}
safe_catalog = quote_identifier(catalog_name, "catalog name")
auth_context = get_auth_context()
await connection.execute(f"SWITCH {safe_catalog}", auth_context=auth_context)
if db_name:
await connection.execute(f"USE `{db_name}`")
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
return {"success": False, "error": f"Invalid database name: {e}"}
safe_db = quote_identifier(db_name, "database name")
await connection.execute(f"USE {safe_db}", auth_context=auth_context)
# Set trace ID for the session using session variable
# According to official docs: set session_context="trace_id:your_trace_id"
await connection.execute(f'set session_context="trace_id:{trace_id}"')
await connection.execute(f'set session_context="trace_id:{trace_id}"', auth_context=auth_context)
logger.info(f"Set trace ID: {trace_id}")
# Enable profile
await connection.execute(f'set enable_profile=true')
await connection.execute(f'set enable_profile=true', auth_context=auth_context)
logger.info(f"Enabled profile")
# Execute the SQL statement
logger.info(f"Executing SQL with trace ID: {sql}")
start_time = time.time()
sql_result = await connection.execute(sql)
sql_result = await connection.execute(sql, auth_context=auth_context)
execution_time = time.time() - start_time
logger.info(f"SQL execution completed in {execution_time:.3f}s")