fix some security issues (#68)
This commit is contained in:
@@ -31,6 +31,7 @@ from mcp.types import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from ..utils.db import DorisConnectionManager
|
from ..utils.db import DorisConnectionManager
|
||||||
|
from ..utils.sql_security_utils import get_auth_context
|
||||||
|
|
||||||
|
|
||||||
class PromptTemplate:
|
class PromptTemplate:
|
||||||
@@ -422,7 +423,8 @@ Please generate accurate and efficient SQL queries based on the above requiremen
|
|||||||
AND table_type = 'BASE TABLE'
|
AND table_type = 'BASE TABLE'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
db_result = await connection.execute(db_info_sql)
|
auth_context = get_auth_context()
|
||||||
|
db_result = await connection.execute(db_info_sql, auth_context=auth_context)
|
||||||
db_info = db_result.data[0] if db_result.data else {}
|
db_info = db_result.data[0] if db_result.data else {}
|
||||||
|
|
||||||
# Get main table list
|
# Get main table list
|
||||||
@@ -438,7 +440,7 @@ Please generate accurate and efficient SQL queries based on the above requiremen
|
|||||||
LIMIT 10
|
LIMIT 10
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tables_result = await connection.execute(tables_sql)
|
tables_result = await connection.execute(tables_sql, auth_context=auth_context)
|
||||||
|
|
||||||
context = f"""Current database statistics:
|
context = f"""Current database statistics:
|
||||||
- Total number of tables: {db_info.get("table_count", 0)}
|
- Total number of tables: {db_info.get("table_count", 0)}
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from typing import Any
|
|||||||
from mcp.types import Resource
|
from mcp.types import Resource
|
||||||
|
|
||||||
from ..utils.db import DorisConnectionManager
|
from ..utils.db import DorisConnectionManager
|
||||||
|
from ..utils.sql_security_utils import get_auth_context
|
||||||
|
|
||||||
|
|
||||||
class TableMetadata:
|
class TableMetadata:
|
||||||
@@ -169,7 +170,8 @@ class DorisResourcesManager:
|
|||||||
ORDER BY table_name
|
ORDER BY table_name
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(tables_query)
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(tables_query, auth_context=auth_context)
|
||||||
tables = []
|
tables = []
|
||||||
|
|
||||||
for row in result.data:
|
for row in result.data:
|
||||||
@@ -204,7 +206,8 @@ class DorisResourcesManager:
|
|||||||
ORDER BY ordinal_position
|
ORDER BY ordinal_position
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(columns_query, (table_name,))
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(columns_query, params=(table_name,), auth_context=auth_context)
|
||||||
return [dict(row) for row in result.data]
|
return [dict(row) for row in result.data]
|
||||||
|
|
||||||
async def _get_view_metadata(self) -> list[ViewMetadata]:
|
async def _get_view_metadata(self) -> list[ViewMetadata]:
|
||||||
@@ -226,7 +229,8 @@ class DorisResourcesManager:
|
|||||||
ORDER BY table_name
|
ORDER BY table_name
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(views_query)
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(views_query, auth_context=auth_context)
|
||||||
views = []
|
views = []
|
||||||
|
|
||||||
for row in result.data:
|
for row in result.data:
|
||||||
@@ -257,7 +261,8 @@ class DorisResourcesManager:
|
|||||||
AND table_name = %s
|
AND table_name = %s
|
||||||
"""
|
"""
|
||||||
|
|
||||||
table_result = await connection.execute(table_info_query, (table_name,))
|
auth_context = get_auth_context()
|
||||||
|
table_result = await connection.execute(table_info_query, params=(table_name,), auth_context=auth_context)
|
||||||
if not table_result.data:
|
if not table_result.data:
|
||||||
raise ValueError(f"Table {table_name} does not exist")
|
raise ValueError(f"Table {table_name} does not exist")
|
||||||
|
|
||||||
@@ -295,7 +300,8 @@ class DorisResourcesManager:
|
|||||||
ORDER BY index_name, seq_in_index
|
ORDER BY index_name, seq_in_index
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(indexes_query, (table_name,))
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(indexes_query, params=(table_name,), auth_context=auth_context)
|
||||||
return [dict(row) for row in result.data]
|
return [dict(row) for row in result.data]
|
||||||
|
|
||||||
async def _get_view_definition(self, view_name: str) -> str:
|
async def _get_view_definition(self, view_name: str) -> str:
|
||||||
@@ -312,7 +318,8 @@ class DorisResourcesManager:
|
|||||||
AND table_name = %s
|
AND table_name = %s
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(view_query, (view_name,))
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(view_query, params=(view_name,), auth_context=auth_context)
|
||||||
if not result.data:
|
if not result.data:
|
||||||
raise ValueError(f"View {view_name} does not exist")
|
raise ValueError(f"View {view_name} does not exist")
|
||||||
|
|
||||||
@@ -340,7 +347,8 @@ class DorisResourcesManager:
|
|||||||
AND table_type = 'BASE TABLE'
|
AND table_type = 'BASE TABLE'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
table_result = await connection.execute(table_stats_query)
|
auth_context = get_auth_context()
|
||||||
|
table_result = await connection.execute(table_stats_query, auth_context=auth_context)
|
||||||
table_stats = table_result.data[0] if table_result.data else {}
|
table_stats = table_result.data[0] if table_result.data else {}
|
||||||
|
|
||||||
# Get view statistics
|
# Get view statistics
|
||||||
@@ -350,7 +358,7 @@ class DorisResourcesManager:
|
|||||||
WHERE table_schema = DATABASE()
|
WHERE table_schema = DATABASE()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
view_result = await connection.execute(view_stats_query)
|
view_result = await connection.execute(view_stats_query, auth_context=auth_context)
|
||||||
view_stats = view_result.data[0] if view_result.data else {}
|
view_stats = view_result.data[0] if view_result.data else {}
|
||||||
|
|
||||||
stats_info = {
|
stats_info = {
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
from ..utils.logger import get_logger
|
from ..utils.logger import get_logger
|
||||||
from ..utils.db import DorisConnectionManager
|
from ..utils.db import DorisConnectionManager
|
||||||
|
from ..utils.sql_security_utils import get_auth_context
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -277,7 +278,8 @@ class DorisADBCQueryTools:
|
|||||||
# Get BE nodes via SHOW BACKENDS
|
# Get BE nodes via SHOW BACKENDS
|
||||||
logger.info("No BE hosts configured, getting BE node information via SHOW BACKENDS")
|
logger.info("No BE hosts configured, getting BE node information via SHOW BACKENDS")
|
||||||
connection = await self.connection_manager.get_connection("query")
|
connection = await self.connection_manager.get_connection("query")
|
||||||
result = await connection.execute("SHOW BACKENDS")
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute("SHOW BACKENDS", auth_context=auth_context)
|
||||||
|
|
||||||
be_hosts = []
|
be_hosts = []
|
||||||
for row in result.data:
|
for row in result.data:
|
||||||
@@ -383,6 +385,20 @@ class DorisADBCQueryTools:
|
|||||||
"error_type": "no_connection"
|
"error_type": "no_connection"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# SECURITY FIX: Perform SQL security validation before executing
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
if self.connection_manager.security_manager:
|
||||||
|
# Always perform security validation, even without auth_context
|
||||||
|
# Use a default context for basic SQL security checks
|
||||||
|
validation_result = await self.connection_manager.security_manager.validate_sql_security(sql, auth_context)
|
||||||
|
if not validation_result.is_valid:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"SQL security validation failed: {validation_result.error_message}",
|
||||||
|
"error_type": "security_violation",
|
||||||
|
"risk_level": validation_result.risk_level
|
||||||
|
}
|
||||||
|
|
||||||
cursor = self.adbc_client.cursor()
|
cursor = self.adbc_client.cursor()
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
|||||||
@@ -29,6 +29,13 @@ from pathlib import Path
|
|||||||
|
|
||||||
from .db import DorisConnectionManager
|
from .db import DorisConnectionManager
|
||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
|
from .sql_security_utils import (
|
||||||
|
SQLSecurityError,
|
||||||
|
validate_identifier,
|
||||||
|
quote_identifier,
|
||||||
|
build_table_reference,
|
||||||
|
get_auth_context
|
||||||
|
)
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -46,10 +53,17 @@ class TableAnalyzer:
|
|||||||
sample_size: int = 10
|
sample_size: int = 10
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Get table summary information"""
|
"""Get table summary information"""
|
||||||
|
# SECURITY FIX: Validate table_name and get auth_context
|
||||||
|
try:
|
||||||
|
validate_identifier(table_name, "table name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
raise ValueError(f"Invalid table name: {e}")
|
||||||
|
|
||||||
|
auth_context = get_auth_context()
|
||||||
connection = await self.connection_manager.get_connection("query")
|
connection = await self.connection_manager.get_connection("query")
|
||||||
|
|
||||||
# Get table basic information
|
# Get table basic information using parameterized query
|
||||||
table_info_sql = f"""
|
table_info_sql = """
|
||||||
SELECT
|
SELECT
|
||||||
table_name,
|
table_name,
|
||||||
table_comment,
|
table_comment,
|
||||||
@@ -58,17 +72,17 @@ class TableAnalyzer:
|
|||||||
engine
|
engine
|
||||||
FROM information_schema.tables
|
FROM information_schema.tables
|
||||||
WHERE table_schema = DATABASE()
|
WHERE table_schema = DATABASE()
|
||||||
AND table_name = '{table_name}'
|
AND table_name = %s
|
||||||
"""
|
"""
|
||||||
|
|
||||||
table_info_result = await connection.execute(table_info_sql)
|
table_info_result = await connection.execute(table_info_sql, params=(table_name,), auth_context=auth_context)
|
||||||
if not table_info_result.data:
|
if not table_info_result.data:
|
||||||
raise ValueError(f"Table {table_name} does not exist")
|
raise ValueError(f"Table {table_name} does not exist")
|
||||||
|
|
||||||
table_info = table_info_result.data[0]
|
table_info = table_info_result.data[0]
|
||||||
|
|
||||||
# Get column information
|
# Get column information using parameterized query
|
||||||
columns_sql = f"""
|
columns_sql = """
|
||||||
SELECT
|
SELECT
|
||||||
column_name,
|
column_name,
|
||||||
data_type,
|
data_type,
|
||||||
@@ -76,11 +90,11 @@ class TableAnalyzer:
|
|||||||
column_comment
|
column_comment
|
||||||
FROM information_schema.columns
|
FROM information_schema.columns
|
||||||
WHERE table_schema = DATABASE()
|
WHERE table_schema = DATABASE()
|
||||||
AND table_name = '{table_name}'
|
AND table_name = %s
|
||||||
ORDER BY ordinal_position
|
ORDER BY ordinal_position
|
||||||
"""
|
"""
|
||||||
|
|
||||||
columns_result = await connection.execute(columns_sql)
|
columns_result = await connection.execute(columns_sql, params=(table_name,), auth_context=auth_context)
|
||||||
|
|
||||||
summary = {
|
summary = {
|
||||||
"table_name": table_info["table_name"],
|
"table_name": table_info["table_name"],
|
||||||
@@ -92,10 +106,11 @@ class TableAnalyzer:
|
|||||||
"columns": columns_result.data,
|
"columns": columns_result.data,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Get sample data
|
# Get sample data using quoted identifier
|
||||||
if include_sample and sample_size > 0:
|
if include_sample and sample_size > 0:
|
||||||
sample_sql = f"SELECT * FROM {table_name} LIMIT {sample_size}"
|
quoted_table = quote_identifier(table_name, "table name")
|
||||||
sample_result = await connection.execute(sample_sql)
|
sample_sql = f"SELECT * FROM {quoted_table} LIMIT {sample_size}"
|
||||||
|
sample_result = await connection.execute(sample_sql, auth_context=auth_context)
|
||||||
summary["sample_data"] = sample_result.data
|
summary["sample_data"] = sample_result.data
|
||||||
|
|
||||||
return summary
|
return summary
|
||||||
@@ -120,7 +135,8 @@ class TableAnalyzer:
|
|||||||
FROM {table_name}
|
FROM {table_name}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
basic_result = await connection.execute(basic_stats_sql)
|
auth_context = get_auth_context()
|
||||||
|
basic_result = await connection.execute(basic_stats_sql, auth_context=auth_context)
|
||||||
if not basic_result.data:
|
if not basic_result.data:
|
||||||
return {
|
return {
|
||||||
"success": False,
|
"success": False,
|
||||||
@@ -144,7 +160,7 @@ class TableAnalyzer:
|
|||||||
LIMIT 20
|
LIMIT 20
|
||||||
"""
|
"""
|
||||||
|
|
||||||
distribution_result = await connection.execute(distribution_sql)
|
distribution_result = await connection.execute(distribution_sql, auth_context=auth_context)
|
||||||
analysis["value_distribution"] = distribution_result.data
|
analysis["value_distribution"] = distribution_result.data
|
||||||
|
|
||||||
if analysis_type == "detailed":
|
if analysis_type == "detailed":
|
||||||
@@ -159,7 +175,7 @@ class TableAnalyzer:
|
|||||||
WHERE {column_name} IS NOT NULL
|
WHERE {column_name} IS NOT NULL
|
||||||
"""
|
"""
|
||||||
|
|
||||||
numeric_result = await connection.execute(numeric_stats_sql)
|
numeric_result = await connection.execute(numeric_stats_sql, auth_context=auth_context)
|
||||||
if numeric_result.data:
|
if numeric_result.data:
|
||||||
analysis.update(numeric_result.data[0])
|
analysis.update(numeric_result.data[0])
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -196,7 +212,8 @@ class TableAnalyzer:
|
|||||||
AND table_name = '{table_name}'
|
AND table_name = '{table_name}'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
table_result = await connection.execute(table_info_sql)
|
auth_context = get_auth_context()
|
||||||
|
table_result = await connection.execute(table_info_sql, auth_context=auth_context)
|
||||||
if not table_result.data:
|
if not table_result.data:
|
||||||
raise ValueError(f"Table {table_name} does not exist")
|
raise ValueError(f"Table {table_name} does not exist")
|
||||||
|
|
||||||
@@ -211,7 +228,7 @@ class TableAnalyzer:
|
|||||||
AND table_name != %s
|
AND table_name != %s
|
||||||
"""
|
"""
|
||||||
|
|
||||||
all_tables_result = await connection.execute(all_tables_sql, (table_name,))
|
all_tables_result = await connection.execute(all_tables_sql, params=(table_name,), auth_context=auth_context)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"center_table": table_result.data[0],
|
"center_table": table_result.data[0],
|
||||||
@@ -291,7 +308,8 @@ class PerformanceMonitor:
|
|||||||
LIMIT 20
|
LIMIT 20
|
||||||
"""
|
"""
|
||||||
|
|
||||||
tables_result = await connection.execute(tables_sql)
|
auth_context = get_auth_context()
|
||||||
|
tables_result = await connection.execute(tables_sql, auth_context=auth_context)
|
||||||
stats = {
|
stats = {
|
||||||
"metric_type": "tables",
|
"metric_type": "tables",
|
||||||
"time_range": time_range,
|
"time_range": time_range,
|
||||||
@@ -380,8 +398,14 @@ class SQLAnalyzer:
|
|||||||
logger.info(f"Generating SQL explain for query ID: {query_id}")
|
logger.info(f"Generating SQL explain for query ID: {query_id}")
|
||||||
|
|
||||||
# Switch database if specified
|
# Switch database if specified
|
||||||
|
# SECURITY FIX: Validate and quote db_name
|
||||||
if db_name:
|
if db_name:
|
||||||
await self.connection_manager.execute_query("explain_session", f"USE {db_name}")
|
try:
|
||||||
|
validate_identifier(db_name, "database name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
return {"success": False, "error": f"Invalid database name: {e}"}
|
||||||
|
safe_db = quote_identifier(db_name, "database name")
|
||||||
|
await self.connection_manager.execute_query("explain_session", f"USE {safe_db}")
|
||||||
|
|
||||||
# Construct EXPLAIN query
|
# Construct EXPLAIN query
|
||||||
explain_type = "EXPLAIN VERBOSE" if verbose else "EXPLAIN"
|
explain_type = "EXPLAIN VERBOSE" if verbose else "EXPLAIN"
|
||||||
@@ -515,24 +539,36 @@ class SQLAnalyzer:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Switch to specified database/catalog if provided
|
# Switch to specified database/catalog if provided
|
||||||
|
# SECURITY FIX: Validate identifiers before using in SQL
|
||||||
if catalog_name:
|
if catalog_name:
|
||||||
await connection.execute(f"SWITCH `{catalog_name}`")
|
try:
|
||||||
|
validate_identifier(catalog_name, "catalog name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
return {"success": False, "error": f"Invalid catalog name: {e}"}
|
||||||
|
safe_catalog = quote_identifier(catalog_name, "catalog name")
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
await connection.execute(f"SWITCH {safe_catalog}", auth_context=auth_context)
|
||||||
if db_name:
|
if db_name:
|
||||||
await connection.execute(f"USE `{db_name}`")
|
try:
|
||||||
|
validate_identifier(db_name, "database name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
return {"success": False, "error": f"Invalid database name: {e}"}
|
||||||
|
safe_db = quote_identifier(db_name, "database name")
|
||||||
|
await connection.execute(f"USE {safe_db}", auth_context=auth_context)
|
||||||
|
|
||||||
# Set trace ID for the session using session variable
|
# Set trace ID for the session using session variable
|
||||||
# According to official docs: set session_context="trace_id:your_trace_id"
|
# According to official docs: set session_context="trace_id:your_trace_id"
|
||||||
await connection.execute(f'set session_context="trace_id:{trace_id}"')
|
await connection.execute(f'set session_context="trace_id:{trace_id}"', auth_context=auth_context)
|
||||||
logger.info(f"Set trace ID: {trace_id}")
|
logger.info(f"Set trace ID: {trace_id}")
|
||||||
|
|
||||||
# Enable profile
|
# Enable profile
|
||||||
await connection.execute(f'set enable_profile=true')
|
await connection.execute(f'set enable_profile=true', auth_context=auth_context)
|
||||||
logger.info(f"Enabled profile")
|
logger.info(f"Enabled profile")
|
||||||
|
|
||||||
# Execute the SQL statement
|
# Execute the SQL statement
|
||||||
logger.info(f"Executing SQL with trace ID: {sql}")
|
logger.info(f"Executing SQL with trace ID: {sql}")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
sql_result = await connection.execute(sql)
|
sql_result = await connection.execute(sql, auth_context=auth_context)
|
||||||
execution_time = time.time() - start_time
|
execution_time = time.time() - start_time
|
||||||
logger.info(f"SQL execution completed in {execution_time:.3f}s")
|
logger.info(f"SQL execution completed in {execution_time:.3f}s")
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,13 @@ from typing import Any, Dict, List, Optional, Union
|
|||||||
|
|
||||||
from .db import DorisConnectionManager
|
from .db import DorisConnectionManager
|
||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
|
from .sql_security_utils import (
|
||||||
|
SQLSecurityError,
|
||||||
|
validate_identifier,
|
||||||
|
quote_identifier,
|
||||||
|
build_table_reference,
|
||||||
|
get_auth_context
|
||||||
|
)
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -43,24 +50,30 @@ class DataExplorationTools:
|
|||||||
|
|
||||||
def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str:
|
def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str:
|
||||||
"""Build full table name with catalog and database using three-part naming convention"""
|
"""Build full table name with catalog and database using three-part naming convention"""
|
||||||
# Default catalog for internal tables
|
# SECURITY FIX: Use build_table_reference for safe identifier handling
|
||||||
effective_catalog = catalog_name if catalog_name else "internal"
|
effective_catalog = catalog_name if catalog_name else "internal"
|
||||||
|
|
||||||
if db_name:
|
if db_name:
|
||||||
return f"{effective_catalog}.{db_name}.{table_name}"
|
return build_table_reference(table_name, db_name, effective_catalog)
|
||||||
else:
|
else:
|
||||||
# If no db_name provided, need to determine the current database
|
return build_table_reference(table_name, catalog_name=effective_catalog)
|
||||||
return f"{effective_catalog}.{table_name}"
|
|
||||||
|
|
||||||
async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]:
|
async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]:
|
||||||
"""Get basic table information including row count"""
|
"""Get basic table information including row count"""
|
||||||
try:
|
try:
|
||||||
|
# SECURITY FIX: Get auth_context for security validation
|
||||||
|
# table_name should already be validated by _build_full_table_name
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
|
||||||
count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}"
|
count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}"
|
||||||
result = await connection.execute(count_sql)
|
result = await connection.execute(count_sql, auth_context=auth_context)
|
||||||
|
|
||||||
if result.data:
|
if result.data:
|
||||||
return {"row_count": result.data[0]["row_count"]}
|
return {"row_count": result.data[0]["row_count"]}
|
||||||
return None
|
return None
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Security validation failed for table {table_name}: {str(e)}")
|
||||||
|
return {"row_count": 0}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to get basic info for table {table_name}: {str(e)}")
|
logger.warning(f"Failed to get basic info for table {table_name}: {str(e)}")
|
||||||
return {"row_count": 0}
|
return {"row_count": 0}
|
||||||
@@ -68,10 +81,24 @@ class DataExplorationTools:
|
|||||||
async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]:
|
async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]:
|
||||||
"""Get detailed column information"""
|
"""Get detailed column information"""
|
||||||
try:
|
try:
|
||||||
where_conditions = [f"table_name = '{table_name}'"]
|
# SECURITY FIX: Validate identifiers and use parameterized query
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
|
||||||
|
try:
|
||||||
|
validate_identifier(table_name, "table name")
|
||||||
|
if db_name:
|
||||||
|
validate_identifier(db_name, "database name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid identifier rejected: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Build parameterized query
|
||||||
|
params = [table_name]
|
||||||
|
where_conditions = ["table_name = %s"]
|
||||||
|
|
||||||
if db_name:
|
if db_name:
|
||||||
where_conditions.append(f"table_schema = '{db_name}'")
|
where_conditions.append("table_schema = %s")
|
||||||
|
params.append(db_name)
|
||||||
else:
|
else:
|
||||||
where_conditions.append("table_schema = DATABASE()")
|
where_conditions.append("table_schema = DATABASE()")
|
||||||
|
|
||||||
@@ -87,9 +114,12 @@ class DataExplorationTools:
|
|||||||
ORDER BY ordinal_position
|
ORDER BY ordinal_position
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(columns_sql)
|
result = await connection.execute(columns_sql, params=tuple(params), auth_context=auth_context)
|
||||||
return result.data if result.data else []
|
return result.data if result.data else []
|
||||||
|
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Security validation failed: {str(e)}")
|
||||||
|
return []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to get columns info for table {table_name}: {str(e)}")
|
logger.warning(f"Failed to get columns info for table {table_name}: {str(e)}")
|
||||||
return []
|
return []
|
||||||
@@ -177,7 +207,8 @@ class DataExplorationTools:
|
|||||||
WHERE {col_name} IS NOT NULL
|
WHERE {col_name} IS NOT NULL
|
||||||
"""
|
"""
|
||||||
|
|
||||||
stats_result = await connection.execute(stats_sql)
|
auth_context = get_auth_context()
|
||||||
|
stats_result = await connection.execute(stats_sql, auth_context=auth_context)
|
||||||
|
|
||||||
if stats_result.data and stats_result.data[0]["count"] > 0:
|
if stats_result.data and stats_result.data[0]["count"] > 0:
|
||||||
stats = stats_result.data[0]
|
stats = stats_result.data[0]
|
||||||
@@ -229,7 +260,8 @@ class DataExplorationTools:
|
|||||||
WHERE {col_name} IS NOT NULL
|
WHERE {col_name} IS NOT NULL
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(percentile_sql)
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(percentile_sql, auth_context=auth_context)
|
||||||
|
|
||||||
if result.data:
|
if result.data:
|
||||||
data = result.data[0]
|
data = result.data[0]
|
||||||
@@ -268,7 +300,8 @@ class DataExplorationTools:
|
|||||||
WHERE {col_name} IS NOT NULL
|
WHERE {col_name} IS NOT NULL
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(outlier_sql)
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(outlier_sql, auth_context=auth_context)
|
||||||
|
|
||||||
if result.data:
|
if result.data:
|
||||||
data = result.data[0]
|
data = result.data[0]
|
||||||
@@ -359,7 +392,8 @@ class DataExplorationTools:
|
|||||||
{sampling_info.get('sample_query_suffix', '')}
|
{sampling_info.get('sample_query_suffix', '')}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cardinality_result = await connection.execute(cardinality_sql)
|
auth_context = get_auth_context()
|
||||||
|
cardinality_result = await connection.execute(cardinality_sql, auth_context=auth_context)
|
||||||
|
|
||||||
if cardinality_result.data:
|
if cardinality_result.data:
|
||||||
cardinality_data = cardinality_result.data[0]
|
cardinality_data = cardinality_result.data[0]
|
||||||
@@ -408,7 +442,8 @@ class DataExplorationTools:
|
|||||||
LIMIT 20
|
LIMIT 20
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(distribution_sql)
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(distribution_sql, auth_context=auth_context)
|
||||||
|
|
||||||
if result.data:
|
if result.data:
|
||||||
distribution = []
|
distribution = []
|
||||||
@@ -458,7 +493,8 @@ class DataExplorationTools:
|
|||||||
WHERE {col_name} IS NOT NULL
|
WHERE {col_name} IS NOT NULL
|
||||||
"""
|
"""
|
||||||
|
|
||||||
range_result = await connection.execute(range_sql)
|
auth_context = get_auth_context()
|
||||||
|
range_result = await connection.execute(range_sql, auth_context=auth_context)
|
||||||
|
|
||||||
if range_result.data and range_result.data[0]["non_null_count"] > 0:
|
if range_result.data and range_result.data[0]["non_null_count"] > 0:
|
||||||
range_data = range_result.data[0]
|
range_data = range_result.data[0]
|
||||||
@@ -539,7 +575,8 @@ class DataExplorationTools:
|
|||||||
ORDER BY day_of_week
|
ORDER BY day_of_week
|
||||||
"""
|
"""
|
||||||
|
|
||||||
weekly_result = await connection.execute(weekly_pattern_sql)
|
auth_context = get_auth_context()
|
||||||
|
weekly_result = await connection.execute(weekly_pattern_sql, auth_context=auth_context)
|
||||||
|
|
||||||
weekly_pattern = []
|
weekly_pattern = []
|
||||||
if weekly_result.data:
|
if weekly_result.data:
|
||||||
@@ -561,7 +598,7 @@ class DataExplorationTools:
|
|||||||
LIMIT 12
|
LIMIT 12
|
||||||
"""
|
"""
|
||||||
|
|
||||||
monthly_result = await connection.execute(monthly_trend_sql)
|
monthly_result = await connection.execute(monthly_trend_sql, auth_context=auth_context)
|
||||||
monthly_trend = "stable" # Simplified trend analysis
|
monthly_trend = "stable" # Simplified trend analysis
|
||||||
|
|
||||||
if monthly_result.data and len(monthly_result.data) > 3:
|
if monthly_result.data and len(monthly_result.data) > 3:
|
||||||
@@ -646,7 +683,8 @@ class DataExplorationTools:
|
|||||||
FROM {table_expr}
|
FROM {table_expr}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(null_sql)
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(null_sql, auth_context=auth_context)
|
||||||
if result.data:
|
if result.data:
|
||||||
data = result.data[0]
|
data = result.data[0]
|
||||||
total_count = data["total_count"]
|
total_count = data["total_count"]
|
||||||
|
|||||||
@@ -26,6 +26,13 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
from .db import DorisConnectionManager
|
from .db import DorisConnectionManager
|
||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
|
from .sql_security_utils import (
|
||||||
|
SQLSecurityError,
|
||||||
|
validate_identifier,
|
||||||
|
quote_identifier,
|
||||||
|
build_table_reference,
|
||||||
|
get_auth_context
|
||||||
|
)
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -216,26 +223,34 @@ class DataGovernanceTools:
|
|||||||
# ==================== Private Helper Methods ====================
|
# ==================== Private Helper Methods ====================
|
||||||
|
|
||||||
def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str:
|
def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str:
|
||||||
"""Build full table name - use three-level naming convention"""
|
"""Build full table name - use three-level naming convention with security validation"""
|
||||||
|
# SECURITY FIX: Use build_table_reference for safe identifier handling
|
||||||
# Default catalog is internal for internal tables
|
# Default catalog is internal for internal tables
|
||||||
effective_catalog = catalog_name if catalog_name else "internal"
|
effective_catalog = catalog_name if catalog_name else "internal"
|
||||||
|
|
||||||
if db_name:
|
if db_name:
|
||||||
return f"{effective_catalog}.{db_name}.{table_name}"
|
return build_table_reference(table_name, db_name, effective_catalog)
|
||||||
else:
|
else:
|
||||||
# If db_name is not provided, need to determine current database
|
# If db_name is not provided, need to determine current database
|
||||||
return f"{effective_catalog}.{table_name}"
|
return build_table_reference(table_name, catalog_name=effective_catalog)
|
||||||
|
|
||||||
async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]:
|
async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]:
|
||||||
"""Get table basic information"""
|
"""Get table basic information"""
|
||||||
try:
|
try:
|
||||||
|
# SECURITY FIX: Get auth_context for security validation
|
||||||
|
# table_name should already be validated by _build_full_table_name
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
|
||||||
# Try to get table row count
|
# Try to get table row count
|
||||||
count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}"
|
count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}"
|
||||||
result = await connection.execute(count_sql)
|
result = await connection.execute(count_sql, auth_context=auth_context)
|
||||||
|
|
||||||
if result.data:
|
if result.data:
|
||||||
return {"row_count": result.data[0]["row_count"]}
|
return {"row_count": result.data[0]["row_count"]}
|
||||||
return None
|
return None
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Security validation failed for table {table_name}: {str(e)}")
|
||||||
|
return {"row_count": 0}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to get basic info for table {table_name}: {str(e)}")
|
logger.warning(f"Failed to get basic info for table {table_name}: {str(e)}")
|
||||||
return {"row_count": 0}
|
return {"row_count": 0}
|
||||||
@@ -243,11 +258,24 @@ class DataGovernanceTools:
|
|||||||
async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]:
|
async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]:
|
||||||
"""Get table column information"""
|
"""Get table column information"""
|
||||||
try:
|
try:
|
||||||
# Build query conditions
|
# SECURITY FIX: Validate identifiers and use parameterized query
|
||||||
where_conditions = [f"table_name = '{table_name}'"]
|
auth_context = get_auth_context()
|
||||||
|
|
||||||
|
try:
|
||||||
|
validate_identifier(table_name, "table name")
|
||||||
|
if db_name:
|
||||||
|
validate_identifier(db_name, "database name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid identifier rejected: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Build parameterized query conditions
|
||||||
|
params = [table_name]
|
||||||
|
where_conditions = ["table_name = %s"]
|
||||||
|
|
||||||
if db_name:
|
if db_name:
|
||||||
where_conditions.append(f"table_schema = '{db_name}'")
|
where_conditions.append("table_schema = %s")
|
||||||
|
params.append(db_name)
|
||||||
else:
|
else:
|
||||||
where_conditions.append("table_schema = DATABASE()")
|
where_conditions.append("table_schema = DATABASE()")
|
||||||
|
|
||||||
@@ -263,30 +291,49 @@ class DataGovernanceTools:
|
|||||||
ORDER BY ordinal_position
|
ORDER BY ordinal_position
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(columns_sql)
|
result = await connection.execute(columns_sql, params=tuple(params), auth_context=auth_context)
|
||||||
return result.data if result.data else []
|
return result.data if result.data else []
|
||||||
|
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Security validation failed: {str(e)}")
|
||||||
|
return []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to get columns info for table {table_name}: {str(e)}")
|
logger.warning(f"Failed to get columns info for table {table_name}: {str(e)}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
async def _analyze_column_completeness(self, connection, table_name: str, columns_info: List[Dict]) -> Dict[str, Any]:
|
async def _analyze_column_completeness(self, connection, table_name: str, columns_info: List[Dict]) -> Dict[str, Any]:
|
||||||
"""Analyze column completeness"""
|
"""Analyze column completeness"""
|
||||||
|
# SECURITY FIX: Get auth_context for security validation
|
||||||
|
auth_context = get_auth_context()
|
||||||
column_completeness = {}
|
column_completeness = {}
|
||||||
|
|
||||||
for column in columns_info:
|
for column in columns_info:
|
||||||
column_name = column["column_name"]
|
column_name = column["column_name"]
|
||||||
try:
|
try:
|
||||||
|
# SECURITY FIX: Validate column name before using in SQL
|
||||||
|
try:
|
||||||
|
validate_identifier(column_name, "column name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid column name rejected: {e}")
|
||||||
|
column_completeness[column_name] = {
|
||||||
|
"error": f"Invalid column name: {e}",
|
||||||
|
"completeness_score": 0.0
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Use quoted identifier for column name
|
||||||
|
quoted_column = quote_identifier(column_name, "column name")
|
||||||
|
|
||||||
# Calculate null value statistics
|
# Calculate null value statistics
|
||||||
null_sql = f"""
|
null_sql = f"""
|
||||||
SELECT
|
SELECT
|
||||||
COUNT(*) as total_count,
|
COUNT(*) as total_count,
|
||||||
COUNT({column_name}) as non_null_count,
|
COUNT({quoted_column}) as non_null_count,
|
||||||
COUNT(*) - COUNT({column_name}) as null_count
|
COUNT(*) - COUNT({quoted_column}) as null_count
|
||||||
FROM {table_name}
|
FROM {table_name}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(null_sql)
|
result = await connection.execute(null_sql, auth_context=auth_context)
|
||||||
if result.data:
|
if result.data:
|
||||||
stats = result.data[0]
|
stats = result.data[0]
|
||||||
total_count = stats["total_count"]
|
total_count = stats["total_count"]
|
||||||
@@ -304,6 +351,12 @@ class DataGovernanceTools:
|
|||||||
"completeness_score": round(completeness_score, 4)
|
"completeness_score": round(completeness_score, 4)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Security validation failed for column {column_name}: {str(e)}")
|
||||||
|
column_completeness[column_name] = {
|
||||||
|
"error": str(e),
|
||||||
|
"completeness_score": 0.0
|
||||||
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to analyze completeness for column {column_name}: {str(e)}")
|
logger.warning(f"Failed to analyze completeness for column {column_name}: {str(e)}")
|
||||||
column_completeness[column_name] = {
|
column_completeness[column_name] = {
|
||||||
@@ -333,7 +386,8 @@ class DataGovernanceTools:
|
|||||||
FROM {table_name}
|
FROM {table_name}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(compliance_sql)
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(compliance_sql, auth_context=auth_context)
|
||||||
if result.data:
|
if result.data:
|
||||||
stats = result.data[0]
|
stats = result.data[0]
|
||||||
pass_count = stats["pass_count"] or 0
|
pass_count = stats["pass_count"] or 0
|
||||||
@@ -378,7 +432,8 @@ class DataGovernanceTools:
|
|||||||
) t
|
) t
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(duplicate_sql)
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(duplicate_sql, auth_context=auth_context)
|
||||||
if result.data and result.data[0]["duplicate_count"] > 0:
|
if result.data and result.data[0]["duplicate_count"] > 0:
|
||||||
issues.append({
|
issues.append({
|
||||||
"type": "duplicate_primary_keys",
|
"type": "duplicate_primary_keys",
|
||||||
@@ -456,10 +511,21 @@ class DataGovernanceTools:
|
|||||||
async def _verify_column_exists(self, connection, table_name: str, column_name: str) -> bool:
|
async def _verify_column_exists(self, connection, table_name: str, column_name: str) -> bool:
|
||||||
"""Verify if column exists"""
|
"""Verify if column exists"""
|
||||||
try:
|
try:
|
||||||
# Simple verification method: try to query the column
|
# SECURITY FIX: Validate and quote column name
|
||||||
verify_sql = f"SELECT {column_name} FROM {table_name} LIMIT 1"
|
try:
|
||||||
await connection.execute(verify_sql)
|
validate_identifier(column_name, "column name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid column name rejected: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
safe_column = quote_identifier(column_name, "column name")
|
||||||
|
# table_name is already safe (from _build_full_table_name)
|
||||||
|
verify_sql = f"SELECT {safe_column} FROM {table_name} LIMIT 1"
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
await connection.execute(verify_sql, auth_context=auth_context)
|
||||||
return True
|
return True
|
||||||
|
except SQLSecurityError:
|
||||||
|
return False
|
||||||
except Exception:
|
except Exception:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -469,21 +535,34 @@ class DataGovernanceTools:
|
|||||||
source_chain = []
|
source_chain = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# SECURITY FIX: Validate table name and use parameterized-like approach
|
||||||
|
table_name_part = table_name.split('.')[-1]
|
||||||
|
try:
|
||||||
|
validate_identifier(table_name_part, "table name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid table name rejected: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Escape special characters for LIKE pattern
|
||||||
|
safe_pattern = table_name_part.replace('%', r'\%').replace('_', r'\_')
|
||||||
|
like_pattern = f"%{safe_pattern}%"
|
||||||
|
|
||||||
# Try to find related INSERT/CREATE TABLE AS SELECT statements from audit logs (one year range)
|
# Try to find related INSERT/CREATE TABLE AS SELECT statements from audit logs (one year range)
|
||||||
|
auth_context = get_auth_context()
|
||||||
audit_sql = """
|
audit_sql = """
|
||||||
SELECT
|
SELECT
|
||||||
stmt as sql_statement,
|
stmt as sql_statement,
|
||||||
`time` as execution_time,
|
`time` as execution_time,
|
||||||
`user` as user_name
|
`user` as user_name
|
||||||
FROM internal.__internal_schema.audit_log
|
FROM internal.__internal_schema.audit_log
|
||||||
WHERE stmt LIKE '%{}%'
|
WHERE stmt LIKE %s
|
||||||
AND (stmt LIKE '%INSERT%' OR stmt LIKE '%CREATE%' OR stmt LIKE '%SELECT%')
|
AND (stmt LIKE '%INSERT%' OR stmt LIKE '%CREATE%' OR stmt LIKE '%SELECT%')
|
||||||
AND `time` >= DATE_SUB(NOW(), INTERVAL 1 YEAR)
|
AND `time` >= DATE_SUB(NOW(), INTERVAL 1 YEAR)
|
||||||
ORDER BY `time` DESC
|
ORDER BY `time` DESC
|
||||||
LIMIT 50
|
LIMIT 50
|
||||||
""".format(table_name.split('.')[-1]) # Use the last part of table name
|
"""
|
||||||
|
|
||||||
result = await connection.execute(audit_sql)
|
result = await connection.execute(audit_sql, params=(like_pattern,), auth_context=auth_context)
|
||||||
|
|
||||||
if result.data:
|
if result.data:
|
||||||
for i, log_entry in enumerate(result.data[:depth]):
|
for i, log_entry in enumerate(result.data[:depth]):
|
||||||
@@ -556,19 +635,33 @@ class DataGovernanceTools:
|
|||||||
downstream_usage = []
|
downstream_usage = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# SECURITY FIX: Validate inputs and use parameterized-like approach
|
||||||
|
table_name_part = table_name.split('.')[-1]
|
||||||
|
try:
|
||||||
|
validate_identifier(table_name_part, "table name")
|
||||||
|
validate_identifier(column_name, "column name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid identifier rejected: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Escape special characters for LIKE pattern
|
||||||
|
safe_table_pattern = f"%{table_name_part.replace('%', r'\\%').replace('_', r'\\_')}%"
|
||||||
|
safe_column_pattern = f"%{column_name.replace('%', r'\\%').replace('_', r'\\_')}%"
|
||||||
|
|
||||||
# Find other tables that might use this field (through audit logs, one year range)
|
# Find other tables that might use this field (through audit logs, one year range)
|
||||||
|
auth_context = get_auth_context()
|
||||||
usage_sql = """
|
usage_sql = """
|
||||||
SELECT DISTINCT
|
SELECT DISTINCT
|
||||||
stmt as sql_statement
|
stmt as sql_statement
|
||||||
FROM internal.__internal_schema.audit_log
|
FROM internal.__internal_schema.audit_log
|
||||||
WHERE stmt LIKE '%{}%'
|
WHERE stmt LIKE %s
|
||||||
AND stmt LIKE '%{}%'
|
AND stmt LIKE %s
|
||||||
AND stmt LIKE '%SELECT%'
|
AND stmt LIKE '%SELECT%'
|
||||||
AND `time` >= DATE_SUB(NOW(), INTERVAL 1 YEAR)
|
AND `time` >= DATE_SUB(NOW(), INTERVAL 1 YEAR)
|
||||||
LIMIT 20
|
LIMIT 20
|
||||||
""".format(table_name.split('.')[-1], column_name)
|
"""
|
||||||
|
|
||||||
result = await connection.execute(usage_sql)
|
result = await connection.execute(usage_sql, params=(safe_table_pattern, safe_column_pattern), auth_context=auth_context)
|
||||||
|
|
||||||
if result.data:
|
if result.data:
|
||||||
for entry in result.data:
|
for entry in result.data:
|
||||||
@@ -634,14 +727,20 @@ class DataGovernanceTools:
|
|||||||
async def _get_all_tables(self, connection, catalog_name: Optional[str], db_name: Optional[str]) -> List[str]:
|
async def _get_all_tables(self, connection, catalog_name: Optional[str], db_name: Optional[str]) -> List[str]:
|
||||||
"""Get list of all tables"""
|
"""Get list of all tables"""
|
||||||
try:
|
try:
|
||||||
where_conditions = []
|
auth_context = get_auth_context()
|
||||||
|
params = []
|
||||||
|
|
||||||
|
# SECURITY FIX: Use parameterized query
|
||||||
if db_name:
|
if db_name:
|
||||||
where_conditions.append(f"table_schema = '{db_name}'")
|
try:
|
||||||
|
validate_identifier(db_name, "database name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid database name rejected: {e}")
|
||||||
|
return []
|
||||||
|
where_clause = "table_schema = %s"
|
||||||
|
params.append(db_name)
|
||||||
else:
|
else:
|
||||||
where_conditions.append("table_schema = DATABASE()")
|
where_clause = "table_schema = DATABASE()"
|
||||||
|
|
||||||
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
|
|
||||||
|
|
||||||
tables_sql = f"""
|
tables_sql = f"""
|
||||||
SELECT table_name
|
SELECT table_name
|
||||||
@@ -651,7 +750,7 @@ class DataGovernanceTools:
|
|||||||
ORDER BY table_name
|
ORDER BY table_name
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(tables_sql)
|
result = await connection.execute(tables_sql, params=tuple(params) if params else None, auth_context=auth_context)
|
||||||
return [row["table_name"] for row in result.data] if result.data else []
|
return [row["table_name"] for row in result.data] if result.data else []
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -728,15 +827,23 @@ class DataGovernanceTools:
|
|||||||
async def _get_freshness_from_partition_info(self, connection, table_name: str) -> Optional[Dict]:
|
async def _get_freshness_from_partition_info(self, connection, table_name: str) -> Optional[Dict]:
|
||||||
"""Get freshness from partition information"""
|
"""Get freshness from partition information"""
|
||||||
try:
|
try:
|
||||||
# Query partition information (if table has partitions)
|
# SECURITY FIX: Validate and use parameterized query
|
||||||
partition_sql = f"""
|
table_name_part = table_name.split('.')[-1]
|
||||||
|
try:
|
||||||
|
validate_identifier(table_name_part, "table name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid table name rejected: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
partition_sql = """
|
||||||
SELECT MAX(CREATE_TIME) as last_update
|
SELECT MAX(CREATE_TIME) as last_update
|
||||||
FROM information_schema.partitions
|
FROM information_schema.partitions
|
||||||
WHERE table_name = '{table_name.split('.')[-1]}'
|
WHERE table_name = %s
|
||||||
AND CREATE_TIME IS NOT NULL
|
AND CREATE_TIME IS NOT NULL
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(partition_sql)
|
result = await connection.execute(partition_sql, params=(table_name_part,), auth_context=auth_context)
|
||||||
if result.data and result.data[0]["last_update"]:
|
if result.data and result.data[0]["last_update"]:
|
||||||
return {
|
return {
|
||||||
"last_update": result.data[0]["last_update"],
|
"last_update": result.data[0]["last_update"],
|
||||||
@@ -759,7 +866,8 @@ class DataGovernanceTools:
|
|||||||
FROM {table_name}
|
FROM {table_name}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(max_time_sql)
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(max_time_sql, auth_context=auth_context)
|
||||||
if result.data and result.data[0]["last_update"]:
|
if result.data and result.data[0]["last_update"]:
|
||||||
return {
|
return {
|
||||||
"last_update": result.data[0]["last_update"],
|
"last_update": result.data[0]["last_update"],
|
||||||
@@ -773,15 +881,23 @@ class DataGovernanceTools:
|
|||||||
async def _get_freshness_from_table_metadata(self, connection, table_name: str) -> Optional[Dict]:
|
async def _get_freshness_from_table_metadata(self, connection, table_name: str) -> Optional[Dict]:
|
||||||
"""Get freshness from table metadata"""
|
"""Get freshness from table metadata"""
|
||||||
try:
|
try:
|
||||||
# Query table's update time
|
# SECURITY FIX: Validate and use parameterized query
|
||||||
metadata_sql = f"""
|
table_name_part = table_name.split('.')[-1]
|
||||||
|
try:
|
||||||
|
validate_identifier(table_name_part, "table name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid table name rejected: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
metadata_sql = """
|
||||||
SELECT UPDATE_TIME as last_update
|
SELECT UPDATE_TIME as last_update
|
||||||
FROM information_schema.tables
|
FROM information_schema.tables
|
||||||
WHERE table_name = '{table_name.split('.')[-1]}'
|
WHERE table_name = %s
|
||||||
AND UPDATE_TIME IS NOT NULL
|
AND UPDATE_TIME IS NOT NULL
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(metadata_sql)
|
result = await connection.execute(metadata_sql, params=(table_name_part,), auth_context=auth_context)
|
||||||
if result.data and result.data[0]["last_update"]:
|
if result.data and result.data[0]["last_update"]:
|
||||||
return {
|
return {
|
||||||
"last_update": result.data[0]["last_update"],
|
"last_update": result.data[0]["last_update"],
|
||||||
@@ -795,10 +911,19 @@ class DataGovernanceTools:
|
|||||||
async def _find_timestamp_columns(self, connection, table_name: str) -> List[str]:
|
async def _find_timestamp_columns(self, connection, table_name: str) -> List[str]:
|
||||||
"""Find possible timestamp fields"""
|
"""Find possible timestamp fields"""
|
||||||
try:
|
try:
|
||||||
timestamp_sql = f"""
|
# SECURITY FIX: Validate and use parameterized query
|
||||||
|
table_name_part = table_name.split('.')[-1]
|
||||||
|
try:
|
||||||
|
validate_identifier(table_name_part, "table name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid table name rejected: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
timestamp_sql = """
|
||||||
SELECT column_name
|
SELECT column_name
|
||||||
FROM information_schema.columns
|
FROM information_schema.columns
|
||||||
WHERE table_name = '{table_name.split('.')[-1]}'
|
WHERE table_name = %s
|
||||||
AND (
|
AND (
|
||||||
data_type IN ('datetime', 'timestamp', 'date')
|
data_type IN ('datetime', 'timestamp', 'date')
|
||||||
OR column_name LIKE '%time%'
|
OR column_name LIKE '%time%'
|
||||||
@@ -815,7 +940,7 @@ class DataGovernanceTools:
|
|||||||
END
|
END
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(timestamp_sql)
|
result = await connection.execute(timestamp_sql, params=(table_name_part,), auth_context=auth_context)
|
||||||
return [row["column_name"] for row in result.data] if result.data else []
|
return [row["column_name"] for row in result.data] if result.data else []
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@@ -31,6 +31,12 @@ from collections import Counter, defaultdict
|
|||||||
from .db import DorisConnectionManager
|
from .db import DorisConnectionManager
|
||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
from .config import DorisConfig
|
from .config import DorisConfig
|
||||||
|
from .sql_security_utils import (
|
||||||
|
SQLSecurityError,
|
||||||
|
validate_identifier,
|
||||||
|
build_table_reference,
|
||||||
|
get_auth_context
|
||||||
|
)
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -299,23 +305,26 @@ class DataQualityTools:
|
|||||||
# ===========================================
|
# ===========================================
|
||||||
|
|
||||||
def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str:
|
def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str:
|
||||||
"""Build full table name"""
|
"""Build full table name with security validation"""
|
||||||
if catalog_name and db_name:
|
# SECURITY FIX: Use build_table_reference for safe identifier handling
|
||||||
return f"{catalog_name}.{db_name}.{table_name}"
|
return build_table_reference(table_name, db_name, catalog_name)
|
||||||
elif db_name:
|
|
||||||
return f"{db_name}.{table_name}"
|
|
||||||
else:
|
|
||||||
return table_name
|
|
||||||
|
|
||||||
async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]:
|
async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]:
|
||||||
"""Get basic table information"""
|
"""Get basic table information"""
|
||||||
try:
|
try:
|
||||||
|
# SECURITY FIX: table_name should already be validated by _build_full_table_name
|
||||||
|
# But we add auth_context for security validation
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
|
||||||
# Try to get row count
|
# Try to get row count
|
||||||
count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}"
|
count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}"
|
||||||
result = await connection.execute(count_sql)
|
result = await connection.execute(count_sql, auth_context=auth_context)
|
||||||
if result.data:
|
if result.data:
|
||||||
return {"row_count": result.data[0]["row_count"]}
|
return {"row_count": result.data[0]["row_count"]}
|
||||||
return None
|
return None
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Security validation failed: {str(e)}")
|
||||||
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to get table basic info: {str(e)}")
|
logger.warning(f"Failed to get table basic info: {str(e)}")
|
||||||
return None
|
return None
|
||||||
@@ -323,9 +332,13 @@ class DataQualityTools:
|
|||||||
async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]:
|
async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]:
|
||||||
"""Get table column information"""
|
"""Get table column information"""
|
||||||
try:
|
try:
|
||||||
# Build DESCRIBE query
|
# SECURITY FIX: Build safe table reference and pass auth_context
|
||||||
describe_sql = f"DESCRIBE {self._build_full_table_name(table_name, catalog_name, db_name)}"
|
auth_context = get_auth_context()
|
||||||
result = await connection.execute(describe_sql)
|
|
||||||
|
# Build DESCRIBE query with safe table reference
|
||||||
|
safe_table_ref = self._build_full_table_name(table_name, catalog_name, db_name)
|
||||||
|
describe_sql = f"DESCRIBE {safe_table_ref}"
|
||||||
|
result = await connection.execute(describe_sql, auth_context=auth_context)
|
||||||
|
|
||||||
columns_info = []
|
columns_info = []
|
||||||
if result.data:
|
if result.data:
|
||||||
@@ -339,6 +352,9 @@ class DataQualityTools:
|
|||||||
})
|
})
|
||||||
|
|
||||||
return columns_info
|
return columns_info
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Security validation failed: {str(e)}")
|
||||||
|
return []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to get table columns info: {str(e)}")
|
logger.warning(f"Failed to get table columns info: {str(e)}")
|
||||||
return []
|
return []
|
||||||
@@ -346,7 +362,32 @@ class DataQualityTools:
|
|||||||
async def _get_table_partitions(self, connection, table_name: str, db_name: Optional[str] = None) -> List[Dict]:
|
async def _get_table_partitions(self, connection, table_name: str, db_name: Optional[str] = None) -> List[Dict]:
|
||||||
"""Get table partition information"""
|
"""Get table partition information"""
|
||||||
try:
|
try:
|
||||||
# Query partition information
|
# SECURITY FIX: Validate identifiers and use parameterized query
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
|
||||||
|
# Validate table_name
|
||||||
|
try:
|
||||||
|
validate_identifier(table_name, "table name")
|
||||||
|
if db_name:
|
||||||
|
validate_identifier(db_name, "database name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid identifier rejected: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Build parameterized query
|
||||||
|
params = []
|
||||||
|
where_conditions = []
|
||||||
|
|
||||||
|
if db_name:
|
||||||
|
where_conditions.append("TABLE_SCHEMA = %s")
|
||||||
|
params.append(db_name)
|
||||||
|
else:
|
||||||
|
where_conditions.append("TABLE_SCHEMA = ''")
|
||||||
|
|
||||||
|
where_conditions.append("TABLE_NAME = %s")
|
||||||
|
params.append(table_name)
|
||||||
|
where_conditions.append("PARTITION_NAME IS NOT NULL")
|
||||||
|
|
||||||
partition_sql = f"""
|
partition_sql = f"""
|
||||||
SELECT
|
SELECT
|
||||||
PARTITION_NAME,
|
PARTITION_NAME,
|
||||||
@@ -355,12 +396,10 @@ class DataQualityTools:
|
|||||||
DATA_LENGTH,
|
DATA_LENGTH,
|
||||||
INDEX_LENGTH
|
INDEX_LENGTH
|
||||||
FROM information_schema.PARTITIONS
|
FROM information_schema.PARTITIONS
|
||||||
WHERE TABLE_SCHEMA = '{db_name or ""}'
|
WHERE {' AND '.join(where_conditions)}
|
||||||
AND TABLE_NAME = '{table_name}'
|
|
||||||
AND PARTITION_NAME IS NOT NULL
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(partition_sql)
|
result = await connection.execute(partition_sql, params=tuple(params), auth_context=auth_context)
|
||||||
partitions = []
|
partitions = []
|
||||||
if result.data:
|
if result.data:
|
||||||
for row in result.data:
|
for row in result.data:
|
||||||
@@ -373,6 +412,9 @@ class DataQualityTools:
|
|||||||
})
|
})
|
||||||
|
|
||||||
return partitions
|
return partitions
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Security validation failed: {str(e)}")
|
||||||
|
return []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to get table partitions: {str(e)}")
|
logger.warning(f"Failed to get table partitions: {str(e)}")
|
||||||
return []
|
return []
|
||||||
@@ -417,7 +459,8 @@ class DataQualityTools:
|
|||||||
if db_name
|
if db_name
|
||||||
else f"SHOW CREATE TABLE {table_name}"
|
else f"SHOW CREATE TABLE {table_name}"
|
||||||
)
|
)
|
||||||
result = await connection.execute(query)
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(query, auth_context=auth_context)
|
||||||
if result.data:
|
if result.data:
|
||||||
return result.data[0].get("Create Table")
|
return result.data[0].get("Create Table")
|
||||||
return None
|
return None
|
||||||
@@ -428,8 +471,16 @@ class DataQualityTools:
|
|||||||
async def _get_table_size_info(self, connection, table_name: str) -> Dict[str, Any]:
|
async def _get_table_size_info(self, connection, table_name: str) -> Dict[str, Any]:
|
||||||
"""Get table size information"""
|
"""Get table size information"""
|
||||||
try:
|
try:
|
||||||
# Query table size information
|
# SECURITY FIX: Validate and use parameterized query
|
||||||
size_sql = f"""
|
table_name_part = table_name.split('.')[-1]
|
||||||
|
try:
|
||||||
|
validate_identifier(table_name_part, "table name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid table name rejected: {e}")
|
||||||
|
return {"engine": "Unknown", "estimated_rows": 0, "data_length": 0, "index_length": 0, "total_size": 0}
|
||||||
|
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
size_sql = """
|
||||||
SELECT
|
SELECT
|
||||||
table_name,
|
table_name,
|
||||||
engine,
|
engine,
|
||||||
@@ -438,10 +489,10 @@ class DataQualityTools:
|
|||||||
index_length,
|
index_length,
|
||||||
(data_length + index_length) as total_size
|
(data_length + index_length) as total_size
|
||||||
FROM information_schema.tables
|
FROM information_schema.tables
|
||||||
WHERE table_name = '{table_name.split('.')[-1]}'
|
WHERE table_name = %s
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(size_sql)
|
result = await connection.execute(size_sql, params=(table_name_part,), auth_context=auth_context)
|
||||||
if result.data and result.data[0]:
|
if result.data and result.data[0]:
|
||||||
row = result.data[0]
|
row = result.data[0]
|
||||||
return {
|
return {
|
||||||
@@ -582,7 +633,8 @@ class DataQualityTools:
|
|||||||
|
|
||||||
batch_sql = f"SELECT {', '.join(select_clauses)} FROM {table_expr}"
|
batch_sql = f"SELECT {', '.join(select_clauses)} FROM {table_expr}"
|
||||||
|
|
||||||
result = await connection.execute(batch_sql)
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(batch_sql, auth_context=auth_context)
|
||||||
if not result.data:
|
if not result.data:
|
||||||
return {"error": "No data returned from batch completeness query"}
|
return {"error": "No data returned from batch completeness query"}
|
||||||
|
|
||||||
@@ -664,7 +716,8 @@ class DataQualityTools:
|
|||||||
|
|
||||||
batch_sql = f"SELECT {', '.join(select_clauses)} FROM {table_expr}"
|
batch_sql = f"SELECT {', '.join(select_clauses)} FROM {table_expr}"
|
||||||
|
|
||||||
result = await connection.execute(batch_sql)
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(batch_sql, auth_context=auth_context)
|
||||||
if not result.data:
|
if not result.data:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@@ -705,7 +758,8 @@ class DataQualityTools:
|
|||||||
LIMIT 10
|
LIMIT 10
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(freq_sql)
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(freq_sql, auth_context=auth_context)
|
||||||
frequencies = result.data if result.data else []
|
frequencies = result.data if result.data else []
|
||||||
|
|
||||||
categorical_results[col_name] = {
|
categorical_results[col_name] = {
|
||||||
@@ -738,7 +792,8 @@ class DataQualityTools:
|
|||||||
|
|
||||||
batch_sql = f"SELECT {', '.join(select_clauses)} FROM {table_expr}"
|
batch_sql = f"SELECT {', '.join(select_clauses)} FROM {table_expr}"
|
||||||
|
|
||||||
result = await connection.execute(batch_sql)
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(batch_sql, auth_context=auth_context)
|
||||||
if not result.data:
|
if not result.data:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@@ -780,7 +835,8 @@ class DataQualityTools:
|
|||||||
FROM {table_expr}
|
FROM {table_expr}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(completeness_sql)
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(completeness_sql, auth_context=auth_context)
|
||||||
if result.data:
|
if result.data:
|
||||||
stats = result.data[0]
|
stats = result.data[0]
|
||||||
total_count = stats["total_count"]
|
total_count = stats["total_count"]
|
||||||
@@ -906,7 +962,8 @@ class DataQualityTools:
|
|||||||
WHERE {col_name} IS NOT NULL
|
WHERE {col_name} IS NOT NULL
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(stats_sql)
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(stats_sql, auth_context=auth_context)
|
||||||
if result.data and result.data[0]["non_null_count"] > 0:
|
if result.data and result.data[0]["non_null_count"] > 0:
|
||||||
stats = result.data[0]
|
stats = result.data[0]
|
||||||
numeric_analysis[col_name] = {
|
numeric_analysis[col_name] = {
|
||||||
@@ -945,7 +1002,8 @@ class DataQualityTools:
|
|||||||
WHERE {col_name} IS NOT NULL
|
WHERE {col_name} IS NOT NULL
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cardinality_result = await connection.execute(cardinality_sql)
|
auth_context = get_auth_context()
|
||||||
|
cardinality_result = await connection.execute(cardinality_sql, auth_context=auth_context)
|
||||||
|
|
||||||
if cardinality_result.data:
|
if cardinality_result.data:
|
||||||
stats = cardinality_result.data[0]
|
stats = cardinality_result.data[0]
|
||||||
@@ -969,7 +1027,7 @@ class DataQualityTools:
|
|||||||
LIMIT 10
|
LIMIT 10
|
||||||
"""
|
"""
|
||||||
|
|
||||||
top_values_result = await connection.execute(top_values_sql)
|
top_values_result = await connection.execute(top_values_sql, auth_context=auth_context)
|
||||||
if top_values_result.data:
|
if top_values_result.data:
|
||||||
categorical_analysis[col_name]["top_values"] = [
|
categorical_analysis[col_name]["top_values"] = [
|
||||||
{"value": row[col_name], "count": row["count"]}
|
{"value": row[col_name], "count": row["count"]}
|
||||||
@@ -998,7 +1056,8 @@ class DataQualityTools:
|
|||||||
WHERE {col_name} IS NOT NULL
|
WHERE {col_name} IS NOT NULL
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(stats_sql)
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(stats_sql, auth_context=auth_context)
|
||||||
if result.data and result.data[0]["non_null_count"] > 0:
|
if result.data and result.data[0]["non_null_count"] > 0:
|
||||||
stats = result.data[0]
|
stats = result.data[0]
|
||||||
temporal_analysis[col_name] = {
|
temporal_analysis[col_name] = {
|
||||||
|
|||||||
@@ -27,6 +27,13 @@ from collections import defaultdict, deque
|
|||||||
|
|
||||||
from .db import DorisConnectionManager
|
from .db import DorisConnectionManager
|
||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
|
from .sql_security_utils import (
|
||||||
|
SQLSecurityError,
|
||||||
|
validate_identifier,
|
||||||
|
quote_identifier,
|
||||||
|
build_table_reference,
|
||||||
|
get_auth_context
|
||||||
|
)
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -122,10 +129,19 @@ class DependencyAnalysisTools:
|
|||||||
async def _get_tables_metadata(self, connection, catalog_name: Optional[str], db_name: Optional[str], include_views: bool) -> List[Dict]:
|
async def _get_tables_metadata(self, connection, catalog_name: Optional[str], db_name: Optional[str], include_views: bool) -> List[Dict]:
|
||||||
"""Get metadata for all tables and views"""
|
"""Get metadata for all tables and views"""
|
||||||
try:
|
try:
|
||||||
# Build conditions for query
|
# Build conditions for query with parameterized values
|
||||||
where_conditions = []
|
where_conditions = []
|
||||||
|
params = []
|
||||||
|
|
||||||
if db_name:
|
if db_name:
|
||||||
where_conditions.append(f"table_schema = '{db_name}'")
|
# SECURITY FIX: Validate identifier and use parameterized query
|
||||||
|
try:
|
||||||
|
validate_identifier(db_name, "database name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid database name rejected: {e}")
|
||||||
|
return []
|
||||||
|
where_conditions.append("table_schema = %s")
|
||||||
|
params.append(db_name)
|
||||||
else:
|
else:
|
||||||
where_conditions.append("table_schema = DATABASE()")
|
where_conditions.append("table_schema = DATABASE()")
|
||||||
|
|
||||||
@@ -148,9 +164,18 @@ class DependencyAnalysisTools:
|
|||||||
ORDER BY table_schema, table_name
|
ORDER BY table_schema, table_name
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(metadata_sql)
|
# SECURITY FIX: Get auth_context and pass to execute for security validation
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(
|
||||||
|
metadata_sql,
|
||||||
|
params=tuple(params) if params else None,
|
||||||
|
auth_context=auth_context
|
||||||
|
)
|
||||||
return result.data if result.data else []
|
return result.data if result.data else []
|
||||||
|
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Security validation failed in _get_tables_metadata: {str(e)}")
|
||||||
|
return []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to get tables metadata: {str(e)}")
|
logger.warning(f"Failed to get tables metadata: {str(e)}")
|
||||||
return []
|
return []
|
||||||
@@ -186,17 +211,31 @@ class DependencyAnalysisTools:
|
|||||||
|
|
||||||
async def _analyze_view_dependencies(self, connection, dependency_graph: Dict, tables_metadata: List[Dict]) -> None:
|
async def _analyze_view_dependencies(self, connection, dependency_graph: Dict, tables_metadata: List[Dict]) -> None:
|
||||||
"""Analyze view definitions to extract table dependencies"""
|
"""Analyze view definitions to extract table dependencies"""
|
||||||
|
# Get auth_context once for all operations in this method
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for table in tables_metadata:
|
for table in tables_metadata:
|
||||||
if table["table_type"] == "VIEW":
|
if table["table_type"] == "VIEW":
|
||||||
table_name = table["table_name"]
|
table_name = table["table_name"]
|
||||||
schema_name = table.get("schema_name", "")
|
schema_name = table.get("schema_name", "")
|
||||||
|
|
||||||
# Get view definition
|
# SECURITY FIX: Validate identifiers before using in SQL
|
||||||
view_def_sql = f"SHOW CREATE VIEW {schema_name}.{table_name}" if schema_name else f"SHOW CREATE VIEW {table_name}"
|
try:
|
||||||
|
validate_identifier(table_name, "table name")
|
||||||
|
if schema_name:
|
||||||
|
validate_identifier(schema_name, "schema name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid identifier rejected in view analysis: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Build safe view reference using quoted identifiers
|
||||||
|
view_ref = build_table_reference(table_name, schema_name) if schema_name else quote_identifier(table_name, "table name")
|
||||||
|
view_def_sql = f"SHOW CREATE VIEW {view_ref}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = await connection.execute(view_def_sql)
|
# SECURITY FIX: Pass auth_context to execute
|
||||||
|
result = await connection.execute(view_def_sql, auth_context=auth_context)
|
||||||
if result.data and len(result.data) > 0:
|
if result.data and len(result.data) > 0:
|
||||||
# Extract view definition from result
|
# Extract view definition from result
|
||||||
view_definition = ""
|
view_definition = ""
|
||||||
@@ -235,6 +274,9 @@ class DependencyAnalysisTools:
|
|||||||
|
|
||||||
async def _analyze_runtime_dependencies(self, connection, dependency_graph: Dict, analysis_depth: int) -> None:
|
async def _analyze_runtime_dependencies(self, connection, dependency_graph: Dict, analysis_depth: int) -> None:
|
||||||
"""Analyze audit logs to discover runtime table dependencies"""
|
"""Analyze audit logs to discover runtime table dependencies"""
|
||||||
|
# Get auth_context for security validation
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get recent SQL statements from audit logs
|
# Get recent SQL statements from audit logs
|
||||||
audit_sql = """
|
audit_sql = """
|
||||||
@@ -252,7 +294,8 @@ class DependencyAnalysisTools:
|
|||||||
LIMIT 1000
|
LIMIT 1000
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(audit_sql)
|
# SECURITY FIX: Pass auth_context to execute
|
||||||
|
result = await connection.execute(audit_sql, auth_context=auth_context)
|
||||||
|
|
||||||
if result.data:
|
if result.data:
|
||||||
for row in result.data:
|
for row in result.data:
|
||||||
@@ -274,6 +317,9 @@ class DependencyAnalysisTools:
|
|||||||
|
|
||||||
async def _analyze_foreign_key_dependencies(self, connection, dependency_graph: Dict, tables_metadata: List[Dict]) -> None:
|
async def _analyze_foreign_key_dependencies(self, connection, dependency_graph: Dict, tables_metadata: List[Dict]) -> None:
|
||||||
"""Analyze foreign key constraints for explicit dependencies"""
|
"""Analyze foreign key constraints for explicit dependencies"""
|
||||||
|
# Get auth_context for security validation
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Get foreign key information
|
# Get foreign key information
|
||||||
fk_sql = """
|
fk_sql = """
|
||||||
@@ -288,7 +334,8 @@ class DependencyAnalysisTools:
|
|||||||
WHERE REFERENCED_TABLE_NAME IS NOT NULL
|
WHERE REFERENCED_TABLE_NAME IS NOT NULL
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(fk_sql)
|
# SECURITY FIX: Pass auth_context to execute
|
||||||
|
result = await connection.execute(fk_sql, auth_context=auth_context)
|
||||||
|
|
||||||
if result.data:
|
if result.data:
|
||||||
for row in result.data:
|
for row in result.data:
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from datetime import datetime
|
|||||||
|
|
||||||
from .db import DorisConnectionManager
|
from .db import DorisConnectionManager
|
||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
|
from .sql_security_utils import get_auth_context
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -713,7 +714,8 @@ class DorisMonitoringTools:
|
|||||||
# Fallback to SHOW BACKENDS if no BE hosts configured
|
# Fallback to SHOW BACKENDS if no BE hosts configured
|
||||||
logger.info("No BE hosts configured, using SHOW BACKENDS to discover BE nodes")
|
logger.info("No BE hosts configured, using SHOW BACKENDS to discover BE nodes")
|
||||||
connection = await self.connection_manager.get_connection("query")
|
connection = await self.connection_manager.get_connection("query")
|
||||||
result = await connection.execute("SHOW BACKENDS")
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute("SHOW BACKENDS", auth_context=auth_context)
|
||||||
|
|
||||||
be_nodes = []
|
be_nodes = []
|
||||||
for row in result.data:
|
for row in result.data:
|
||||||
|
|||||||
@@ -27,6 +27,13 @@ from collections import defaultdict, Counter
|
|||||||
|
|
||||||
from .db import DorisConnectionManager
|
from .db import DorisConnectionManager
|
||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
|
from .sql_security_utils import (
|
||||||
|
SQLSecurityError,
|
||||||
|
validate_identifier,
|
||||||
|
quote_identifier,
|
||||||
|
build_table_reference,
|
||||||
|
get_auth_context
|
||||||
|
)
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -229,7 +236,8 @@ class PerformanceAnalyticsTools:
|
|||||||
ORDER BY query_date
|
ORDER BY query_date
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(query_volume_sql)
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(query_volume_sql, auth_context=auth_context)
|
||||||
daily_data = result.data if result.data else []
|
daily_data = result.data if result.data else []
|
||||||
|
|
||||||
if not daily_data:
|
if not daily_data:
|
||||||
@@ -304,7 +312,8 @@ class PerformanceAnalyticsTools:
|
|||||||
ORDER BY activity_date
|
ORDER BY activity_date
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(user_activity_sql)
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(user_activity_sql, auth_context=auth_context)
|
||||||
daily_data = result.data if result.data else []
|
daily_data = result.data if result.data else []
|
||||||
|
|
||||||
if not daily_data:
|
if not daily_data:
|
||||||
@@ -383,7 +392,8 @@ class PerformanceAnalyticsTools:
|
|||||||
LIMIT 5000
|
LIMIT 5000
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(slow_query_sql)
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(slow_query_sql, auth_context=auth_context)
|
||||||
return result.data if result.data else []
|
return result.data if result.data else []
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -705,7 +715,8 @@ class PerformanceAnalyticsTools:
|
|||||||
ORDER BY size_mb DESC
|
ORDER BY size_mb DESC
|
||||||
"""
|
"""
|
||||||
|
|
||||||
db_result = await connection.execute(db_sizes_sql)
|
auth_context = get_auth_context()
|
||||||
|
db_result = await connection.execute(db_sizes_sql, auth_context=auth_context)
|
||||||
|
|
||||||
if not db_result.data:
|
if not db_result.data:
|
||||||
logger.warning("No database size information available")
|
logger.warning("No database size information available")
|
||||||
@@ -805,7 +816,16 @@ class PerformanceAnalyticsTools:
|
|||||||
async def _get_database_table_details_from_schema(self, connection, db_name: str) -> List[Dict]:
|
async def _get_database_table_details_from_schema(self, connection, db_name: str) -> List[Dict]:
|
||||||
"""Get table details for a specific database using information_schema"""
|
"""Get table details for a specific database using information_schema"""
|
||||||
try:
|
try:
|
||||||
table_details_sql = f"""
|
# SECURITY FIX: Validate db_name and use parameterized query
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
|
||||||
|
try:
|
||||||
|
validate_identifier(db_name, "database name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid database name rejected: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
table_details_sql = """
|
||||||
SELECT
|
SELECT
|
||||||
TABLE_SCHEMA as schema_name,
|
TABLE_SCHEMA as schema_name,
|
||||||
TABLE_NAME as table_name,
|
TABLE_NAME as table_name,
|
||||||
@@ -814,13 +834,13 @@ class PerformanceAnalyticsTools:
|
|||||||
CREATE_TIME as create_time,
|
CREATE_TIME as create_time,
|
||||||
UPDATE_TIME as update_time
|
UPDATE_TIME as update_time
|
||||||
FROM information_schema.tables
|
FROM information_schema.tables
|
||||||
WHERE TABLE_SCHEMA = '{db_name}'
|
WHERE TABLE_SCHEMA = %s
|
||||||
AND TABLE_TYPE = 'BASE TABLE'
|
AND TABLE_TYPE = 'BASE TABLE'
|
||||||
AND (COALESCE(DATA_LENGTH, 0) + COALESCE(INDEX_LENGTH, 0)) > 0
|
AND (COALESCE(DATA_LENGTH, 0) + COALESCE(INDEX_LENGTH, 0)) > 0
|
||||||
ORDER BY size_mb DESC
|
ORDER BY size_mb DESC
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(table_details_sql)
|
result = await connection.execute(table_details_sql, params=(db_name,), auth_context=auth_context)
|
||||||
|
|
||||||
if not result.data:
|
if not result.data:
|
||||||
logger.warning(f"No table details found for database {db_name}")
|
logger.warning(f"No table details found for database {db_name}")
|
||||||
@@ -867,6 +887,13 @@ class PerformanceAnalyticsTools:
|
|||||||
async def _get_database_table_details(self, connection, db_name: str) -> List[Dict]:
|
async def _get_database_table_details(self, connection, db_name: str) -> List[Dict]:
|
||||||
"""Get table details for a specific database using session-consistent queries"""
|
"""Get table details for a specific database using session-consistent queries"""
|
||||||
try:
|
try:
|
||||||
|
# SECURITY FIX: Validate db_name before using in SQL
|
||||||
|
try:
|
||||||
|
validate_identifier(db_name, "database name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid database name rejected: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
# Method 1: Try to use session-consistent approach with raw connection
|
# Method 1: Try to use session-consistent approach with raw connection
|
||||||
# This requires accessing the underlying connection to maintain session state
|
# This requires accessing the underlying connection to maintain session state
|
||||||
|
|
||||||
@@ -877,8 +904,9 @@ class PerformanceAnalyticsTools:
|
|||||||
# Use raw connection to maintain session state
|
# Use raw connection to maintain session state
|
||||||
cursor = await raw_conn.cursor()
|
cursor = await raw_conn.cursor()
|
||||||
try:
|
try:
|
||||||
# Execute USE and SHOW DATA in the same session
|
# SECURITY FIX: Use quoted identifier for USE statement
|
||||||
await cursor.execute(f"USE {db_name}")
|
quoted_db = quote_identifier(db_name, "database name")
|
||||||
|
await cursor.execute(f"USE {quoted_db}")
|
||||||
await cursor.execute("SHOW DATA")
|
await cursor.execute("SHOW DATA")
|
||||||
|
|
||||||
result = await cursor.fetchall()
|
result = await cursor.fetchall()
|
||||||
@@ -922,9 +950,19 @@ class PerformanceAnalyticsTools:
|
|||||||
async def _get_database_table_details_fallback(self, connection, db_name: str) -> List[Dict]:
|
async def _get_database_table_details_fallback(self, connection, db_name: str) -> List[Dict]:
|
||||||
"""Fallback method to get table details using individual queries"""
|
"""Fallback method to get table details using individual queries"""
|
||||||
try:
|
try:
|
||||||
# Get all tables in the database
|
# SECURITY FIX: Validate db_name and get auth_context
|
||||||
tables_sql = f"SHOW TABLES FROM {db_name}"
|
auth_context = get_auth_context()
|
||||||
tables_result = await connection.execute(tables_sql)
|
|
||||||
|
try:
|
||||||
|
validate_identifier(db_name, "database name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid database name rejected: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Get all tables in the database using quoted identifier
|
||||||
|
quoted_db = quote_identifier(db_name, "database name")
|
||||||
|
tables_sql = f"SHOW TABLES FROM {quoted_db}"
|
||||||
|
tables_result = await connection.execute(tables_sql, auth_context=auth_context)
|
||||||
|
|
||||||
if not tables_result.data:
|
if not tables_result.data:
|
||||||
return []
|
return []
|
||||||
@@ -934,9 +972,11 @@ class PerformanceAnalyticsTools:
|
|||||||
table_name = table_row.get(f"Tables_in_{db_name}", "") or table_row.get("table_name", "")
|
table_name = table_row.get(f"Tables_in_{db_name}", "") or table_row.get("table_name", "")
|
||||||
if table_name:
|
if table_name:
|
||||||
try:
|
try:
|
||||||
# Use SHOW DATA FROM db.table for each table
|
# SECURITY FIX: Validate table_name and use safe reference
|
||||||
data_sql = f"SHOW DATA FROM {db_name}.{table_name}"
|
validate_identifier(table_name, "table name")
|
||||||
data_result = await connection.execute(data_sql)
|
safe_table_ref = build_table_reference(table_name, db_name)
|
||||||
|
data_sql = f"SHOW DATA FROM {safe_table_ref}"
|
||||||
|
data_result = await connection.execute(data_sql, auth_context=auth_context)
|
||||||
|
|
||||||
if data_result.data:
|
if data_result.data:
|
||||||
for row in data_result.data:
|
for row in data_result.data:
|
||||||
@@ -1036,6 +1076,7 @@ class PerformanceAnalyticsTools:
|
|||||||
async def _get_all_tables_info(self, connection) -> List[Dict]:
|
async def _get_all_tables_info(self, connection) -> List[Dict]:
|
||||||
"""Get basic information for all tables (fallback method)"""
|
"""Get basic information for all tables (fallback method)"""
|
||||||
try:
|
try:
|
||||||
|
auth_context = get_auth_context()
|
||||||
tables_sql = """
|
tables_sql = """
|
||||||
SELECT
|
SELECT
|
||||||
table_schema,
|
table_schema,
|
||||||
@@ -1053,7 +1094,7 @@ class PerformanceAnalyticsTools:
|
|||||||
ORDER BY (data_length + index_length) DESC
|
ORDER BY (data_length + index_length) DESC
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(tables_sql)
|
result = await connection.execute(tables_sql, auth_context=auth_context)
|
||||||
return result.data if result.data else []
|
return result.data if result.data else []
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1120,23 +1161,37 @@ class PerformanceAnalyticsTools:
|
|||||||
async def _get_current_table_size(self, connection, full_table_name: str) -> Optional[Dict]:
|
async def _get_current_table_size(self, connection, full_table_name: str) -> Optional[Dict]:
|
||||||
"""Get current table size"""
|
"""Get current table size"""
|
||||||
try:
|
try:
|
||||||
# Try to query table size directly
|
# SECURITY FIX: Get auth_context and use parameterized query
|
||||||
size_sql = f"""
|
auth_context = get_auth_context()
|
||||||
|
|
||||||
|
# Extract table name for parameterized query
|
||||||
|
table_name_only = full_table_name.split('.')[-1] if '.' in full_table_name else full_table_name
|
||||||
|
|
||||||
|
# Validate identifiers
|
||||||
|
try:
|
||||||
|
validate_identifier(table_name_only, "table name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid table name rejected: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Use parameterized query for safety
|
||||||
|
size_sql = """
|
||||||
SELECT
|
SELECT
|
||||||
COALESCE(ROUND((COALESCE(data_length, 0) + COALESCE(index_length, 0)) / 1024 / 1024, 2), 0) as size_mb,
|
COALESCE(ROUND((COALESCE(data_length, 0) + COALESCE(index_length, 0)) / 1024 / 1024, 2), 0) as size_mb,
|
||||||
COALESCE(table_rows, 0) as `rows`
|
COALESCE(table_rows, 0) as `rows`
|
||||||
FROM information_schema.tables
|
FROM information_schema.tables
|
||||||
WHERE CONCAT(table_schema, '.', table_name) = '{full_table_name}'
|
WHERE CONCAT(table_schema, '.', table_name) = %s
|
||||||
OR table_name = '{full_table_name.split('.')[-1]}'
|
OR table_name = %s
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(size_sql)
|
result = await connection.execute(size_sql, params=(full_table_name, table_name_only), auth_context=auth_context)
|
||||||
if result.data and result.data[0]:
|
if result.data and result.data[0]:
|
||||||
return result.data[0]
|
return result.data[0]
|
||||||
|
|
||||||
# If information_schema has no data, try COUNT query
|
# If information_schema has no data, try COUNT query
|
||||||
|
# full_table_name should already be validated by caller using build_table_reference
|
||||||
count_sql = f"SELECT COUNT(*) as rows FROM {full_table_name}"
|
count_sql = f"SELECT COUNT(*) as rows FROM {full_table_name}"
|
||||||
count_result = await connection.execute(count_sql)
|
count_result = await connection.execute(count_sql, auth_context=auth_context)
|
||||||
if count_result.data:
|
if count_result.data:
|
||||||
return {
|
return {
|
||||||
"size_mb": 0, # Cannot get exact size
|
"size_mb": 0, # Cannot get exact size
|
||||||
@@ -1145,6 +1200,9 @@ class PerformanceAnalyticsTools:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Security validation failed for {full_table_name}: {str(e)}")
|
||||||
|
return None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to get current size for {full_table_name}: {str(e)}")
|
logger.warning(f"Failed to get current size for {full_table_name}: {str(e)}")
|
||||||
return None
|
return None
|
||||||
@@ -1154,8 +1212,19 @@ class PerformanceAnalyticsTools:
|
|||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
"""Get historical growth data based on partitions"""
|
"""Get historical growth data based on partitions"""
|
||||||
try:
|
try:
|
||||||
# Query partition information
|
# SECURITY FIX: Validate identifiers and use parameterized query
|
||||||
partition_sql = f"""
|
auth_context = get_auth_context()
|
||||||
|
|
||||||
|
try:
|
||||||
|
validate_identifier(table_name, "table name")
|
||||||
|
if schema_name:
|
||||||
|
validate_identifier(schema_name, "schema name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid identifier rejected: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Use parameterized query for safety
|
||||||
|
partition_sql = """
|
||||||
SELECT
|
SELECT
|
||||||
partition_name,
|
partition_name,
|
||||||
partition_description,
|
partition_description,
|
||||||
@@ -1163,15 +1232,19 @@ class PerformanceAnalyticsTools:
|
|||||||
data_length,
|
data_length,
|
||||||
create_time
|
create_time
|
||||||
FROM information_schema.partitions
|
FROM information_schema.partitions
|
||||||
WHERE table_schema = '{schema_name or ""}'
|
WHERE table_schema = %s
|
||||||
AND table_name = '{table_name}'
|
AND table_name = %s
|
||||||
AND partition_name IS NOT NULL
|
AND partition_name IS NOT NULL
|
||||||
AND create_time IS NOT NULL
|
AND create_time IS NOT NULL
|
||||||
AND create_time >= DATE_SUB(NOW(), INTERVAL {days} DAY)
|
AND create_time >= DATE_SUB(NOW(), INTERVAL %s DAY)
|
||||||
ORDER BY create_time DESC
|
ORDER BY create_time DESC
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(partition_sql)
|
result = await connection.execute(
|
||||||
|
partition_sql,
|
||||||
|
params=(schema_name or "", table_name, days),
|
||||||
|
auth_context=auth_context
|
||||||
|
)
|
||||||
if not result.data:
|
if not result.data:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@@ -1210,6 +1283,9 @@ class PerformanceAnalyticsTools:
|
|||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
"""Get historical growth data based on timestamp fields"""
|
"""Get historical growth data based on timestamp fields"""
|
||||||
try:
|
try:
|
||||||
|
# SECURITY FIX: Get auth_context
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
|
||||||
# Find possible timestamp fields
|
# Find possible timestamp fields
|
||||||
timestamp_columns = await self._find_timestamp_columns(connection, table_name, schema_name)
|
timestamp_columns = await self._find_timestamp_columns(connection, table_name, schema_name)
|
||||||
if not timestamp_columns:
|
if not timestamp_columns:
|
||||||
@@ -1218,20 +1294,29 @@ class PerformanceAnalyticsTools:
|
|||||||
# Use best timestamp field for analysis
|
# Use best timestamp field for analysis
|
||||||
time_column = timestamp_columns[0]
|
time_column = timestamp_columns[0]
|
||||||
|
|
||||||
# Aggregate data by date
|
# SECURITY FIX: Validate time_column before using in SQL
|
||||||
|
try:
|
||||||
|
validate_identifier(time_column, "column name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid column name rejected: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
quoted_time_column = quote_identifier(time_column, "column name")
|
||||||
|
|
||||||
|
# Aggregate data by date (full_table_name should be validated by caller)
|
||||||
growth_sql = f"""
|
growth_sql = f"""
|
||||||
SELECT
|
SELECT
|
||||||
DATE({time_column}) as date,
|
DATE({quoted_time_column}) as date,
|
||||||
COUNT(*) as daily_records,
|
COUNT(*) as daily_records,
|
||||||
COUNT(*) / SUM(COUNT(*)) OVER() * 100 as percentage
|
COUNT(*) / SUM(COUNT(*)) OVER() * 100 as percentage
|
||||||
FROM {full_table_name}
|
FROM {full_table_name}
|
||||||
WHERE {time_column} >= DATE_SUB(NOW(), INTERVAL {days} DAY)
|
WHERE {quoted_time_column} >= DATE_SUB(NOW(), INTERVAL %s DAY)
|
||||||
AND {time_column} IS NOT NULL
|
AND {quoted_time_column} IS NOT NULL
|
||||||
GROUP BY DATE({time_column})
|
GROUP BY DATE({quoted_time_column})
|
||||||
ORDER BY date DESC
|
ORDER BY date DESC
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(growth_sql)
|
result = await connection.execute(growth_sql, params=(days,), auth_context=auth_context)
|
||||||
if not result.data:
|
if not result.data:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@@ -1257,11 +1342,22 @@ class PerformanceAnalyticsTools:
|
|||||||
async def _find_timestamp_columns(self, connection, table_name: str, schema_name: str) -> List[str]:
|
async def _find_timestamp_columns(self, connection, table_name: str, schema_name: str) -> List[str]:
|
||||||
"""Find timestamp fields in table"""
|
"""Find timestamp fields in table"""
|
||||||
try:
|
try:
|
||||||
timestamp_sql = f"""
|
# SECURITY FIX: Validate identifiers and use parameterized query
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
|
||||||
|
try:
|
||||||
|
validate_identifier(table_name, "table name")
|
||||||
|
if schema_name:
|
||||||
|
validate_identifier(schema_name, "schema name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid identifier rejected: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
timestamp_sql = """
|
||||||
SELECT column_name, data_type
|
SELECT column_name, data_type
|
||||||
FROM information_schema.columns
|
FROM information_schema.columns
|
||||||
WHERE table_schema = '{schema_name or ""}'
|
WHERE table_schema = %s
|
||||||
AND table_name = '{table_name}'
|
AND table_name = %s
|
||||||
AND (
|
AND (
|
||||||
data_type IN ('datetime', 'timestamp', 'date')
|
data_type IN ('datetime', 'timestamp', 'date')
|
||||||
OR column_name REGEXP '(create|insert|update|modify).*time'
|
OR column_name REGEXP '(create|insert|update|modify).*time'
|
||||||
@@ -1278,9 +1374,16 @@ class PerformanceAnalyticsTools:
|
|||||||
END
|
END
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(timestamp_sql)
|
result = await connection.execute(
|
||||||
|
timestamp_sql,
|
||||||
|
params=(schema_name or "", table_name),
|
||||||
|
auth_context=auth_context
|
||||||
|
)
|
||||||
return [row["column_name"] for row in result.data] if result.data else []
|
return [row["column_name"] for row in result.data] if result.data else []
|
||||||
|
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Security validation failed: {str(e)}")
|
||||||
|
return []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to find timestamp columns: {str(e)}")
|
logger.warning(f"Failed to find timestamp columns: {str(e)}")
|
||||||
return []
|
return []
|
||||||
@@ -1290,8 +1393,22 @@ class PerformanceAnalyticsTools:
|
|||||||
) -> List[Dict]:
|
) -> List[Dict]:
|
||||||
"""Estimate growth data based on audit logs"""
|
"""Estimate growth data based on audit logs"""
|
||||||
try:
|
try:
|
||||||
|
# SECURITY FIX: Validate table_name and use parameterized query
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
|
||||||
|
try:
|
||||||
|
validate_identifier(table_name.split(".")[-1], "table name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid table name rejected: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Extract just the table name for LIKE pattern
|
||||||
|
table_name_only = table_name.split(".")[-1]
|
||||||
|
like_pattern_full = f"%{table_name}%"
|
||||||
|
like_pattern_short = f"%{table_name_only}%"
|
||||||
|
|
||||||
# Analyze operation history for this table
|
# Analyze operation history for this table
|
||||||
audit_sql = f"""
|
audit_sql = """
|
||||||
SELECT
|
SELECT
|
||||||
DATE(`time`) as operation_date,
|
DATE(`time`) as operation_date,
|
||||||
COUNT(*) as operation_count,
|
COUNT(*) as operation_count,
|
||||||
@@ -1299,17 +1416,21 @@ class PerformanceAnalyticsTools:
|
|||||||
SUM(CASE WHEN stmt LIKE 'UPDATE%' THEN 1 ELSE 0 END) as update_count,
|
SUM(CASE WHEN stmt LIKE 'UPDATE%' THEN 1 ELSE 0 END) as update_count,
|
||||||
SUM(CASE WHEN stmt LIKE 'DELETE%' THEN 1 ELSE 0 END) as delete_count
|
SUM(CASE WHEN stmt LIKE 'DELETE%' THEN 1 ELSE 0 END) as delete_count
|
||||||
FROM internal.__internal_schema.audit_log
|
FROM internal.__internal_schema.audit_log
|
||||||
WHERE `time` >= DATE_SUB(NOW(), INTERVAL {days} DAY)
|
WHERE `time` >= DATE_SUB(NOW(), INTERVAL %s DAY)
|
||||||
AND stmt IS NOT NULL
|
AND stmt IS NOT NULL
|
||||||
AND (
|
AND (
|
||||||
stmt LIKE '%{table_name}%'
|
stmt LIKE %s
|
||||||
OR stmt LIKE '%{table_name.split(".")[-1]}%'
|
OR stmt LIKE %s
|
||||||
)
|
)
|
||||||
GROUP BY DATE(`time`)
|
GROUP BY DATE(`time`)
|
||||||
ORDER BY operation_date DESC
|
ORDER BY operation_date DESC
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(audit_sql)
|
result = await connection.execute(
|
||||||
|
audit_sql,
|
||||||
|
params=(days, like_pattern_full, like_pattern_short),
|
||||||
|
auth_context=auth_context
|
||||||
|
)
|
||||||
if not result.data:
|
if not result.data:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from decimal import Decimal
|
|||||||
|
|
||||||
from .db import DorisConnectionManager, QueryResult
|
from .db import DorisConnectionManager, QueryResult
|
||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
|
from .sql_security_utils import get_auth_context
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -497,7 +498,8 @@ class DorisQueryExecutor:
|
|||||||
explain_sql = f"EXPLAIN {sql}"
|
explain_sql = f"EXPLAIN {sql}"
|
||||||
|
|
||||||
connection = await self.connection_manager.get_connection(session_id)
|
connection = await self.connection_manager.get_connection(session_id)
|
||||||
result = await connection.execute(explain_sql)
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(explain_sql, auth_context=auth_context)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"query": sql,
|
"query": sql,
|
||||||
|
|||||||
@@ -32,6 +32,11 @@ from datetime import datetime, timedelta
|
|||||||
|
|
||||||
# Import unified logging configuration
|
# Import unified logging configuration
|
||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
|
from .sql_security_utils import (
|
||||||
|
SQLSecurityError,
|
||||||
|
validate_identifier,
|
||||||
|
quote_identifier
|
||||||
|
)
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
@@ -431,6 +436,16 @@ class MetadataExtractor:
|
|||||||
logger.warning("Database name not specified")
|
logger.warning("Database name not specified")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
# SECURITY FIX: Validate identifiers to prevent SQL injection
|
||||||
|
try:
|
||||||
|
validate_identifier(table_name, "table name")
|
||||||
|
validate_identifier(db_name, "database name")
|
||||||
|
if effective_catalog:
|
||||||
|
validate_identifier(effective_catalog, "catalog name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid identifier rejected in get_table_schema: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
cache_key = f"schema_{effective_catalog or 'default'}_{db_name}_{table_name}"
|
cache_key = f"schema_{effective_catalog or 'default'}_{db_name}_{table_name}"
|
||||||
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
|
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
|
||||||
return self.metadata_cache[cache_key]
|
return self.metadata_cache[cache_key]
|
||||||
@@ -536,6 +551,16 @@ class MetadataExtractor:
|
|||||||
logger.warning("Database name not specified")
|
logger.warning("Database name not specified")
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
# SECURITY FIX: Validate identifiers
|
||||||
|
try:
|
||||||
|
validate_identifier(table_name, "table name")
|
||||||
|
validate_identifier(db_name, "database name")
|
||||||
|
if effective_catalog:
|
||||||
|
validate_identifier(effective_catalog, "catalog name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid identifier rejected: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
cache_key = f"table_comment_{effective_catalog or 'default'}_{db_name}_{table_name}"
|
cache_key = f"table_comment_{effective_catalog or 'default'}_{db_name}_{table_name}"
|
||||||
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
|
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
|
||||||
return self.metadata_cache[cache_key]
|
return self.metadata_cache[cache_key]
|
||||||
@@ -587,6 +612,16 @@ class MetadataExtractor:
|
|||||||
logger.warning("Database name not specified")
|
logger.warning("Database name not specified")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
# SECURITY FIX: Validate identifiers
|
||||||
|
try:
|
||||||
|
validate_identifier(table_name, "table name")
|
||||||
|
validate_identifier(db_name, "database name")
|
||||||
|
if effective_catalog:
|
||||||
|
validate_identifier(effective_catalog, "catalog name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid identifier rejected: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
cache_key = f"column_comments_{effective_catalog or 'default'}_{db_name}_{table_name}"
|
cache_key = f"column_comments_{effective_catalog or 'default'}_{db_name}_{table_name}"
|
||||||
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
|
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
|
||||||
return self.metadata_cache[cache_key]
|
return self.metadata_cache[cache_key]
|
||||||
@@ -643,17 +678,30 @@ class MetadataExtractor:
|
|||||||
logger.error("Database name not specified")
|
logger.error("Database name not specified")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
# SECURITY FIX: Validate identifiers
|
||||||
|
try:
|
||||||
|
validate_identifier(table_name, "table name")
|
||||||
|
validate_identifier(db_name, "database name")
|
||||||
|
if effective_catalog:
|
||||||
|
validate_identifier(effective_catalog, "catalog name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid identifier rejected: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
cache_key = f"indexes_{effective_catalog or 'default'}_{db_name}_{table_name}"
|
cache_key = f"indexes_{effective_catalog or 'default'}_{db_name}_{table_name}"
|
||||||
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
|
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
|
||||||
return self.metadata_cache[cache_key]
|
return self.metadata_cache[cache_key]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Build query with catalog prefix if specified
|
# Build query with catalog prefix if specified (identifiers already validated)
|
||||||
|
safe_table = quote_identifier(table_name, "table name")
|
||||||
|
safe_db = quote_identifier(db_name, "database name")
|
||||||
if effective_catalog:
|
if effective_catalog:
|
||||||
query = f"SHOW INDEX FROM `{effective_catalog}`.`{db_name}`.`{table_name}`"
|
safe_catalog = quote_identifier(effective_catalog, "catalog name")
|
||||||
|
query = f"SHOW INDEX FROM {safe_catalog}.{safe_db}.{safe_table}"
|
||||||
logger.info(f"Using three-part naming for index query: {query}")
|
logger.info(f"Using three-part naming for index query: {query}")
|
||||||
else:
|
else:
|
||||||
query = f"SHOW INDEX FROM `{db_name}`.`{table_name}`"
|
query = f"SHOW INDEX FROM {safe_db}.{safe_table}"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# NOTE: Deprecated sync path retained for compatibility; use async variant instead.
|
# NOTE: Deprecated sync path retained for compatibility; use async variant instead.
|
||||||
@@ -1188,12 +1236,28 @@ class MetadataExtractor:
|
|||||||
try:
|
try:
|
||||||
# Use async query method
|
# Use async query method
|
||||||
effective_catalog = catalog_name or self.catalog_name
|
effective_catalog = catalog_name or self.catalog_name
|
||||||
|
effective_db = db_name or self.db_name
|
||||||
|
|
||||||
|
# SECURITY FIX: Validate identifiers
|
||||||
|
try:
|
||||||
|
validate_identifier(table_name, "table name")
|
||||||
|
if effective_db:
|
||||||
|
validate_identifier(effective_db, "database name")
|
||||||
|
if effective_catalog and effective_catalog != "internal":
|
||||||
|
validate_identifier(effective_catalog, "catalog name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid identifier rejected: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Build query statement using safe identifiers
|
||||||
|
safe_table = quote_identifier(table_name, "table name")
|
||||||
|
safe_db = quote_identifier(effective_db, "database name") if effective_db else None
|
||||||
|
|
||||||
# Build query statement
|
|
||||||
if effective_catalog and effective_catalog != "internal":
|
if effective_catalog and effective_catalog != "internal":
|
||||||
query = f"DESCRIBE `{effective_catalog}`.`{db_name or self.db_name}`.`{table_name}`"
|
safe_catalog = quote_identifier(effective_catalog, "catalog name")
|
||||||
|
query = f"DESCRIBE {safe_catalog}.{safe_db}.{safe_table}"
|
||||||
else:
|
else:
|
||||||
query = f"DESCRIBE `{db_name or self.db_name}`.`{table_name}`"
|
query = f"DESCRIBE {safe_db}.{safe_table}"
|
||||||
|
|
||||||
# Execute async query
|
# Execute async query
|
||||||
result = await self._execute_query_async(query, db_name)
|
result = await self._execute_query_async(query, db_name)
|
||||||
@@ -1226,8 +1290,15 @@ class MetadataExtractor:
|
|||||||
try:
|
try:
|
||||||
effective_catalog = catalog_name or self.catalog_name
|
effective_catalog = catalog_name or self.catalog_name
|
||||||
|
|
||||||
|
# SECURITY FIX: Validate catalog name if provided
|
||||||
if effective_catalog and effective_catalog != "internal":
|
if effective_catalog and effective_catalog != "internal":
|
||||||
query = f"SHOW DATABASES FROM `{effective_catalog}`"
|
try:
|
||||||
|
validate_identifier(effective_catalog, "catalog name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid catalog name rejected: {e}")
|
||||||
|
return []
|
||||||
|
safe_catalog = quote_identifier(effective_catalog, "catalog name")
|
||||||
|
query = f"SHOW DATABASES FROM {safe_catalog}"
|
||||||
else:
|
else:
|
||||||
query = "SHOW DATABASES"
|
query = "SHOW DATABASES"
|
||||||
|
|
||||||
@@ -1257,10 +1328,23 @@ class MetadataExtractor:
|
|||||||
effective_catalog = catalog_name or self.catalog_name
|
effective_catalog = catalog_name or self.catalog_name
|
||||||
effective_db = db_name or self.db_name
|
effective_db = db_name or self.db_name
|
||||||
|
|
||||||
|
# SECURITY FIX: Validate identifiers
|
||||||
|
try:
|
||||||
|
if effective_db:
|
||||||
|
validate_identifier(effective_db, "database name")
|
||||||
|
if effective_catalog and effective_catalog != "internal":
|
||||||
|
validate_identifier(effective_catalog, "catalog name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid identifier rejected: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
safe_db = quote_identifier(effective_db, "database name") if effective_db else None
|
||||||
|
|
||||||
if effective_catalog and effective_catalog != "internal":
|
if effective_catalog and effective_catalog != "internal":
|
||||||
query = f"SHOW TABLES FROM `{effective_catalog}`.`{effective_db}`"
|
safe_catalog = quote_identifier(effective_catalog, "catalog name")
|
||||||
|
query = f"SHOW TABLES FROM {safe_catalog}.{safe_db}"
|
||||||
else:
|
else:
|
||||||
query = f"SHOW TABLES FROM `{effective_db}`"
|
query = f"SHOW TABLES FROM {safe_db}"
|
||||||
|
|
||||||
result = await self._execute_query_async(query, effective_db)
|
result = await self._execute_query_async(query, effective_db)
|
||||||
|
|
||||||
@@ -1319,6 +1403,15 @@ class MetadataExtractor:
|
|||||||
effective_db = db_name or self.db_name
|
effective_db = db_name or self.db_name
|
||||||
effective_catalog = catalog_name or self.catalog_name
|
effective_catalog = catalog_name or self.catalog_name
|
||||||
|
|
||||||
|
# SECURITY FIX: Validate identifiers
|
||||||
|
try:
|
||||||
|
validate_identifier(table_name, "table name")
|
||||||
|
if effective_db:
|
||||||
|
validate_identifier(effective_db, "database name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid identifier rejected: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
query = f"""
|
query = f"""
|
||||||
SELECT
|
SELECT
|
||||||
TABLE_COMMENT
|
TABLE_COMMENT
|
||||||
@@ -1343,6 +1436,15 @@ class MetadataExtractor:
|
|||||||
effective_db = db_name or self.db_name
|
effective_db = db_name or self.db_name
|
||||||
effective_catalog = catalog_name or self.catalog_name
|
effective_catalog = catalog_name or self.catalog_name
|
||||||
|
|
||||||
|
# SECURITY FIX: Validate identifiers
|
||||||
|
try:
|
||||||
|
validate_identifier(table_name, "table name")
|
||||||
|
if effective_db:
|
||||||
|
validate_identifier(effective_db, "database name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid identifier rejected: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
query = f"""
|
query = f"""
|
||||||
SELECT
|
SELECT
|
||||||
COLUMN_NAME,
|
COLUMN_NAME,
|
||||||
@@ -1373,12 +1475,27 @@ class MetadataExtractor:
|
|||||||
effective_db = db_name or self.db_name
|
effective_db = db_name or self.db_name
|
||||||
effective_catalog = catalog_name or self.catalog_name
|
effective_catalog = catalog_name or self.catalog_name
|
||||||
|
|
||||||
# Build query with catalog prefix if specified
|
# SECURITY FIX: Validate identifiers
|
||||||
|
try:
|
||||||
|
validate_identifier(table_name, "table name")
|
||||||
|
if effective_db:
|
||||||
|
validate_identifier(effective_db, "database name")
|
||||||
|
if effective_catalog:
|
||||||
|
validate_identifier(effective_catalog, "catalog name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid identifier rejected: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Build query with catalog prefix if specified (using safe identifiers)
|
||||||
|
safe_table = quote_identifier(table_name, "table name")
|
||||||
|
safe_db = quote_identifier(effective_db, "database name") if effective_db else None
|
||||||
|
|
||||||
if effective_catalog:
|
if effective_catalog:
|
||||||
query = f"SHOW INDEX FROM `{effective_catalog}`.`{effective_db}`.`{table_name}`"
|
safe_catalog = quote_identifier(effective_catalog, "catalog name")
|
||||||
|
query = f"SHOW INDEX FROM {safe_catalog}.{safe_db}.{safe_table}"
|
||||||
logger.info(f"Using three-part naming for async index query: {query}")
|
logger.info(f"Using three-part naming for async index query: {query}")
|
||||||
else:
|
else:
|
||||||
query = f"SHOW INDEX FROM `{effective_db}`.`{table_name}`"
|
query = f"SHOW INDEX FROM {safe_db}.{safe_table}"
|
||||||
|
|
||||||
rows = await self._execute_query_async(query, effective_db)
|
rows = await self._execute_query_async(query, effective_db)
|
||||||
indexes: List[Dict[str, Any]] = []
|
indexes: List[Dict[str, Any]] = []
|
||||||
@@ -1475,21 +1592,45 @@ class MetadataExtractor:
|
|||||||
return self._format_response(success=False, error="No SQL statement provided", message="Please provide SQL statement to execute")
|
return self._format_response(success=False, error="No SQL statement provided", message="Please provide SQL statement to execute")
|
||||||
|
|
||||||
# FIX for Issue #62 Bug 3: Build context switching SQL if db_name or catalog_name is specified
|
# FIX for Issue #62 Bug 3: Build context switching SQL if db_name or catalog_name is specified
|
||||||
|
# SECURITY FIX: Validate catalog_name and db_name to prevent SQL injection
|
||||||
final_sql = sql
|
final_sql = sql
|
||||||
if catalog_name or db_name:
|
if catalog_name or db_name:
|
||||||
context_statements = []
|
context_statements = []
|
||||||
|
|
||||||
|
# Validate and sanitize catalog_name
|
||||||
if catalog_name:
|
if catalog_name:
|
||||||
# Switch to specified catalog
|
try:
|
||||||
context_statements.append(f"USE CATALOG `{catalog_name}`")
|
validate_identifier(catalog_name, "catalog name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid catalog name rejected: {e}")
|
||||||
|
return self._format_response(
|
||||||
|
success=False,
|
||||||
|
error=f"Invalid catalog name: {catalog_name}",
|
||||||
|
message="Catalog name contains invalid characters"
|
||||||
|
)
|
||||||
|
# Use quote_identifier to safely escape the catalog name
|
||||||
|
safe_catalog = quote_identifier(catalog_name, "catalog name")
|
||||||
|
context_statements.append(f"USE CATALOG {safe_catalog}")
|
||||||
logger.debug(f"Switching to catalog: {catalog_name}")
|
logger.debug(f"Switching to catalog: {catalog_name}")
|
||||||
|
|
||||||
|
# Validate and sanitize db_name
|
||||||
if db_name:
|
if db_name:
|
||||||
# Switch to specified database
|
try:
|
||||||
|
validate_identifier(db_name, "database name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
logger.warning(f"Invalid database name rejected: {e}")
|
||||||
|
return self._format_response(
|
||||||
|
success=False,
|
||||||
|
error=f"Invalid database name: {db_name}",
|
||||||
|
message="Database name contains invalid characters"
|
||||||
|
)
|
||||||
|
# Use quote_identifier to safely escape the database name
|
||||||
|
safe_db = quote_identifier(db_name, "database name")
|
||||||
if catalog_name:
|
if catalog_name:
|
||||||
context_statements.append(f"USE `{catalog_name}`.`{db_name}`")
|
safe_catalog = quote_identifier(catalog_name, "catalog name")
|
||||||
|
context_statements.append(f"USE {safe_catalog}.{safe_db}")
|
||||||
else:
|
else:
|
||||||
context_statements.append(f"USE `{db_name}`")
|
context_statements.append(f"USE {safe_db}")
|
||||||
logger.debug(f"Switching to database: {db_name}")
|
logger.debug(f"Switching to database: {db_name}")
|
||||||
|
|
||||||
# Combine context switching with original SQL
|
# Combine context switching with original SQL
|
||||||
@@ -1551,6 +1692,36 @@ class MetadataExtractor:
|
|||||||
if not table_name:
|
if not table_name:
|
||||||
return self._format_response(success=False, error="Missing table_name parameter")
|
return self._format_response(success=False, error="Missing table_name parameter")
|
||||||
|
|
||||||
|
# SECURITY: Validate identifiers before processing
|
||||||
|
try:
|
||||||
|
validate_identifier(table_name, "table name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
return self._format_response(
|
||||||
|
success=False,
|
||||||
|
error=f"Invalid table name: {table_name}",
|
||||||
|
message="Table name contains invalid characters"
|
||||||
|
)
|
||||||
|
|
||||||
|
if db_name:
|
||||||
|
try:
|
||||||
|
validate_identifier(db_name, "database name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
return self._format_response(
|
||||||
|
success=False,
|
||||||
|
error=f"Invalid database name: {db_name}",
|
||||||
|
message="Database name contains invalid characters"
|
||||||
|
)
|
||||||
|
|
||||||
|
if catalog_name and catalog_name != "internal":
|
||||||
|
try:
|
||||||
|
validate_identifier(catalog_name, "catalog name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
return self._format_response(
|
||||||
|
success=False,
|
||||||
|
error=f"Invalid catalog name: {catalog_name}",
|
||||||
|
message="Catalog name contains invalid characters"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
schema = await self.get_table_schema_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
schema = await self.get_table_schema_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||||
|
|
||||||
@@ -1574,6 +1745,27 @@ class MetadataExtractor:
|
|||||||
"""Get list of all table names in specified database - MCP interface"""
|
"""Get list of all table names in specified database - MCP interface"""
|
||||||
logger.info(f"Getting database table list: DB: {db_name}, Catalog: {catalog_name}")
|
logger.info(f"Getting database table list: DB: {db_name}, Catalog: {catalog_name}")
|
||||||
|
|
||||||
|
# SECURITY: Validate identifiers
|
||||||
|
if db_name:
|
||||||
|
try:
|
||||||
|
validate_identifier(db_name, "database name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
return self._format_response(
|
||||||
|
success=False,
|
||||||
|
error=f"Invalid database name: {db_name}",
|
||||||
|
message="Database name contains invalid characters"
|
||||||
|
)
|
||||||
|
|
||||||
|
if catalog_name and catalog_name != "internal":
|
||||||
|
try:
|
||||||
|
validate_identifier(catalog_name, "catalog name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
return self._format_response(
|
||||||
|
success=False,
|
||||||
|
error=f"Invalid catalog name: {catalog_name}",
|
||||||
|
message="Catalog name contains invalid characters"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tables = await self.get_database_tables_async(db_name=db_name, catalog_name=catalog_name)
|
tables = await self.get_database_tables_async(db_name=db_name, catalog_name=catalog_name)
|
||||||
return self._format_response(success=True, result=tables)
|
return self._format_response(success=True, result=tables)
|
||||||
@@ -1604,6 +1796,36 @@ class MetadataExtractor:
|
|||||||
if not table_name:
|
if not table_name:
|
||||||
return self._format_response(success=False, error="Missing table_name parameter")
|
return self._format_response(success=False, error="Missing table_name parameter")
|
||||||
|
|
||||||
|
# SECURITY: Validate identifiers
|
||||||
|
try:
|
||||||
|
validate_identifier(table_name, "table name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
return self._format_response(
|
||||||
|
success=False,
|
||||||
|
error=f"Invalid table name: {table_name}",
|
||||||
|
message="Table name contains invalid characters"
|
||||||
|
)
|
||||||
|
|
||||||
|
if db_name:
|
||||||
|
try:
|
||||||
|
validate_identifier(db_name, "database name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
return self._format_response(
|
||||||
|
success=False,
|
||||||
|
error=f"Invalid database name: {db_name}",
|
||||||
|
message="Database name contains invalid characters"
|
||||||
|
)
|
||||||
|
|
||||||
|
if catalog_name and catalog_name != "internal":
|
||||||
|
try:
|
||||||
|
validate_identifier(catalog_name, "catalog name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
return self._format_response(
|
||||||
|
success=False,
|
||||||
|
error=f"Invalid catalog name: {catalog_name}",
|
||||||
|
message="Catalog name contains invalid characters"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
comment = await self.get_table_comment_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
comment = await self.get_table_comment_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||||
return self._format_response(success=True, result=comment)
|
return self._format_response(success=True, result=comment)
|
||||||
@@ -1623,6 +1845,36 @@ class MetadataExtractor:
|
|||||||
if not table_name:
|
if not table_name:
|
||||||
return self._format_response(success=False, error="Missing table_name parameter")
|
return self._format_response(success=False, error="Missing table_name parameter")
|
||||||
|
|
||||||
|
# SECURITY: Validate identifiers
|
||||||
|
try:
|
||||||
|
validate_identifier(table_name, "table name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
return self._format_response(
|
||||||
|
success=False,
|
||||||
|
error=f"Invalid table name: {table_name}",
|
||||||
|
message="Table name contains invalid characters"
|
||||||
|
)
|
||||||
|
|
||||||
|
if db_name:
|
||||||
|
try:
|
||||||
|
validate_identifier(db_name, "database name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
return self._format_response(
|
||||||
|
success=False,
|
||||||
|
error=f"Invalid database name: {db_name}",
|
||||||
|
message="Database name contains invalid characters"
|
||||||
|
)
|
||||||
|
|
||||||
|
if catalog_name and catalog_name != "internal":
|
||||||
|
try:
|
||||||
|
validate_identifier(catalog_name, "catalog name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
return self._format_response(
|
||||||
|
success=False,
|
||||||
|
error=f"Invalid catalog name: {catalog_name}",
|
||||||
|
message="Catalog name contains invalid characters"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
comments = await self.get_column_comments_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
comments = await self.get_column_comments_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||||
return self._format_response(success=True, result=comments)
|
return self._format_response(success=True, result=comments)
|
||||||
@@ -1642,6 +1894,36 @@ class MetadataExtractor:
|
|||||||
if not table_name:
|
if not table_name:
|
||||||
return self._format_response(success=False, error="Missing table_name parameter")
|
return self._format_response(success=False, error="Missing table_name parameter")
|
||||||
|
|
||||||
|
# SECURITY: Validate identifiers
|
||||||
|
try:
|
||||||
|
validate_identifier(table_name, "table name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
return self._format_response(
|
||||||
|
success=False,
|
||||||
|
error=f"Invalid table name: {table_name}",
|
||||||
|
message="Table name contains invalid characters"
|
||||||
|
)
|
||||||
|
|
||||||
|
if db_name:
|
||||||
|
try:
|
||||||
|
validate_identifier(db_name, "database name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
return self._format_response(
|
||||||
|
success=False,
|
||||||
|
error=f"Invalid database name: {db_name}",
|
||||||
|
message="Database name contains invalid characters"
|
||||||
|
)
|
||||||
|
|
||||||
|
if catalog_name and catalog_name != "internal":
|
||||||
|
try:
|
||||||
|
validate_identifier(catalog_name, "catalog name")
|
||||||
|
except SQLSecurityError as e:
|
||||||
|
return self._format_response(
|
||||||
|
success=False,
|
||||||
|
error=f"Invalid catalog name: {catalog_name}",
|
||||||
|
message="Catalog name contains invalid characters"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
indexes = await self.get_table_indexes_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
indexes = await self.get_table_indexes_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||||
return self._format_response(success=True, result=indexes)
|
return self._format_response(success=True, result=indexes)
|
||||||
|
|||||||
@@ -901,30 +901,50 @@ class SQLSecurityValidator:
|
|||||||
if not self.enable_security_check:
|
if not self.enable_security_check:
|
||||||
self.logger.debug("SQL security check is disabled, allowing all queries")
|
self.logger.debug("SQL security check is disabled, allowing all queries")
|
||||||
return ValidationResult(is_valid=True)
|
return ValidationResult(is_valid=True)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Parse SQL statement
|
# SECURITY FIX: Parse ALL SQL statements, not just the first one
|
||||||
parsed = sqlparse.parse(sql)[0]
|
# This prevents bypassing security checks by injecting additional statements
|
||||||
|
all_statements = sqlparse.parse(sql)
|
||||||
|
|
||||||
# Check blocked operations first (more specific)
|
if not all_statements:
|
||||||
keyword_result = await self._check_blocked_keywords(parsed)
|
return ValidationResult(
|
||||||
if not keyword_result.is_valid:
|
is_valid=False,
|
||||||
return keyword_result
|
error_message="Empty or invalid SQL statement",
|
||||||
|
risk_level="medium"
|
||||||
|
)
|
||||||
|
|
||||||
# Check SQL injection risks
|
# SECURITY FIX: Validate each statement individually
|
||||||
injection_result = await self._check_sql_injection(sql, parsed)
|
for idx, parsed in enumerate(all_statements):
|
||||||
if not injection_result.is_valid:
|
# Skip empty statements (e.g., from trailing semicolons)
|
||||||
return injection_result
|
if not parsed.tokens or str(parsed).strip() == '':
|
||||||
|
continue
|
||||||
|
|
||||||
# Check query complexity
|
self.logger.debug(f"Validating SQL statement {idx + 1}/{len(all_statements)}: {str(parsed)[:100]}...")
|
||||||
complexity_result = await self._check_query_complexity(parsed)
|
|
||||||
if not complexity_result.is_valid:
|
|
||||||
return complexity_result
|
|
||||||
|
|
||||||
# Check table access permissions
|
# Check blocked operations first (more specific)
|
||||||
table_result = await self._check_table_access(parsed, auth_context)
|
keyword_result = await self._check_blocked_keywords(parsed)
|
||||||
if not table_result.is_valid:
|
if not keyword_result.is_valid:
|
||||||
return table_result
|
keyword_result.error_message = f"Statement {idx + 1}: {keyword_result.error_message}"
|
||||||
|
return keyword_result
|
||||||
|
|
||||||
|
# Check SQL injection risks
|
||||||
|
injection_result = await self._check_sql_injection(sql, parsed)
|
||||||
|
if not injection_result.is_valid:
|
||||||
|
injection_result.error_message = f"Statement {idx + 1}: {injection_result.error_message}"
|
||||||
|
return injection_result
|
||||||
|
|
||||||
|
# Check query complexity
|
||||||
|
complexity_result = await self._check_query_complexity(parsed)
|
||||||
|
if not complexity_result.is_valid:
|
||||||
|
complexity_result.error_message = f"Statement {idx + 1}: {complexity_result.error_message}"
|
||||||
|
return complexity_result
|
||||||
|
|
||||||
|
# Check table access permissions
|
||||||
|
table_result = await self._check_table_access(parsed, auth_context)
|
||||||
|
if not table_result.is_valid:
|
||||||
|
table_result.error_message = f"Statement {idx + 1}: {table_result.error_message}"
|
||||||
|
return table_result
|
||||||
|
|
||||||
return ValidationResult(is_valid=True)
|
return ValidationResult(is_valid=True)
|
||||||
|
|
||||||
@@ -1134,6 +1154,10 @@ class SQLSecurityValidator:
|
|||||||
self, parsed: Statement, auth_context: AuthContext
|
self, parsed: Statement, auth_context: AuthContext
|
||||||
) -> ValidationResult:
|
) -> ValidationResult:
|
||||||
"""Check table access permissions"""
|
"""Check table access permissions"""
|
||||||
|
# If no auth_context, skip table access checks (rely on other security checks)
|
||||||
|
if auth_context is None:
|
||||||
|
return ValidationResult(is_valid=True)
|
||||||
|
|
||||||
# Extract table names from query
|
# Extract table names from query
|
||||||
tables = self._extract_table_names(parsed)
|
tables = self._extract_table_names(parsed)
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from collections import Counter, defaultdict
|
|||||||
|
|
||||||
from .db import DorisConnectionManager
|
from .db import DorisConnectionManager
|
||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
|
from .sql_security_utils import get_auth_context
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@@ -192,7 +193,9 @@ class SecurityAnalyticsTools:
|
|||||||
LIMIT 10000
|
LIMIT 10000
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(audit_sql)
|
# SECURITY FIX: Pass auth_context to execute
|
||||||
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(audit_sql, auth_context=auth_context)
|
||||||
return result.data if result.data else []
|
return result.data if result.data else []
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -215,7 +218,8 @@ class SecurityAnalyticsTools:
|
|||||||
LIMIT 10000
|
LIMIT 10000
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(simple_audit_sql)
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(simple_audit_sql, auth_context=auth_context)
|
||||||
return result.data if result.data else []
|
return result.data if result.data else []
|
||||||
|
|
||||||
except Exception as e2:
|
except Exception as e2:
|
||||||
@@ -498,7 +502,8 @@ class SecurityAnalyticsTools:
|
|||||||
FROM mysql.user
|
FROM mysql.user
|
||||||
"""
|
"""
|
||||||
|
|
||||||
result = await connection.execute(roles_sql)
|
auth_context = get_auth_context()
|
||||||
|
result = await connection.execute(roles_sql, auth_context=auth_context)
|
||||||
|
|
||||||
user_roles = defaultdict(list)
|
user_roles = defaultdict(list)
|
||||||
if result.data:
|
if result.data:
|
||||||
|
|||||||
301
doris_mcp_server/utils/sql_security_utils.py
Normal file
301
doris_mcp_server/utils/sql_security_utils.py
Normal file
@@ -0,0 +1,301 @@
|
|||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
"""
|
||||||
|
SQL Security Utilities Module
|
||||||
|
|
||||||
|
Provides SQL identifier validation, escaping, and safe query building utilities
|
||||||
|
to prevent SQL injection attacks.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import Optional, Tuple, List, Any
|
||||||
|
|
||||||
|
from .logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# Context variable for auth_context (set by HTTP middleware)
|
||||||
|
auth_context_var: ContextVar = ContextVar('mcp_auth_context', default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class SQLSecurityError(Exception):
|
||||||
|
"""Exception raised for SQL security validation failures"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SQLSecurityUtils:
|
||||||
|
"""
|
||||||
|
SQL Security Utilities for preventing SQL injection attacks.
|
||||||
|
|
||||||
|
Provides:
|
||||||
|
- Identifier validation (database names, table names, column names)
|
||||||
|
- Safe identifier quoting with backticks
|
||||||
|
- Safe table reference building
|
||||||
|
- Auth context retrieval from context variables
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Valid SQL identifier pattern: letters, numbers, underscores
|
||||||
|
# Must start with letter or underscore, not a number
|
||||||
|
# Supports Unicode letters for international database/table names
|
||||||
|
IDENTIFIER_PATTERN = re.compile(r'^[a-zA-Z_\u4e00-\u9fff][a-zA-Z0-9_\u4e00-\u9fff]*$')
|
||||||
|
|
||||||
|
# Maximum identifier length (MySQL/Doris standard)
|
||||||
|
MAX_IDENTIFIER_LENGTH = 64
|
||||||
|
|
||||||
|
# SQL reserved keywords that should be quoted
|
||||||
|
SQL_KEYWORDS = {
|
||||||
|
'SELECT', 'FROM', 'WHERE', 'INSERT', 'UPDATE', 'DELETE', 'DROP',
|
||||||
|
'CREATE', 'ALTER', 'TABLE', 'DATABASE', 'INDEX', 'VIEW', 'AND',
|
||||||
|
'OR', 'NOT', 'NULL', 'TRUE', 'FALSE', 'IN', 'LIKE', 'BETWEEN',
|
||||||
|
'JOIN', 'LEFT', 'RIGHT', 'INNER', 'OUTER', 'ON', 'AS', 'ORDER',
|
||||||
|
'BY', 'GROUP', 'HAVING', 'LIMIT', 'OFFSET', 'UNION', 'ALL',
|
||||||
|
'DISTINCT', 'INTO', 'VALUES', 'SET', 'DEFAULT', 'PRIMARY', 'KEY',
|
||||||
|
'FOREIGN', 'REFERENCES', 'CHECK', 'UNIQUE', 'CONSTRAINT'
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_identifier(cls, name: str, identifier_type: str = "identifier") -> str:
|
||||||
|
"""
|
||||||
|
Validate a SQL identifier (database name, table name, column name, etc.)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The identifier to validate
|
||||||
|
identifier_type: Type description for error messages (e.g., "database name", "table name")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The validated identifier (unchanged if valid)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SQLSecurityError: If the identifier is invalid
|
||||||
|
"""
|
||||||
|
if not name:
|
||||||
|
raise SQLSecurityError(f"Empty {identifier_type} is not allowed")
|
||||||
|
|
||||||
|
if not isinstance(name, str):
|
||||||
|
raise SQLSecurityError(f"Invalid {identifier_type}: must be a string, got {type(name).__name__}")
|
||||||
|
|
||||||
|
# Strip whitespace
|
||||||
|
name = name.strip()
|
||||||
|
|
||||||
|
if not name:
|
||||||
|
raise SQLSecurityError(f"Empty {identifier_type} is not allowed")
|
||||||
|
|
||||||
|
# Check length
|
||||||
|
if len(name) > cls.MAX_IDENTIFIER_LENGTH:
|
||||||
|
raise SQLSecurityError(
|
||||||
|
f"Invalid {identifier_type}: '{name[:20]}...' exceeds maximum length of {cls.MAX_IDENTIFIER_LENGTH} characters"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for dangerous characters that could be SQL injection
|
||||||
|
dangerous_chars = ["'", '"', ';', '--', '/*', '*/', '\\', '\x00']
|
||||||
|
for char in dangerous_chars:
|
||||||
|
if char in name:
|
||||||
|
raise SQLSecurityError(
|
||||||
|
f"Invalid {identifier_type}: '{name}' contains forbidden character '{char}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate pattern
|
||||||
|
if not cls.IDENTIFIER_PATTERN.match(name):
|
||||||
|
raise SQLSecurityError(
|
||||||
|
f"Invalid {identifier_type}: '{name}' contains invalid characters. "
|
||||||
|
f"Only letters, numbers, and underscores are allowed, and must start with a letter or underscore."
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"Validated {identifier_type}: {name}")
|
||||||
|
return name
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def quote_identifier(cls, name: str, identifier_type: str = "identifier") -> str:
|
||||||
|
"""
|
||||||
|
Safely quote a SQL identifier using backticks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: The identifier to quote
|
||||||
|
identifier_type: Type description for error messages
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The quoted identifier (e.g., `table_name`)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SQLSecurityError: If the identifier is invalid
|
||||||
|
"""
|
||||||
|
# First validate the identifier
|
||||||
|
validated_name = cls.validate_identifier(name, identifier_type)
|
||||||
|
|
||||||
|
# Escape any backticks within the name (double them)
|
||||||
|
escaped_name = validated_name.replace('`', '``')
|
||||||
|
|
||||||
|
return f"`{escaped_name}`"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_table_reference(
|
||||||
|
cls,
|
||||||
|
table_name: str,
|
||||||
|
db_name: Optional[str] = None,
|
||||||
|
catalog_name: Optional[str] = None,
|
||||||
|
quote: bool = True
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Build a safe, fully-qualified table reference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name: The table name (required)
|
||||||
|
db_name: The database name (optional)
|
||||||
|
catalog_name: The catalog name (optional)
|
||||||
|
quote: Whether to quote identifiers with backticks (default: True)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A safe table reference string (e.g., `catalog`.`db`.`table`)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SQLSecurityError: If any identifier is invalid
|
||||||
|
"""
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
if catalog_name:
|
||||||
|
if quote:
|
||||||
|
parts.append(cls.quote_identifier(catalog_name, "catalog name"))
|
||||||
|
else:
|
||||||
|
parts.append(cls.validate_identifier(catalog_name, "catalog name"))
|
||||||
|
|
||||||
|
if db_name:
|
||||||
|
if quote:
|
||||||
|
parts.append(cls.quote_identifier(db_name, "database name"))
|
||||||
|
else:
|
||||||
|
parts.append(cls.validate_identifier(db_name, "database name"))
|
||||||
|
|
||||||
|
if quote:
|
||||||
|
parts.append(cls.quote_identifier(table_name, "table name"))
|
||||||
|
else:
|
||||||
|
parts.append(cls.validate_identifier(table_name, "table name"))
|
||||||
|
|
||||||
|
return '.'.join(parts)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_column_reference(
|
||||||
|
cls,
|
||||||
|
column_name: str,
|
||||||
|
table_name: Optional[str] = None,
|
||||||
|
quote: bool = True
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Build a safe column reference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
column_name: The column name (required)
|
||||||
|
table_name: The table name (optional, for qualified references)
|
||||||
|
quote: Whether to quote identifiers with backticks (default: True)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A safe column reference string (e.g., `table`.`column`)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SQLSecurityError: If any identifier is invalid
|
||||||
|
"""
|
||||||
|
parts = []
|
||||||
|
|
||||||
|
if table_name:
|
||||||
|
if quote:
|
||||||
|
parts.append(cls.quote_identifier(table_name, "table name"))
|
||||||
|
else:
|
||||||
|
parts.append(cls.validate_identifier(table_name, "table name"))
|
||||||
|
|
||||||
|
if quote:
|
||||||
|
parts.append(cls.quote_identifier(column_name, "column name"))
|
||||||
|
else:
|
||||||
|
parts.append(cls.validate_identifier(column_name, "column name"))
|
||||||
|
|
||||||
|
return '.'.join(parts)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_and_build_where_condition(
|
||||||
|
cls,
|
||||||
|
column_name: str,
|
||||||
|
operator: str = "=",
|
||||||
|
use_param: bool = True
|
||||||
|
) -> Tuple[str, bool]:
|
||||||
|
"""
|
||||||
|
Build a safe WHERE condition for a column.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
column_name: The column name
|
||||||
|
operator: The comparison operator (=, !=, <, >, <=, >=, LIKE, IN)
|
||||||
|
use_param: Whether to use parameterized placeholder (%s)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (condition_string, needs_param)
|
||||||
|
e.g., ("`column` = %s", True) or ("`column` = DATABASE()", False)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
SQLSecurityError: If column name is invalid or operator is not allowed
|
||||||
|
"""
|
||||||
|
# Validate column name
|
||||||
|
quoted_column = cls.quote_identifier(column_name, "column name")
|
||||||
|
|
||||||
|
# Validate operator
|
||||||
|
allowed_operators = {'=', '!=', '<>', '<', '>', '<=', '>=', 'LIKE', 'IN', 'IS'}
|
||||||
|
if operator.upper() not in allowed_operators:
|
||||||
|
raise SQLSecurityError(f"Invalid operator: '{operator}'. Allowed: {allowed_operators}")
|
||||||
|
|
||||||
|
if use_param:
|
||||||
|
return f"{quoted_column} {operator} %s", True
|
||||||
|
else:
|
||||||
|
return f"{quoted_column} {operator}", False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_auth_context():
|
||||||
|
"""
|
||||||
|
Get auth_context from the context variable.
|
||||||
|
|
||||||
|
This retrieves the auth_context that was set by the HTTP middleware
|
||||||
|
during request processing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The auth_context object, or None if not available
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
auth_context = auth_context_var.get()
|
||||||
|
if auth_context:
|
||||||
|
logger.debug(f"Retrieved auth_context from context variable")
|
||||||
|
return auth_context
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not retrieve auth_context: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def set_auth_context(auth_context):
|
||||||
|
"""
|
||||||
|
Set auth_context in the context variable.
|
||||||
|
|
||||||
|
This is typically called by the HTTP middleware during request processing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auth_context: The auth_context object to set
|
||||||
|
"""
|
||||||
|
auth_context_var.set(auth_context)
|
||||||
|
logger.debug("Set auth_context in context variable")
|
||||||
|
|
||||||
|
|
||||||
|
# Convenience functions for direct use
|
||||||
|
validate_identifier = SQLSecurityUtils.validate_identifier
|
||||||
|
quote_identifier = SQLSecurityUtils.quote_identifier
|
||||||
|
build_table_reference = SQLSecurityUtils.build_table_reference
|
||||||
|
build_column_reference = SQLSecurityUtils.build_column_reference
|
||||||
|
get_auth_context = SQLSecurityUtils.get_auth_context
|
||||||
|
set_auth_context = SQLSecurityUtils.set_auth_context
|
||||||
|
|
||||||
367
test/security/test_sql_injection.py
Normal file
367
test/security/test_sql_injection.py
Normal file
@@ -0,0 +1,367 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
SQL Security Test Suite for Apache Doris MCP Server
|
||||||
|
|
||||||
|
Tests for:
|
||||||
|
1. SQL injection prevention via identifier validation
|
||||||
|
2. Multi-statement SQL parsing in security validator
|
||||||
|
3. auth_context enforcement
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import MagicMock, AsyncMock, patch
|
||||||
|
|
||||||
|
|
||||||
|
class TestSQLSecurityUtils:
|
||||||
|
"""Test cases for sql_security_utils module"""
|
||||||
|
|
||||||
|
def test_validate_identifier_accepts_valid_names(self):
|
||||||
|
"""Test that valid identifiers are accepted"""
|
||||||
|
from doris_mcp_server.utils.sql_security_utils import validate_identifier
|
||||||
|
|
||||||
|
valid_names = [
|
||||||
|
"users",
|
||||||
|
"my_table",
|
||||||
|
"Table123",
|
||||||
|
"_private_table",
|
||||||
|
"CamelCaseTable",
|
||||||
|
"table_with_numbers_123",
|
||||||
|
]
|
||||||
|
|
||||||
|
for name in valid_names:
|
||||||
|
result = validate_identifier(name, "table")
|
||||||
|
assert result == name, f"Valid identifier '{name}' should be accepted"
|
||||||
|
|
||||||
|
def test_validate_identifier_rejects_sql_injection(self):
|
||||||
|
"""Test that SQL injection attempts are rejected"""
|
||||||
|
from doris_mcp_server.utils.sql_security_utils import (
|
||||||
|
validate_identifier,
|
||||||
|
SQLSecurityError
|
||||||
|
)
|
||||||
|
|
||||||
|
injection_attempts = [
|
||||||
|
# Basic SQL injection
|
||||||
|
"'; DROP TABLE users; --",
|
||||||
|
"table' OR '1'='1",
|
||||||
|
"table'; DELETE FROM users; --",
|
||||||
|
|
||||||
|
# Union-based injection
|
||||||
|
"table' UNION SELECT * FROM passwords --",
|
||||||
|
|
||||||
|
# Comment injection
|
||||||
|
"table/**/OR/**/1=1",
|
||||||
|
"table--comment",
|
||||||
|
|
||||||
|
# Special characters
|
||||||
|
"table`; DROP TABLE users;",
|
||||||
|
'table"; DROP TABLE users;',
|
||||||
|
"table\"; DELETE FROM",
|
||||||
|
|
||||||
|
# Backtick escape attempt
|
||||||
|
"analytics`; SELECT * FROM sensitive_table;--",
|
||||||
|
|
||||||
|
# Whitespace injection
|
||||||
|
"table name with spaces",
|
||||||
|
"table\ttab",
|
||||||
|
"table\nnewline",
|
||||||
|
]
|
||||||
|
|
||||||
|
for injection in injection_attempts:
|
||||||
|
with pytest.raises(SQLSecurityError):
|
||||||
|
validate_identifier(injection, "table")
|
||||||
|
|
||||||
|
def test_validate_identifier_rejects_empty(self):
|
||||||
|
"""Test that empty identifiers are rejected"""
|
||||||
|
from doris_mcp_server.utils.sql_security_utils import (
|
||||||
|
validate_identifier,
|
||||||
|
SQLSecurityError
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(SQLSecurityError):
|
||||||
|
validate_identifier("", "table")
|
||||||
|
|
||||||
|
with pytest.raises(SQLSecurityError):
|
||||||
|
validate_identifier(None, "table")
|
||||||
|
|
||||||
|
def test_validate_identifier_rejects_too_long(self):
|
||||||
|
"""Test that identifiers exceeding max length are rejected"""
|
||||||
|
from doris_mcp_server.utils.sql_security_utils import (
|
||||||
|
validate_identifier,
|
||||||
|
SQLSecurityError
|
||||||
|
)
|
||||||
|
|
||||||
|
# Doris identifier max length is typically 64 characters
|
||||||
|
long_name = "a" * 100
|
||||||
|
with pytest.raises(SQLSecurityError):
|
||||||
|
validate_identifier(long_name, "table")
|
||||||
|
|
||||||
|
def test_quote_identifier_adds_backticks(self):
|
||||||
|
"""Test that quote_identifier properly escapes identifiers"""
|
||||||
|
from doris_mcp_server.utils.sql_security_utils import quote_identifier
|
||||||
|
|
||||||
|
assert quote_identifier("my_table", "table") == "`my_table`"
|
||||||
|
assert quote_identifier("users", "table") == "`users`"
|
||||||
|
assert quote_identifier("Table123", "table") == "`Table123`"
|
||||||
|
|
||||||
|
def test_quote_identifier_validates_first(self):
|
||||||
|
"""Test that quote_identifier validates before quoting"""
|
||||||
|
from doris_mcp_server.utils.sql_security_utils import (
|
||||||
|
quote_identifier,
|
||||||
|
SQLSecurityError
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(SQLSecurityError):
|
||||||
|
quote_identifier("'; DROP TABLE users; --", "table")
|
||||||
|
|
||||||
|
|
||||||
|
class TestSQLSecurityValidator:
|
||||||
|
"""Test cases for SQLSecurityValidator multi-statement parsing"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dict_config(self):
|
||||||
|
"""Create dictionary configuration"""
|
||||||
|
return {
|
||||||
|
"blocked_keywords": [
|
||||||
|
"DROP", "CREATE", "ALTER", "TRUNCATE",
|
||||||
|
"DELETE", "INSERT", "UPDATE",
|
||||||
|
"GRANT", "REVOKE", "EXEC", "EXECUTE"
|
||||||
|
],
|
||||||
|
"max_query_complexity": 100,
|
||||||
|
"enable_security_check": True
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_auth_context(self):
|
||||||
|
"""Create mock auth context"""
|
||||||
|
from doris_mcp_server.utils.security import AuthContext, SecurityLevel
|
||||||
|
return AuthContext(
|
||||||
|
user_id="test_user",
|
||||||
|
roles=["user"],
|
||||||
|
security_level=SecurityLevel.INTERNAL
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_validates_all_statements(self, dict_config, mock_auth_context):
|
||||||
|
"""Test that validator checks ALL SQL statements, not just the first"""
|
||||||
|
from doris_mcp_server.utils.security import SQLSecurityValidator
|
||||||
|
|
||||||
|
validator = SQLSecurityValidator(dict_config)
|
||||||
|
|
||||||
|
# Multi-statement with injection in second statement
|
||||||
|
# This should be BLOCKED
|
||||||
|
malicious_sql = "SELECT 1; DROP TABLE users; SELECT 2"
|
||||||
|
|
||||||
|
result = await validator.validate(malicious_sql, mock_auth_context)
|
||||||
|
|
||||||
|
assert not result.is_valid, "Multi-statement injection should be blocked"
|
||||||
|
# Check for either DROP keyword detection or SQL injection detection
|
||||||
|
error_upper = result.error_message.upper()
|
||||||
|
assert ("DROP" in error_upper or
|
||||||
|
"INJECTION" in error_upper or
|
||||||
|
"BLOCKED" in error_upper), f"Expected DROP/injection/blocked in: {result.error_message}"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_blocks_hidden_dangerous_statement(self, dict_config, mock_auth_context):
|
||||||
|
"""Test that dangerous statements hidden after safe ones are blocked"""
|
||||||
|
from doris_mcp_server.utils.security import SQLSecurityValidator
|
||||||
|
|
||||||
|
validator = SQLSecurityValidator(dict_config)
|
||||||
|
|
||||||
|
# Safe statement followed by dangerous one
|
||||||
|
malicious_sql = """
|
||||||
|
SELECT * FROM users WHERE id = 1;
|
||||||
|
DELETE FROM audit_log;
|
||||||
|
SELECT 1;
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = await validator.validate(malicious_sql, mock_auth_context)
|
||||||
|
|
||||||
|
assert not result.is_valid, "Hidden DELETE statement should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_allows_safe_multi_statement(self, dict_config, mock_auth_context):
|
||||||
|
"""Test that multiple safe SELECT statements are allowed"""
|
||||||
|
from doris_mcp_server.utils.security import SQLSecurityValidator
|
||||||
|
|
||||||
|
validator = SQLSecurityValidator(dict_config)
|
||||||
|
|
||||||
|
safe_sql = """
|
||||||
|
SELECT * FROM users;
|
||||||
|
SELECT COUNT(*) FROM orders;
|
||||||
|
SELECT id, name FROM products;
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = await validator.validate(safe_sql, mock_auth_context)
|
||||||
|
|
||||||
|
assert result.is_valid, f"Multiple safe SELECT statements should be allowed, got: {result.error_message}"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_context_switch_injection_blocked(self, dict_config, mock_auth_context):
|
||||||
|
"""Test that context switch SQL injection is blocked"""
|
||||||
|
from doris_mcp_server.utils.security import SQLSecurityValidator
|
||||||
|
|
||||||
|
validator = SQLSecurityValidator(dict_config)
|
||||||
|
|
||||||
|
# Simulating the exec_query_for_mcp attack vector
|
||||||
|
injected_sql = """
|
||||||
|
USE `analytics`; SELECT * FROM sensitive_table;-- `;
|
||||||
|
SELECT * FROM public_table;
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = await validator.validate(injected_sql, mock_auth_context)
|
||||||
|
|
||||||
|
# The validator should process all statements
|
||||||
|
# Even if USE is allowed, subsequent unauthorized access should be caught
|
||||||
|
# by table access checks (if configured)
|
||||||
|
|
||||||
|
|
||||||
|
class TestExecQueryForMCP:
|
||||||
|
"""Test cases for exec_query_for_mcp function"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_malicious_db_name(self):
|
||||||
|
"""Test that malicious db_name is rejected"""
|
||||||
|
from doris_mcp_server.utils.sql_security_utils import (
|
||||||
|
validate_identifier,
|
||||||
|
SQLSecurityError
|
||||||
|
)
|
||||||
|
|
||||||
|
# The attack vector from security report
|
||||||
|
malicious_db_name = "analytics`; SELECT * FROM sensitive_table;--"
|
||||||
|
|
||||||
|
with pytest.raises(SQLSecurityError):
|
||||||
|
validate_identifier(malicious_db_name, "database name")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_rejects_malicious_catalog_name(self):
|
||||||
|
"""Test that malicious catalog_name is rejected"""
|
||||||
|
from doris_mcp_server.utils.sql_security_utils import (
|
||||||
|
validate_identifier,
|
||||||
|
SQLSecurityError
|
||||||
|
)
|
||||||
|
|
||||||
|
malicious_catalog_name = "internal'; DROP DATABASE production;--"
|
||||||
|
|
||||||
|
with pytest.raises(SQLSecurityError):
|
||||||
|
validate_identifier(malicious_catalog_name, "catalog name")
|
||||||
|
|
||||||
|
|
||||||
|
class TestDependencyAnalysisTools:
|
||||||
|
"""Test cases for dependency_analysis_tools security fixes"""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_tables_metadata_rejects_injection(self):
|
||||||
|
"""Test that _get_tables_metadata rejects SQL injection"""
|
||||||
|
from doris_mcp_server.utils.sql_security_utils import (
|
||||||
|
validate_identifier,
|
||||||
|
SQLSecurityError
|
||||||
|
)
|
||||||
|
|
||||||
|
# The attack vector from security report
|
||||||
|
injection_db_name = "test_db' OR '1'='1' --"
|
||||||
|
|
||||||
|
with pytest.raises(SQLSecurityError):
|
||||||
|
validate_identifier(injection_db_name, "database name")
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuthContextEnforcement:
|
||||||
|
"""Test cases for auth_context enforcement"""
|
||||||
|
|
||||||
|
def test_execute_requires_auth_context_for_security(self):
|
||||||
|
"""Test that security checks require auth_context"""
|
||||||
|
# This test documents the expected behavior:
|
||||||
|
# When auth_context is None, security checks are skipped
|
||||||
|
# When auth_context is provided, security checks are performed
|
||||||
|
|
||||||
|
# The fix ensures all execute() calls pass auth_context
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_auth_context_returns_context(self):
|
||||||
|
"""Test that get_auth_context retrieves context from ContextVar"""
|
||||||
|
from doris_mcp_server.utils.sql_security_utils import get_auth_context
|
||||||
|
|
||||||
|
# When no context is set, should return None
|
||||||
|
result = get_auth_context()
|
||||||
|
# This is expected - context is set by HTTP middleware
|
||||||
|
assert result is None or hasattr(result, 'user_id')
|
||||||
|
|
||||||
|
|
||||||
|
class TestIntegrationScenarios:
|
||||||
|
"""Integration test scenarios for security fixes"""
|
||||||
|
|
||||||
|
def test_attack_scenario_1_permission_bypass(self):
|
||||||
|
"""
|
||||||
|
Attack Scenario 1: Permission Bypass in Multi-Tenant Environment
|
||||||
|
|
||||||
|
Expected: User can only query their own database (db_name="tenant_a_db")
|
||||||
|
Attack: Inject "tenant_a_db' OR '1'='1' --" to query ALL databases
|
||||||
|
Result: Should be BLOCKED by validate_identifier()
|
||||||
|
"""
|
||||||
|
from doris_mcp_server.utils.sql_security_utils import (
|
||||||
|
validate_identifier,
|
||||||
|
SQLSecurityError
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(SQLSecurityError):
|
||||||
|
validate_identifier("tenant_a_db' OR '1'='1' --", "database name")
|
||||||
|
|
||||||
|
def test_attack_scenario_2_union_injection(self):
|
||||||
|
"""
|
||||||
|
Attack Scenario 2: UNION-based Information Disclosure
|
||||||
|
|
||||||
|
Attack: Inject UNION SELECT to extract sensitive data
|
||||||
|
Result: Should be BLOCKED
|
||||||
|
"""
|
||||||
|
from doris_mcp_server.utils.sql_security_utils import (
|
||||||
|
validate_identifier,
|
||||||
|
SQLSecurityError
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(SQLSecurityError):
|
||||||
|
validate_identifier(
|
||||||
|
"test' UNION SELECT password FROM users --",
|
||||||
|
"database name"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_attack_scenario_3_backtick_escape(self):
|
||||||
|
"""
|
||||||
|
Attack Scenario 3: Backtick Escape Attempt
|
||||||
|
|
||||||
|
Attack: Use backticks to break out of quoted identifier
|
||||||
|
Result: Should be BLOCKED
|
||||||
|
"""
|
||||||
|
from doris_mcp_server.utils.sql_security_utils import (
|
||||||
|
validate_identifier,
|
||||||
|
SQLSecurityError
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(SQLSecurityError):
|
||||||
|
validate_identifier(
|
||||||
|
"analytics`; SELECT * FROM sensitive_table;--",
|
||||||
|
"database name"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Run tests with: pytest tests/test_sql_security.py -v
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v", "--tb=short"])
|
||||||
|
|
||||||
871
test/security/test_sql_injection_api.py
Normal file
871
test/security/test_sql_injection_api.py
Normal file
@@ -0,0 +1,871 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
SQL Injection API Integration Tests
|
||||||
|
|
||||||
|
This module tests SQL injection prevention through the MCP HTTP API.
|
||||||
|
It sends malicious payloads and verifies they are properly blocked.
|
||||||
|
|
||||||
|
Prerequisites:
|
||||||
|
- MCP server running on localhost:3000
|
||||||
|
- Run with: pytest test/security/test_sql_injection_api.py -v
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Start server first
|
||||||
|
bash start_server.sh
|
||||||
|
|
||||||
|
# Run tests
|
||||||
|
pytest test/security/test_sql_injection_api.py -v --no-cov
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import httpx
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
# Server configuration
|
||||||
|
MCP_BASE_URL = "http://localhost:3000"
|
||||||
|
MCP_ENDPOINT = f"{MCP_BASE_URL}/mcp"
|
||||||
|
HEALTH_ENDPOINT = f"{MCP_BASE_URL}/health"
|
||||||
|
TIMEOUT = 30.0
|
||||||
|
|
||||||
|
|
||||||
|
class MCPClient:
|
||||||
|
"""Simple MCP HTTP client for testing"""
|
||||||
|
|
||||||
|
def __init__(self, base_url: str = MCP_BASE_URL):
|
||||||
|
self.base_url = base_url
|
||||||
|
self.mcp_endpoint = f"{base_url}/mcp"
|
||||||
|
self.session_id: Optional[str] = None
|
||||||
|
self.request_id = 0
|
||||||
|
self.client = httpx.AsyncClient(timeout=TIMEOUT)
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
await self.client.aclose()
|
||||||
|
|
||||||
|
def _next_id(self) -> int:
|
||||||
|
self.request_id += 1
|
||||||
|
return self.request_id
|
||||||
|
|
||||||
|
async def initialize(self) -> dict:
|
||||||
|
"""Initialize MCP session"""
|
||||||
|
response = await self.client.post(
|
||||||
|
self.mcp_endpoint,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json, text/event-stream"
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": "initialize",
|
||||||
|
"params": {
|
||||||
|
"protocolVersion": "2024-11-05",
|
||||||
|
"capabilities": {},
|
||||||
|
"clientInfo": {
|
||||||
|
"name": "sql-injection-test",
|
||||||
|
"version": "1.0.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": self._next_id()
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract session ID from response header
|
||||||
|
self.session_id = response.headers.get("mcp-session-id")
|
||||||
|
return self._parse_response(response.text)
|
||||||
|
|
||||||
|
async def call_tool(self, tool_name: str, arguments: dict) -> dict:
|
||||||
|
"""Call an MCP tool"""
|
||||||
|
if not self.session_id:
|
||||||
|
await self.initialize()
|
||||||
|
|
||||||
|
response = await self.client.post(
|
||||||
|
self.mcp_endpoint,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Accept": "application/json, text/event-stream",
|
||||||
|
"mcp-session-id": self.session_id
|
||||||
|
},
|
||||||
|
json={
|
||||||
|
"jsonrpc": "2.0",
|
||||||
|
"method": "tools/call",
|
||||||
|
"params": {
|
||||||
|
"name": tool_name,
|
||||||
|
"arguments": arguments
|
||||||
|
},
|
||||||
|
"id": self._next_id()
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._parse_response(response.text)
|
||||||
|
|
||||||
|
def _parse_response(self, text: str) -> dict:
|
||||||
|
"""Parse JSON response"""
|
||||||
|
try:
|
||||||
|
return json.loads(text)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# Try SSE format
|
||||||
|
lines = text.strip().split("\n")
|
||||||
|
for line in lines:
|
||||||
|
if line.startswith("data: "):
|
||||||
|
try:
|
||||||
|
return json.loads(line[6:])
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
return {"raw": text}
|
||||||
|
|
||||||
|
|
||||||
|
def print_result(test_name: str, payload: dict, result: dict):
|
||||||
|
"""Print test result in a readable format"""
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"TEST: {test_name}")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
print(f"PAYLOAD: {json.dumps(payload, ensure_ascii=False)}")
|
||||||
|
print(f"{'-'*60}")
|
||||||
|
|
||||||
|
# Extract inner result content
|
||||||
|
if "result" in result and "content" in result.get("result", {}):
|
||||||
|
for item in result["result"]["content"]:
|
||||||
|
if item.get("type") == "text":
|
||||||
|
try:
|
||||||
|
inner = json.loads(item["text"])
|
||||||
|
print("RESPONSE:")
|
||||||
|
print(f" success: {inner.get('success')}")
|
||||||
|
if inner.get('error'):
|
||||||
|
print(f" error: {inner.get('error')}")
|
||||||
|
if inner.get('error_type'):
|
||||||
|
print(f" error_type: {inner.get('error_type')}")
|
||||||
|
if inner.get('risk_level'):
|
||||||
|
print(f" risk_level: {inner.get('risk_level')}")
|
||||||
|
if inner.get('message'):
|
||||||
|
print(f" message: {inner.get('message')}")
|
||||||
|
if inner.get('data') is not None and inner.get('success'):
|
||||||
|
data_str = json.dumps(inner.get('data'), ensure_ascii=False)
|
||||||
|
if len(data_str) > 200:
|
||||||
|
data_str = data_str[:200] + "..."
|
||||||
|
print(f" data: {data_str}")
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
print(f"RESPONSE (raw): {item.get('text', '')[:500]}")
|
||||||
|
elif "error" in result:
|
||||||
|
print(f"RESPONSE ERROR: {result['error']}")
|
||||||
|
else:
|
||||||
|
print(f"RESPONSE (raw): {json.dumps(result, ensure_ascii=False)[:500]}")
|
||||||
|
|
||||||
|
print(f"{'='*60}\n")
|
||||||
|
|
||||||
|
|
||||||
|
class TestSQLInjectionAPI:
|
||||||
|
"""Test SQL injection prevention through MCP API"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mcp_client(self):
|
||||||
|
"""Create MCP client instance"""
|
||||||
|
client = MCPClient()
|
||||||
|
yield client
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def is_server_running(self):
|
||||||
|
"""Check if MCP server is running"""
|
||||||
|
import httpx
|
||||||
|
try:
|
||||||
|
response = httpx.get(HEALTH_ENDPOINT, timeout=5.0)
|
||||||
|
return response.status_code == 200
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_server_health(self):
|
||||||
|
"""Test that MCP server is running and healthy"""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(HEALTH_ENDPOINT)
|
||||||
|
assert response.status_code == 200
|
||||||
|
data = response.json()
|
||||||
|
assert data["status"] == "healthy"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exec_query_with_drop_injection(self, mcp_client):
|
||||||
|
"""Test exec_query rejects DROP TABLE injection"""
|
||||||
|
# Classic SQL injection: append DROP TABLE
|
||||||
|
payload = {"sql": "SELECT * FROM users; DROP TABLE users; --"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("DROP TABLE Injection", payload, result)
|
||||||
|
|
||||||
|
# Should return error, not execute the DROP
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
f"DROP TABLE injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exec_query_with_union_injection(self, mcp_client):
|
||||||
|
"""Test exec_query blocks UNION-based injection attempts"""
|
||||||
|
# UNION injection to extract data from other tables
|
||||||
|
payload = {"sql": "SELECT id FROM users UNION SELECT password FROM admin_users"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("UNION Injection", payload, result)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exec_query_with_delete_injection(self, mcp_client):
|
||||||
|
"""Test exec_query rejects DELETE injection"""
|
||||||
|
payload = {"sql": "SELECT 1; DELETE FROM users WHERE 1=1; SELECT 2"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("DELETE Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
f"DELETE injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exec_query_with_update_injection(self, mcp_client):
|
||||||
|
"""Test exec_query rejects UPDATE injection"""
|
||||||
|
payload = {"sql": "SELECT 1; UPDATE users SET role='admin' WHERE id=1; --"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("UPDATE Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
f"UPDATE injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exec_query_db_name_injection(self, mcp_client):
|
||||||
|
"""Test exec_query rejects SQL injection via db_name parameter"""
|
||||||
|
# Attack vector: inject SQL via db_name parameter
|
||||||
|
payload = {"sql": "SELECT 1", "db_name": "test'; DROP TABLE users; --"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("db_name Parameter Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
f"db_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exec_query_catalog_name_injection(self, mcp_client):
|
||||||
|
"""Test exec_query rejects SQL injection via catalog_name parameter"""
|
||||||
|
# Attack vector: inject SQL via catalog_name parameter
|
||||||
|
payload = {"sql": "SELECT 1", "catalog_name": "internal`; SELECT * FROM mysql.user; --"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("catalog_name Parameter Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
f"catalog_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_table_schema_injection(self, mcp_client):
|
||||||
|
"""Test get_table_schema rejects SQL injection via table_name"""
|
||||||
|
# Attack vector: inject SQL via table_name parameter
|
||||||
|
payload = {"table_name": "users'; DROP TABLE users; --"}
|
||||||
|
result = await mcp_client.call_tool("get_table_schema", payload)
|
||||||
|
print_result("table_name Injection (get_table_schema)", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
f"table_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_table_schema_db_injection(self, mcp_client):
|
||||||
|
"""Test get_table_schema rejects SQL injection via db_name"""
|
||||||
|
payload = {"table_name": "users", "db_name": "test' OR '1'='1"}
|
||||||
|
result = await mcp_client.call_tool("get_table_schema", payload)
|
||||||
|
print_result("db_name Injection (get_table_schema)", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
f"db_name injection in get_table_schema should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_analyze_dependencies_injection(self, mcp_client):
|
||||||
|
"""Test analyze_dependencies rejects SQL injection"""
|
||||||
|
# This was the original vulnerability reported
|
||||||
|
payload = {"table_name": "users", "db_name": "test_db' OR '1'='1' --"}
|
||||||
|
result = await mcp_client.call_tool("analyze_dependencies", payload)
|
||||||
|
print_result("analyze_dependencies Injection (Original Report)", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
f"analyze_dependencies db_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stacked_queries_injection(self, mcp_client):
|
||||||
|
"""Test that stacked queries (multiple statements) are blocked"""
|
||||||
|
# Multiple statements injection
|
||||||
|
payload = {"sql": "SELECT * FROM users WHERE id = 1; INSERT INTO audit_log VALUES (NULL, 'hacked', NOW()); SELECT 1;"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("Stacked Queries (INSERT) Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
f"Stacked queries with INSERT should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_comment_based_injection(self, mcp_client):
|
||||||
|
"""Test that comment-based injection is blocked"""
|
||||||
|
# Using comments to bypass filters
|
||||||
|
payload = {"sql": "SELECT * FROM users WHERE id = 1/**/OR/**/1=1"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("Comment-based Injection", payload, result)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_hex_encoded_injection(self, mcp_client):
|
||||||
|
"""Test that hex-encoded injection attempts are handled"""
|
||||||
|
# Hex-encoded 'DROP' attempt
|
||||||
|
payload = {"sql": "SELECT 0x44524F50205441424C4520757365727320"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("Hex Encoded Injection", payload, result)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_backtick_escape_injection(self, mcp_client):
|
||||||
|
"""Test backtick escape injection is blocked"""
|
||||||
|
# Attempt to escape backtick quoting
|
||||||
|
payload = {"sql": "SELECT 1", "db_name": "analytics`; SELECT * FROM sensitive_table;--"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("Backtick Escape Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
f"Backtick escape injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_valid_query_succeeds(self, mcp_client):
|
||||||
|
"""Test that valid queries still work"""
|
||||||
|
# Simple valid query should work
|
||||||
|
payload = {"sql": "SELECT 1 AS test_value"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("Valid Query (should succeed)", payload, result)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_valid_show_databases(self, mcp_client):
|
||||||
|
"""Test that SHOW DATABASES works"""
|
||||||
|
payload = {"sql": "SHOW DATABASES"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("SHOW DATABASES (should succeed)", payload, result)
|
||||||
|
|
||||||
|
def _is_blocked_or_error(self, result: dict) -> bool:
|
||||||
|
"""Check if result indicates blocked or error"""
|
||||||
|
if not result:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check for JSON-RPC error
|
||||||
|
if "error" in result:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check for error in result content
|
||||||
|
if "result" in result:
|
||||||
|
result_content = result.get("result", {})
|
||||||
|
if isinstance(result_content, dict):
|
||||||
|
# Check for isError flag
|
||||||
|
if result_content.get("isError"):
|
||||||
|
return True
|
||||||
|
# Check content array for error messages
|
||||||
|
content = result_content.get("content", [])
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, dict):
|
||||||
|
text = item.get("text", "")
|
||||||
|
# Parse the JSON text content
|
||||||
|
try:
|
||||||
|
text_data = json.loads(text)
|
||||||
|
# Check for success: false
|
||||||
|
if text_data.get("success") is False:
|
||||||
|
return True
|
||||||
|
# Check for error field
|
||||||
|
if text_data.get("error"):
|
||||||
|
return True
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
pass
|
||||||
|
# Check text for security keywords
|
||||||
|
if any(keyword in text.lower() for keyword in [
|
||||||
|
"error", "blocked", "invalid", "security",
|
||||||
|
"injection", "denied", "forbidden", "not allowed",
|
||||||
|
"security_violation", "risk_level"
|
||||||
|
]):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check raw text response
|
||||||
|
raw = result.get("raw", "")
|
||||||
|
if isinstance(raw, str) and any(keyword in raw.lower() for keyword in [
|
||||||
|
"error", "blocked", "invalid", "security"
|
||||||
|
]):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class TestIdentifierInjectionAPI:
|
||||||
|
"""Test identifier-based SQL injection prevention"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mcp_client(self):
|
||||||
|
"""Create MCP client instance"""
|
||||||
|
client = MCPClient()
|
||||||
|
yield client
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_table_name_with_semicolon(self, mcp_client):
|
||||||
|
"""Test table name containing semicolon is rejected"""
|
||||||
|
payload = {"table_name": "users; DROP TABLE users"}
|
||||||
|
result = await mcp_client.call_tool("get_table_schema", payload)
|
||||||
|
print_result("Table Name with Semicolon", payload, result)
|
||||||
|
|
||||||
|
# Should be blocked by identifier validation
|
||||||
|
assert self._contains_error_indicator(result), \
|
||||||
|
f"Table name with semicolon should be rejected"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_table_name_with_quotes(self, mcp_client):
|
||||||
|
"""Test table name containing quotes is rejected"""
|
||||||
|
payload = {"table_name": "users' OR '1'='1"}
|
||||||
|
result = await mcp_client.call_tool("get_table_schema", payload)
|
||||||
|
print_result("Table Name with Quotes", payload, result)
|
||||||
|
|
||||||
|
assert self._contains_error_indicator(result), \
|
||||||
|
f"Table name with quotes should be rejected"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_db_name_with_special_chars(self, mcp_client):
|
||||||
|
"""Test database name with special characters is rejected"""
|
||||||
|
special_chars = [
|
||||||
|
"test;db",
|
||||||
|
"test'db",
|
||||||
|
"test\"db",
|
||||||
|
"test`db",
|
||||||
|
"test--db",
|
||||||
|
"test/*db*/",
|
||||||
|
]
|
||||||
|
|
||||||
|
for db_name in special_chars:
|
||||||
|
payload = {"table_name": "users", "db_name": db_name}
|
||||||
|
result = await mcp_client.call_tool("get_table_schema", payload)
|
||||||
|
print_result(f"Special Char in db_name: {db_name}", payload, result)
|
||||||
|
|
||||||
|
assert self._contains_error_indicator(result), \
|
||||||
|
f"db_name '{db_name}' should be rejected"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_valid_identifiers_accepted(self, mcp_client):
|
||||||
|
"""Test that valid identifiers are accepted"""
|
||||||
|
valid_names = [
|
||||||
|
"users",
|
||||||
|
"my_table",
|
||||||
|
"Table123",
|
||||||
|
"_internal_table",
|
||||||
|
]
|
||||||
|
|
||||||
|
for table_name in valid_names:
|
||||||
|
payload = {"table_name": table_name}
|
||||||
|
result = await mcp_client.call_tool("get_table_schema", payload)
|
||||||
|
print_result(f"Valid Identifier: {table_name}", payload, result)
|
||||||
|
|
||||||
|
def _contains_error_indicator(self, result: dict) -> bool:
|
||||||
|
"""Check if result contains error indicators"""
|
||||||
|
if not result:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check for JSON-RPC error
|
||||||
|
if "error" in result:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check result content
|
||||||
|
result_str = json.dumps(result).lower()
|
||||||
|
error_keywords = [
|
||||||
|
"error", "invalid", "illegal", "blocked",
|
||||||
|
"security", "injection", "denied", "forbidden"
|
||||||
|
]
|
||||||
|
|
||||||
|
return any(keyword in result_str for keyword in error_keywords)
|
||||||
|
|
||||||
|
|
||||||
|
class TestMultiStatementInjectionAPI:
|
||||||
|
"""Test multi-statement SQL injection prevention"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mcp_client(self):
|
||||||
|
"""Create MCP client instance"""
|
||||||
|
client = MCPClient()
|
||||||
|
yield client
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_hidden_drop_after_select(self, mcp_client):
|
||||||
|
"""Test DROP hidden after legitimate SELECT is blocked"""
|
||||||
|
payload = {"sql": "SELECT id, name FROM users WHERE status = 'active'; DROP TABLE audit_log; SELECT 1;"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("Hidden DROP after SELECT", payload, result)
|
||||||
|
|
||||||
|
assert self._is_dangerous_blocked(result), \
|
||||||
|
f"Hidden DROP statement should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_hidden_truncate_after_select(self, mcp_client):
|
||||||
|
"""Test TRUNCATE hidden after SELECT is blocked"""
|
||||||
|
payload = {"sql": "SELECT 1; TRUNCATE TABLE users"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("Hidden TRUNCATE after SELECT", payload, result)
|
||||||
|
|
||||||
|
assert self._is_dangerous_blocked(result), \
|
||||||
|
f"Hidden TRUNCATE should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_hidden_grant_after_select(self, mcp_client):
|
||||||
|
"""Test GRANT hidden after SELECT is blocked"""
|
||||||
|
payload = {"sql": "SELECT 1; GRANT ALL ON *.* TO 'hacker'@'%'"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("Hidden GRANT after SELECT", payload, result)
|
||||||
|
|
||||||
|
assert self._is_dangerous_blocked(result), \
|
||||||
|
f"Hidden GRANT should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_safe_selects_allowed(self, mcp_client):
|
||||||
|
"""Test that multiple SELECT statements may be allowed"""
|
||||||
|
payload = {"sql": "SELECT 1; SELECT 2; SELECT 3;"}
|
||||||
|
result = await mcp_client.call_tool("exec_query", payload)
|
||||||
|
print_result("Multiple Safe SELECTs", payload, result)
|
||||||
|
|
||||||
|
def _is_dangerous_blocked(self, result: dict) -> bool:
|
||||||
|
"""Check if dangerous operation was blocked"""
|
||||||
|
if not result:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check for error
|
||||||
|
if "error" in result:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check result content for blocking indicators
|
||||||
|
result_str = json.dumps(result).lower()
|
||||||
|
block_indicators = [
|
||||||
|
"drop", "truncate", "grant", "revoke",
|
||||||
|
"blocked", "denied", "forbidden", "not allowed",
|
||||||
|
"security", "error"
|
||||||
|
]
|
||||||
|
|
||||||
|
return any(indicator in result_str for indicator in block_indicators)
|
||||||
|
|
||||||
|
|
||||||
|
class TestADBCQueryInjectionAPI:
|
||||||
|
"""Test ADBC query SQL injection prevention"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mcp_client(self):
|
||||||
|
"""Create MCP client instance"""
|
||||||
|
client = MCPClient()
|
||||||
|
yield client
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exec_adbc_query_drop_injection(self, mcp_client):
|
||||||
|
"""Test exec_adbc_query rejects DROP TABLE injection"""
|
||||||
|
payload = {"sql": "SELECT * FROM users; DROP TABLE users; --"}
|
||||||
|
result = await mcp_client.call_tool("exec_adbc_query", payload)
|
||||||
|
print_result("ADBC DROP TABLE Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"ADBC DROP TABLE injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exec_adbc_query_delete_injection(self, mcp_client):
|
||||||
|
"""Test exec_adbc_query rejects DELETE injection"""
|
||||||
|
payload = {"sql": "SELECT 1; DELETE FROM users; --"}
|
||||||
|
result = await mcp_client.call_tool("exec_adbc_query", payload)
|
||||||
|
print_result("ADBC DELETE Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"ADBC DELETE injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_exec_adbc_query_valid(self, mcp_client):
|
||||||
|
"""Test exec_adbc_query allows valid queries"""
|
||||||
|
payload = {"sql": "SELECT 1 AS test"}
|
||||||
|
result = await mcp_client.call_tool("exec_adbc_query", payload)
|
||||||
|
print_result("ADBC Valid Query", payload, result)
|
||||||
|
|
||||||
|
def _is_blocked_or_error(self, result: dict) -> bool:
|
||||||
|
"""Check if result indicates blocked or error"""
|
||||||
|
if not result:
|
||||||
|
return True
|
||||||
|
if "error" in result:
|
||||||
|
return True
|
||||||
|
result_str = json.dumps(result).lower()
|
||||||
|
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetadataToolsInjectionAPI:
|
||||||
|
"""Test metadata tools SQL injection prevention"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mcp_client(self):
|
||||||
|
"""Create MCP client instance"""
|
||||||
|
client = MCPClient()
|
||||||
|
yield client
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_db_table_list_db_injection(self, mcp_client):
|
||||||
|
"""Test get_db_table_list rejects db_name injection"""
|
||||||
|
payload = {"db_name": "test'; DROP TABLE users; --"}
|
||||||
|
result = await mcp_client.call_tool("get_db_table_list", payload)
|
||||||
|
print_result("get_db_table_list db_name Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"db_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_db_table_list_catalog_injection(self, mcp_client):
|
||||||
|
"""Test get_db_table_list rejects catalog_name injection"""
|
||||||
|
payload = {"catalog_name": "internal`; SELECT * FROM mysql.user; --"}
|
||||||
|
result = await mcp_client.call_tool("get_db_table_list", payload)
|
||||||
|
print_result("get_db_table_list catalog_name Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"catalog_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_table_comment_injection(self, mcp_client):
|
||||||
|
"""Test get_table_comment rejects table_name injection"""
|
||||||
|
payload = {"table_name": "users'; DROP TABLE users; --"}
|
||||||
|
result = await mcp_client.call_tool("get_table_comment", payload)
|
||||||
|
print_result("get_table_comment table_name Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"table_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_table_column_comments_injection(self, mcp_client):
|
||||||
|
"""Test get_table_column_comments rejects injection"""
|
||||||
|
payload = {"table_name": "users'; DROP TABLE users; --", "db_name": "test"}
|
||||||
|
result = await mcp_client.call_tool("get_table_column_comments", payload)
|
||||||
|
print_result("get_table_column_comments Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"table_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_table_indexes_injection(self, mcp_client):
|
||||||
|
"""Test get_table_indexes rejects table_name injection"""
|
||||||
|
payload = {"table_name": "users; DROP TABLE users", "db_name": "test"}
|
||||||
|
result = await mcp_client.call_tool("get_table_indexes", payload)
|
||||||
|
print_result("get_table_indexes Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"table_name injection should be blocked"
|
||||||
|
|
||||||
|
def _is_blocked_or_error(self, result: dict) -> bool:
|
||||||
|
"""Check if result indicates blocked or error"""
|
||||||
|
if not result:
|
||||||
|
return True
|
||||||
|
if "error" in result:
|
||||||
|
return True
|
||||||
|
result_str = json.dumps(result).lower()
|
||||||
|
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
|
||||||
|
|
||||||
|
|
||||||
|
class TestAnalyticsToolsInjectionAPI:
|
||||||
|
"""Test analytics tools SQL injection prevention"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mcp_client(self):
|
||||||
|
"""Create MCP client instance"""
|
||||||
|
client = MCPClient()
|
||||||
|
yield client
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_analyze_columns_table_injection(self, mcp_client):
|
||||||
|
"""Test analyze_columns rejects table_name injection"""
|
||||||
|
payload = {"table_name": "users'; DROP TABLE users; --"}
|
||||||
|
result = await mcp_client.call_tool("analyze_columns", payload)
|
||||||
|
print_result("analyze_columns table_name Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"table_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_analyze_columns_db_injection(self, mcp_client):
|
||||||
|
"""Test analyze_columns rejects db_name injection"""
|
||||||
|
payload = {"table_name": "users", "db_name": "test' OR '1'='1"}
|
||||||
|
result = await mcp_client.call_tool("analyze_columns", payload)
|
||||||
|
print_result("analyze_columns db_name Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"db_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_table_basic_info_injection(self, mcp_client):
|
||||||
|
"""Test get_table_basic_info rejects injection"""
|
||||||
|
payload = {"table_name": "users; DROP TABLE audit_log"}
|
||||||
|
result = await mcp_client.call_tool("get_table_basic_info", payload)
|
||||||
|
print_result("get_table_basic_info Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"table_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_analyze_table_storage_injection(self, mcp_client):
|
||||||
|
"""Test analyze_table_storage rejects injection"""
|
||||||
|
payload = {"table_name": "users`; SELECT * FROM sensitive; --"}
|
||||||
|
result = await mcp_client.call_tool("analyze_table_storage", payload)
|
||||||
|
print_result("analyze_table_storage Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"table_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_sql_explain_injection(self, mcp_client):
|
||||||
|
"""Test get_sql_explain rejects SQL injection"""
|
||||||
|
payload = {"sql": "SELECT 1; DROP TABLE users; --"}
|
||||||
|
result = await mcp_client.call_tool("get_sql_explain", payload)
|
||||||
|
print_result("get_sql_explain SQL Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"SQL injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_sql_profile_injection(self, mcp_client):
|
||||||
|
"""Test get_sql_profile rejects SQL injection"""
|
||||||
|
payload = {"sql": "SELECT 1; DELETE FROM audit_log; --"}
|
||||||
|
result = await mcp_client.call_tool("get_sql_profile", payload)
|
||||||
|
print_result("get_sql_profile SQL Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"SQL injection should be blocked"
|
||||||
|
|
||||||
|
def _is_blocked_or_error(self, result: dict) -> bool:
|
||||||
|
"""Check if result indicates blocked or error"""
|
||||||
|
if not result:
|
||||||
|
return True
|
||||||
|
if "error" in result:
|
||||||
|
return True
|
||||||
|
result_str = json.dumps(result).lower()
|
||||||
|
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
|
||||||
|
|
||||||
|
|
||||||
|
class TestGovernanceToolsInjectionAPI:
|
||||||
|
"""Test data governance tools SQL injection prevention"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mcp_client(self):
|
||||||
|
"""Create MCP client instance"""
|
||||||
|
client = MCPClient()
|
||||||
|
yield client
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trace_column_lineage_table_injection(self, mcp_client):
|
||||||
|
"""Test trace_column_lineage rejects table_name injection"""
|
||||||
|
payload = {"table_name": "users'; DROP TABLE users; --", "column_name": "id"}
|
||||||
|
result = await mcp_client.call_tool("trace_column_lineage", payload)
|
||||||
|
print_result("trace_column_lineage table_name Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"table_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trace_column_lineage_column_injection(self, mcp_client):
|
||||||
|
"""Test trace_column_lineage rejects column_name injection"""
|
||||||
|
payload = {"table_name": "users", "column_name": "id; DROP TABLE users"}
|
||||||
|
result = await mcp_client.call_tool("trace_column_lineage", payload)
|
||||||
|
print_result("trace_column_lineage column_name Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"column_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_monitor_data_freshness_injection(self, mcp_client):
|
||||||
|
"""Test monitor_data_freshness rejects table_name injection"""
|
||||||
|
payload = {"table_name": "users`; SELECT * FROM passwords; --"}
|
||||||
|
result = await mcp_client.call_tool("monitor_data_freshness", payload)
|
||||||
|
print_result("monitor_data_freshness Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"table_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_analyze_data_access_patterns_injection(self, mcp_client):
|
||||||
|
"""Test analyze_data_access_patterns rejects injection"""
|
||||||
|
payload = {"table_name": "users' UNION SELECT password FROM admin --"}
|
||||||
|
result = await mcp_client.call_tool("analyze_data_access_patterns", payload)
|
||||||
|
print_result("analyze_data_access_patterns Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"table_name injection should be blocked"
|
||||||
|
|
||||||
|
def _is_blocked_or_error(self, result: dict) -> bool:
|
||||||
|
"""Check if result indicates blocked or error"""
|
||||||
|
if not result:
|
||||||
|
return True
|
||||||
|
if "error" in result:
|
||||||
|
return True
|
||||||
|
result_str = json.dumps(result).lower()
|
||||||
|
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
|
||||||
|
|
||||||
|
|
||||||
|
class TestPerformanceToolsInjectionAPI:
|
||||||
|
"""Test performance analytics tools SQL injection prevention"""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def mcp_client(self):
|
||||||
|
"""Create MCP client instance"""
|
||||||
|
client = MCPClient()
|
||||||
|
yield client
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_analyze_slow_queries_db_injection(self, mcp_client):
|
||||||
|
"""Test analyze_slow_queries_topn rejects db_name injection"""
|
||||||
|
payload = {"db_name": "test'; DROP TABLE audit_log; --"}
|
||||||
|
result = await mcp_client.call_tool("analyze_slow_queries_topn", payload)
|
||||||
|
print_result("analyze_slow_queries_topn db_name Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"db_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_analyze_resource_growth_db_injection(self, mcp_client):
|
||||||
|
"""Test analyze_resource_growth_curves rejects db_name injection"""
|
||||||
|
payload = {"db_name": "test`; GRANT ALL ON *.* TO 'hacker'; --"}
|
||||||
|
result = await mcp_client.call_tool("analyze_resource_growth_curves", payload)
|
||||||
|
print_result("analyze_resource_growth_curves db_name Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"db_name injection should be blocked"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_table_data_size_injection(self, mcp_client):
|
||||||
|
"""Test get_table_data_size rejects table_name injection"""
|
||||||
|
payload = {"table_name": "users; TRUNCATE TABLE logs"}
|
||||||
|
result = await mcp_client.call_tool("get_table_data_size", payload)
|
||||||
|
print_result("get_table_data_size Injection", payload, result)
|
||||||
|
|
||||||
|
assert self._is_blocked_or_error(result), \
|
||||||
|
"table_name injection should be blocked"
|
||||||
|
|
||||||
|
def _is_blocked_or_error(self, result: dict) -> bool:
|
||||||
|
"""Check if result indicates blocked or error"""
|
||||||
|
if not result:
|
||||||
|
return True
|
||||||
|
if "error" in result:
|
||||||
|
return True
|
||||||
|
result_str = json.dumps(result).lower()
|
||||||
|
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
|
||||||
|
|
||||||
|
|
||||||
|
# Pytest configuration for async tests
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def event_loop():
|
||||||
|
"""Create event loop for async tests"""
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
yield loop
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
pytest.main([__file__, "-v", "--tb=short", "-x"])
|
||||||
|
|
||||||
Reference in New Issue
Block a user