diff --git a/doris_mcp_server/tools/prompts_manager.py b/doris_mcp_server/tools/prompts_manager.py index 45b6d1a..e034e1e 100644 --- a/doris_mcp_server/tools/prompts_manager.py +++ b/doris_mcp_server/tools/prompts_manager.py @@ -31,6 +31,7 @@ from mcp.types import ( ) from ..utils.db import DorisConnectionManager +from ..utils.sql_security_utils import get_auth_context class PromptTemplate: @@ -422,7 +423,8 @@ Please generate accurate and efficient SQL queries based on the above requiremen AND table_type = 'BASE TABLE' """ - db_result = await connection.execute(db_info_sql) + auth_context = get_auth_context() + db_result = await connection.execute(db_info_sql, auth_context=auth_context) db_info = db_result.data[0] if db_result.data else {} # Get main table list @@ -438,7 +440,7 @@ Please generate accurate and efficient SQL queries based on the above requiremen LIMIT 10 """ - tables_result = await connection.execute(tables_sql) + tables_result = await connection.execute(tables_sql, auth_context=auth_context) context = f"""Current database statistics: - Total number of tables: {db_info.get("table_count", 0)} diff --git a/doris_mcp_server/tools/resources_manager.py b/doris_mcp_server/tools/resources_manager.py index bcf34e7..2a2e4ff 100644 --- a/doris_mcp_server/tools/resources_manager.py +++ b/doris_mcp_server/tools/resources_manager.py @@ -26,6 +26,7 @@ from typing import Any from mcp.types import Resource from ..utils.db import DorisConnectionManager +from ..utils.sql_security_utils import get_auth_context class TableMetadata: @@ -169,7 +170,8 @@ class DorisResourcesManager: ORDER BY table_name """ - result = await connection.execute(tables_query) + auth_context = get_auth_context() + result = await connection.execute(tables_query, auth_context=auth_context) tables = [] for row in result.data: @@ -204,7 +206,8 @@ class DorisResourcesManager: ORDER BY ordinal_position """ - result = await connection.execute(columns_query, (table_name,)) + auth_context = get_auth_context() + result = await connection.execute(columns_query, params=(table_name,), auth_context=auth_context) return [dict(row) for row in result.data] async def _get_view_metadata(self) -> list[ViewMetadata]: @@ -226,7 +229,8 @@ class DorisResourcesManager: ORDER BY table_name """ - result = await connection.execute(views_query) + auth_context = get_auth_context() + result = await connection.execute(views_query, auth_context=auth_context) views = [] for row in result.data: @@ -257,7 +261,8 @@ class DorisResourcesManager: AND table_name = %s """ - table_result = await connection.execute(table_info_query, (table_name,)) + auth_context = get_auth_context() + table_result = await connection.execute(table_info_query, params=(table_name,), auth_context=auth_context) if not table_result.data: raise ValueError(f"Table {table_name} does not exist") @@ -295,7 +300,8 @@ class DorisResourcesManager: ORDER BY index_name, seq_in_index """ - result = await connection.execute(indexes_query, (table_name,)) + auth_context = get_auth_context() + result = await connection.execute(indexes_query, params=(table_name,), auth_context=auth_context) return [dict(row) for row in result.data] async def _get_view_definition(self, view_name: str) -> str: @@ -312,7 +318,8 @@ class DorisResourcesManager: AND table_name = %s """ - result = await connection.execute(view_query, (view_name,)) + auth_context = get_auth_context() + result = await connection.execute(view_query, params=(view_name,), auth_context=auth_context) if not result.data: raise ValueError(f"View {view_name} does not exist") @@ -340,7 +347,8 @@ class DorisResourcesManager: AND table_type = 'BASE TABLE' """ - table_result = await connection.execute(table_stats_query) + auth_context = get_auth_context() + table_result = await connection.execute(table_stats_query, auth_context=auth_context) table_stats = table_result.data[0] if table_result.data else {} # Get view statistics @@ -350,7 +358,7 @@ class DorisResourcesManager: WHERE table_schema = DATABASE() """ - view_result = await connection.execute(view_stats_query) + view_result = await connection.execute(view_stats_query, auth_context=auth_context) view_stats = view_result.data[0] if view_result.data else {} stats_info = { diff --git a/doris_mcp_server/utils/adbc_query_tools.py b/doris_mcp_server/utils/adbc_query_tools.py index f6d5785..f8d6368 100644 --- a/doris_mcp_server/utils/adbc_query_tools.py +++ b/doris_mcp_server/utils/adbc_query_tools.py @@ -28,6 +28,7 @@ from typing import Any, Dict, List, Optional from ..utils.logger import get_logger from ..utils.db import DorisConnectionManager +from ..utils.sql_security_utils import get_auth_context logger = get_logger(__name__) @@ -277,7 +278,8 @@ class DorisADBCQueryTools: # Get BE nodes via SHOW BACKENDS logger.info("No BE hosts configured, getting BE node information via SHOW BACKENDS") connection = await self.connection_manager.get_connection("query") - result = await connection.execute("SHOW BACKENDS") + auth_context = get_auth_context() + result = await connection.execute("SHOW BACKENDS", auth_context=auth_context) be_hosts = [] for row in result.data: @@ -383,6 +385,20 @@ class DorisADBCQueryTools: "error_type": "no_connection" } + # SECURITY FIX: Perform SQL security validation before executing + auth_context = get_auth_context() + if self.connection_manager.security_manager: + # Always perform security validation, even without auth_context + # Use a default context for basic SQL security checks + validation_result = await self.connection_manager.security_manager.validate_sql_security(sql, auth_context) + if not validation_result.is_valid: + return { + "success": False, + "error": f"SQL security validation failed: {validation_result.error_message}", + "error_type": "security_violation", + "risk_level": validation_result.risk_level + } + cursor = self.adbc_client.cursor() start_time = time.time() diff --git a/doris_mcp_server/utils/analysis_tools.py b/doris_mcp_server/utils/analysis_tools.py index b1c9283..b3fb7e4 100644 --- a/doris_mcp_server/utils/analysis_tools.py +++ b/doris_mcp_server/utils/analysis_tools.py @@ -29,6 +29,13 @@ from pathlib import Path from .db import DorisConnectionManager from .logger import get_logger +from .sql_security_utils import ( + SQLSecurityError, + validate_identifier, + quote_identifier, + build_table_reference, + get_auth_context +) logger = get_logger(__name__) @@ -46,10 +53,17 @@ class TableAnalyzer: sample_size: int = 10 ) -> Dict[str, Any]: """Get table summary information""" + # SECURITY FIX: Validate table_name and get auth_context + try: + validate_identifier(table_name, "table name") + except SQLSecurityError as e: + raise ValueError(f"Invalid table name: {e}") + + auth_context = get_auth_context() connection = await self.connection_manager.get_connection("query") - # Get table basic information - table_info_sql = f""" + # Get table basic information using parameterized query + table_info_sql = """ SELECT table_name, table_comment, @@ -58,17 +72,17 @@ class TableAnalyzer: engine FROM information_schema.tables WHERE table_schema = DATABASE() - AND table_name = '{table_name}' + AND table_name = %s """ - table_info_result = await connection.execute(table_info_sql) + table_info_result = await connection.execute(table_info_sql, params=(table_name,), auth_context=auth_context) if not table_info_result.data: raise ValueError(f"Table {table_name} does not exist") table_info = table_info_result.data[0] - # Get column information - columns_sql = f""" + # Get column information using parameterized query + columns_sql = """ SELECT column_name, data_type, @@ -76,11 +90,11 @@ class TableAnalyzer: column_comment FROM information_schema.columns WHERE table_schema = DATABASE() - AND table_name = '{table_name}' + AND table_name = %s ORDER BY ordinal_position """ - columns_result = await connection.execute(columns_sql) + columns_result = await connection.execute(columns_sql, params=(table_name,), auth_context=auth_context) summary = { "table_name": table_info["table_name"], @@ -92,10 +106,11 @@ class TableAnalyzer: "columns": columns_result.data, } - # Get sample data + # Get sample data using quoted identifier if include_sample and sample_size > 0: - sample_sql = f"SELECT * FROM {table_name} LIMIT {sample_size}" - sample_result = await connection.execute(sample_sql) + quoted_table = quote_identifier(table_name, "table name") + sample_sql = f"SELECT * FROM {quoted_table} LIMIT {sample_size}" + sample_result = await connection.execute(sample_sql, auth_context=auth_context) summary["sample_data"] = sample_result.data return summary @@ -120,7 +135,8 @@ class TableAnalyzer: FROM {table_name} """ - basic_result = await connection.execute(basic_stats_sql) + auth_context = get_auth_context() + basic_result = await connection.execute(basic_stats_sql, auth_context=auth_context) if not basic_result.data: return { "success": False, @@ -144,7 +160,7 @@ class TableAnalyzer: LIMIT 20 """ - distribution_result = await connection.execute(distribution_sql) + distribution_result = await connection.execute(distribution_sql, auth_context=auth_context) analysis["value_distribution"] = distribution_result.data if analysis_type == "detailed": @@ -159,7 +175,7 @@ class TableAnalyzer: WHERE {column_name} IS NOT NULL """ - numeric_result = await connection.execute(numeric_stats_sql) + numeric_result = await connection.execute(numeric_stats_sql, auth_context=auth_context) if numeric_result.data: analysis.update(numeric_result.data[0]) except Exception: @@ -196,7 +212,8 @@ class TableAnalyzer: AND table_name = '{table_name}' """ - table_result = await connection.execute(table_info_sql) + auth_context = get_auth_context() + table_result = await connection.execute(table_info_sql, auth_context=auth_context) if not table_result.data: raise ValueError(f"Table {table_name} does not exist") @@ -211,7 +228,7 @@ class TableAnalyzer: AND table_name != %s """ - all_tables_result = await connection.execute(all_tables_sql, (table_name,)) + all_tables_result = await connection.execute(all_tables_sql, params=(table_name,), auth_context=auth_context) return { "center_table": table_result.data[0], @@ -291,7 +308,8 @@ class PerformanceMonitor: LIMIT 20 """ - tables_result = await connection.execute(tables_sql) + auth_context = get_auth_context() + tables_result = await connection.execute(tables_sql, auth_context=auth_context) stats = { "metric_type": "tables", "time_range": time_range, @@ -380,8 +398,14 @@ class SQLAnalyzer: logger.info(f"Generating SQL explain for query ID: {query_id}") # Switch database if specified + # SECURITY FIX: Validate and quote db_name if db_name: - await self.connection_manager.execute_query("explain_session", f"USE {db_name}") + try: + validate_identifier(db_name, "database name") + except SQLSecurityError as e: + return {"success": False, "error": f"Invalid database name: {e}"} + safe_db = quote_identifier(db_name, "database name") + await self.connection_manager.execute_query("explain_session", f"USE {safe_db}") # Construct EXPLAIN query explain_type = "EXPLAIN VERBOSE" if verbose else "EXPLAIN" @@ -515,24 +539,36 @@ class SQLAnalyzer: try: # Switch to specified database/catalog if provided + # SECURITY FIX: Validate identifiers before using in SQL if catalog_name: - await connection.execute(f"SWITCH `{catalog_name}`") + try: + validate_identifier(catalog_name, "catalog name") + except SQLSecurityError as e: + return {"success": False, "error": f"Invalid catalog name: {e}"} + safe_catalog = quote_identifier(catalog_name, "catalog name") + auth_context = get_auth_context() + await connection.execute(f"SWITCH {safe_catalog}", auth_context=auth_context) if db_name: - await connection.execute(f"USE `{db_name}`") + try: + validate_identifier(db_name, "database name") + except SQLSecurityError as e: + return {"success": False, "error": f"Invalid database name: {e}"} + safe_db = quote_identifier(db_name, "database name") + await connection.execute(f"USE {safe_db}", auth_context=auth_context) # Set trace ID for the session using session variable # According to official docs: set session_context="trace_id:your_trace_id" - await connection.execute(f'set session_context="trace_id:{trace_id}"') + await connection.execute(f'set session_context="trace_id:{trace_id}"', auth_context=auth_context) logger.info(f"Set trace ID: {trace_id}") # Enable profile - await connection.execute(f'set enable_profile=true') + await connection.execute(f'set enable_profile=true', auth_context=auth_context) logger.info(f"Enabled profile") # Execute the SQL statement logger.info(f"Executing SQL with trace ID: {sql}") start_time = time.time() - sql_result = await connection.execute(sql) + sql_result = await connection.execute(sql, auth_context=auth_context) execution_time = time.time() - start_time logger.info(f"SQL execution completed in {execution_time:.3f}s") diff --git a/doris_mcp_server/utils/data_exploration_tools.py b/doris_mcp_server/utils/data_exploration_tools.py index 1d4e41d..b3e2a76 100644 --- a/doris_mcp_server/utils/data_exploration_tools.py +++ b/doris_mcp_server/utils/data_exploration_tools.py @@ -26,6 +26,13 @@ from typing import Any, Dict, List, Optional, Union from .db import DorisConnectionManager from .logger import get_logger +from .sql_security_utils import ( + SQLSecurityError, + validate_identifier, + quote_identifier, + build_table_reference, + get_auth_context +) logger = get_logger(__name__) @@ -43,24 +50,30 @@ class DataExplorationTools: def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str: """Build full table name with catalog and database using three-part naming convention""" - # Default catalog for internal tables + # SECURITY FIX: Use build_table_reference for safe identifier handling effective_catalog = catalog_name if catalog_name else "internal" if db_name: - return f"{effective_catalog}.{db_name}.{table_name}" + return build_table_reference(table_name, db_name, effective_catalog) else: - # If no db_name provided, need to determine the current database - return f"{effective_catalog}.{table_name}" + return build_table_reference(table_name, catalog_name=effective_catalog) async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]: """Get basic table information including row count""" try: + # SECURITY FIX: Get auth_context for security validation + # table_name should already be validated by _build_full_table_name + auth_context = get_auth_context() + count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}" - result = await connection.execute(count_sql) + result = await connection.execute(count_sql, auth_context=auth_context) if result.data: return {"row_count": result.data[0]["row_count"]} return None + except SQLSecurityError as e: + logger.warning(f"Security validation failed for table {table_name}: {str(e)}") + return {"row_count": 0} except Exception as e: logger.warning(f"Failed to get basic info for table {table_name}: {str(e)}") return {"row_count": 0} @@ -68,10 +81,24 @@ class DataExplorationTools: async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]: """Get detailed column information""" try: - where_conditions = [f"table_name = '{table_name}'"] + # SECURITY FIX: Validate identifiers and use parameterized query + auth_context = get_auth_context() + + try: + validate_identifier(table_name, "table name") + if db_name: + validate_identifier(db_name, "database name") + except SQLSecurityError as e: + logger.warning(f"Invalid identifier rejected: {e}") + return [] + + # Build parameterized query + params = [table_name] + where_conditions = ["table_name = %s"] if db_name: - where_conditions.append(f"table_schema = '{db_name}'") + where_conditions.append("table_schema = %s") + params.append(db_name) else: where_conditions.append("table_schema = DATABASE()") @@ -87,9 +114,12 @@ class DataExplorationTools: ORDER BY ordinal_position """ - result = await connection.execute(columns_sql) + result = await connection.execute(columns_sql, params=tuple(params), auth_context=auth_context) return result.data if result.data else [] + except SQLSecurityError as e: + logger.warning(f"Security validation failed: {str(e)}") + return [] except Exception as e: logger.warning(f"Failed to get columns info for table {table_name}: {str(e)}") return [] @@ -177,7 +207,8 @@ class DataExplorationTools: WHERE {col_name} IS NOT NULL """ - stats_result = await connection.execute(stats_sql) + auth_context = get_auth_context() + stats_result = await connection.execute(stats_sql, auth_context=auth_context) if stats_result.data and stats_result.data[0]["count"] > 0: stats = stats_result.data[0] @@ -229,7 +260,8 @@ class DataExplorationTools: WHERE {col_name} IS NOT NULL """ - result = await connection.execute(percentile_sql) + auth_context = get_auth_context() + result = await connection.execute(percentile_sql, auth_context=auth_context) if result.data: data = result.data[0] @@ -268,7 +300,8 @@ class DataExplorationTools: WHERE {col_name} IS NOT NULL """ - result = await connection.execute(outlier_sql) + auth_context = get_auth_context() + result = await connection.execute(outlier_sql, auth_context=auth_context) if result.data: data = result.data[0] @@ -359,7 +392,8 @@ class DataExplorationTools: {sampling_info.get('sample_query_suffix', '')} """ - cardinality_result = await connection.execute(cardinality_sql) + auth_context = get_auth_context() + cardinality_result = await connection.execute(cardinality_sql, auth_context=auth_context) if cardinality_result.data: cardinality_data = cardinality_result.data[0] @@ -408,7 +442,8 @@ class DataExplorationTools: LIMIT 20 """ - result = await connection.execute(distribution_sql) + auth_context = get_auth_context() + result = await connection.execute(distribution_sql, auth_context=auth_context) if result.data: distribution = [] @@ -458,7 +493,8 @@ class DataExplorationTools: WHERE {col_name} IS NOT NULL """ - range_result = await connection.execute(range_sql) + auth_context = get_auth_context() + range_result = await connection.execute(range_sql, auth_context=auth_context) if range_result.data and range_result.data[0]["non_null_count"] > 0: range_data = range_result.data[0] @@ -539,7 +575,8 @@ class DataExplorationTools: ORDER BY day_of_week """ - weekly_result = await connection.execute(weekly_pattern_sql) + auth_context = get_auth_context() + weekly_result = await connection.execute(weekly_pattern_sql, auth_context=auth_context) weekly_pattern = [] if weekly_result.data: @@ -561,7 +598,7 @@ class DataExplorationTools: LIMIT 12 """ - monthly_result = await connection.execute(monthly_trend_sql) + monthly_result = await connection.execute(monthly_trend_sql, auth_context=auth_context) monthly_trend = "stable" # Simplified trend analysis if monthly_result.data and len(monthly_result.data) > 3: @@ -646,7 +683,8 @@ class DataExplorationTools: FROM {table_expr} """ - result = await connection.execute(null_sql) + auth_context = get_auth_context() + result = await connection.execute(null_sql, auth_context=auth_context) if result.data: data = result.data[0] total_count = data["total_count"] diff --git a/doris_mcp_server/utils/data_governance_tools.py b/doris_mcp_server/utils/data_governance_tools.py index 4589c34..ebdc11c 100644 --- a/doris_mcp_server/utils/data_governance_tools.py +++ b/doris_mcp_server/utils/data_governance_tools.py @@ -26,6 +26,13 @@ from typing import Any, Dict, List, Optional from .db import DorisConnectionManager from .logger import get_logger +from .sql_security_utils import ( + SQLSecurityError, + validate_identifier, + quote_identifier, + build_table_reference, + get_auth_context +) logger = get_logger(__name__) @@ -216,26 +223,34 @@ class DataGovernanceTools: # ==================== Private Helper Methods ==================== def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str: - """Build full table name - use three-level naming convention""" + """Build full table name - use three-level naming convention with security validation""" + # SECURITY FIX: Use build_table_reference for safe identifier handling # Default catalog is internal for internal tables effective_catalog = catalog_name if catalog_name else "internal" if db_name: - return f"{effective_catalog}.{db_name}.{table_name}" + return build_table_reference(table_name, db_name, effective_catalog) else: # If db_name is not provided, need to determine current database - return f"{effective_catalog}.{table_name}" + return build_table_reference(table_name, catalog_name=effective_catalog) async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]: """Get table basic information""" try: + # SECURITY FIX: Get auth_context for security validation + # table_name should already be validated by _build_full_table_name + auth_context = get_auth_context() + # Try to get table row count count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}" - result = await connection.execute(count_sql) + result = await connection.execute(count_sql, auth_context=auth_context) if result.data: return {"row_count": result.data[0]["row_count"]} return None + except SQLSecurityError as e: + logger.warning(f"Security validation failed for table {table_name}: {str(e)}") + return {"row_count": 0} except Exception as e: logger.warning(f"Failed to get basic info for table {table_name}: {str(e)}") return {"row_count": 0} @@ -243,11 +258,24 @@ class DataGovernanceTools: async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]: """Get table column information""" try: - # Build query conditions - where_conditions = [f"table_name = '{table_name}'"] + # SECURITY FIX: Validate identifiers and use parameterized query + auth_context = get_auth_context() + + try: + validate_identifier(table_name, "table name") + if db_name: + validate_identifier(db_name, "database name") + except SQLSecurityError as e: + logger.warning(f"Invalid identifier rejected: {e}") + return [] + + # Build parameterized query conditions + params = [table_name] + where_conditions = ["table_name = %s"] if db_name: - where_conditions.append(f"table_schema = '{db_name}'") + where_conditions.append("table_schema = %s") + params.append(db_name) else: where_conditions.append("table_schema = DATABASE()") @@ -263,30 +291,49 @@ class DataGovernanceTools: ORDER BY ordinal_position """ - result = await connection.execute(columns_sql) + result = await connection.execute(columns_sql, params=tuple(params), auth_context=auth_context) return result.data if result.data else [] + except SQLSecurityError as e: + logger.warning(f"Security validation failed: {str(e)}") + return [] except Exception as e: logger.warning(f"Failed to get columns info for table {table_name}: {str(e)}") return [] async def _analyze_column_completeness(self, connection, table_name: str, columns_info: List[Dict]) -> Dict[str, Any]: """Analyze column completeness""" + # SECURITY FIX: Get auth_context for security validation + auth_context = get_auth_context() column_completeness = {} for column in columns_info: column_name = column["column_name"] try: + # SECURITY FIX: Validate column name before using in SQL + try: + validate_identifier(column_name, "column name") + except SQLSecurityError as e: + logger.warning(f"Invalid column name rejected: {e}") + column_completeness[column_name] = { + "error": f"Invalid column name: {e}", + "completeness_score": 0.0 + } + continue + + # Use quoted identifier for column name + quoted_column = quote_identifier(column_name, "column name") + # Calculate null value statistics null_sql = f""" SELECT COUNT(*) as total_count, - COUNT({column_name}) as non_null_count, - COUNT(*) - COUNT({column_name}) as null_count + COUNT({quoted_column}) as non_null_count, + COUNT(*) - COUNT({quoted_column}) as null_count FROM {table_name} """ - result = await connection.execute(null_sql) + result = await connection.execute(null_sql, auth_context=auth_context) if result.data: stats = result.data[0] total_count = stats["total_count"] @@ -304,6 +351,12 @@ class DataGovernanceTools: "completeness_score": round(completeness_score, 4) } + except SQLSecurityError as e: + logger.warning(f"Security validation failed for column {column_name}: {str(e)}") + column_completeness[column_name] = { + "error": str(e), + "completeness_score": 0.0 + } except Exception as e: logger.warning(f"Failed to analyze completeness for column {column_name}: {str(e)}") column_completeness[column_name] = { @@ -333,7 +386,8 @@ class DataGovernanceTools: FROM {table_name} """ - result = await connection.execute(compliance_sql) + auth_context = get_auth_context() + result = await connection.execute(compliance_sql, auth_context=auth_context) if result.data: stats = result.data[0] pass_count = stats["pass_count"] or 0 @@ -378,7 +432,8 @@ class DataGovernanceTools: ) t """ - result = await connection.execute(duplicate_sql) + auth_context = get_auth_context() + result = await connection.execute(duplicate_sql, auth_context=auth_context) if result.data and result.data[0]["duplicate_count"] > 0: issues.append({ "type": "duplicate_primary_keys", @@ -456,10 +511,21 @@ class DataGovernanceTools: async def _verify_column_exists(self, connection, table_name: str, column_name: str) -> bool: """Verify if column exists""" try: - # Simple verification method: try to query the column - verify_sql = f"SELECT {column_name} FROM {table_name} LIMIT 1" - await connection.execute(verify_sql) + # SECURITY FIX: Validate and quote column name + try: + validate_identifier(column_name, "column name") + except SQLSecurityError as e: + logger.warning(f"Invalid column name rejected: {e}") + return False + + safe_column = quote_identifier(column_name, "column name") + # table_name is already safe (from _build_full_table_name) + verify_sql = f"SELECT {safe_column} FROM {table_name} LIMIT 1" + auth_context = get_auth_context() + await connection.execute(verify_sql, auth_context=auth_context) return True + except SQLSecurityError: + return False except Exception: return False @@ -469,21 +535,34 @@ class DataGovernanceTools: source_chain = [] try: + # SECURITY FIX: Validate table name and use parameterized-like approach + table_name_part = table_name.split('.')[-1] + try: + validate_identifier(table_name_part, "table name") + except SQLSecurityError as e: + logger.warning(f"Invalid table name rejected: {e}") + return [] + + # Escape special characters for LIKE pattern + safe_pattern = table_name_part.replace('%', r'\%').replace('_', r'\_') + like_pattern = f"%{safe_pattern}%" + # Try to find related INSERT/CREATE TABLE AS SELECT statements from audit logs (one year range) + auth_context = get_auth_context() audit_sql = """ SELECT stmt as sql_statement, `time` as execution_time, `user` as user_name FROM internal.__internal_schema.audit_log - WHERE stmt LIKE '%{}%' + WHERE stmt LIKE %s AND (stmt LIKE '%INSERT%' OR stmt LIKE '%CREATE%' OR stmt LIKE '%SELECT%') AND `time` >= DATE_SUB(NOW(), INTERVAL 1 YEAR) ORDER BY `time` DESC LIMIT 50 - """.format(table_name.split('.')[-1]) # Use the last part of table name + """ - result = await connection.execute(audit_sql) + result = await connection.execute(audit_sql, params=(like_pattern,), auth_context=auth_context) if result.data: for i, log_entry in enumerate(result.data[:depth]): @@ -556,19 +635,33 @@ class DataGovernanceTools: downstream_usage = [] try: + # SECURITY FIX: Validate inputs and use parameterized-like approach + table_name_part = table_name.split('.')[-1] + try: + validate_identifier(table_name_part, "table name") + validate_identifier(column_name, "column name") + except SQLSecurityError as e: + logger.warning(f"Invalid identifier rejected: {e}") + return [] + + # Escape special characters for LIKE pattern + safe_table_pattern = f"%{table_name_part.replace('%', r'\\%').replace('_', r'\\_')}%" + safe_column_pattern = f"%{column_name.replace('%', r'\\%').replace('_', r'\\_')}%" + # Find other tables that might use this field (through audit logs, one year range) + auth_context = get_auth_context() usage_sql = """ SELECT DISTINCT stmt as sql_statement FROM internal.__internal_schema.audit_log - WHERE stmt LIKE '%{}%' - AND stmt LIKE '%{}%' + WHERE stmt LIKE %s + AND stmt LIKE %s AND stmt LIKE '%SELECT%' AND `time` >= DATE_SUB(NOW(), INTERVAL 1 YEAR) LIMIT 20 - """.format(table_name.split('.')[-1], column_name) + """ - result = await connection.execute(usage_sql) + result = await connection.execute(usage_sql, params=(safe_table_pattern, safe_column_pattern), auth_context=auth_context) if result.data: for entry in result.data: @@ -634,14 +727,20 @@ class DataGovernanceTools: async def _get_all_tables(self, connection, catalog_name: Optional[str], db_name: Optional[str]) -> List[str]: """Get list of all tables""" try: - where_conditions = [] + auth_context = get_auth_context() + params = [] + # SECURITY FIX: Use parameterized query if db_name: - where_conditions.append(f"table_schema = '{db_name}'") + try: + validate_identifier(db_name, "database name") + except SQLSecurityError as e: + logger.warning(f"Invalid database name rejected: {e}") + return [] + where_clause = "table_schema = %s" + params.append(db_name) else: - where_conditions.append("table_schema = DATABASE()") - - where_clause = " AND ".join(where_conditions) if where_conditions else "1=1" + where_clause = "table_schema = DATABASE()" tables_sql = f""" SELECT table_name @@ -651,7 +750,7 @@ class DataGovernanceTools: ORDER BY table_name """ - result = await connection.execute(tables_sql) + result = await connection.execute(tables_sql, params=tuple(params) if params else None, auth_context=auth_context) return [row["table_name"] for row in result.data] if result.data else [] except Exception as e: @@ -728,15 +827,23 @@ class DataGovernanceTools: async def _get_freshness_from_partition_info(self, connection, table_name: str) -> Optional[Dict]: """Get freshness from partition information""" try: - # Query partition information (if table has partitions) - partition_sql = f""" + # SECURITY FIX: Validate and use parameterized query + table_name_part = table_name.split('.')[-1] + try: + validate_identifier(table_name_part, "table name") + except SQLSecurityError as e: + logger.warning(f"Invalid table name rejected: {e}") + return None + + auth_context = get_auth_context() + partition_sql = """ SELECT MAX(CREATE_TIME) as last_update FROM information_schema.partitions - WHERE table_name = '{table_name.split('.')[-1]}' + WHERE table_name = %s AND CREATE_TIME IS NOT NULL """ - result = await connection.execute(partition_sql) + result = await connection.execute(partition_sql, params=(table_name_part,), auth_context=auth_context) if result.data and result.data[0]["last_update"]: return { "last_update": result.data[0]["last_update"], @@ -759,7 +866,8 @@ class DataGovernanceTools: FROM {table_name} """ - result = await connection.execute(max_time_sql) + auth_context = get_auth_context() + result = await connection.execute(max_time_sql, auth_context=auth_context) if result.data and result.data[0]["last_update"]: return { "last_update": result.data[0]["last_update"], @@ -773,15 +881,23 @@ class DataGovernanceTools: async def _get_freshness_from_table_metadata(self, connection, table_name: str) -> Optional[Dict]: """Get freshness from table metadata""" try: - # Query table's update time - metadata_sql = f""" + # SECURITY FIX: Validate and use parameterized query + table_name_part = table_name.split('.')[-1] + try: + validate_identifier(table_name_part, "table name") + except SQLSecurityError as e: + logger.warning(f"Invalid table name rejected: {e}") + return None + + auth_context = get_auth_context() + metadata_sql = """ SELECT UPDATE_TIME as last_update FROM information_schema.tables - WHERE table_name = '{table_name.split('.')[-1]}' + WHERE table_name = %s AND UPDATE_TIME IS NOT NULL """ - result = await connection.execute(metadata_sql) + result = await connection.execute(metadata_sql, params=(table_name_part,), auth_context=auth_context) if result.data and result.data[0]["last_update"]: return { "last_update": result.data[0]["last_update"], @@ -795,10 +911,19 @@ class DataGovernanceTools: async def _find_timestamp_columns(self, connection, table_name: str) -> List[str]: """Find possible timestamp fields""" try: - timestamp_sql = f""" + # SECURITY FIX: Validate and use parameterized query + table_name_part = table_name.split('.')[-1] + try: + validate_identifier(table_name_part, "table name") + except SQLSecurityError as e: + logger.warning(f"Invalid table name rejected: {e}") + return [] + + auth_context = get_auth_context() + timestamp_sql = """ SELECT column_name FROM information_schema.columns - WHERE table_name = '{table_name.split('.')[-1]}' + WHERE table_name = %s AND ( data_type IN ('datetime', 'timestamp', 'date') OR column_name LIKE '%time%' @@ -815,7 +940,7 @@ class DataGovernanceTools: END """ - result = await connection.execute(timestamp_sql) + result = await connection.execute(timestamp_sql, params=(table_name_part,), auth_context=auth_context) return [row["column_name"] for row in result.data] if result.data else [] except Exception: diff --git a/doris_mcp_server/utils/data_quality_tools.py b/doris_mcp_server/utils/data_quality_tools.py index 3984bc7..bdcf4af 100644 --- a/doris_mcp_server/utils/data_quality_tools.py +++ b/doris_mcp_server/utils/data_quality_tools.py @@ -31,6 +31,12 @@ from collections import Counter, defaultdict from .db import DorisConnectionManager from .logger import get_logger from .config import DorisConfig +from .sql_security_utils import ( + SQLSecurityError, + validate_identifier, + build_table_reference, + get_auth_context +) logger = get_logger(__name__) @@ -299,23 +305,26 @@ class DataQualityTools: # =========================================== def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str: - """Build full table name""" - if catalog_name and db_name: - return f"{catalog_name}.{db_name}.{table_name}" - elif db_name: - return f"{db_name}.{table_name}" - else: - return table_name + """Build full table name with security validation""" + # SECURITY FIX: Use build_table_reference for safe identifier handling + return build_table_reference(table_name, db_name, catalog_name) async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]: """Get basic table information""" try: + # SECURITY FIX: table_name should already be validated by _build_full_table_name + # But we add auth_context for security validation + auth_context = get_auth_context() + # Try to get row count count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}" - result = await connection.execute(count_sql) + result = await connection.execute(count_sql, auth_context=auth_context) if result.data: return {"row_count": result.data[0]["row_count"]} return None + except SQLSecurityError as e: + logger.warning(f"Security validation failed: {str(e)}") + return None except Exception as e: logger.warning(f"Failed to get table basic info: {str(e)}") return None @@ -323,9 +332,13 @@ class DataQualityTools: async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]: """Get table column information""" try: - # Build DESCRIBE query - describe_sql = f"DESCRIBE {self._build_full_table_name(table_name, catalog_name, db_name)}" - result = await connection.execute(describe_sql) + # SECURITY FIX: Build safe table reference and pass auth_context + auth_context = get_auth_context() + + # Build DESCRIBE query with safe table reference + safe_table_ref = self._build_full_table_name(table_name, catalog_name, db_name) + describe_sql = f"DESCRIBE {safe_table_ref}" + result = await connection.execute(describe_sql, auth_context=auth_context) columns_info = [] if result.data: @@ -339,6 +352,9 @@ class DataQualityTools: }) return columns_info + except SQLSecurityError as e: + logger.warning(f"Security validation failed: {str(e)}") + return [] except Exception as e: logger.warning(f"Failed to get table columns info: {str(e)}") return [] @@ -346,7 +362,32 @@ class DataQualityTools: async def _get_table_partitions(self, connection, table_name: str, db_name: Optional[str] = None) -> List[Dict]: """Get table partition information""" try: - # Query partition information + # SECURITY FIX: Validate identifiers and use parameterized query + auth_context = get_auth_context() + + # Validate table_name + try: + validate_identifier(table_name, "table name") + if db_name: + validate_identifier(db_name, "database name") + except SQLSecurityError as e: + logger.warning(f"Invalid identifier rejected: {e}") + return [] + + # Build parameterized query + params = [] + where_conditions = [] + + if db_name: + where_conditions.append("TABLE_SCHEMA = %s") + params.append(db_name) + else: + where_conditions.append("TABLE_SCHEMA = ''") + + where_conditions.append("TABLE_NAME = %s") + params.append(table_name) + where_conditions.append("PARTITION_NAME IS NOT NULL") + partition_sql = f""" SELECT PARTITION_NAME, @@ -355,12 +396,10 @@ class DataQualityTools: DATA_LENGTH, INDEX_LENGTH FROM information_schema.PARTITIONS - WHERE TABLE_SCHEMA = '{db_name or ""}' - AND TABLE_NAME = '{table_name}' - AND PARTITION_NAME IS NOT NULL + WHERE {' AND '.join(where_conditions)} """ - result = await connection.execute(partition_sql) + result = await connection.execute(partition_sql, params=tuple(params), auth_context=auth_context) partitions = [] if result.data: for row in result.data: @@ -373,6 +412,9 @@ class DataQualityTools: }) return partitions + except SQLSecurityError as e: + logger.warning(f"Security validation failed: {str(e)}") + return [] except Exception as e: logger.warning(f"Failed to get table partitions: {str(e)}") return [] @@ -417,7 +459,8 @@ class DataQualityTools: if db_name else f"SHOW CREATE TABLE {table_name}" ) - result = await connection.execute(query) + auth_context = get_auth_context() + result = await connection.execute(query, auth_context=auth_context) if result.data: return result.data[0].get("Create Table") return None @@ -428,8 +471,16 @@ class DataQualityTools: async def _get_table_size_info(self, connection, table_name: str) -> Dict[str, Any]: """Get table size information""" try: - # Query table size information - size_sql = f""" + # SECURITY FIX: Validate and use parameterized query + table_name_part = table_name.split('.')[-1] + try: + validate_identifier(table_name_part, "table name") + except SQLSecurityError as e: + logger.warning(f"Invalid table name rejected: {e}") + return {"engine": "Unknown", "estimated_rows": 0, "data_length": 0, "index_length": 0, "total_size": 0} + + auth_context = get_auth_context() + size_sql = """ SELECT table_name, engine, @@ -438,10 +489,10 @@ class DataQualityTools: index_length, (data_length + index_length) as total_size FROM information_schema.tables - WHERE table_name = '{table_name.split('.')[-1]}' + WHERE table_name = %s """ - result = await connection.execute(size_sql) + result = await connection.execute(size_sql, params=(table_name_part,), auth_context=auth_context) if result.data and result.data[0]: row = result.data[0] return { @@ -582,7 +633,8 @@ class DataQualityTools: batch_sql = f"SELECT {', '.join(select_clauses)} FROM {table_expr}" - result = await connection.execute(batch_sql) + auth_context = get_auth_context() + result = await connection.execute(batch_sql, auth_context=auth_context) if not result.data: return {"error": "No data returned from batch completeness query"} @@ -664,7 +716,8 @@ class DataQualityTools: batch_sql = f"SELECT {', '.join(select_clauses)} FROM {table_expr}" - result = await connection.execute(batch_sql) + auth_context = get_auth_context() + result = await connection.execute(batch_sql, auth_context=auth_context) if not result.data: return {} @@ -705,7 +758,8 @@ class DataQualityTools: LIMIT 10 """ - result = await connection.execute(freq_sql) + auth_context = get_auth_context() + result = await connection.execute(freq_sql, auth_context=auth_context) frequencies = result.data if result.data else [] categorical_results[col_name] = { @@ -738,7 +792,8 @@ class DataQualityTools: batch_sql = f"SELECT {', '.join(select_clauses)} FROM {table_expr}" - result = await connection.execute(batch_sql) + auth_context = get_auth_context() + result = await connection.execute(batch_sql, auth_context=auth_context) if not result.data: return {} @@ -780,7 +835,8 @@ class DataQualityTools: FROM {table_expr} """ - result = await connection.execute(completeness_sql) + auth_context = get_auth_context() + result = await connection.execute(completeness_sql, auth_context=auth_context) if result.data: stats = result.data[0] total_count = stats["total_count"] @@ -906,7 +962,8 @@ class DataQualityTools: WHERE {col_name} IS NOT NULL """ - result = await connection.execute(stats_sql) + auth_context = get_auth_context() + result = await connection.execute(stats_sql, auth_context=auth_context) if result.data and result.data[0]["non_null_count"] > 0: stats = result.data[0] numeric_analysis[col_name] = { @@ -945,7 +1002,8 @@ class DataQualityTools: WHERE {col_name} IS NOT NULL """ - cardinality_result = await connection.execute(cardinality_sql) + auth_context = get_auth_context() + cardinality_result = await connection.execute(cardinality_sql, auth_context=auth_context) if cardinality_result.data: stats = cardinality_result.data[0] @@ -969,7 +1027,7 @@ class DataQualityTools: LIMIT 10 """ - top_values_result = await connection.execute(top_values_sql) + top_values_result = await connection.execute(top_values_sql, auth_context=auth_context) if top_values_result.data: categorical_analysis[col_name]["top_values"] = [ {"value": row[col_name], "count": row["count"]} @@ -998,7 +1056,8 @@ class DataQualityTools: WHERE {col_name} IS NOT NULL """ - result = await connection.execute(stats_sql) + auth_context = get_auth_context() + result = await connection.execute(stats_sql, auth_context=auth_context) if result.data and result.data[0]["non_null_count"] > 0: stats = result.data[0] temporal_analysis[col_name] = { diff --git a/doris_mcp_server/utils/dependency_analysis_tools.py b/doris_mcp_server/utils/dependency_analysis_tools.py index 0f10e99..feb55a3 100644 --- a/doris_mcp_server/utils/dependency_analysis_tools.py +++ b/doris_mcp_server/utils/dependency_analysis_tools.py @@ -27,6 +27,13 @@ from collections import defaultdict, deque from .db import DorisConnectionManager from .logger import get_logger +from .sql_security_utils import ( + SQLSecurityError, + validate_identifier, + quote_identifier, + build_table_reference, + get_auth_context +) logger = get_logger(__name__) @@ -122,10 +129,19 @@ class DependencyAnalysisTools: async def _get_tables_metadata(self, connection, catalog_name: Optional[str], db_name: Optional[str], include_views: bool) -> List[Dict]: """Get metadata for all tables and views""" try: - # Build conditions for query + # Build conditions for query with parameterized values where_conditions = [] + params = [] + if db_name: - where_conditions.append(f"table_schema = '{db_name}'") + # SECURITY FIX: Validate identifier and use parameterized query + try: + validate_identifier(db_name, "database name") + except SQLSecurityError as e: + logger.warning(f"Invalid database name rejected: {e}") + return [] + where_conditions.append("table_schema = %s") + params.append(db_name) else: where_conditions.append("table_schema = DATABASE()") @@ -148,9 +164,18 @@ class DependencyAnalysisTools: ORDER BY table_schema, table_name """ - result = await connection.execute(metadata_sql) + # SECURITY FIX: Get auth_context and pass to execute for security validation + auth_context = get_auth_context() + result = await connection.execute( + metadata_sql, + params=tuple(params) if params else None, + auth_context=auth_context + ) return result.data if result.data else [] + except SQLSecurityError as e: + logger.warning(f"Security validation failed in _get_tables_metadata: {str(e)}") + return [] except Exception as e: logger.warning(f"Failed to get tables metadata: {str(e)}") return [] @@ -186,17 +211,31 @@ class DependencyAnalysisTools: async def _analyze_view_dependencies(self, connection, dependency_graph: Dict, tables_metadata: List[Dict]) -> None: """Analyze view definitions to extract table dependencies""" + # Get auth_context once for all operations in this method + auth_context = get_auth_context() + try: for table in tables_metadata: if table["table_type"] == "VIEW": table_name = table["table_name"] schema_name = table.get("schema_name", "") - # Get view definition - view_def_sql = f"SHOW CREATE VIEW {schema_name}.{table_name}" if schema_name else f"SHOW CREATE VIEW {table_name}" + # SECURITY FIX: Validate identifiers before using in SQL + try: + validate_identifier(table_name, "table name") + if schema_name: + validate_identifier(schema_name, "schema name") + except SQLSecurityError as e: + logger.warning(f"Invalid identifier rejected in view analysis: {e}") + continue + + # Build safe view reference using quoted identifiers + view_ref = build_table_reference(table_name, schema_name) if schema_name else quote_identifier(table_name, "table name") + view_def_sql = f"SHOW CREATE VIEW {view_ref}" try: - result = await connection.execute(view_def_sql) + # SECURITY FIX: Pass auth_context to execute + result = await connection.execute(view_def_sql, auth_context=auth_context) if result.data and len(result.data) > 0: # Extract view definition from result view_definition = "" @@ -235,6 +274,9 @@ class DependencyAnalysisTools: async def _analyze_runtime_dependencies(self, connection, dependency_graph: Dict, analysis_depth: int) -> None: """Analyze audit logs to discover runtime table dependencies""" + # Get auth_context for security validation + auth_context = get_auth_context() + try: # Get recent SQL statements from audit logs audit_sql = """ @@ -252,7 +294,8 @@ class DependencyAnalysisTools: LIMIT 1000 """ - result = await connection.execute(audit_sql) + # SECURITY FIX: Pass auth_context to execute + result = await connection.execute(audit_sql, auth_context=auth_context) if result.data: for row in result.data: @@ -274,6 +317,9 @@ class DependencyAnalysisTools: async def _analyze_foreign_key_dependencies(self, connection, dependency_graph: Dict, tables_metadata: List[Dict]) -> None: """Analyze foreign key constraints for explicit dependencies""" + # Get auth_context for security validation + auth_context = get_auth_context() + try: # Get foreign key information fk_sql = """ @@ -288,7 +334,8 @@ class DependencyAnalysisTools: WHERE REFERENCED_TABLE_NAME IS NOT NULL """ - result = await connection.execute(fk_sql) + # SECURITY FIX: Pass auth_context to execute + result = await connection.execute(fk_sql, auth_context=auth_context) if result.data: for row in result.data: diff --git a/doris_mcp_server/utils/monitoring_tools.py b/doris_mcp_server/utils/monitoring_tools.py index 40e41b0..d9055c2 100644 --- a/doris_mcp_server/utils/monitoring_tools.py +++ b/doris_mcp_server/utils/monitoring_tools.py @@ -28,6 +28,7 @@ from datetime import datetime from .db import DorisConnectionManager from .logger import get_logger +from .sql_security_utils import get_auth_context logger = get_logger(__name__) @@ -713,7 +714,8 @@ class DorisMonitoringTools: # Fallback to SHOW BACKENDS if no BE hosts configured logger.info("No BE hosts configured, using SHOW BACKENDS to discover BE nodes") connection = await self.connection_manager.get_connection("query") - result = await connection.execute("SHOW BACKENDS") + auth_context = get_auth_context() + result = await connection.execute("SHOW BACKENDS", auth_context=auth_context) be_nodes = [] for row in result.data: diff --git a/doris_mcp_server/utils/performance_analytics_tools.py b/doris_mcp_server/utils/performance_analytics_tools.py index 82e1c2d..a9e4a2a 100644 --- a/doris_mcp_server/utils/performance_analytics_tools.py +++ b/doris_mcp_server/utils/performance_analytics_tools.py @@ -27,6 +27,13 @@ from collections import defaultdict, Counter from .db import DorisConnectionManager from .logger import get_logger +from .sql_security_utils import ( + SQLSecurityError, + validate_identifier, + quote_identifier, + build_table_reference, + get_auth_context +) logger = get_logger(__name__) @@ -229,7 +236,8 @@ class PerformanceAnalyticsTools: ORDER BY query_date """ - result = await connection.execute(query_volume_sql) + auth_context = get_auth_context() + result = await connection.execute(query_volume_sql, auth_context=auth_context) daily_data = result.data if result.data else [] if not daily_data: @@ -304,7 +312,8 @@ class PerformanceAnalyticsTools: ORDER BY activity_date """ - result = await connection.execute(user_activity_sql) + auth_context = get_auth_context() + result = await connection.execute(user_activity_sql, auth_context=auth_context) daily_data = result.data if result.data else [] if not daily_data: @@ -383,7 +392,8 @@ class PerformanceAnalyticsTools: LIMIT 5000 """ - result = await connection.execute(slow_query_sql) + auth_context = get_auth_context() + result = await connection.execute(slow_query_sql, auth_context=auth_context) return result.data if result.data else [] except Exception as e: @@ -705,7 +715,8 @@ class PerformanceAnalyticsTools: ORDER BY size_mb DESC """ - db_result = await connection.execute(db_sizes_sql) + auth_context = get_auth_context() + db_result = await connection.execute(db_sizes_sql, auth_context=auth_context) if not db_result.data: logger.warning("No database size information available") @@ -805,7 +816,16 @@ class PerformanceAnalyticsTools: async def _get_database_table_details_from_schema(self, connection, db_name: str) -> List[Dict]: """Get table details for a specific database using information_schema""" try: - table_details_sql = f""" + # SECURITY FIX: Validate db_name and use parameterized query + auth_context = get_auth_context() + + try: + validate_identifier(db_name, "database name") + except SQLSecurityError as e: + logger.warning(f"Invalid database name rejected: {e}") + return [] + + table_details_sql = """ SELECT TABLE_SCHEMA as schema_name, TABLE_NAME as table_name, @@ -814,13 +834,13 @@ class PerformanceAnalyticsTools: CREATE_TIME as create_time, UPDATE_TIME as update_time FROM information_schema.tables - WHERE TABLE_SCHEMA = '{db_name}' + WHERE TABLE_SCHEMA = %s AND TABLE_TYPE = 'BASE TABLE' AND (COALESCE(DATA_LENGTH, 0) + COALESCE(INDEX_LENGTH, 0)) > 0 ORDER BY size_mb DESC """ - result = await connection.execute(table_details_sql) + result = await connection.execute(table_details_sql, params=(db_name,), auth_context=auth_context) if not result.data: logger.warning(f"No table details found for database {db_name}") @@ -867,6 +887,13 @@ class PerformanceAnalyticsTools: async def _get_database_table_details(self, connection, db_name: str) -> List[Dict]: """Get table details for a specific database using session-consistent queries""" try: + # SECURITY FIX: Validate db_name before using in SQL + try: + validate_identifier(db_name, "database name") + except SQLSecurityError as e: + logger.warning(f"Invalid database name rejected: {e}") + return [] + # Method 1: Try to use session-consistent approach with raw connection # This requires accessing the underlying connection to maintain session state @@ -877,8 +904,9 @@ class PerformanceAnalyticsTools: # Use raw connection to maintain session state cursor = await raw_conn.cursor() try: - # Execute USE and SHOW DATA in the same session - await cursor.execute(f"USE {db_name}") + # SECURITY FIX: Use quoted identifier for USE statement + quoted_db = quote_identifier(db_name, "database name") + await cursor.execute(f"USE {quoted_db}") await cursor.execute("SHOW DATA") result = await cursor.fetchall() @@ -922,9 +950,19 @@ class PerformanceAnalyticsTools: async def _get_database_table_details_fallback(self, connection, db_name: str) -> List[Dict]: """Fallback method to get table details using individual queries""" try: - # Get all tables in the database - tables_sql = f"SHOW TABLES FROM {db_name}" - tables_result = await connection.execute(tables_sql) + # SECURITY FIX: Validate db_name and get auth_context + auth_context = get_auth_context() + + try: + validate_identifier(db_name, "database name") + except SQLSecurityError as e: + logger.warning(f"Invalid database name rejected: {e}") + return [] + + # Get all tables in the database using quoted identifier + quoted_db = quote_identifier(db_name, "database name") + tables_sql = f"SHOW TABLES FROM {quoted_db}" + tables_result = await connection.execute(tables_sql, auth_context=auth_context) if not tables_result.data: return [] @@ -934,9 +972,11 @@ class PerformanceAnalyticsTools: table_name = table_row.get(f"Tables_in_{db_name}", "") or table_row.get("table_name", "") if table_name: try: - # Use SHOW DATA FROM db.table for each table - data_sql = f"SHOW DATA FROM {db_name}.{table_name}" - data_result = await connection.execute(data_sql) + # SECURITY FIX: Validate table_name and use safe reference + validate_identifier(table_name, "table name") + safe_table_ref = build_table_reference(table_name, db_name) + data_sql = f"SHOW DATA FROM {safe_table_ref}" + data_result = await connection.execute(data_sql, auth_context=auth_context) if data_result.data: for row in data_result.data: @@ -1036,6 +1076,7 @@ class PerformanceAnalyticsTools: async def _get_all_tables_info(self, connection) -> List[Dict]: """Get basic information for all tables (fallback method)""" try: + auth_context = get_auth_context() tables_sql = """ SELECT table_schema, @@ -1053,7 +1094,7 @@ class PerformanceAnalyticsTools: ORDER BY (data_length + index_length) DESC """ - result = await connection.execute(tables_sql) + result = await connection.execute(tables_sql, auth_context=auth_context) return result.data if result.data else [] except Exception as e: @@ -1120,23 +1161,37 @@ class PerformanceAnalyticsTools: async def _get_current_table_size(self, connection, full_table_name: str) -> Optional[Dict]: """Get current table size""" try: - # Try to query table size directly - size_sql = f""" + # SECURITY FIX: Get auth_context and use parameterized query + auth_context = get_auth_context() + + # Extract table name for parameterized query + table_name_only = full_table_name.split('.')[-1] if '.' in full_table_name else full_table_name + + # Validate identifiers + try: + validate_identifier(table_name_only, "table name") + except SQLSecurityError as e: + logger.warning(f"Invalid table name rejected: {e}") + return None + + # Use parameterized query for safety + size_sql = """ SELECT COALESCE(ROUND((COALESCE(data_length, 0) + COALESCE(index_length, 0)) / 1024 / 1024, 2), 0) as size_mb, COALESCE(table_rows, 0) as `rows` FROM information_schema.tables - WHERE CONCAT(table_schema, '.', table_name) = '{full_table_name}' - OR table_name = '{full_table_name.split('.')[-1]}' + WHERE CONCAT(table_schema, '.', table_name) = %s + OR table_name = %s """ - result = await connection.execute(size_sql) + result = await connection.execute(size_sql, params=(full_table_name, table_name_only), auth_context=auth_context) if result.data and result.data[0]: return result.data[0] # If information_schema has no data, try COUNT query + # full_table_name should already be validated by caller using build_table_reference count_sql = f"SELECT COUNT(*) as rows FROM {full_table_name}" - count_result = await connection.execute(count_sql) + count_result = await connection.execute(count_sql, auth_context=auth_context) if count_result.data: return { "size_mb": 0, # Cannot get exact size @@ -1145,6 +1200,9 @@ class PerformanceAnalyticsTools: return None + except SQLSecurityError as e: + logger.warning(f"Security validation failed for {full_table_name}: {str(e)}") + return None except Exception as e: logger.warning(f"Failed to get current size for {full_table_name}: {str(e)}") return None @@ -1154,8 +1212,19 @@ class PerformanceAnalyticsTools: ) -> List[Dict]: """Get historical growth data based on partitions""" try: - # Query partition information - partition_sql = f""" + # SECURITY FIX: Validate identifiers and use parameterized query + auth_context = get_auth_context() + + try: + validate_identifier(table_name, "table name") + if schema_name: + validate_identifier(schema_name, "schema name") + except SQLSecurityError as e: + logger.warning(f"Invalid identifier rejected: {e}") + return [] + + # Use parameterized query for safety + partition_sql = """ SELECT partition_name, partition_description, @@ -1163,15 +1232,19 @@ class PerformanceAnalyticsTools: data_length, create_time FROM information_schema.partitions - WHERE table_schema = '{schema_name or ""}' - AND table_name = '{table_name}' + WHERE table_schema = %s + AND table_name = %s AND partition_name IS NOT NULL AND create_time IS NOT NULL - AND create_time >= DATE_SUB(NOW(), INTERVAL {days} DAY) + AND create_time >= DATE_SUB(NOW(), INTERVAL %s DAY) ORDER BY create_time DESC """ - result = await connection.execute(partition_sql) + result = await connection.execute( + partition_sql, + params=(schema_name or "", table_name, days), + auth_context=auth_context + ) if not result.data: return [] @@ -1210,6 +1283,9 @@ class PerformanceAnalyticsTools: ) -> List[Dict]: """Get historical growth data based on timestamp fields""" try: + # SECURITY FIX: Get auth_context + auth_context = get_auth_context() + # Find possible timestamp fields timestamp_columns = await self._find_timestamp_columns(connection, table_name, schema_name) if not timestamp_columns: @@ -1218,20 +1294,29 @@ class PerformanceAnalyticsTools: # Use best timestamp field for analysis time_column = timestamp_columns[0] - # Aggregate data by date + # SECURITY FIX: Validate time_column before using in SQL + try: + validate_identifier(time_column, "column name") + except SQLSecurityError as e: + logger.warning(f"Invalid column name rejected: {e}") + return [] + + quoted_time_column = quote_identifier(time_column, "column name") + + # Aggregate data by date (full_table_name should be validated by caller) growth_sql = f""" SELECT - DATE({time_column}) as date, + DATE({quoted_time_column}) as date, COUNT(*) as daily_records, COUNT(*) / SUM(COUNT(*)) OVER() * 100 as percentage FROM {full_table_name} - WHERE {time_column} >= DATE_SUB(NOW(), INTERVAL {days} DAY) - AND {time_column} IS NOT NULL - GROUP BY DATE({time_column}) + WHERE {quoted_time_column} >= DATE_SUB(NOW(), INTERVAL %s DAY) + AND {quoted_time_column} IS NOT NULL + GROUP BY DATE({quoted_time_column}) ORDER BY date DESC """ - result = await connection.execute(growth_sql) + result = await connection.execute(growth_sql, params=(days,), auth_context=auth_context) if not result.data: return [] @@ -1257,11 +1342,22 @@ class PerformanceAnalyticsTools: async def _find_timestamp_columns(self, connection, table_name: str, schema_name: str) -> List[str]: """Find timestamp fields in table""" try: - timestamp_sql = f""" + # SECURITY FIX: Validate identifiers and use parameterized query + auth_context = get_auth_context() + + try: + validate_identifier(table_name, "table name") + if schema_name: + validate_identifier(schema_name, "schema name") + except SQLSecurityError as e: + logger.warning(f"Invalid identifier rejected: {e}") + return [] + + timestamp_sql = """ SELECT column_name, data_type FROM information_schema.columns - WHERE table_schema = '{schema_name or ""}' - AND table_name = '{table_name}' + WHERE table_schema = %s + AND table_name = %s AND ( data_type IN ('datetime', 'timestamp', 'date') OR column_name REGEXP '(create|insert|update|modify).*time' @@ -1278,9 +1374,16 @@ class PerformanceAnalyticsTools: END """ - result = await connection.execute(timestamp_sql) + result = await connection.execute( + timestamp_sql, + params=(schema_name or "", table_name), + auth_context=auth_context + ) return [row["column_name"] for row in result.data] if result.data else [] + except SQLSecurityError as e: + logger.warning(f"Security validation failed: {str(e)}") + return [] except Exception as e: logger.warning(f"Failed to find timestamp columns: {str(e)}") return [] @@ -1290,8 +1393,22 @@ class PerformanceAnalyticsTools: ) -> List[Dict]: """Estimate growth data based on audit logs""" try: + # SECURITY FIX: Validate table_name and use parameterized query + auth_context = get_auth_context() + + try: + validate_identifier(table_name.split(".")[-1], "table name") + except SQLSecurityError as e: + logger.warning(f"Invalid table name rejected: {e}") + return [] + + # Extract just the table name for LIKE pattern + table_name_only = table_name.split(".")[-1] + like_pattern_full = f"%{table_name}%" + like_pattern_short = f"%{table_name_only}%" + # Analyze operation history for this table - audit_sql = f""" + audit_sql = """ SELECT DATE(`time`) as operation_date, COUNT(*) as operation_count, @@ -1299,17 +1416,21 @@ class PerformanceAnalyticsTools: SUM(CASE WHEN stmt LIKE 'UPDATE%' THEN 1 ELSE 0 END) as update_count, SUM(CASE WHEN stmt LIKE 'DELETE%' THEN 1 ELSE 0 END) as delete_count FROM internal.__internal_schema.audit_log - WHERE `time` >= DATE_SUB(NOW(), INTERVAL {days} DAY) + WHERE `time` >= DATE_SUB(NOW(), INTERVAL %s DAY) AND stmt IS NOT NULL AND ( - stmt LIKE '%{table_name}%' - OR stmt LIKE '%{table_name.split(".")[-1]}%' + stmt LIKE %s + OR stmt LIKE %s ) GROUP BY DATE(`time`) ORDER BY operation_date DESC """ - result = await connection.execute(audit_sql) + result = await connection.execute( + audit_sql, + params=(days, like_pattern_full, like_pattern_short), + auth_context=auth_context + ) if not result.data: return [] diff --git a/doris_mcp_server/utils/query_executor.py b/doris_mcp_server/utils/query_executor.py index e342a95..cf2fc28 100644 --- a/doris_mcp_server/utils/query_executor.py +++ b/doris_mcp_server/utils/query_executor.py @@ -35,6 +35,7 @@ from decimal import Decimal from .db import DorisConnectionManager, QueryResult from .logger import get_logger +from .sql_security_utils import get_auth_context @dataclass @@ -497,7 +498,8 @@ class DorisQueryExecutor: explain_sql = f"EXPLAIN {sql}" connection = await self.connection_manager.get_connection(session_id) - result = await connection.execute(explain_sql) + auth_context = get_auth_context() + result = await connection.execute(explain_sql, auth_context=auth_context) return { "query": sql, diff --git a/doris_mcp_server/utils/schema_extractor.py b/doris_mcp_server/utils/schema_extractor.py index 78eb524..f46e659 100644 --- a/doris_mcp_server/utils/schema_extractor.py +++ b/doris_mcp_server/utils/schema_extractor.py @@ -32,6 +32,11 @@ from datetime import datetime, timedelta # Import unified logging configuration from .logger import get_logger +from .sql_security_utils import ( + SQLSecurityError, + validate_identifier, + quote_identifier +) # Configure logging logger = get_logger(__name__) @@ -431,6 +436,16 @@ class MetadataExtractor: logger.warning("Database name not specified") return {} + # SECURITY FIX: Validate identifiers to prevent SQL injection + try: + validate_identifier(table_name, "table name") + validate_identifier(db_name, "database name") + if effective_catalog: + validate_identifier(effective_catalog, "catalog name") + except SQLSecurityError as e: + logger.warning(f"Invalid identifier rejected in get_table_schema: {e}") + return {} + cache_key = f"schema_{effective_catalog or 'default'}_{db_name}_{table_name}" if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl: return self.metadata_cache[cache_key] @@ -536,6 +551,16 @@ class MetadataExtractor: logger.warning("Database name not specified") return "" + # SECURITY FIX: Validate identifiers + try: + validate_identifier(table_name, "table name") + validate_identifier(db_name, "database name") + if effective_catalog: + validate_identifier(effective_catalog, "catalog name") + except SQLSecurityError as e: + logger.warning(f"Invalid identifier rejected: {e}") + return "" + cache_key = f"table_comment_{effective_catalog or 'default'}_{db_name}_{table_name}" if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl: return self.metadata_cache[cache_key] @@ -587,6 +612,16 @@ class MetadataExtractor: logger.warning("Database name not specified") return {} + # SECURITY FIX: Validate identifiers + try: + validate_identifier(table_name, "table name") + validate_identifier(db_name, "database name") + if effective_catalog: + validate_identifier(effective_catalog, "catalog name") + except SQLSecurityError as e: + logger.warning(f"Invalid identifier rejected: {e}") + return {} + cache_key = f"column_comments_{effective_catalog or 'default'}_{db_name}_{table_name}" if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl: return self.metadata_cache[cache_key] @@ -643,17 +678,30 @@ class MetadataExtractor: logger.error("Database name not specified") return [] + # SECURITY FIX: Validate identifiers + try: + validate_identifier(table_name, "table name") + validate_identifier(db_name, "database name") + if effective_catalog: + validate_identifier(effective_catalog, "catalog name") + except SQLSecurityError as e: + logger.warning(f"Invalid identifier rejected: {e}") + return [] + cache_key = f"indexes_{effective_catalog or 'default'}_{db_name}_{table_name}" if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl: return self.metadata_cache[cache_key] try: - # Build query with catalog prefix if specified + # Build query with catalog prefix if specified (identifiers already validated) + safe_table = quote_identifier(table_name, "table name") + safe_db = quote_identifier(db_name, "database name") if effective_catalog: - query = f"SHOW INDEX FROM `{effective_catalog}`.`{db_name}`.`{table_name}`" + safe_catalog = quote_identifier(effective_catalog, "catalog name") + query = f"SHOW INDEX FROM {safe_catalog}.{safe_db}.{safe_table}" logger.info(f"Using three-part naming for index query: {query}") else: - query = f"SHOW INDEX FROM `{db_name}`.`{table_name}`" + query = f"SHOW INDEX FROM {safe_db}.{safe_table}" try: # NOTE: Deprecated sync path retained for compatibility; use async variant instead. @@ -1188,12 +1236,28 @@ class MetadataExtractor: try: # Use async query method effective_catalog = catalog_name or self.catalog_name + effective_db = db_name or self.db_name + + # SECURITY FIX: Validate identifiers + try: + validate_identifier(table_name, "table name") + if effective_db: + validate_identifier(effective_db, "database name") + if effective_catalog and effective_catalog != "internal": + validate_identifier(effective_catalog, "catalog name") + except SQLSecurityError as e: + logger.warning(f"Invalid identifier rejected: {e}") + return [] + + # Build query statement using safe identifiers + safe_table = quote_identifier(table_name, "table name") + safe_db = quote_identifier(effective_db, "database name") if effective_db else None - # Build query statement if effective_catalog and effective_catalog != "internal": - query = f"DESCRIBE `{effective_catalog}`.`{db_name or self.db_name}`.`{table_name}`" + safe_catalog = quote_identifier(effective_catalog, "catalog name") + query = f"DESCRIBE {safe_catalog}.{safe_db}.{safe_table}" else: - query = f"DESCRIBE `{db_name or self.db_name}`.`{table_name}`" + query = f"DESCRIBE {safe_db}.{safe_table}" # Execute async query result = await self._execute_query_async(query, db_name) @@ -1226,8 +1290,15 @@ class MetadataExtractor: try: effective_catalog = catalog_name or self.catalog_name + # SECURITY FIX: Validate catalog name if provided if effective_catalog and effective_catalog != "internal": - query = f"SHOW DATABASES FROM `{effective_catalog}`" + try: + validate_identifier(effective_catalog, "catalog name") + except SQLSecurityError as e: + logger.warning(f"Invalid catalog name rejected: {e}") + return [] + safe_catalog = quote_identifier(effective_catalog, "catalog name") + query = f"SHOW DATABASES FROM {safe_catalog}" else: query = "SHOW DATABASES" @@ -1257,10 +1328,23 @@ class MetadataExtractor: effective_catalog = catalog_name or self.catalog_name effective_db = db_name or self.db_name + # SECURITY FIX: Validate identifiers + try: + if effective_db: + validate_identifier(effective_db, "database name") + if effective_catalog and effective_catalog != "internal": + validate_identifier(effective_catalog, "catalog name") + except SQLSecurityError as e: + logger.warning(f"Invalid identifier rejected: {e}") + return [] + + safe_db = quote_identifier(effective_db, "database name") if effective_db else None + if effective_catalog and effective_catalog != "internal": - query = f"SHOW TABLES FROM `{effective_catalog}`.`{effective_db}`" + safe_catalog = quote_identifier(effective_catalog, "catalog name") + query = f"SHOW TABLES FROM {safe_catalog}.{safe_db}" else: - query = f"SHOW TABLES FROM `{effective_db}`" + query = f"SHOW TABLES FROM {safe_db}" result = await self._execute_query_async(query, effective_db) @@ -1319,6 +1403,15 @@ class MetadataExtractor: effective_db = db_name or self.db_name effective_catalog = catalog_name or self.catalog_name + # SECURITY FIX: Validate identifiers + try: + validate_identifier(table_name, "table name") + if effective_db: + validate_identifier(effective_db, "database name") + except SQLSecurityError as e: + logger.warning(f"Invalid identifier rejected: {e}") + return "" + query = f""" SELECT TABLE_COMMENT @@ -1343,6 +1436,15 @@ class MetadataExtractor: effective_db = db_name or self.db_name effective_catalog = catalog_name or self.catalog_name + # SECURITY FIX: Validate identifiers + try: + validate_identifier(table_name, "table name") + if effective_db: + validate_identifier(effective_db, "database name") + except SQLSecurityError as e: + logger.warning(f"Invalid identifier rejected: {e}") + return {} + query = f""" SELECT COLUMN_NAME, @@ -1373,12 +1475,27 @@ class MetadataExtractor: effective_db = db_name or self.db_name effective_catalog = catalog_name or self.catalog_name - # Build query with catalog prefix if specified + # SECURITY FIX: Validate identifiers + try: + validate_identifier(table_name, "table name") + if effective_db: + validate_identifier(effective_db, "database name") + if effective_catalog: + validate_identifier(effective_catalog, "catalog name") + except SQLSecurityError as e: + logger.warning(f"Invalid identifier rejected: {e}") + return [] + + # Build query with catalog prefix if specified (using safe identifiers) + safe_table = quote_identifier(table_name, "table name") + safe_db = quote_identifier(effective_db, "database name") if effective_db else None + if effective_catalog: - query = f"SHOW INDEX FROM `{effective_catalog}`.`{effective_db}`.`{table_name}`" + safe_catalog = quote_identifier(effective_catalog, "catalog name") + query = f"SHOW INDEX FROM {safe_catalog}.{safe_db}.{safe_table}" logger.info(f"Using three-part naming for async index query: {query}") else: - query = f"SHOW INDEX FROM `{effective_db}`.`{table_name}`" + query = f"SHOW INDEX FROM {safe_db}.{safe_table}" rows = await self._execute_query_async(query, effective_db) indexes: List[Dict[str, Any]] = [] @@ -1475,21 +1592,45 @@ class MetadataExtractor: return self._format_response(success=False, error="No SQL statement provided", message="Please provide SQL statement to execute") # FIX for Issue #62 Bug 3: Build context switching SQL if db_name or catalog_name is specified + # SECURITY FIX: Validate catalog_name and db_name to prevent SQL injection final_sql = sql if catalog_name or db_name: context_statements = [] + # Validate and sanitize catalog_name if catalog_name: - # Switch to specified catalog - context_statements.append(f"USE CATALOG `{catalog_name}`") + try: + validate_identifier(catalog_name, "catalog name") + except SQLSecurityError as e: + logger.warning(f"Invalid catalog name rejected: {e}") + return self._format_response( + success=False, + error=f"Invalid catalog name: {catalog_name}", + message="Catalog name contains invalid characters" + ) + # Use quote_identifier to safely escape the catalog name + safe_catalog = quote_identifier(catalog_name, "catalog name") + context_statements.append(f"USE CATALOG {safe_catalog}") logger.debug(f"Switching to catalog: {catalog_name}") + # Validate and sanitize db_name if db_name: - # Switch to specified database + try: + validate_identifier(db_name, "database name") + except SQLSecurityError as e: + logger.warning(f"Invalid database name rejected: {e}") + return self._format_response( + success=False, + error=f"Invalid database name: {db_name}", + message="Database name contains invalid characters" + ) + # Use quote_identifier to safely escape the database name + safe_db = quote_identifier(db_name, "database name") if catalog_name: - context_statements.append(f"USE `{catalog_name}`.`{db_name}`") + safe_catalog = quote_identifier(catalog_name, "catalog name") + context_statements.append(f"USE {safe_catalog}.{safe_db}") else: - context_statements.append(f"USE `{db_name}`") + context_statements.append(f"USE {safe_db}") logger.debug(f"Switching to database: {db_name}") # Combine context switching with original SQL @@ -1551,6 +1692,36 @@ class MetadataExtractor: if not table_name: return self._format_response(success=False, error="Missing table_name parameter") + # SECURITY: Validate identifiers before processing + try: + validate_identifier(table_name, "table name") + except SQLSecurityError as e: + return self._format_response( + success=False, + error=f"Invalid table name: {table_name}", + message="Table name contains invalid characters" + ) + + if db_name: + try: + validate_identifier(db_name, "database name") + except SQLSecurityError as e: + return self._format_response( + success=False, + error=f"Invalid database name: {db_name}", + message="Database name contains invalid characters" + ) + + if catalog_name and catalog_name != "internal": + try: + validate_identifier(catalog_name, "catalog name") + except SQLSecurityError as e: + return self._format_response( + success=False, + error=f"Invalid catalog name: {catalog_name}", + message="Catalog name contains invalid characters" + ) + try: schema = await self.get_table_schema_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name) @@ -1574,6 +1745,27 @@ class MetadataExtractor: """Get list of all table names in specified database - MCP interface""" logger.info(f"Getting database table list: DB: {db_name}, Catalog: {catalog_name}") + # SECURITY: Validate identifiers + if db_name: + try: + validate_identifier(db_name, "database name") + except SQLSecurityError as e: + return self._format_response( + success=False, + error=f"Invalid database name: {db_name}", + message="Database name contains invalid characters" + ) + + if catalog_name and catalog_name != "internal": + try: + validate_identifier(catalog_name, "catalog name") + except SQLSecurityError as e: + return self._format_response( + success=False, + error=f"Invalid catalog name: {catalog_name}", + message="Catalog name contains invalid characters" + ) + try: tables = await self.get_database_tables_async(db_name=db_name, catalog_name=catalog_name) return self._format_response(success=True, result=tables) @@ -1604,6 +1796,36 @@ class MetadataExtractor: if not table_name: return self._format_response(success=False, error="Missing table_name parameter") + # SECURITY: Validate identifiers + try: + validate_identifier(table_name, "table name") + except SQLSecurityError as e: + return self._format_response( + success=False, + error=f"Invalid table name: {table_name}", + message="Table name contains invalid characters" + ) + + if db_name: + try: + validate_identifier(db_name, "database name") + except SQLSecurityError as e: + return self._format_response( + success=False, + error=f"Invalid database name: {db_name}", + message="Database name contains invalid characters" + ) + + if catalog_name and catalog_name != "internal": + try: + validate_identifier(catalog_name, "catalog name") + except SQLSecurityError as e: + return self._format_response( + success=False, + error=f"Invalid catalog name: {catalog_name}", + message="Catalog name contains invalid characters" + ) + try: comment = await self.get_table_comment_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name) return self._format_response(success=True, result=comment) @@ -1623,6 +1845,36 @@ class MetadataExtractor: if not table_name: return self._format_response(success=False, error="Missing table_name parameter") + # SECURITY: Validate identifiers + try: + validate_identifier(table_name, "table name") + except SQLSecurityError as e: + return self._format_response( + success=False, + error=f"Invalid table name: {table_name}", + message="Table name contains invalid characters" + ) + + if db_name: + try: + validate_identifier(db_name, "database name") + except SQLSecurityError as e: + return self._format_response( + success=False, + error=f"Invalid database name: {db_name}", + message="Database name contains invalid characters" + ) + + if catalog_name and catalog_name != "internal": + try: + validate_identifier(catalog_name, "catalog name") + except SQLSecurityError as e: + return self._format_response( + success=False, + error=f"Invalid catalog name: {catalog_name}", + message="Catalog name contains invalid characters" + ) + try: comments = await self.get_column_comments_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name) return self._format_response(success=True, result=comments) @@ -1642,6 +1894,36 @@ class MetadataExtractor: if not table_name: return self._format_response(success=False, error="Missing table_name parameter") + # SECURITY: Validate identifiers + try: + validate_identifier(table_name, "table name") + except SQLSecurityError as e: + return self._format_response( + success=False, + error=f"Invalid table name: {table_name}", + message="Table name contains invalid characters" + ) + + if db_name: + try: + validate_identifier(db_name, "database name") + except SQLSecurityError as e: + return self._format_response( + success=False, + error=f"Invalid database name: {db_name}", + message="Database name contains invalid characters" + ) + + if catalog_name and catalog_name != "internal": + try: + validate_identifier(catalog_name, "catalog name") + except SQLSecurityError as e: + return self._format_response( + success=False, + error=f"Invalid catalog name: {catalog_name}", + message="Catalog name contains invalid characters" + ) + try: indexes = await self.get_table_indexes_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name) return self._format_response(success=True, result=indexes) diff --git a/doris_mcp_server/utils/security.py b/doris_mcp_server/utils/security.py index 95471e1..d1efc43 100644 --- a/doris_mcp_server/utils/security.py +++ b/doris_mcp_server/utils/security.py @@ -901,30 +901,50 @@ class SQLSecurityValidator: if not self.enable_security_check: self.logger.debug("SQL security check is disabled, allowing all queries") return ValidationResult(is_valid=True) - + try: - # Parse SQL statement - parsed = sqlparse.parse(sql)[0] + # SECURITY FIX: Parse ALL SQL statements, not just the first one + # This prevents bypassing security checks by injecting additional statements + all_statements = sqlparse.parse(sql) - # Check blocked operations first (more specific) - keyword_result = await self._check_blocked_keywords(parsed) - if not keyword_result.is_valid: - return keyword_result + if not all_statements: + return ValidationResult( + is_valid=False, + error_message="Empty or invalid SQL statement", + risk_level="medium" + ) - # Check SQL injection risks - injection_result = await self._check_sql_injection(sql, parsed) - if not injection_result.is_valid: - return injection_result + # SECURITY FIX: Validate each statement individually + for idx, parsed in enumerate(all_statements): + # Skip empty statements (e.g., from trailing semicolons) + if not parsed.tokens or str(parsed).strip() == '': + continue - # Check query complexity - complexity_result = await self._check_query_complexity(parsed) - if not complexity_result.is_valid: - return complexity_result + self.logger.debug(f"Validating SQL statement {idx + 1}/{len(all_statements)}: {str(parsed)[:100]}...") - # Check table access permissions - table_result = await self._check_table_access(parsed, auth_context) - if not table_result.is_valid: - return table_result + # Check blocked operations first (more specific) + keyword_result = await self._check_blocked_keywords(parsed) + if not keyword_result.is_valid: + keyword_result.error_message = f"Statement {idx + 1}: {keyword_result.error_message}" + return keyword_result + + # Check SQL injection risks + injection_result = await self._check_sql_injection(sql, parsed) + if not injection_result.is_valid: + injection_result.error_message = f"Statement {idx + 1}: {injection_result.error_message}" + return injection_result + + # Check query complexity + complexity_result = await self._check_query_complexity(parsed) + if not complexity_result.is_valid: + complexity_result.error_message = f"Statement {idx + 1}: {complexity_result.error_message}" + return complexity_result + + # Check table access permissions + table_result = await self._check_table_access(parsed, auth_context) + if not table_result.is_valid: + table_result.error_message = f"Statement {idx + 1}: {table_result.error_message}" + return table_result return ValidationResult(is_valid=True) @@ -1134,6 +1154,10 @@ class SQLSecurityValidator: self, parsed: Statement, auth_context: AuthContext ) -> ValidationResult: """Check table access permissions""" + # If no auth_context, skip table access checks (rely on other security checks) + if auth_context is None: + return ValidationResult(is_valid=True) + # Extract table names from query tables = self._extract_table_names(parsed) diff --git a/doris_mcp_server/utils/security_analytics_tools.py b/doris_mcp_server/utils/security_analytics_tools.py index 74d3c54..25d24ca 100644 --- a/doris_mcp_server/utils/security_analytics_tools.py +++ b/doris_mcp_server/utils/security_analytics_tools.py @@ -26,6 +26,7 @@ from collections import Counter, defaultdict from .db import DorisConnectionManager from .logger import get_logger +from .sql_security_utils import get_auth_context logger = get_logger(__name__) @@ -192,7 +193,9 @@ class SecurityAnalyticsTools: LIMIT 10000 """ - result = await connection.execute(audit_sql) + # SECURITY FIX: Pass auth_context to execute + auth_context = get_auth_context() + result = await connection.execute(audit_sql, auth_context=auth_context) return result.data if result.data else [] except Exception as e: @@ -215,7 +218,8 @@ class SecurityAnalyticsTools: LIMIT 10000 """ - result = await connection.execute(simple_audit_sql) + auth_context = get_auth_context() + result = await connection.execute(simple_audit_sql, auth_context=auth_context) return result.data if result.data else [] except Exception as e2: @@ -498,7 +502,8 @@ class SecurityAnalyticsTools: FROM mysql.user """ - result = await connection.execute(roles_sql) + auth_context = get_auth_context() + result = await connection.execute(roles_sql, auth_context=auth_context) user_roles = defaultdict(list) if result.data: diff --git a/doris_mcp_server/utils/sql_security_utils.py b/doris_mcp_server/utils/sql_security_utils.py new file mode 100644 index 0000000..20d8c35 --- /dev/null +++ b/doris_mcp_server/utils/sql_security_utils.py @@ -0,0 +1,301 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +SQL Security Utilities Module + +Provides SQL identifier validation, escaping, and safe query building utilities +to prevent SQL injection attacks. +""" + +import re +from contextvars import ContextVar +from typing import Optional, Tuple, List, Any + +from .logger import get_logger + +logger = get_logger(__name__) + +# Context variable for auth_context (set by HTTP middleware) +auth_context_var: ContextVar = ContextVar('mcp_auth_context', default=None) + + +class SQLSecurityError(Exception): + """Exception raised for SQL security validation failures""" + pass + + +class SQLSecurityUtils: + """ + SQL Security Utilities for preventing SQL injection attacks. + + Provides: + - Identifier validation (database names, table names, column names) + - Safe identifier quoting with backticks + - Safe table reference building + - Auth context retrieval from context variables + """ + + # Valid SQL identifier pattern: letters, numbers, underscores + # Must start with letter or underscore, not a number + # Supports Unicode letters for international database/table names + IDENTIFIER_PATTERN = re.compile(r'^[a-zA-Z_\u4e00-\u9fff][a-zA-Z0-9_\u4e00-\u9fff]*$') + + # Maximum identifier length (MySQL/Doris standard) + MAX_IDENTIFIER_LENGTH = 64 + + # SQL reserved keywords that should be quoted + SQL_KEYWORDS = { + 'SELECT', 'FROM', 'WHERE', 'INSERT', 'UPDATE', 'DELETE', 'DROP', + 'CREATE', 'ALTER', 'TABLE', 'DATABASE', 'INDEX', 'VIEW', 'AND', + 'OR', 'NOT', 'NULL', 'TRUE', 'FALSE', 'IN', 'LIKE', 'BETWEEN', + 'JOIN', 'LEFT', 'RIGHT', 'INNER', 'OUTER', 'ON', 'AS', 'ORDER', + 'BY', 'GROUP', 'HAVING', 'LIMIT', 'OFFSET', 'UNION', 'ALL', + 'DISTINCT', 'INTO', 'VALUES', 'SET', 'DEFAULT', 'PRIMARY', 'KEY', + 'FOREIGN', 'REFERENCES', 'CHECK', 'UNIQUE', 'CONSTRAINT' + } + + @classmethod + def validate_identifier(cls, name: str, identifier_type: str = "identifier") -> str: + """ + Validate a SQL identifier (database name, table name, column name, etc.) + + Args: + name: The identifier to validate + identifier_type: Type description for error messages (e.g., "database name", "table name") + + Returns: + The validated identifier (unchanged if valid) + + Raises: + SQLSecurityError: If the identifier is invalid + """ + if not name: + raise SQLSecurityError(f"Empty {identifier_type} is not allowed") + + if not isinstance(name, str): + raise SQLSecurityError(f"Invalid {identifier_type}: must be a string, got {type(name).__name__}") + + # Strip whitespace + name = name.strip() + + if not name: + raise SQLSecurityError(f"Empty {identifier_type} is not allowed") + + # Check length + if len(name) > cls.MAX_IDENTIFIER_LENGTH: + raise SQLSecurityError( + f"Invalid {identifier_type}: '{name[:20]}...' exceeds maximum length of {cls.MAX_IDENTIFIER_LENGTH} characters" + ) + + # Check for dangerous characters that could be SQL injection + dangerous_chars = ["'", '"', ';', '--', '/*', '*/', '\\', '\x00'] + for char in dangerous_chars: + if char in name: + raise SQLSecurityError( + f"Invalid {identifier_type}: '{name}' contains forbidden character '{char}'" + ) + + # Validate pattern + if not cls.IDENTIFIER_PATTERN.match(name): + raise SQLSecurityError( + f"Invalid {identifier_type}: '{name}' contains invalid characters. " + f"Only letters, numbers, and underscores are allowed, and must start with a letter or underscore." + ) + + logger.debug(f"Validated {identifier_type}: {name}") + return name + + @classmethod + def quote_identifier(cls, name: str, identifier_type: str = "identifier") -> str: + """ + Safely quote a SQL identifier using backticks. + + Args: + name: The identifier to quote + identifier_type: Type description for error messages + + Returns: + The quoted identifier (e.g., `table_name`) + + Raises: + SQLSecurityError: If the identifier is invalid + """ + # First validate the identifier + validated_name = cls.validate_identifier(name, identifier_type) + + # Escape any backticks within the name (double them) + escaped_name = validated_name.replace('`', '``') + + return f"`{escaped_name}`" + + @classmethod + def build_table_reference( + cls, + table_name: str, + db_name: Optional[str] = None, + catalog_name: Optional[str] = None, + quote: bool = True + ) -> str: + """ + Build a safe, fully-qualified table reference. + + Args: + table_name: The table name (required) + db_name: The database name (optional) + catalog_name: The catalog name (optional) + quote: Whether to quote identifiers with backticks (default: True) + + Returns: + A safe table reference string (e.g., `catalog`.`db`.`table`) + + Raises: + SQLSecurityError: If any identifier is invalid + """ + parts = [] + + if catalog_name: + if quote: + parts.append(cls.quote_identifier(catalog_name, "catalog name")) + else: + parts.append(cls.validate_identifier(catalog_name, "catalog name")) + + if db_name: + if quote: + parts.append(cls.quote_identifier(db_name, "database name")) + else: + parts.append(cls.validate_identifier(db_name, "database name")) + + if quote: + parts.append(cls.quote_identifier(table_name, "table name")) + else: + parts.append(cls.validate_identifier(table_name, "table name")) + + return '.'.join(parts) + + @classmethod + def build_column_reference( + cls, + column_name: str, + table_name: Optional[str] = None, + quote: bool = True + ) -> str: + """ + Build a safe column reference. + + Args: + column_name: The column name (required) + table_name: The table name (optional, for qualified references) + quote: Whether to quote identifiers with backticks (default: True) + + Returns: + A safe column reference string (e.g., `table`.`column`) + + Raises: + SQLSecurityError: If any identifier is invalid + """ + parts = [] + + if table_name: + if quote: + parts.append(cls.quote_identifier(table_name, "table name")) + else: + parts.append(cls.validate_identifier(table_name, "table name")) + + if quote: + parts.append(cls.quote_identifier(column_name, "column name")) + else: + parts.append(cls.validate_identifier(column_name, "column name")) + + return '.'.join(parts) + + @classmethod + def validate_and_build_where_condition( + cls, + column_name: str, + operator: str = "=", + use_param: bool = True + ) -> Tuple[str, bool]: + """ + Build a safe WHERE condition for a column. + + Args: + column_name: The column name + operator: The comparison operator (=, !=, <, >, <=, >=, LIKE, IN) + use_param: Whether to use parameterized placeholder (%s) + + Returns: + Tuple of (condition_string, needs_param) + e.g., ("`column` = %s", True) or ("`column` = DATABASE()", False) + + Raises: + SQLSecurityError: If column name is invalid or operator is not allowed + """ + # Validate column name + quoted_column = cls.quote_identifier(column_name, "column name") + + # Validate operator + allowed_operators = {'=', '!=', '<>', '<', '>', '<=', '>=', 'LIKE', 'IN', 'IS'} + if operator.upper() not in allowed_operators: + raise SQLSecurityError(f"Invalid operator: '{operator}'. Allowed: {allowed_operators}") + + if use_param: + return f"{quoted_column} {operator} %s", True + else: + return f"{quoted_column} {operator}", False + + @staticmethod + def get_auth_context(): + """ + Get auth_context from the context variable. + + This retrieves the auth_context that was set by the HTTP middleware + during request processing. + + Returns: + The auth_context object, or None if not available + """ + try: + auth_context = auth_context_var.get() + if auth_context: + logger.debug(f"Retrieved auth_context from context variable") + return auth_context + except Exception as e: + logger.debug(f"Could not retrieve auth_context: {e}") + return None + + @staticmethod + def set_auth_context(auth_context): + """ + Set auth_context in the context variable. + + This is typically called by the HTTP middleware during request processing. + + Args: + auth_context: The auth_context object to set + """ + auth_context_var.set(auth_context) + logger.debug("Set auth_context in context variable") + + +# Convenience functions for direct use +validate_identifier = SQLSecurityUtils.validate_identifier +quote_identifier = SQLSecurityUtils.quote_identifier +build_table_reference = SQLSecurityUtils.build_table_reference +build_column_reference = SQLSecurityUtils.build_column_reference +get_auth_context = SQLSecurityUtils.get_auth_context +set_auth_context = SQLSecurityUtils.set_auth_context + diff --git a/test/security/test_sql_injection.py b/test/security/test_sql_injection.py new file mode 100644 index 0000000..d729bea --- /dev/null +++ b/test/security/test_sql_injection.py @@ -0,0 +1,367 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +SQL Security Test Suite for Apache Doris MCP Server + +Tests for: +1. SQL injection prevention via identifier validation +2. Multi-statement SQL parsing in security validator +3. auth_context enforcement +""" + +import pytest +import asyncio +from unittest.mock import MagicMock, AsyncMock, patch + + +class TestSQLSecurityUtils: + """Test cases for sql_security_utils module""" + + def test_validate_identifier_accepts_valid_names(self): + """Test that valid identifiers are accepted""" + from doris_mcp_server.utils.sql_security_utils import validate_identifier + + valid_names = [ + "users", + "my_table", + "Table123", + "_private_table", + "CamelCaseTable", + "table_with_numbers_123", + ] + + for name in valid_names: + result = validate_identifier(name, "table") + assert result == name, f"Valid identifier '{name}' should be accepted" + + def test_validate_identifier_rejects_sql_injection(self): + """Test that SQL injection attempts are rejected""" + from doris_mcp_server.utils.sql_security_utils import ( + validate_identifier, + SQLSecurityError + ) + + injection_attempts = [ + # Basic SQL injection + "'; DROP TABLE users; --", + "table' OR '1'='1", + "table'; DELETE FROM users; --", + + # Union-based injection + "table' UNION SELECT * FROM passwords --", + + # Comment injection + "table/**/OR/**/1=1", + "table--comment", + + # Special characters + "table`; DROP TABLE users;", + 'table"; DROP TABLE users;', + "table\"; DELETE FROM", + + # Backtick escape attempt + "analytics`; SELECT * FROM sensitive_table;--", + + # Whitespace injection + "table name with spaces", + "table\ttab", + "table\nnewline", + ] + + for injection in injection_attempts: + with pytest.raises(SQLSecurityError): + validate_identifier(injection, "table") + + def test_validate_identifier_rejects_empty(self): + """Test that empty identifiers are rejected""" + from doris_mcp_server.utils.sql_security_utils import ( + validate_identifier, + SQLSecurityError + ) + + with pytest.raises(SQLSecurityError): + validate_identifier("", "table") + + with pytest.raises(SQLSecurityError): + validate_identifier(None, "table") + + def test_validate_identifier_rejects_too_long(self): + """Test that identifiers exceeding max length are rejected""" + from doris_mcp_server.utils.sql_security_utils import ( + validate_identifier, + SQLSecurityError + ) + + # Doris identifier max length is typically 64 characters + long_name = "a" * 100 + with pytest.raises(SQLSecurityError): + validate_identifier(long_name, "table") + + def test_quote_identifier_adds_backticks(self): + """Test that quote_identifier properly escapes identifiers""" + from doris_mcp_server.utils.sql_security_utils import quote_identifier + + assert quote_identifier("my_table", "table") == "`my_table`" + assert quote_identifier("users", "table") == "`users`" + assert quote_identifier("Table123", "table") == "`Table123`" + + def test_quote_identifier_validates_first(self): + """Test that quote_identifier validates before quoting""" + from doris_mcp_server.utils.sql_security_utils import ( + quote_identifier, + SQLSecurityError + ) + + with pytest.raises(SQLSecurityError): + quote_identifier("'; DROP TABLE users; --", "table") + + +class TestSQLSecurityValidator: + """Test cases for SQLSecurityValidator multi-statement parsing""" + + @pytest.fixture + def dict_config(self): + """Create dictionary configuration""" + return { + "blocked_keywords": [ + "DROP", "CREATE", "ALTER", "TRUNCATE", + "DELETE", "INSERT", "UPDATE", + "GRANT", "REVOKE", "EXEC", "EXECUTE" + ], + "max_query_complexity": 100, + "enable_security_check": True + } + + @pytest.fixture + def mock_auth_context(self): + """Create mock auth context""" + from doris_mcp_server.utils.security import AuthContext, SecurityLevel + return AuthContext( + user_id="test_user", + roles=["user"], + security_level=SecurityLevel.INTERNAL + ) + + @pytest.mark.asyncio + async def test_validates_all_statements(self, dict_config, mock_auth_context): + """Test that validator checks ALL SQL statements, not just the first""" + from doris_mcp_server.utils.security import SQLSecurityValidator + + validator = SQLSecurityValidator(dict_config) + + # Multi-statement with injection in second statement + # This should be BLOCKED + malicious_sql = "SELECT 1; DROP TABLE users; SELECT 2" + + result = await validator.validate(malicious_sql, mock_auth_context) + + assert not result.is_valid, "Multi-statement injection should be blocked" + # Check for either DROP keyword detection or SQL injection detection + error_upper = result.error_message.upper() + assert ("DROP" in error_upper or + "INJECTION" in error_upper or + "BLOCKED" in error_upper), f"Expected DROP/injection/blocked in: {result.error_message}" + + @pytest.mark.asyncio + async def test_blocks_hidden_dangerous_statement(self, dict_config, mock_auth_context): + """Test that dangerous statements hidden after safe ones are blocked""" + from doris_mcp_server.utils.security import SQLSecurityValidator + + validator = SQLSecurityValidator(dict_config) + + # Safe statement followed by dangerous one + malicious_sql = """ + SELECT * FROM users WHERE id = 1; + DELETE FROM audit_log; + SELECT 1; + """ + + result = await validator.validate(malicious_sql, mock_auth_context) + + assert not result.is_valid, "Hidden DELETE statement should be blocked" + + @pytest.mark.asyncio + async def test_allows_safe_multi_statement(self, dict_config, mock_auth_context): + """Test that multiple safe SELECT statements are allowed""" + from doris_mcp_server.utils.security import SQLSecurityValidator + + validator = SQLSecurityValidator(dict_config) + + safe_sql = """ + SELECT * FROM users; + SELECT COUNT(*) FROM orders; + SELECT id, name FROM products; + """ + + result = await validator.validate(safe_sql, mock_auth_context) + + assert result.is_valid, f"Multiple safe SELECT statements should be allowed, got: {result.error_message}" + + @pytest.mark.asyncio + async def test_context_switch_injection_blocked(self, dict_config, mock_auth_context): + """Test that context switch SQL injection is blocked""" + from doris_mcp_server.utils.security import SQLSecurityValidator + + validator = SQLSecurityValidator(dict_config) + + # Simulating the exec_query_for_mcp attack vector + injected_sql = """ + USE `analytics`; SELECT * FROM sensitive_table;-- `; + SELECT * FROM public_table; + """ + + result = await validator.validate(injected_sql, mock_auth_context) + + # The validator should process all statements + # Even if USE is allowed, subsequent unauthorized access should be caught + # by table access checks (if configured) + + +class TestExecQueryForMCP: + """Test cases for exec_query_for_mcp function""" + + @pytest.mark.asyncio + async def test_rejects_malicious_db_name(self): + """Test that malicious db_name is rejected""" + from doris_mcp_server.utils.sql_security_utils import ( + validate_identifier, + SQLSecurityError + ) + + # The attack vector from security report + malicious_db_name = "analytics`; SELECT * FROM sensitive_table;--" + + with pytest.raises(SQLSecurityError): + validate_identifier(malicious_db_name, "database name") + + @pytest.mark.asyncio + async def test_rejects_malicious_catalog_name(self): + """Test that malicious catalog_name is rejected""" + from doris_mcp_server.utils.sql_security_utils import ( + validate_identifier, + SQLSecurityError + ) + + malicious_catalog_name = "internal'; DROP DATABASE production;--" + + with pytest.raises(SQLSecurityError): + validate_identifier(malicious_catalog_name, "catalog name") + + +class TestDependencyAnalysisTools: + """Test cases for dependency_analysis_tools security fixes""" + + @pytest.mark.asyncio + async def test_get_tables_metadata_rejects_injection(self): + """Test that _get_tables_metadata rejects SQL injection""" + from doris_mcp_server.utils.sql_security_utils import ( + validate_identifier, + SQLSecurityError + ) + + # The attack vector from security report + injection_db_name = "test_db' OR '1'='1' --" + + with pytest.raises(SQLSecurityError): + validate_identifier(injection_db_name, "database name") + + +class TestAuthContextEnforcement: + """Test cases for auth_context enforcement""" + + def test_execute_requires_auth_context_for_security(self): + """Test that security checks require auth_context""" + # This test documents the expected behavior: + # When auth_context is None, security checks are skipped + # When auth_context is provided, security checks are performed + + # The fix ensures all execute() calls pass auth_context + pass + + @pytest.mark.asyncio + async def test_get_auth_context_returns_context(self): + """Test that get_auth_context retrieves context from ContextVar""" + from doris_mcp_server.utils.sql_security_utils import get_auth_context + + # When no context is set, should return None + result = get_auth_context() + # This is expected - context is set by HTTP middleware + assert result is None or hasattr(result, 'user_id') + + +class TestIntegrationScenarios: + """Integration test scenarios for security fixes""" + + def test_attack_scenario_1_permission_bypass(self): + """ + Attack Scenario 1: Permission Bypass in Multi-Tenant Environment + + Expected: User can only query their own database (db_name="tenant_a_db") + Attack: Inject "tenant_a_db' OR '1'='1' --" to query ALL databases + Result: Should be BLOCKED by validate_identifier() + """ + from doris_mcp_server.utils.sql_security_utils import ( + validate_identifier, + SQLSecurityError + ) + + with pytest.raises(SQLSecurityError): + validate_identifier("tenant_a_db' OR '1'='1' --", "database name") + + def test_attack_scenario_2_union_injection(self): + """ + Attack Scenario 2: UNION-based Information Disclosure + + Attack: Inject UNION SELECT to extract sensitive data + Result: Should be BLOCKED + """ + from doris_mcp_server.utils.sql_security_utils import ( + validate_identifier, + SQLSecurityError + ) + + with pytest.raises(SQLSecurityError): + validate_identifier( + "test' UNION SELECT password FROM users --", + "database name" + ) + + def test_attack_scenario_3_backtick_escape(self): + """ + Attack Scenario 3: Backtick Escape Attempt + + Attack: Use backticks to break out of quoted identifier + Result: Should be BLOCKED + """ + from doris_mcp_server.utils.sql_security_utils import ( + validate_identifier, + SQLSecurityError + ) + + with pytest.raises(SQLSecurityError): + validate_identifier( + "analytics`; SELECT * FROM sensitive_table;--", + "database name" + ) + + +# Run tests with: pytest tests/test_sql_security.py -v +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short"]) + diff --git a/test/security/test_sql_injection_api.py b/test/security/test_sql_injection_api.py new file mode 100644 index 0000000..b728ce2 --- /dev/null +++ b/test/security/test_sql_injection_api.py @@ -0,0 +1,871 @@ +#!/usr/bin/env python3 +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +SQL Injection API Integration Tests + +This module tests SQL injection prevention through the MCP HTTP API. +It sends malicious payloads and verifies they are properly blocked. + +Prerequisites: + - MCP server running on localhost:3000 + - Run with: pytest test/security/test_sql_injection_api.py -v + +Usage: + # Start server first + bash start_server.sh + + # Run tests + pytest test/security/test_sql_injection_api.py -v --no-cov +""" + +import pytest +import httpx +import json +import asyncio +from typing import Optional + + +# Server configuration +MCP_BASE_URL = "http://localhost:3000" +MCP_ENDPOINT = f"{MCP_BASE_URL}/mcp" +HEALTH_ENDPOINT = f"{MCP_BASE_URL}/health" +TIMEOUT = 30.0 + + +class MCPClient: + """Simple MCP HTTP client for testing""" + + def __init__(self, base_url: str = MCP_BASE_URL): + self.base_url = base_url + self.mcp_endpoint = f"{base_url}/mcp" + self.session_id: Optional[str] = None + self.request_id = 0 + self.client = httpx.AsyncClient(timeout=TIMEOUT) + + async def close(self): + await self.client.aclose() + + def _next_id(self) -> int: + self.request_id += 1 + return self.request_id + + async def initialize(self) -> dict: + """Initialize MCP session""" + response = await self.client.post( + self.mcp_endpoint, + headers={ + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream" + }, + json={ + "jsonrpc": "2.0", + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { + "name": "sql-injection-test", + "version": "1.0.0" + } + }, + "id": self._next_id() + } + ) + + # Extract session ID from response header + self.session_id = response.headers.get("mcp-session-id") + return self._parse_response(response.text) + + async def call_tool(self, tool_name: str, arguments: dict) -> dict: + """Call an MCP tool""" + if not self.session_id: + await self.initialize() + + response = await self.client.post( + self.mcp_endpoint, + headers={ + "Content-Type": "application/json", + "Accept": "application/json, text/event-stream", + "mcp-session-id": self.session_id + }, + json={ + "jsonrpc": "2.0", + "method": "tools/call", + "params": { + "name": tool_name, + "arguments": arguments + }, + "id": self._next_id() + } + ) + + return self._parse_response(response.text) + + def _parse_response(self, text: str) -> dict: + """Parse JSON response""" + try: + return json.loads(text) + except json.JSONDecodeError: + # Try SSE format + lines = text.strip().split("\n") + for line in lines: + if line.startswith("data: "): + try: + return json.loads(line[6:]) + except json.JSONDecodeError: + continue + return {"raw": text} + + +def print_result(test_name: str, payload: dict, result: dict): + """Print test result in a readable format""" + print(f"\n{'='*60}") + print(f"TEST: {test_name}") + print(f"{'='*60}") + print(f"PAYLOAD: {json.dumps(payload, ensure_ascii=False)}") + print(f"{'-'*60}") + + # Extract inner result content + if "result" in result and "content" in result.get("result", {}): + for item in result["result"]["content"]: + if item.get("type") == "text": + try: + inner = json.loads(item["text"]) + print("RESPONSE:") + print(f" success: {inner.get('success')}") + if inner.get('error'): + print(f" error: {inner.get('error')}") + if inner.get('error_type'): + print(f" error_type: {inner.get('error_type')}") + if inner.get('risk_level'): + print(f" risk_level: {inner.get('risk_level')}") + if inner.get('message'): + print(f" message: {inner.get('message')}") + if inner.get('data') is not None and inner.get('success'): + data_str = json.dumps(inner.get('data'), ensure_ascii=False) + if len(data_str) > 200: + data_str = data_str[:200] + "..." + print(f" data: {data_str}") + except (json.JSONDecodeError, TypeError): + print(f"RESPONSE (raw): {item.get('text', '')[:500]}") + elif "error" in result: + print(f"RESPONSE ERROR: {result['error']}") + else: + print(f"RESPONSE (raw): {json.dumps(result, ensure_ascii=False)[:500]}") + + print(f"{'='*60}\n") + + +class TestSQLInjectionAPI: + """Test SQL injection prevention through MCP API""" + + @pytest.fixture + async def mcp_client(self): + """Create MCP client instance""" + client = MCPClient() + yield client + await client.close() + + @pytest.fixture + def is_server_running(self): + """Check if MCP server is running""" + import httpx + try: + response = httpx.get(HEALTH_ENDPOINT, timeout=5.0) + return response.status_code == 200 + except Exception: + return False + + @pytest.mark.asyncio + async def test_server_health(self): + """Test that MCP server is running and healthy""" + async with httpx.AsyncClient() as client: + response = await client.get(HEALTH_ENDPOINT) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + + @pytest.mark.asyncio + async def test_exec_query_with_drop_injection(self, mcp_client): + """Test exec_query rejects DROP TABLE injection""" + # Classic SQL injection: append DROP TABLE + payload = {"sql": "SELECT * FROM users; DROP TABLE users; --"} + result = await mcp_client.call_tool("exec_query", payload) + print_result("DROP TABLE Injection", payload, result) + + # Should return error, not execute the DROP + assert self._is_blocked_or_error(result), \ + f"DROP TABLE injection should be blocked" + + @pytest.mark.asyncio + async def test_exec_query_with_union_injection(self, mcp_client): + """Test exec_query blocks UNION-based injection attempts""" + # UNION injection to extract data from other tables + payload = {"sql": "SELECT id FROM users UNION SELECT password FROM admin_users"} + result = await mcp_client.call_tool("exec_query", payload) + print_result("UNION Injection", payload, result) + + @pytest.mark.asyncio + async def test_exec_query_with_delete_injection(self, mcp_client): + """Test exec_query rejects DELETE injection""" + payload = {"sql": "SELECT 1; DELETE FROM users WHERE 1=1; SELECT 2"} + result = await mcp_client.call_tool("exec_query", payload) + print_result("DELETE Injection", payload, result) + + assert self._is_blocked_or_error(result), \ + f"DELETE injection should be blocked" + + @pytest.mark.asyncio + async def test_exec_query_with_update_injection(self, mcp_client): + """Test exec_query rejects UPDATE injection""" + payload = {"sql": "SELECT 1; UPDATE users SET role='admin' WHERE id=1; --"} + result = await mcp_client.call_tool("exec_query", payload) + print_result("UPDATE Injection", payload, result) + + assert self._is_blocked_or_error(result), \ + f"UPDATE injection should be blocked" + + @pytest.mark.asyncio + async def test_exec_query_db_name_injection(self, mcp_client): + """Test exec_query rejects SQL injection via db_name parameter""" + # Attack vector: inject SQL via db_name parameter + payload = {"sql": "SELECT 1", "db_name": "test'; DROP TABLE users; --"} + result = await mcp_client.call_tool("exec_query", payload) + print_result("db_name Parameter Injection", payload, result) + + assert self._is_blocked_or_error(result), \ + f"db_name injection should be blocked" + + @pytest.mark.asyncio + async def test_exec_query_catalog_name_injection(self, mcp_client): + """Test exec_query rejects SQL injection via catalog_name parameter""" + # Attack vector: inject SQL via catalog_name parameter + payload = {"sql": "SELECT 1", "catalog_name": "internal`; SELECT * FROM mysql.user; --"} + result = await mcp_client.call_tool("exec_query", payload) + print_result("catalog_name Parameter Injection", payload, result) + + assert self._is_blocked_or_error(result), \ + f"catalog_name injection should be blocked" + + @pytest.mark.asyncio + async def test_get_table_schema_injection(self, mcp_client): + """Test get_table_schema rejects SQL injection via table_name""" + # Attack vector: inject SQL via table_name parameter + payload = {"table_name": "users'; DROP TABLE users; --"} + result = await mcp_client.call_tool("get_table_schema", payload) + print_result("table_name Injection (get_table_schema)", payload, result) + + assert self._is_blocked_or_error(result), \ + f"table_name injection should be blocked" + + @pytest.mark.asyncio + async def test_get_table_schema_db_injection(self, mcp_client): + """Test get_table_schema rejects SQL injection via db_name""" + payload = {"table_name": "users", "db_name": "test' OR '1'='1"} + result = await mcp_client.call_tool("get_table_schema", payload) + print_result("db_name Injection (get_table_schema)", payload, result) + + assert self._is_blocked_or_error(result), \ + f"db_name injection in get_table_schema should be blocked" + + @pytest.mark.asyncio + async def test_analyze_dependencies_injection(self, mcp_client): + """Test analyze_dependencies rejects SQL injection""" + # This was the original vulnerability reported + payload = {"table_name": "users", "db_name": "test_db' OR '1'='1' --"} + result = await mcp_client.call_tool("analyze_dependencies", payload) + print_result("analyze_dependencies Injection (Original Report)", payload, result) + + assert self._is_blocked_or_error(result), \ + f"analyze_dependencies db_name injection should be blocked" + + @pytest.mark.asyncio + async def test_stacked_queries_injection(self, mcp_client): + """Test that stacked queries (multiple statements) are blocked""" + # Multiple statements injection + payload = {"sql": "SELECT * FROM users WHERE id = 1; INSERT INTO audit_log VALUES (NULL, 'hacked', NOW()); SELECT 1;"} + result = await mcp_client.call_tool("exec_query", payload) + print_result("Stacked Queries (INSERT) Injection", payload, result) + + assert self._is_blocked_or_error(result), \ + f"Stacked queries with INSERT should be blocked" + + @pytest.mark.asyncio + async def test_comment_based_injection(self, mcp_client): + """Test that comment-based injection is blocked""" + # Using comments to bypass filters + payload = {"sql": "SELECT * FROM users WHERE id = 1/**/OR/**/1=1"} + result = await mcp_client.call_tool("exec_query", payload) + print_result("Comment-based Injection", payload, result) + + @pytest.mark.asyncio + async def test_hex_encoded_injection(self, mcp_client): + """Test that hex-encoded injection attempts are handled""" + # Hex-encoded 'DROP' attempt + payload = {"sql": "SELECT 0x44524F50205441424C4520757365727320"} + result = await mcp_client.call_tool("exec_query", payload) + print_result("Hex Encoded Injection", payload, result) + + @pytest.mark.asyncio + async def test_backtick_escape_injection(self, mcp_client): + """Test backtick escape injection is blocked""" + # Attempt to escape backtick quoting + payload = {"sql": "SELECT 1", "db_name": "analytics`; SELECT * FROM sensitive_table;--"} + result = await mcp_client.call_tool("exec_query", payload) + print_result("Backtick Escape Injection", payload, result) + + assert self._is_blocked_or_error(result), \ + f"Backtick escape injection should be blocked" + + @pytest.mark.asyncio + async def test_valid_query_succeeds(self, mcp_client): + """Test that valid queries still work""" + # Simple valid query should work + payload = {"sql": "SELECT 1 AS test_value"} + result = await mcp_client.call_tool("exec_query", payload) + print_result("Valid Query (should succeed)", payload, result) + + @pytest.mark.asyncio + async def test_valid_show_databases(self, mcp_client): + """Test that SHOW DATABASES works""" + payload = {"sql": "SHOW DATABASES"} + result = await mcp_client.call_tool("exec_query", payload) + print_result("SHOW DATABASES (should succeed)", payload, result) + + def _is_blocked_or_error(self, result: dict) -> bool: + """Check if result indicates blocked or error""" + if not result: + return True + + # Check for JSON-RPC error + if "error" in result: + return True + + # Check for error in result content + if "result" in result: + result_content = result.get("result", {}) + if isinstance(result_content, dict): + # Check for isError flag + if result_content.get("isError"): + return True + # Check content array for error messages + content = result_content.get("content", []) + for item in content: + if isinstance(item, dict): + text = item.get("text", "") + # Parse the JSON text content + try: + text_data = json.loads(text) + # Check for success: false + if text_data.get("success") is False: + return True + # Check for error field + if text_data.get("error"): + return True + except (json.JSONDecodeError, TypeError): + pass + # Check text for security keywords + if any(keyword in text.lower() for keyword in [ + "error", "blocked", "invalid", "security", + "injection", "denied", "forbidden", "not allowed", + "security_violation", "risk_level" + ]): + return True + + # Check raw text response + raw = result.get("raw", "") + if isinstance(raw, str) and any(keyword in raw.lower() for keyword in [ + "error", "blocked", "invalid", "security" + ]): + return True + + return False + + +class TestIdentifierInjectionAPI: + """Test identifier-based SQL injection prevention""" + + @pytest.fixture + async def mcp_client(self): + """Create MCP client instance""" + client = MCPClient() + yield client + await client.close() + + @pytest.mark.asyncio + async def test_table_name_with_semicolon(self, mcp_client): + """Test table name containing semicolon is rejected""" + payload = {"table_name": "users; DROP TABLE users"} + result = await mcp_client.call_tool("get_table_schema", payload) + print_result("Table Name with Semicolon", payload, result) + + # Should be blocked by identifier validation + assert self._contains_error_indicator(result), \ + f"Table name with semicolon should be rejected" + + @pytest.mark.asyncio + async def test_table_name_with_quotes(self, mcp_client): + """Test table name containing quotes is rejected""" + payload = {"table_name": "users' OR '1'='1"} + result = await mcp_client.call_tool("get_table_schema", payload) + print_result("Table Name with Quotes", payload, result) + + assert self._contains_error_indicator(result), \ + f"Table name with quotes should be rejected" + + @pytest.mark.asyncio + async def test_db_name_with_special_chars(self, mcp_client): + """Test database name with special characters is rejected""" + special_chars = [ + "test;db", + "test'db", + "test\"db", + "test`db", + "test--db", + "test/*db*/", + ] + + for db_name in special_chars: + payload = {"table_name": "users", "db_name": db_name} + result = await mcp_client.call_tool("get_table_schema", payload) + print_result(f"Special Char in db_name: {db_name}", payload, result) + + assert self._contains_error_indicator(result), \ + f"db_name '{db_name}' should be rejected" + + @pytest.mark.asyncio + async def test_valid_identifiers_accepted(self, mcp_client): + """Test that valid identifiers are accepted""" + valid_names = [ + "users", + "my_table", + "Table123", + "_internal_table", + ] + + for table_name in valid_names: + payload = {"table_name": table_name} + result = await mcp_client.call_tool("get_table_schema", payload) + print_result(f"Valid Identifier: {table_name}", payload, result) + + def _contains_error_indicator(self, result: dict) -> bool: + """Check if result contains error indicators""" + if not result: + return True + + # Check for JSON-RPC error + if "error" in result: + return True + + # Check result content + result_str = json.dumps(result).lower() + error_keywords = [ + "error", "invalid", "illegal", "blocked", + "security", "injection", "denied", "forbidden" + ] + + return any(keyword in result_str for keyword in error_keywords) + + +class TestMultiStatementInjectionAPI: + """Test multi-statement SQL injection prevention""" + + @pytest.fixture + async def mcp_client(self): + """Create MCP client instance""" + client = MCPClient() + yield client + await client.close() + + @pytest.mark.asyncio + async def test_hidden_drop_after_select(self, mcp_client): + """Test DROP hidden after legitimate SELECT is blocked""" + payload = {"sql": "SELECT id, name FROM users WHERE status = 'active'; DROP TABLE audit_log; SELECT 1;"} + result = await mcp_client.call_tool("exec_query", payload) + print_result("Hidden DROP after SELECT", payload, result) + + assert self._is_dangerous_blocked(result), \ + f"Hidden DROP statement should be blocked" + + @pytest.mark.asyncio + async def test_hidden_truncate_after_select(self, mcp_client): + """Test TRUNCATE hidden after SELECT is blocked""" + payload = {"sql": "SELECT 1; TRUNCATE TABLE users"} + result = await mcp_client.call_tool("exec_query", payload) + print_result("Hidden TRUNCATE after SELECT", payload, result) + + assert self._is_dangerous_blocked(result), \ + f"Hidden TRUNCATE should be blocked" + + @pytest.mark.asyncio + async def test_hidden_grant_after_select(self, mcp_client): + """Test GRANT hidden after SELECT is blocked""" + payload = {"sql": "SELECT 1; GRANT ALL ON *.* TO 'hacker'@'%'"} + result = await mcp_client.call_tool("exec_query", payload) + print_result("Hidden GRANT after SELECT", payload, result) + + assert self._is_dangerous_blocked(result), \ + f"Hidden GRANT should be blocked" + + @pytest.mark.asyncio + async def test_multiple_safe_selects_allowed(self, mcp_client): + """Test that multiple SELECT statements may be allowed""" + payload = {"sql": "SELECT 1; SELECT 2; SELECT 3;"} + result = await mcp_client.call_tool("exec_query", payload) + print_result("Multiple Safe SELECTs", payload, result) + + def _is_dangerous_blocked(self, result: dict) -> bool: + """Check if dangerous operation was blocked""" + if not result: + return True + + # Check for error + if "error" in result: + return True + + # Check result content for blocking indicators + result_str = json.dumps(result).lower() + block_indicators = [ + "drop", "truncate", "grant", "revoke", + "blocked", "denied", "forbidden", "not allowed", + "security", "error" + ] + + return any(indicator in result_str for indicator in block_indicators) + + +class TestADBCQueryInjectionAPI: + """Test ADBC query SQL injection prevention""" + + @pytest.fixture + async def mcp_client(self): + """Create MCP client instance""" + client = MCPClient() + yield client + await client.close() + + @pytest.mark.asyncio + async def test_exec_adbc_query_drop_injection(self, mcp_client): + """Test exec_adbc_query rejects DROP TABLE injection""" + payload = {"sql": "SELECT * FROM users; DROP TABLE users; --"} + result = await mcp_client.call_tool("exec_adbc_query", payload) + print_result("ADBC DROP TABLE Injection", payload, result) + + assert self._is_blocked_or_error(result), \ + "ADBC DROP TABLE injection should be blocked" + + @pytest.mark.asyncio + async def test_exec_adbc_query_delete_injection(self, mcp_client): + """Test exec_adbc_query rejects DELETE injection""" + payload = {"sql": "SELECT 1; DELETE FROM users; --"} + result = await mcp_client.call_tool("exec_adbc_query", payload) + print_result("ADBC DELETE Injection", payload, result) + + assert self._is_blocked_or_error(result), \ + "ADBC DELETE injection should be blocked" + + @pytest.mark.asyncio + async def test_exec_adbc_query_valid(self, mcp_client): + """Test exec_adbc_query allows valid queries""" + payload = {"sql": "SELECT 1 AS test"} + result = await mcp_client.call_tool("exec_adbc_query", payload) + print_result("ADBC Valid Query", payload, result) + + def _is_blocked_or_error(self, result: dict) -> bool: + """Check if result indicates blocked or error""" + if not result: + return True + if "error" in result: + return True + result_str = json.dumps(result).lower() + return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"]) + + +class TestMetadataToolsInjectionAPI: + """Test metadata tools SQL injection prevention""" + + @pytest.fixture + async def mcp_client(self): + """Create MCP client instance""" + client = MCPClient() + yield client + await client.close() + + @pytest.mark.asyncio + async def test_get_db_table_list_db_injection(self, mcp_client): + """Test get_db_table_list rejects db_name injection""" + payload = {"db_name": "test'; DROP TABLE users; --"} + result = await mcp_client.call_tool("get_db_table_list", payload) + print_result("get_db_table_list db_name Injection", payload, result) + + assert self._is_blocked_or_error(result), \ + "db_name injection should be blocked" + + @pytest.mark.asyncio + async def test_get_db_table_list_catalog_injection(self, mcp_client): + """Test get_db_table_list rejects catalog_name injection""" + payload = {"catalog_name": "internal`; SELECT * FROM mysql.user; --"} + result = await mcp_client.call_tool("get_db_table_list", payload) + print_result("get_db_table_list catalog_name Injection", payload, result) + + assert self._is_blocked_or_error(result), \ + "catalog_name injection should be blocked" + + @pytest.mark.asyncio + async def test_get_table_comment_injection(self, mcp_client): + """Test get_table_comment rejects table_name injection""" + payload = {"table_name": "users'; DROP TABLE users; --"} + result = await mcp_client.call_tool("get_table_comment", payload) + print_result("get_table_comment table_name Injection", payload, result) + + assert self._is_blocked_or_error(result), \ + "table_name injection should be blocked" + + @pytest.mark.asyncio + async def test_get_table_column_comments_injection(self, mcp_client): + """Test get_table_column_comments rejects injection""" + payload = {"table_name": "users'; DROP TABLE users; --", "db_name": "test"} + result = await mcp_client.call_tool("get_table_column_comments", payload) + print_result("get_table_column_comments Injection", payload, result) + + assert self._is_blocked_or_error(result), \ + "table_name injection should be blocked" + + @pytest.mark.asyncio + async def test_get_table_indexes_injection(self, mcp_client): + """Test get_table_indexes rejects table_name injection""" + payload = {"table_name": "users; DROP TABLE users", "db_name": "test"} + result = await mcp_client.call_tool("get_table_indexes", payload) + print_result("get_table_indexes Injection", payload, result) + + assert self._is_blocked_or_error(result), \ + "table_name injection should be blocked" + + def _is_blocked_or_error(self, result: dict) -> bool: + """Check if result indicates blocked or error""" + if not result: + return True + if "error" in result: + return True + result_str = json.dumps(result).lower() + return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"]) + + +class TestAnalyticsToolsInjectionAPI: + """Test analytics tools SQL injection prevention""" + + @pytest.fixture + async def mcp_client(self): + """Create MCP client instance""" + client = MCPClient() + yield client + await client.close() + + @pytest.mark.asyncio + async def test_analyze_columns_table_injection(self, mcp_client): + """Test analyze_columns rejects table_name injection""" + payload = {"table_name": "users'; DROP TABLE users; --"} + result = await mcp_client.call_tool("analyze_columns", payload) + print_result("analyze_columns table_name Injection", payload, result) + + assert self._is_blocked_or_error(result), \ + "table_name injection should be blocked" + + @pytest.mark.asyncio + async def test_analyze_columns_db_injection(self, mcp_client): + """Test analyze_columns rejects db_name injection""" + payload = {"table_name": "users", "db_name": "test' OR '1'='1"} + result = await mcp_client.call_tool("analyze_columns", payload) + print_result("analyze_columns db_name Injection", payload, result) + + assert self._is_blocked_or_error(result), \ + "db_name injection should be blocked" + + @pytest.mark.asyncio + async def test_get_table_basic_info_injection(self, mcp_client): + """Test get_table_basic_info rejects injection""" + payload = {"table_name": "users; DROP TABLE audit_log"} + result = await mcp_client.call_tool("get_table_basic_info", payload) + print_result("get_table_basic_info Injection", payload, result) + + assert self._is_blocked_or_error(result), \ + "table_name injection should be blocked" + + @pytest.mark.asyncio + async def test_analyze_table_storage_injection(self, mcp_client): + """Test analyze_table_storage rejects injection""" + payload = {"table_name": "users`; SELECT * FROM sensitive; --"} + result = await mcp_client.call_tool("analyze_table_storage", payload) + print_result("analyze_table_storage Injection", payload, result) + + assert self._is_blocked_or_error(result), \ + "table_name injection should be blocked" + + @pytest.mark.asyncio + async def test_get_sql_explain_injection(self, mcp_client): + """Test get_sql_explain rejects SQL injection""" + payload = {"sql": "SELECT 1; DROP TABLE users; --"} + result = await mcp_client.call_tool("get_sql_explain", payload) + print_result("get_sql_explain SQL Injection", payload, result) + + assert self._is_blocked_or_error(result), \ + "SQL injection should be blocked" + + @pytest.mark.asyncio + async def test_get_sql_profile_injection(self, mcp_client): + """Test get_sql_profile rejects SQL injection""" + payload = {"sql": "SELECT 1; DELETE FROM audit_log; --"} + result = await mcp_client.call_tool("get_sql_profile", payload) + print_result("get_sql_profile SQL Injection", payload, result) + + assert self._is_blocked_or_error(result), \ + "SQL injection should be blocked" + + def _is_blocked_or_error(self, result: dict) -> bool: + """Check if result indicates blocked or error""" + if not result: + return True + if "error" in result: + return True + result_str = json.dumps(result).lower() + return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"]) + + +class TestGovernanceToolsInjectionAPI: + """Test data governance tools SQL injection prevention""" + + @pytest.fixture + async def mcp_client(self): + """Create MCP client instance""" + client = MCPClient() + yield client + await client.close() + + @pytest.mark.asyncio + async def test_trace_column_lineage_table_injection(self, mcp_client): + """Test trace_column_lineage rejects table_name injection""" + payload = {"table_name": "users'; DROP TABLE users; --", "column_name": "id"} + result = await mcp_client.call_tool("trace_column_lineage", payload) + print_result("trace_column_lineage table_name Injection", payload, result) + + assert self._is_blocked_or_error(result), \ + "table_name injection should be blocked" + + @pytest.mark.asyncio + async def test_trace_column_lineage_column_injection(self, mcp_client): + """Test trace_column_lineage rejects column_name injection""" + payload = {"table_name": "users", "column_name": "id; DROP TABLE users"} + result = await mcp_client.call_tool("trace_column_lineage", payload) + print_result("trace_column_lineage column_name Injection", payload, result) + + assert self._is_blocked_or_error(result), \ + "column_name injection should be blocked" + + @pytest.mark.asyncio + async def test_monitor_data_freshness_injection(self, mcp_client): + """Test monitor_data_freshness rejects table_name injection""" + payload = {"table_name": "users`; SELECT * FROM passwords; --"} + result = await mcp_client.call_tool("monitor_data_freshness", payload) + print_result("monitor_data_freshness Injection", payload, result) + + assert self._is_blocked_or_error(result), \ + "table_name injection should be blocked" + + @pytest.mark.asyncio + async def test_analyze_data_access_patterns_injection(self, mcp_client): + """Test analyze_data_access_patterns rejects injection""" + payload = {"table_name": "users' UNION SELECT password FROM admin --"} + result = await mcp_client.call_tool("analyze_data_access_patterns", payload) + print_result("analyze_data_access_patterns Injection", payload, result) + + assert self._is_blocked_or_error(result), \ + "table_name injection should be blocked" + + def _is_blocked_or_error(self, result: dict) -> bool: + """Check if result indicates blocked or error""" + if not result: + return True + if "error" in result: + return True + result_str = json.dumps(result).lower() + return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"]) + + +class TestPerformanceToolsInjectionAPI: + """Test performance analytics tools SQL injection prevention""" + + @pytest.fixture + async def mcp_client(self): + """Create MCP client instance""" + client = MCPClient() + yield client + await client.close() + + @pytest.mark.asyncio + async def test_analyze_slow_queries_db_injection(self, mcp_client): + """Test analyze_slow_queries_topn rejects db_name injection""" + payload = {"db_name": "test'; DROP TABLE audit_log; --"} + result = await mcp_client.call_tool("analyze_slow_queries_topn", payload) + print_result("analyze_slow_queries_topn db_name Injection", payload, result) + + assert self._is_blocked_or_error(result), \ + "db_name injection should be blocked" + + @pytest.mark.asyncio + async def test_analyze_resource_growth_db_injection(self, mcp_client): + """Test analyze_resource_growth_curves rejects db_name injection""" + payload = {"db_name": "test`; GRANT ALL ON *.* TO 'hacker'; --"} + result = await mcp_client.call_tool("analyze_resource_growth_curves", payload) + print_result("analyze_resource_growth_curves db_name Injection", payload, result) + + assert self._is_blocked_or_error(result), \ + "db_name injection should be blocked" + + @pytest.mark.asyncio + async def test_get_table_data_size_injection(self, mcp_client): + """Test get_table_data_size rejects table_name injection""" + payload = {"table_name": "users; TRUNCATE TABLE logs"} + result = await mcp_client.call_tool("get_table_data_size", payload) + print_result("get_table_data_size Injection", payload, result) + + assert self._is_blocked_or_error(result), \ + "table_name injection should be blocked" + + def _is_blocked_or_error(self, result: dict) -> bool: + """Check if result indicates blocked or error""" + if not result: + return True + if "error" in result: + return True + result_str = json.dumps(result).lower() + return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"]) + + +# Pytest configuration for async tests +@pytest.fixture(scope="session") +def event_loop(): + """Create event loop for async tests""" + loop = asyncio.new_event_loop() + yield loop + loop.close() + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--tb=short", "-x"]) +