Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
81305ffbf9 | ||
|
|
43143f0b30 | ||
|
|
e58361e04b | ||
|
|
a125a2f5f8 | ||
|
|
2613912df3 |
@@ -432,9 +432,9 @@ class DorisServer:
|
||||
await self.security_manager.initialize()
|
||||
self.logger.info("Security manager initialization completed")
|
||||
|
||||
# Ensure connection manager is initialized
|
||||
await self.connection_manager.initialize()
|
||||
self.logger.info("Connection manager initialization completed")
|
||||
# For stdio mode, we must establish a working database connection
|
||||
# Use the dedicated stdio mode initialization method
|
||||
await self.connection_manager.initialize_for_stdio_mode()
|
||||
|
||||
# Start stdio server - using compatible import approach
|
||||
try:
|
||||
@@ -502,8 +502,12 @@ class DorisServer:
|
||||
await self.security_manager.initialize()
|
||||
self.logger.info("Security manager initialization completed")
|
||||
|
||||
# Ensure connection manager is initialized
|
||||
await self.connection_manager.initialize()
|
||||
# For HTTP mode, try to initialize global connection pool with graceful degradation
|
||||
global_pool_created = await self.connection_manager.initialize_for_http_mode()
|
||||
if global_pool_created:
|
||||
self.logger.info("Global database connection pool available for HTTP mode")
|
||||
else:
|
||||
self.logger.info("HTTP mode running without global database pool, will use token-bound configurations")
|
||||
|
||||
# Use Starlette and StreamableHTTPSessionManager according to official example
|
||||
import uvicorn
|
||||
@@ -638,6 +642,16 @@ class DorisServer:
|
||||
# Store auth context in scope for potential use by tools/resources
|
||||
scope["auth_context"] = auth_context
|
||||
|
||||
# FIX for Issue #62 Bug 1: Set auth_context in context variable
|
||||
# This allows tools to access token information for token-bound database configuration
|
||||
# CRITICAL: Use the global ContextVar from security.py to ensure same instance is used everywhere
|
||||
try:
|
||||
from .utils.security import mcp_auth_context_var
|
||||
mcp_auth_context_var.set(auth_context)
|
||||
self.logger.debug(f"Set auth_context in context variable with token: {bool(hasattr(auth_context, 'token') and auth_context.token)}")
|
||||
except Exception as ctx_error:
|
||||
self.logger.warning(f"Failed to set auth_context in context variable: {ctx_error}")
|
||||
|
||||
except Exception as auth_error:
|
||||
self.logger.error(f"MCP authentication failed: {auth_error}")
|
||||
# Return 401 Unauthorized
|
||||
|
||||
@@ -31,6 +31,7 @@ from mcp.types import (
|
||||
)
|
||||
|
||||
from ..utils.db import DorisConnectionManager
|
||||
from ..utils.sql_security_utils import get_auth_context
|
||||
|
||||
|
||||
class PromptTemplate:
|
||||
@@ -422,7 +423,8 @@ Please generate accurate and efficient SQL queries based on the above requiremen
|
||||
AND table_type = 'BASE TABLE'
|
||||
"""
|
||||
|
||||
db_result = await connection.execute(db_info_sql)
|
||||
auth_context = get_auth_context()
|
||||
db_result = await connection.execute(db_info_sql, auth_context=auth_context)
|
||||
db_info = db_result.data[0] if db_result.data else {}
|
||||
|
||||
# Get main table list
|
||||
@@ -438,7 +440,7 @@ Please generate accurate and efficient SQL queries based on the above requiremen
|
||||
LIMIT 10
|
||||
"""
|
||||
|
||||
tables_result = await connection.execute(tables_sql)
|
||||
tables_result = await connection.execute(tables_sql, auth_context=auth_context)
|
||||
|
||||
context = f"""Current database statistics:
|
||||
- Total number of tables: {db_info.get("table_count", 0)}
|
||||
|
||||
@@ -26,6 +26,7 @@ from typing import Any
|
||||
from mcp.types import Resource
|
||||
|
||||
from ..utils.db import DorisConnectionManager
|
||||
from ..utils.sql_security_utils import get_auth_context
|
||||
|
||||
|
||||
class TableMetadata:
|
||||
@@ -169,7 +170,8 @@ class DorisResourcesManager:
|
||||
ORDER BY table_name
|
||||
"""
|
||||
|
||||
result = await connection.execute(tables_query)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(tables_query, auth_context=auth_context)
|
||||
tables = []
|
||||
|
||||
for row in result.data:
|
||||
@@ -204,7 +206,8 @@ class DorisResourcesManager:
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
|
||||
result = await connection.execute(columns_query, (table_name,))
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(columns_query, params=(table_name,), auth_context=auth_context)
|
||||
return [dict(row) for row in result.data]
|
||||
|
||||
async def _get_view_metadata(self) -> list[ViewMetadata]:
|
||||
@@ -226,7 +229,8 @@ class DorisResourcesManager:
|
||||
ORDER BY table_name
|
||||
"""
|
||||
|
||||
result = await connection.execute(views_query)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(views_query, auth_context=auth_context)
|
||||
views = []
|
||||
|
||||
for row in result.data:
|
||||
@@ -257,7 +261,8 @@ class DorisResourcesManager:
|
||||
AND table_name = %s
|
||||
"""
|
||||
|
||||
table_result = await connection.execute(table_info_query, (table_name,))
|
||||
auth_context = get_auth_context()
|
||||
table_result = await connection.execute(table_info_query, params=(table_name,), auth_context=auth_context)
|
||||
if not table_result.data:
|
||||
raise ValueError(f"Table {table_name} does not exist")
|
||||
|
||||
@@ -295,7 +300,8 @@ class DorisResourcesManager:
|
||||
ORDER BY index_name, seq_in_index
|
||||
"""
|
||||
|
||||
result = await connection.execute(indexes_query, (table_name,))
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(indexes_query, params=(table_name,), auth_context=auth_context)
|
||||
return [dict(row) for row in result.data]
|
||||
|
||||
async def _get_view_definition(self, view_name: str) -> str:
|
||||
@@ -312,7 +318,8 @@ class DorisResourcesManager:
|
||||
AND table_name = %s
|
||||
"""
|
||||
|
||||
result = await connection.execute(view_query, (view_name,))
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(view_query, params=(view_name,), auth_context=auth_context)
|
||||
if not result.data:
|
||||
raise ValueError(f"View {view_name} does not exist")
|
||||
|
||||
@@ -340,7 +347,8 @@ class DorisResourcesManager:
|
||||
AND table_type = 'BASE TABLE'
|
||||
"""
|
||||
|
||||
table_result = await connection.execute(table_stats_query)
|
||||
auth_context = get_auth_context()
|
||||
table_result = await connection.execute(table_stats_query, auth_context=auth_context)
|
||||
table_stats = table_result.data[0] if table_result.data else {}
|
||||
|
||||
# Get view statistics
|
||||
@@ -350,7 +358,7 @@ class DorisResourcesManager:
|
||||
WHERE table_schema = DATABASE()
|
||||
"""
|
||||
|
||||
view_result = await connection.execute(view_stats_query)
|
||||
view_result = await connection.execute(view_stats_query, auth_context=auth_context)
|
||||
view_stats = view_result.data[0] if view_result.data else {}
|
||||
|
||||
stats_info = {
|
||||
|
||||
@@ -28,6 +28,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.db import DorisConnectionManager
|
||||
from ..utils.sql_security_utils import get_auth_context
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -277,7 +278,8 @@ class DorisADBCQueryTools:
|
||||
# Get BE nodes via SHOW BACKENDS
|
||||
logger.info("No BE hosts configured, getting BE node information via SHOW BACKENDS")
|
||||
connection = await self.connection_manager.get_connection("query")
|
||||
result = await connection.execute("SHOW BACKENDS")
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute("SHOW BACKENDS", auth_context=auth_context)
|
||||
|
||||
be_hosts = []
|
||||
for row in result.data:
|
||||
@@ -383,6 +385,20 @@ class DorisADBCQueryTools:
|
||||
"error_type": "no_connection"
|
||||
}
|
||||
|
||||
# SECURITY FIX: Perform SQL security validation before executing
|
||||
auth_context = get_auth_context()
|
||||
if self.connection_manager.security_manager:
|
||||
# Always perform security validation, even without auth_context
|
||||
# Use a default context for basic SQL security checks
|
||||
validation_result = await self.connection_manager.security_manager.validate_sql_security(sql, auth_context)
|
||||
if not validation_result.is_valid:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"SQL security validation failed: {validation_result.error_message}",
|
||||
"error_type": "security_violation",
|
||||
"risk_level": validation_result.risk_level
|
||||
}
|
||||
|
||||
cursor = self.adbc_client.cursor()
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
@@ -29,6 +29,13 @@ from pathlib import Path
|
||||
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
from .sql_security_utils import (
|
||||
SQLSecurityError,
|
||||
validate_identifier,
|
||||
quote_identifier,
|
||||
build_table_reference,
|
||||
get_auth_context
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -46,10 +53,17 @@ class TableAnalyzer:
|
||||
sample_size: int = 10
|
||||
) -> Dict[str, Any]:
|
||||
"""Get table summary information"""
|
||||
# SECURITY FIX: Validate table_name and get auth_context
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
except SQLSecurityError as e:
|
||||
raise ValueError(f"Invalid table name: {e}")
|
||||
|
||||
auth_context = get_auth_context()
|
||||
connection = await self.connection_manager.get_connection("query")
|
||||
|
||||
# Get table basic information
|
||||
table_info_sql = f"""
|
||||
# Get table basic information using parameterized query
|
||||
table_info_sql = """
|
||||
SELECT
|
||||
table_name,
|
||||
table_comment,
|
||||
@@ -58,17 +72,17 @@ class TableAnalyzer:
|
||||
engine
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = '{table_name}'
|
||||
AND table_name = %s
|
||||
"""
|
||||
|
||||
table_info_result = await connection.execute(table_info_sql)
|
||||
table_info_result = await connection.execute(table_info_sql, params=(table_name,), auth_context=auth_context)
|
||||
if not table_info_result.data:
|
||||
raise ValueError(f"Table {table_name} does not exist")
|
||||
|
||||
table_info = table_info_result.data[0]
|
||||
|
||||
# Get column information
|
||||
columns_sql = f"""
|
||||
# Get column information using parameterized query
|
||||
columns_sql = """
|
||||
SELECT
|
||||
column_name,
|
||||
data_type,
|
||||
@@ -76,11 +90,11 @@ class TableAnalyzer:
|
||||
column_comment
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = '{table_name}'
|
||||
AND table_name = %s
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
|
||||
columns_result = await connection.execute(columns_sql)
|
||||
columns_result = await connection.execute(columns_sql, params=(table_name,), auth_context=auth_context)
|
||||
|
||||
summary = {
|
||||
"table_name": table_info["table_name"],
|
||||
@@ -92,10 +106,11 @@ class TableAnalyzer:
|
||||
"columns": columns_result.data,
|
||||
}
|
||||
|
||||
# Get sample data
|
||||
# Get sample data using quoted identifier
|
||||
if include_sample and sample_size > 0:
|
||||
sample_sql = f"SELECT * FROM {table_name} LIMIT {sample_size}"
|
||||
sample_result = await connection.execute(sample_sql)
|
||||
quoted_table = quote_identifier(table_name, "table name")
|
||||
sample_sql = f"SELECT * FROM {quoted_table} LIMIT {sample_size}"
|
||||
sample_result = await connection.execute(sample_sql, auth_context=auth_context)
|
||||
summary["sample_data"] = sample_result.data
|
||||
|
||||
return summary
|
||||
@@ -120,7 +135,8 @@ class TableAnalyzer:
|
||||
FROM {table_name}
|
||||
"""
|
||||
|
||||
basic_result = await connection.execute(basic_stats_sql)
|
||||
auth_context = get_auth_context()
|
||||
basic_result = await connection.execute(basic_stats_sql, auth_context=auth_context)
|
||||
if not basic_result.data:
|
||||
return {
|
||||
"success": False,
|
||||
@@ -144,7 +160,7 @@ class TableAnalyzer:
|
||||
LIMIT 20
|
||||
"""
|
||||
|
||||
distribution_result = await connection.execute(distribution_sql)
|
||||
distribution_result = await connection.execute(distribution_sql, auth_context=auth_context)
|
||||
analysis["value_distribution"] = distribution_result.data
|
||||
|
||||
if analysis_type == "detailed":
|
||||
@@ -159,7 +175,7 @@ class TableAnalyzer:
|
||||
WHERE {column_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
numeric_result = await connection.execute(numeric_stats_sql)
|
||||
numeric_result = await connection.execute(numeric_stats_sql, auth_context=auth_context)
|
||||
if numeric_result.data:
|
||||
analysis.update(numeric_result.data[0])
|
||||
except Exception:
|
||||
@@ -196,7 +212,8 @@ class TableAnalyzer:
|
||||
AND table_name = '{table_name}'
|
||||
"""
|
||||
|
||||
table_result = await connection.execute(table_info_sql)
|
||||
auth_context = get_auth_context()
|
||||
table_result = await connection.execute(table_info_sql, auth_context=auth_context)
|
||||
if not table_result.data:
|
||||
raise ValueError(f"Table {table_name} does not exist")
|
||||
|
||||
@@ -211,7 +228,7 @@ class TableAnalyzer:
|
||||
AND table_name != %s
|
||||
"""
|
||||
|
||||
all_tables_result = await connection.execute(all_tables_sql, (table_name,))
|
||||
all_tables_result = await connection.execute(all_tables_sql, params=(table_name,), auth_context=auth_context)
|
||||
|
||||
return {
|
||||
"center_table": table_result.data[0],
|
||||
@@ -291,7 +308,8 @@ class PerformanceMonitor:
|
||||
LIMIT 20
|
||||
"""
|
||||
|
||||
tables_result = await connection.execute(tables_sql)
|
||||
auth_context = get_auth_context()
|
||||
tables_result = await connection.execute(tables_sql, auth_context=auth_context)
|
||||
stats = {
|
||||
"metric_type": "tables",
|
||||
"time_range": time_range,
|
||||
@@ -379,9 +397,23 @@ class SQLAnalyzer:
|
||||
|
||||
logger.info(f"Generating SQL explain for query ID: {query_id}")
|
||||
|
||||
# 🔧 FIX: Get auth_context for token-bound database configuration
|
||||
auth_context = None
|
||||
try:
|
||||
from .security import mcp_auth_context_var
|
||||
auth_context = mcp_auth_context_var.get()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Switch database if specified
|
||||
# SECURITY FIX: Validate and quote db_name
|
||||
if db_name:
|
||||
await self.connection_manager.execute_query("explain_session", f"USE {db_name}")
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
return {"success": False, "error": f"Invalid database name: {e}"}
|
||||
safe_db = quote_identifier(db_name, "database name")
|
||||
await self.connection_manager.execute_query("explain_session", f"USE {safe_db}", None, auth_context)
|
||||
|
||||
# Construct EXPLAIN query
|
||||
explain_type = "EXPLAIN VERBOSE" if verbose else "EXPLAIN"
|
||||
@@ -390,7 +422,7 @@ class SQLAnalyzer:
|
||||
logger.info(f"Executing explain query: {explain_sql}")
|
||||
|
||||
# Execute explain query
|
||||
result = await self.connection_manager.execute_query("explain_session", explain_sql)
|
||||
result = await self.connection_manager.execute_query("explain_session", explain_sql, None, auth_context)
|
||||
|
||||
# Format explain output
|
||||
explain_content = []
|
||||
@@ -515,24 +547,36 @@ class SQLAnalyzer:
|
||||
|
||||
try:
|
||||
# Switch to specified database/catalog if provided
|
||||
# SECURITY FIX: Validate identifiers before using in SQL
|
||||
if catalog_name:
|
||||
await connection.execute(f"SWITCH `{catalog_name}`")
|
||||
try:
|
||||
validate_identifier(catalog_name, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
return {"success": False, "error": f"Invalid catalog name: {e}"}
|
||||
safe_catalog = quote_identifier(catalog_name, "catalog name")
|
||||
auth_context = get_auth_context()
|
||||
await connection.execute(f"SWITCH {safe_catalog}", auth_context=auth_context)
|
||||
if db_name:
|
||||
await connection.execute(f"USE `{db_name}`")
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
return {"success": False, "error": f"Invalid database name: {e}"}
|
||||
safe_db = quote_identifier(db_name, "database name")
|
||||
await connection.execute(f"USE {safe_db}", auth_context=auth_context)
|
||||
|
||||
# Set trace ID for the session using session variable
|
||||
# According to official docs: set session_context="trace_id:your_trace_id"
|
||||
await connection.execute(f'set session_context="trace_id:{trace_id}"')
|
||||
await connection.execute(f'set session_context="trace_id:{trace_id}"', auth_context=auth_context)
|
||||
logger.info(f"Set trace ID: {trace_id}")
|
||||
|
||||
# Enable profile
|
||||
await connection.execute(f'set enable_profile=true')
|
||||
await connection.execute(f'set enable_profile=true', auth_context=auth_context)
|
||||
logger.info(f"Enabled profile")
|
||||
|
||||
# Execute the SQL statement
|
||||
logger.info(f"Executing SQL with trace ID: {sql}")
|
||||
start_time = time.time()
|
||||
sql_result = await connection.execute(sql)
|
||||
sql_result = await connection.execute(sql, auth_context=auth_context)
|
||||
execution_time = time.time() - start_time
|
||||
logger.info(f"SQL execution completed in {execution_time:.3f}s")
|
||||
|
||||
|
||||
@@ -26,6 +26,13 @@ from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
from .sql_security_utils import (
|
||||
SQLSecurityError,
|
||||
validate_identifier,
|
||||
quote_identifier,
|
||||
build_table_reference,
|
||||
get_auth_context
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -43,24 +50,30 @@ class DataExplorationTools:
|
||||
|
||||
def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str:
|
||||
"""Build full table name with catalog and database using three-part naming convention"""
|
||||
# Default catalog for internal tables
|
||||
# SECURITY FIX: Use build_table_reference for safe identifier handling
|
||||
effective_catalog = catalog_name if catalog_name else "internal"
|
||||
|
||||
if db_name:
|
||||
return f"{effective_catalog}.{db_name}.{table_name}"
|
||||
return build_table_reference(table_name, db_name, effective_catalog)
|
||||
else:
|
||||
# If no db_name provided, need to determine the current database
|
||||
return f"{effective_catalog}.{table_name}"
|
||||
return build_table_reference(table_name, catalog_name=effective_catalog)
|
||||
|
||||
async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]:
|
||||
"""Get basic table information including row count"""
|
||||
try:
|
||||
# SECURITY FIX: Get auth_context for security validation
|
||||
# table_name should already be validated by _build_full_table_name
|
||||
auth_context = get_auth_context()
|
||||
|
||||
count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}"
|
||||
result = await connection.execute(count_sql)
|
||||
result = await connection.execute(count_sql, auth_context=auth_context)
|
||||
|
||||
if result.data:
|
||||
return {"row_count": result.data[0]["row_count"]}
|
||||
return None
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Security validation failed for table {table_name}: {str(e)}")
|
||||
return {"row_count": 0}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get basic info for table {table_name}: {str(e)}")
|
||||
return {"row_count": 0}
|
||||
@@ -68,10 +81,24 @@ class DataExplorationTools:
|
||||
async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]:
|
||||
"""Get detailed column information"""
|
||||
try:
|
||||
where_conditions = [f"table_name = '{table_name}'"]
|
||||
# SECURITY FIX: Validate identifiers and use parameterized query
|
||||
auth_context = get_auth_context()
|
||||
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
if db_name:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return []
|
||||
|
||||
# Build parameterized query
|
||||
params = [table_name]
|
||||
where_conditions = ["table_name = %s"]
|
||||
|
||||
if db_name:
|
||||
where_conditions.append(f"table_schema = '{db_name}'")
|
||||
where_conditions.append("table_schema = %s")
|
||||
params.append(db_name)
|
||||
else:
|
||||
where_conditions.append("table_schema = DATABASE()")
|
||||
|
||||
@@ -87,9 +114,12 @@ class DataExplorationTools:
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
|
||||
result = await connection.execute(columns_sql)
|
||||
result = await connection.execute(columns_sql, params=tuple(params), auth_context=auth_context)
|
||||
return result.data if result.data else []
|
||||
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Security validation failed: {str(e)}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get columns info for table {table_name}: {str(e)}")
|
||||
return []
|
||||
@@ -177,7 +207,8 @@ class DataExplorationTools:
|
||||
WHERE {col_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
stats_result = await connection.execute(stats_sql)
|
||||
auth_context = get_auth_context()
|
||||
stats_result = await connection.execute(stats_sql, auth_context=auth_context)
|
||||
|
||||
if stats_result.data and stats_result.data[0]["count"] > 0:
|
||||
stats = stats_result.data[0]
|
||||
@@ -229,7 +260,8 @@ class DataExplorationTools:
|
||||
WHERE {col_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
result = await connection.execute(percentile_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(percentile_sql, auth_context=auth_context)
|
||||
|
||||
if result.data:
|
||||
data = result.data[0]
|
||||
@@ -268,7 +300,8 @@ class DataExplorationTools:
|
||||
WHERE {col_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
result = await connection.execute(outlier_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(outlier_sql, auth_context=auth_context)
|
||||
|
||||
if result.data:
|
||||
data = result.data[0]
|
||||
@@ -359,7 +392,8 @@ class DataExplorationTools:
|
||||
{sampling_info.get('sample_query_suffix', '')}
|
||||
"""
|
||||
|
||||
cardinality_result = await connection.execute(cardinality_sql)
|
||||
auth_context = get_auth_context()
|
||||
cardinality_result = await connection.execute(cardinality_sql, auth_context=auth_context)
|
||||
|
||||
if cardinality_result.data:
|
||||
cardinality_data = cardinality_result.data[0]
|
||||
@@ -408,7 +442,8 @@ class DataExplorationTools:
|
||||
LIMIT 20
|
||||
"""
|
||||
|
||||
result = await connection.execute(distribution_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(distribution_sql, auth_context=auth_context)
|
||||
|
||||
if result.data:
|
||||
distribution = []
|
||||
@@ -458,7 +493,8 @@ class DataExplorationTools:
|
||||
WHERE {col_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
range_result = await connection.execute(range_sql)
|
||||
auth_context = get_auth_context()
|
||||
range_result = await connection.execute(range_sql, auth_context=auth_context)
|
||||
|
||||
if range_result.data and range_result.data[0]["non_null_count"] > 0:
|
||||
range_data = range_result.data[0]
|
||||
@@ -539,7 +575,8 @@ class DataExplorationTools:
|
||||
ORDER BY day_of_week
|
||||
"""
|
||||
|
||||
weekly_result = await connection.execute(weekly_pattern_sql)
|
||||
auth_context = get_auth_context()
|
||||
weekly_result = await connection.execute(weekly_pattern_sql, auth_context=auth_context)
|
||||
|
||||
weekly_pattern = []
|
||||
if weekly_result.data:
|
||||
@@ -561,7 +598,7 @@ class DataExplorationTools:
|
||||
LIMIT 12
|
||||
"""
|
||||
|
||||
monthly_result = await connection.execute(monthly_trend_sql)
|
||||
monthly_result = await connection.execute(monthly_trend_sql, auth_context=auth_context)
|
||||
monthly_trend = "stable" # Simplified trend analysis
|
||||
|
||||
if monthly_result.data and len(monthly_result.data) > 3:
|
||||
@@ -646,7 +683,8 @@ class DataExplorationTools:
|
||||
FROM {table_expr}
|
||||
"""
|
||||
|
||||
result = await connection.execute(null_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(null_sql, auth_context=auth_context)
|
||||
if result.data:
|
||||
data = result.data[0]
|
||||
total_count = data["total_count"]
|
||||
|
||||
@@ -26,6 +26,13 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
from .sql_security_utils import (
|
||||
SQLSecurityError,
|
||||
validate_identifier,
|
||||
quote_identifier,
|
||||
build_table_reference,
|
||||
get_auth_context
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -216,26 +223,34 @@ class DataGovernanceTools:
|
||||
# ==================== Private Helper Methods ====================
|
||||
|
||||
def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str:
|
||||
"""Build full table name - use three-level naming convention"""
|
||||
"""Build full table name - use three-level naming convention with security validation"""
|
||||
# SECURITY FIX: Use build_table_reference for safe identifier handling
|
||||
# Default catalog is internal for internal tables
|
||||
effective_catalog = catalog_name if catalog_name else "internal"
|
||||
|
||||
if db_name:
|
||||
return f"{effective_catalog}.{db_name}.{table_name}"
|
||||
return build_table_reference(table_name, db_name, effective_catalog)
|
||||
else:
|
||||
# If db_name is not provided, need to determine current database
|
||||
return f"{effective_catalog}.{table_name}"
|
||||
return build_table_reference(table_name, catalog_name=effective_catalog)
|
||||
|
||||
async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]:
|
||||
"""Get table basic information"""
|
||||
try:
|
||||
# SECURITY FIX: Get auth_context for security validation
|
||||
# table_name should already be validated by _build_full_table_name
|
||||
auth_context = get_auth_context()
|
||||
|
||||
# Try to get table row count
|
||||
count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}"
|
||||
result = await connection.execute(count_sql)
|
||||
result = await connection.execute(count_sql, auth_context=auth_context)
|
||||
|
||||
if result.data:
|
||||
return {"row_count": result.data[0]["row_count"]}
|
||||
return None
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Security validation failed for table {table_name}: {str(e)}")
|
||||
return {"row_count": 0}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get basic info for table {table_name}: {str(e)}")
|
||||
return {"row_count": 0}
|
||||
@@ -243,11 +258,24 @@ class DataGovernanceTools:
|
||||
async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]:
|
||||
"""Get table column information"""
|
||||
try:
|
||||
# Build query conditions
|
||||
where_conditions = [f"table_name = '{table_name}'"]
|
||||
# SECURITY FIX: Validate identifiers and use parameterized query
|
||||
auth_context = get_auth_context()
|
||||
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
if db_name:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return []
|
||||
|
||||
# Build parameterized query conditions
|
||||
params = [table_name]
|
||||
where_conditions = ["table_name = %s"]
|
||||
|
||||
if db_name:
|
||||
where_conditions.append(f"table_schema = '{db_name}'")
|
||||
where_conditions.append("table_schema = %s")
|
||||
params.append(db_name)
|
||||
else:
|
||||
where_conditions.append("table_schema = DATABASE()")
|
||||
|
||||
@@ -263,30 +291,49 @@ class DataGovernanceTools:
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
|
||||
result = await connection.execute(columns_sql)
|
||||
result = await connection.execute(columns_sql, params=tuple(params), auth_context=auth_context)
|
||||
return result.data if result.data else []
|
||||
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Security validation failed: {str(e)}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get columns info for table {table_name}: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _analyze_column_completeness(self, connection, table_name: str, columns_info: List[Dict]) -> Dict[str, Any]:
|
||||
"""Analyze column completeness"""
|
||||
# SECURITY FIX: Get auth_context for security validation
|
||||
auth_context = get_auth_context()
|
||||
column_completeness = {}
|
||||
|
||||
for column in columns_info:
|
||||
column_name = column["column_name"]
|
||||
try:
|
||||
# SECURITY FIX: Validate column name before using in SQL
|
||||
try:
|
||||
validate_identifier(column_name, "column name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid column name rejected: {e}")
|
||||
column_completeness[column_name] = {
|
||||
"error": f"Invalid column name: {e}",
|
||||
"completeness_score": 0.0
|
||||
}
|
||||
continue
|
||||
|
||||
# Use quoted identifier for column name
|
||||
quoted_column = quote_identifier(column_name, "column name")
|
||||
|
||||
# Calculate null value statistics
|
||||
null_sql = f"""
|
||||
SELECT
|
||||
COUNT(*) as total_count,
|
||||
COUNT({column_name}) as non_null_count,
|
||||
COUNT(*) - COUNT({column_name}) as null_count
|
||||
COUNT({quoted_column}) as non_null_count,
|
||||
COUNT(*) - COUNT({quoted_column}) as null_count
|
||||
FROM {table_name}
|
||||
"""
|
||||
|
||||
result = await connection.execute(null_sql)
|
||||
result = await connection.execute(null_sql, auth_context=auth_context)
|
||||
if result.data:
|
||||
stats = result.data[0]
|
||||
total_count = stats["total_count"]
|
||||
@@ -304,6 +351,12 @@ class DataGovernanceTools:
|
||||
"completeness_score": round(completeness_score, 4)
|
||||
}
|
||||
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Security validation failed for column {column_name}: {str(e)}")
|
||||
column_completeness[column_name] = {
|
||||
"error": str(e),
|
||||
"completeness_score": 0.0
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze completeness for column {column_name}: {str(e)}")
|
||||
column_completeness[column_name] = {
|
||||
@@ -333,7 +386,8 @@ class DataGovernanceTools:
|
||||
FROM {table_name}
|
||||
"""
|
||||
|
||||
result = await connection.execute(compliance_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(compliance_sql, auth_context=auth_context)
|
||||
if result.data:
|
||||
stats = result.data[0]
|
||||
pass_count = stats["pass_count"] or 0
|
||||
@@ -378,7 +432,8 @@ class DataGovernanceTools:
|
||||
) t
|
||||
"""
|
||||
|
||||
result = await connection.execute(duplicate_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(duplicate_sql, auth_context=auth_context)
|
||||
if result.data and result.data[0]["duplicate_count"] > 0:
|
||||
issues.append({
|
||||
"type": "duplicate_primary_keys",
|
||||
@@ -456,10 +511,21 @@ class DataGovernanceTools:
|
||||
async def _verify_column_exists(self, connection, table_name: str, column_name: str) -> bool:
|
||||
"""Verify if column exists"""
|
||||
try:
|
||||
# Simple verification method: try to query the column
|
||||
verify_sql = f"SELECT {column_name} FROM {table_name} LIMIT 1"
|
||||
await connection.execute(verify_sql)
|
||||
# SECURITY FIX: Validate and quote column name
|
||||
try:
|
||||
validate_identifier(column_name, "column name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid column name rejected: {e}")
|
||||
return False
|
||||
|
||||
safe_column = quote_identifier(column_name, "column name")
|
||||
# table_name is already safe (from _build_full_table_name)
|
||||
verify_sql = f"SELECT {safe_column} FROM {table_name} LIMIT 1"
|
||||
auth_context = get_auth_context()
|
||||
await connection.execute(verify_sql, auth_context=auth_context)
|
||||
return True
|
||||
except SQLSecurityError:
|
||||
return False
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@@ -469,21 +535,34 @@ class DataGovernanceTools:
|
||||
source_chain = []
|
||||
|
||||
try:
|
||||
# SECURITY FIX: Validate table name and use parameterized-like approach
|
||||
table_name_part = table_name.split('.')[-1]
|
||||
try:
|
||||
validate_identifier(table_name_part, "table name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid table name rejected: {e}")
|
||||
return []
|
||||
|
||||
# Escape special characters for LIKE pattern
|
||||
safe_pattern = table_name_part.replace('%', r'\%').replace('_', r'\_')
|
||||
like_pattern = f"%{safe_pattern}%"
|
||||
|
||||
# Try to find related INSERT/CREATE TABLE AS SELECT statements from audit logs (one year range)
|
||||
auth_context = get_auth_context()
|
||||
audit_sql = """
|
||||
SELECT
|
||||
stmt as sql_statement,
|
||||
`time` as execution_time,
|
||||
`user` as user_name
|
||||
FROM internal.__internal_schema.audit_log
|
||||
WHERE stmt LIKE '%{}%'
|
||||
WHERE stmt LIKE %s
|
||||
AND (stmt LIKE '%INSERT%' OR stmt LIKE '%CREATE%' OR stmt LIKE '%SELECT%')
|
||||
AND `time` >= DATE_SUB(NOW(), INTERVAL 1 YEAR)
|
||||
ORDER BY `time` DESC
|
||||
LIMIT 50
|
||||
""".format(table_name.split('.')[-1]) # Use the last part of table name
|
||||
"""
|
||||
|
||||
result = await connection.execute(audit_sql)
|
||||
result = await connection.execute(audit_sql, params=(like_pattern,), auth_context=auth_context)
|
||||
|
||||
if result.data:
|
||||
for i, log_entry in enumerate(result.data[:depth]):
|
||||
@@ -556,19 +635,33 @@ class DataGovernanceTools:
|
||||
downstream_usage = []
|
||||
|
||||
try:
|
||||
# SECURITY FIX: Validate inputs and use parameterized-like approach
|
||||
table_name_part = table_name.split('.')[-1]
|
||||
try:
|
||||
validate_identifier(table_name_part, "table name")
|
||||
validate_identifier(column_name, "column name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return []
|
||||
|
||||
# Escape special characters for LIKE pattern
|
||||
safe_table_pattern = f"%{table_name_part.replace('%', r'\\%').replace('_', r'\\_')}%"
|
||||
safe_column_pattern = f"%{column_name.replace('%', r'\\%').replace('_', r'\\_')}%"
|
||||
|
||||
# Find other tables that might use this field (through audit logs, one year range)
|
||||
auth_context = get_auth_context()
|
||||
usage_sql = """
|
||||
SELECT DISTINCT
|
||||
stmt as sql_statement
|
||||
FROM internal.__internal_schema.audit_log
|
||||
WHERE stmt LIKE '%{}%'
|
||||
AND stmt LIKE '%{}%'
|
||||
WHERE stmt LIKE %s
|
||||
AND stmt LIKE %s
|
||||
AND stmt LIKE '%SELECT%'
|
||||
AND `time` >= DATE_SUB(NOW(), INTERVAL 1 YEAR)
|
||||
LIMIT 20
|
||||
""".format(table_name.split('.')[-1], column_name)
|
||||
"""
|
||||
|
||||
result = await connection.execute(usage_sql)
|
||||
result = await connection.execute(usage_sql, params=(safe_table_pattern, safe_column_pattern), auth_context=auth_context)
|
||||
|
||||
if result.data:
|
||||
for entry in result.data:
|
||||
@@ -634,14 +727,20 @@ class DataGovernanceTools:
|
||||
async def _get_all_tables(self, connection, catalog_name: Optional[str], db_name: Optional[str]) -> List[str]:
|
||||
"""Get list of all tables"""
|
||||
try:
|
||||
where_conditions = []
|
||||
auth_context = get_auth_context()
|
||||
params = []
|
||||
|
||||
# SECURITY FIX: Use parameterized query
|
||||
if db_name:
|
||||
where_conditions.append(f"table_schema = '{db_name}'")
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid database name rejected: {e}")
|
||||
return []
|
||||
where_clause = "table_schema = %s"
|
||||
params.append(db_name)
|
||||
else:
|
||||
where_conditions.append("table_schema = DATABASE()")
|
||||
|
||||
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
|
||||
where_clause = "table_schema = DATABASE()"
|
||||
|
||||
tables_sql = f"""
|
||||
SELECT table_name
|
||||
@@ -651,7 +750,7 @@ class DataGovernanceTools:
|
||||
ORDER BY table_name
|
||||
"""
|
||||
|
||||
result = await connection.execute(tables_sql)
|
||||
result = await connection.execute(tables_sql, params=tuple(params) if params else None, auth_context=auth_context)
|
||||
return [row["table_name"] for row in result.data] if result.data else []
|
||||
|
||||
except Exception as e:
|
||||
@@ -728,15 +827,23 @@ class DataGovernanceTools:
|
||||
async def _get_freshness_from_partition_info(self, connection, table_name: str) -> Optional[Dict]:
|
||||
"""Get freshness from partition information"""
|
||||
try:
|
||||
# Query partition information (if table has partitions)
|
||||
partition_sql = f"""
|
||||
# SECURITY FIX: Validate and use parameterized query
|
||||
table_name_part = table_name.split('.')[-1]
|
||||
try:
|
||||
validate_identifier(table_name_part, "table name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid table name rejected: {e}")
|
||||
return None
|
||||
|
||||
auth_context = get_auth_context()
|
||||
partition_sql = """
|
||||
SELECT MAX(CREATE_TIME) as last_update
|
||||
FROM information_schema.partitions
|
||||
WHERE table_name = '{table_name.split('.')[-1]}'
|
||||
WHERE table_name = %s
|
||||
AND CREATE_TIME IS NOT NULL
|
||||
"""
|
||||
|
||||
result = await connection.execute(partition_sql)
|
||||
result = await connection.execute(partition_sql, params=(table_name_part,), auth_context=auth_context)
|
||||
if result.data and result.data[0]["last_update"]:
|
||||
return {
|
||||
"last_update": result.data[0]["last_update"],
|
||||
@@ -759,7 +866,8 @@ class DataGovernanceTools:
|
||||
FROM {table_name}
|
||||
"""
|
||||
|
||||
result = await connection.execute(max_time_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(max_time_sql, auth_context=auth_context)
|
||||
if result.data and result.data[0]["last_update"]:
|
||||
return {
|
||||
"last_update": result.data[0]["last_update"],
|
||||
@@ -773,15 +881,23 @@ class DataGovernanceTools:
|
||||
async def _get_freshness_from_table_metadata(self, connection, table_name: str) -> Optional[Dict]:
|
||||
"""Get freshness from table metadata"""
|
||||
try:
|
||||
# Query table's update time
|
||||
metadata_sql = f"""
|
||||
# SECURITY FIX: Validate and use parameterized query
|
||||
table_name_part = table_name.split('.')[-1]
|
||||
try:
|
||||
validate_identifier(table_name_part, "table name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid table name rejected: {e}")
|
||||
return None
|
||||
|
||||
auth_context = get_auth_context()
|
||||
metadata_sql = """
|
||||
SELECT UPDATE_TIME as last_update
|
||||
FROM information_schema.tables
|
||||
WHERE table_name = '{table_name.split('.')[-1]}'
|
||||
WHERE table_name = %s
|
||||
AND UPDATE_TIME IS NOT NULL
|
||||
"""
|
||||
|
||||
result = await connection.execute(metadata_sql)
|
||||
result = await connection.execute(metadata_sql, params=(table_name_part,), auth_context=auth_context)
|
||||
if result.data and result.data[0]["last_update"]:
|
||||
return {
|
||||
"last_update": result.data[0]["last_update"],
|
||||
@@ -795,10 +911,19 @@ class DataGovernanceTools:
|
||||
async def _find_timestamp_columns(self, connection, table_name: str) -> List[str]:
|
||||
"""Find possible timestamp fields"""
|
||||
try:
|
||||
timestamp_sql = f"""
|
||||
# SECURITY FIX: Validate and use parameterized query
|
||||
table_name_part = table_name.split('.')[-1]
|
||||
try:
|
||||
validate_identifier(table_name_part, "table name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid table name rejected: {e}")
|
||||
return []
|
||||
|
||||
auth_context = get_auth_context()
|
||||
timestamp_sql = """
|
||||
SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = '{table_name.split('.')[-1]}'
|
||||
WHERE table_name = %s
|
||||
AND (
|
||||
data_type IN ('datetime', 'timestamp', 'date')
|
||||
OR column_name LIKE '%time%'
|
||||
@@ -815,7 +940,7 @@ class DataGovernanceTools:
|
||||
END
|
||||
"""
|
||||
|
||||
result = await connection.execute(timestamp_sql)
|
||||
result = await connection.execute(timestamp_sql, params=(table_name_part,), auth_context=auth_context)
|
||||
return [row["column_name"] for row in result.data] if result.data else []
|
||||
|
||||
except Exception:
|
||||
|
||||
@@ -31,6 +31,12 @@ from collections import Counter, defaultdict
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
from .config import DorisConfig
|
||||
from .sql_security_utils import (
|
||||
SQLSecurityError,
|
||||
validate_identifier,
|
||||
build_table_reference,
|
||||
get_auth_context
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -299,23 +305,26 @@ class DataQualityTools:
|
||||
# ===========================================
|
||||
|
||||
def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str:
|
||||
"""Build full table name"""
|
||||
if catalog_name and db_name:
|
||||
return f"{catalog_name}.{db_name}.{table_name}"
|
||||
elif db_name:
|
||||
return f"{db_name}.{table_name}"
|
||||
else:
|
||||
return table_name
|
||||
"""Build full table name with security validation"""
|
||||
# SECURITY FIX: Use build_table_reference for safe identifier handling
|
||||
return build_table_reference(table_name, db_name, catalog_name)
|
||||
|
||||
async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]:
|
||||
"""Get basic table information"""
|
||||
try:
|
||||
# SECURITY FIX: table_name should already be validated by _build_full_table_name
|
||||
# But we add auth_context for security validation
|
||||
auth_context = get_auth_context()
|
||||
|
||||
# Try to get row count
|
||||
count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}"
|
||||
result = await connection.execute(count_sql)
|
||||
result = await connection.execute(count_sql, auth_context=auth_context)
|
||||
if result.data:
|
||||
return {"row_count": result.data[0]["row_count"]}
|
||||
return None
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Security validation failed: {str(e)}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get table basic info: {str(e)}")
|
||||
return None
|
||||
@@ -323,9 +332,13 @@ class DataQualityTools:
|
||||
async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]:
|
||||
"""Get table column information"""
|
||||
try:
|
||||
# Build DESCRIBE query
|
||||
describe_sql = f"DESCRIBE {self._build_full_table_name(table_name, catalog_name, db_name)}"
|
||||
result = await connection.execute(describe_sql)
|
||||
# SECURITY FIX: Build safe table reference and pass auth_context
|
||||
auth_context = get_auth_context()
|
||||
|
||||
# Build DESCRIBE query with safe table reference
|
||||
safe_table_ref = self._build_full_table_name(table_name, catalog_name, db_name)
|
||||
describe_sql = f"DESCRIBE {safe_table_ref}"
|
||||
result = await connection.execute(describe_sql, auth_context=auth_context)
|
||||
|
||||
columns_info = []
|
||||
if result.data:
|
||||
@@ -339,6 +352,9 @@ class DataQualityTools:
|
||||
})
|
||||
|
||||
return columns_info
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Security validation failed: {str(e)}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get table columns info: {str(e)}")
|
||||
return []
|
||||
@@ -346,7 +362,32 @@ class DataQualityTools:
|
||||
async def _get_table_partitions(self, connection, table_name: str, db_name: Optional[str] = None) -> List[Dict]:
|
||||
"""Get table partition information"""
|
||||
try:
|
||||
# Query partition information
|
||||
# SECURITY FIX: Validate identifiers and use parameterized query
|
||||
auth_context = get_auth_context()
|
||||
|
||||
# Validate table_name
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
if db_name:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return []
|
||||
|
||||
# Build parameterized query
|
||||
params = []
|
||||
where_conditions = []
|
||||
|
||||
if db_name:
|
||||
where_conditions.append("TABLE_SCHEMA = %s")
|
||||
params.append(db_name)
|
||||
else:
|
||||
where_conditions.append("TABLE_SCHEMA = ''")
|
||||
|
||||
where_conditions.append("TABLE_NAME = %s")
|
||||
params.append(table_name)
|
||||
where_conditions.append("PARTITION_NAME IS NOT NULL")
|
||||
|
||||
partition_sql = f"""
|
||||
SELECT
|
||||
PARTITION_NAME,
|
||||
@@ -355,12 +396,10 @@ class DataQualityTools:
|
||||
DATA_LENGTH,
|
||||
INDEX_LENGTH
|
||||
FROM information_schema.PARTITIONS
|
||||
WHERE TABLE_SCHEMA = '{db_name or ""}'
|
||||
AND TABLE_NAME = '{table_name}'
|
||||
AND PARTITION_NAME IS NOT NULL
|
||||
WHERE {' AND '.join(where_conditions)}
|
||||
"""
|
||||
|
||||
result = await connection.execute(partition_sql)
|
||||
result = await connection.execute(partition_sql, params=tuple(params), auth_context=auth_context)
|
||||
partitions = []
|
||||
if result.data:
|
||||
for row in result.data:
|
||||
@@ -373,6 +412,9 @@ class DataQualityTools:
|
||||
})
|
||||
|
||||
return partitions
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Security validation failed: {str(e)}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get table partitions: {str(e)}")
|
||||
return []
|
||||
@@ -417,7 +459,8 @@ class DataQualityTools:
|
||||
if db_name
|
||||
else f"SHOW CREATE TABLE {table_name}"
|
||||
)
|
||||
result = await connection.execute(query)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(query, auth_context=auth_context)
|
||||
if result.data:
|
||||
return result.data[0].get("Create Table")
|
||||
return None
|
||||
@@ -428,8 +471,16 @@ class DataQualityTools:
|
||||
async def _get_table_size_info(self, connection, table_name: str) -> Dict[str, Any]:
|
||||
"""Get table size information"""
|
||||
try:
|
||||
# Query table size information
|
||||
size_sql = f"""
|
||||
# SECURITY FIX: Validate and use parameterized query
|
||||
table_name_part = table_name.split('.')[-1]
|
||||
try:
|
||||
validate_identifier(table_name_part, "table name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid table name rejected: {e}")
|
||||
return {"engine": "Unknown", "estimated_rows": 0, "data_length": 0, "index_length": 0, "total_size": 0}
|
||||
|
||||
auth_context = get_auth_context()
|
||||
size_sql = """
|
||||
SELECT
|
||||
table_name,
|
||||
engine,
|
||||
@@ -438,10 +489,10 @@ class DataQualityTools:
|
||||
index_length,
|
||||
(data_length + index_length) as total_size
|
||||
FROM information_schema.tables
|
||||
WHERE table_name = '{table_name.split('.')[-1]}'
|
||||
WHERE table_name = %s
|
||||
"""
|
||||
|
||||
result = await connection.execute(size_sql)
|
||||
result = await connection.execute(size_sql, params=(table_name_part,), auth_context=auth_context)
|
||||
if result.data and result.data[0]:
|
||||
row = result.data[0]
|
||||
return {
|
||||
@@ -582,7 +633,8 @@ class DataQualityTools:
|
||||
|
||||
batch_sql = f"SELECT {', '.join(select_clauses)} FROM {table_expr}"
|
||||
|
||||
result = await connection.execute(batch_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(batch_sql, auth_context=auth_context)
|
||||
if not result.data:
|
||||
return {"error": "No data returned from batch completeness query"}
|
||||
|
||||
@@ -664,7 +716,8 @@ class DataQualityTools:
|
||||
|
||||
batch_sql = f"SELECT {', '.join(select_clauses)} FROM {table_expr}"
|
||||
|
||||
result = await connection.execute(batch_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(batch_sql, auth_context=auth_context)
|
||||
if not result.data:
|
||||
return {}
|
||||
|
||||
@@ -705,7 +758,8 @@ class DataQualityTools:
|
||||
LIMIT 10
|
||||
"""
|
||||
|
||||
result = await connection.execute(freq_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(freq_sql, auth_context=auth_context)
|
||||
frequencies = result.data if result.data else []
|
||||
|
||||
categorical_results[col_name] = {
|
||||
@@ -738,7 +792,8 @@ class DataQualityTools:
|
||||
|
||||
batch_sql = f"SELECT {', '.join(select_clauses)} FROM {table_expr}"
|
||||
|
||||
result = await connection.execute(batch_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(batch_sql, auth_context=auth_context)
|
||||
if not result.data:
|
||||
return {}
|
||||
|
||||
@@ -780,7 +835,8 @@ class DataQualityTools:
|
||||
FROM {table_expr}
|
||||
"""
|
||||
|
||||
result = await connection.execute(completeness_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(completeness_sql, auth_context=auth_context)
|
||||
if result.data:
|
||||
stats = result.data[0]
|
||||
total_count = stats["total_count"]
|
||||
@@ -906,7 +962,8 @@ class DataQualityTools:
|
||||
WHERE {col_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
result = await connection.execute(stats_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(stats_sql, auth_context=auth_context)
|
||||
if result.data and result.data[0]["non_null_count"] > 0:
|
||||
stats = result.data[0]
|
||||
numeric_analysis[col_name] = {
|
||||
@@ -945,7 +1002,8 @@ class DataQualityTools:
|
||||
WHERE {col_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
cardinality_result = await connection.execute(cardinality_sql)
|
||||
auth_context = get_auth_context()
|
||||
cardinality_result = await connection.execute(cardinality_sql, auth_context=auth_context)
|
||||
|
||||
if cardinality_result.data:
|
||||
stats = cardinality_result.data[0]
|
||||
@@ -969,7 +1027,7 @@ class DataQualityTools:
|
||||
LIMIT 10
|
||||
"""
|
||||
|
||||
top_values_result = await connection.execute(top_values_sql)
|
||||
top_values_result = await connection.execute(top_values_sql, auth_context=auth_context)
|
||||
if top_values_result.data:
|
||||
categorical_analysis[col_name]["top_values"] = [
|
||||
{"value": row[col_name], "count": row["count"]}
|
||||
@@ -998,7 +1056,8 @@ class DataQualityTools:
|
||||
WHERE {col_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
result = await connection.execute(stats_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(stats_sql, auth_context=auth_context)
|
||||
if result.data and result.data[0]["non_null_count"] > 0:
|
||||
stats = result.data[0]
|
||||
temporal_analysis[col_name] = {
|
||||
|
||||
@@ -59,6 +59,7 @@ class QueryResult:
|
||||
metadata: dict[str, Any]
|
||||
execution_time: float
|
||||
row_count: int
|
||||
sql: str
|
||||
|
||||
|
||||
class DorisConnection:
|
||||
@@ -95,12 +96,14 @@ class DorisConnection:
|
||||
await cursor.execute(sql, params)
|
||||
|
||||
# Check if it's a query statement (statement that returns result set)
|
||||
# FIX for Issue #62 Bug 5: Added WITH support for Common Table Expressions (CTE)
|
||||
sql_upper = sql.strip().upper()
|
||||
if (sql_upper.startswith("SELECT") or
|
||||
sql_upper.startswith("SHOW") or
|
||||
sql_upper.startswith("DESCRIBE") or
|
||||
sql_upper.startswith("DESC") or
|
||||
sql_upper.startswith("EXPLAIN")):
|
||||
sql_upper.startswith("EXPLAIN") or
|
||||
sql_upper.startswith("WITH")): # FIX: Support CTE queries
|
||||
data = await cursor.fetchall()
|
||||
row_count = len(data)
|
||||
else:
|
||||
@@ -130,6 +133,7 @@ class DorisConnection:
|
||||
metadata=metadata,
|
||||
execution_time=execution_time,
|
||||
row_count=row_count,
|
||||
sql=sql
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -250,7 +254,23 @@ class DorisConnectionManager:
|
||||
self.logger = get_logger(__name__)
|
||||
self.security_manager = security_manager
|
||||
self.token_manager = token_manager # Token manager for token-bound DB config
|
||||
self.session_cache = DorisSessionCache(self)
|
||||
|
||||
# 🔧 FIX for multi-tenant concurrency: Per-token connection pool isolation
|
||||
# Each token gets its own connection pool to prevent configuration conflicts
|
||||
self.token_pools: Dict[str, Pool] = {} # token_hash -> pool
|
||||
self.token_configs: Dict[str, dict] = {} # token_hash -> db_config
|
||||
self._token_pool_locks: Dict[str, asyncio.Lock] = {} # token_hash -> lock
|
||||
self._token_pools_lock = asyncio.Lock() # Lock for managing token_pools dict
|
||||
|
||||
# FIX for Issue #58 Problem 1: Disable session caching to prevent connection sharing
|
||||
# Session caching causes multiple threads to share the same MySQL connection,
|
||||
# leading to race conditions and deadlocks in multi-threaded environments
|
||||
# By disabling caching, each request gets a fresh connection from the pool
|
||||
self.session_cache = DorisSessionCache(
|
||||
self,
|
||||
cache_system_session=False, # Disabled to prevent multi-thread issues
|
||||
cache_user_session=False # Disabled to prevent multi-thread issues
|
||||
)
|
||||
|
||||
# Store original database config for fallback
|
||||
self.original_db_config = {
|
||||
@@ -263,6 +283,7 @@ class DorisConnectionManager:
|
||||
}
|
||||
|
||||
# Current active database config (may be overridden by token-bound config)
|
||||
# NOTE: This is kept for backward compatibility with non-token requests
|
||||
self.active_db_config = self.original_db_config.copy()
|
||||
|
||||
# Connection pool state management
|
||||
@@ -346,6 +367,281 @@ class DorisConnectionManager:
|
||||
self.logger.error(f"Error finding available token: {e}")
|
||||
return ""
|
||||
|
||||
def _get_token_hash(self, token: str) -> str:
|
||||
"""Get hash of token for use as dictionary key"""
|
||||
import hashlib
|
||||
return hashlib.sha256(token.encode()).hexdigest()[:16]
|
||||
|
||||
def _get_current_token_db_config(self, token: str) -> dict | None:
|
||||
"""Get current database config for token from TokenManager
|
||||
|
||||
This is used to check if config has changed for hot reload support.
|
||||
"""
|
||||
if not self.token_manager:
|
||||
return None
|
||||
|
||||
token_db_config = self.token_manager.get_database_config_by_token(token)
|
||||
if token_db_config:
|
||||
return {
|
||||
'host': token_db_config.host,
|
||||
'port': token_db_config.port,
|
||||
'user': token_db_config.user,
|
||||
'password': token_db_config.password,
|
||||
'database': token_db_config.database,
|
||||
'charset': token_db_config.charset
|
||||
}
|
||||
return None
|
||||
|
||||
def _config_changed(self, old_config: dict, new_config: dict) -> bool:
|
||||
"""Check if database configuration has changed"""
|
||||
if old_config is None or new_config is None:
|
||||
return old_config != new_config
|
||||
|
||||
# Compare key fields
|
||||
for key in ['host', 'port', 'user', 'password', 'database']:
|
||||
if old_config.get(key) != new_config.get(key):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def get_pool_for_token(self, token: str) -> tuple[Pool, dict]:
|
||||
"""Get or create a dedicated connection pool for a specific token
|
||||
|
||||
This method implements per-token connection pool isolation to prevent
|
||||
concurrent requests from different tokens interfering with each other.
|
||||
|
||||
🔧 FIX: Supports hot reload - if tokens.json config changes,
|
||||
the old pool is closed and a new one is created automatically.
|
||||
|
||||
Args:
|
||||
token: Authentication token
|
||||
|
||||
Returns:
|
||||
(pool, db_config): The dedicated pool and its configuration
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no valid database configuration is available
|
||||
"""
|
||||
token_hash = self._get_token_hash(token)
|
||||
|
||||
# Fast path: pool already exists
|
||||
if token_hash in self.token_pools:
|
||||
pool = self.token_pools[token_hash]
|
||||
cached_config = self.token_configs.get(token_hash)
|
||||
|
||||
# 🔧 FIX: Check if config has changed (hot reload support)
|
||||
current_config = self._get_current_token_db_config(token)
|
||||
if current_config and cached_config and self._config_changed(cached_config, current_config):
|
||||
self.logger.info(f"Token config changed (hash: {token_hash[:8]}...), recreating pool...")
|
||||
# Config changed, need to recreate pool
|
||||
async with self._token_pools_lock:
|
||||
# Close old pool
|
||||
old_pool = self.token_pools.pop(token_hash, None)
|
||||
if old_pool and not old_pool.closed:
|
||||
try:
|
||||
old_pool.close()
|
||||
await asyncio.wait_for(old_pool.wait_closed(), timeout=2.0)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error closing old pool during hot reload: {e}")
|
||||
self.token_configs.pop(token_hash, None)
|
||||
# Continue to slow path to create new pool
|
||||
elif pool and not pool.closed:
|
||||
return pool, cached_config
|
||||
|
||||
# Slow path: need to create pool (with lock to prevent race conditions)
|
||||
async with self._token_pools_lock:
|
||||
# Double-check after acquiring lock
|
||||
if token_hash in self.token_pools:
|
||||
pool = self.token_pools[token_hash]
|
||||
if pool and not pool.closed:
|
||||
return pool, self.token_configs[token_hash]
|
||||
|
||||
# Get database config for this token
|
||||
db_config = None
|
||||
config_source = "unknown"
|
||||
|
||||
if self.token_manager:
|
||||
token_db_config = self.token_manager.get_database_config_by_token(token)
|
||||
if token_db_config:
|
||||
db_config = {
|
||||
'host': token_db_config.host,
|
||||
'port': token_db_config.port,
|
||||
'user': token_db_config.user,
|
||||
'password': token_db_config.password,
|
||||
'database': token_db_config.database,
|
||||
'charset': token_db_config.charset
|
||||
}
|
||||
config_source = "token-bound"
|
||||
|
||||
# Fallback to global config if token has no specific config
|
||||
if not db_config or self._is_config_empty(db_config.get('host')) or self._is_config_empty(db_config.get('user')):
|
||||
if self._has_valid_global_config():
|
||||
db_config = self.original_db_config.copy()
|
||||
config_source = "global-env"
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"No valid database configuration available for token. "
|
||||
f"Please configure database in tokens.json or .env file."
|
||||
)
|
||||
|
||||
# Create dedicated pool for this token
|
||||
self.logger.info(f"Creating dedicated connection pool for token (hash: {token_hash[:8]}...) "
|
||||
f"using {config_source} config: {db_config['user']}@{db_config['host']}:{db_config['port']}")
|
||||
|
||||
pool = await self._create_pool_with_config(db_config)
|
||||
|
||||
# Store pool and config
|
||||
self.token_pools[token_hash] = pool
|
||||
self.token_configs[token_hash] = db_config
|
||||
|
||||
# Create lock for this token if not exists
|
||||
if token_hash not in self._token_pool_locks:
|
||||
self._token_pool_locks[token_hash] = asyncio.Lock()
|
||||
|
||||
return pool, db_config
|
||||
|
||||
async def _create_pool_with_config(self, db_config: dict) -> Pool:
|
||||
"""Create a connection pool with specified configuration
|
||||
|
||||
Args:
|
||||
db_config: Database configuration dictionary
|
||||
|
||||
Returns:
|
||||
Created connection pool
|
||||
"""
|
||||
# Convert charset to aiomysql compatible format
|
||||
charset_map = {"UTF8": "utf8", "UTF8MB4": "utf8mb4"}
|
||||
charset = charset_map.get(db_config['charset'].upper(), db_config['charset'].lower())
|
||||
|
||||
self.logger.debug(f"Creating pool for {db_config['user']}@{db_config['host']}:{db_config['port']}/{db_config['database']}")
|
||||
|
||||
try:
|
||||
pool = await asyncio.wait_for(
|
||||
aiomysql.create_pool(
|
||||
host=db_config['host'],
|
||||
port=db_config['port'],
|
||||
user=db_config['user'],
|
||||
password=db_config['password'],
|
||||
db=db_config['database'],
|
||||
charset=charset,
|
||||
minsize=0, # Don't pre-create connections
|
||||
maxsize=self.maxsize,
|
||||
connect_timeout=self.connect_timeout,
|
||||
autocommit=True,
|
||||
pool_recycle=self.pool_recycle
|
||||
),
|
||||
timeout=self.connect_timeout + 5 # Give extra time for pool creation
|
||||
)
|
||||
self.logger.info(f"Successfully created pool for {db_config['user']}@{db_config['host']}:{db_config['port']}")
|
||||
return pool
|
||||
except asyncio.TimeoutError:
|
||||
self.logger.error(f"Timeout creating pool for {db_config['user']}@{db_config['host']}:{db_config['port']}")
|
||||
raise RuntimeError(f"Timeout creating connection pool for {db_config['user']}@{db_config['host']}:{db_config['port']}")
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to create pool for {db_config['user']}@{db_config['host']}:{db_config['port']}: {type(e).__name__}: {e}")
|
||||
raise
|
||||
|
||||
async def get_connection_for_token(self, token: str, session_id: str) -> 'DorisConnection':
|
||||
"""Get a connection from the token's dedicated pool
|
||||
|
||||
Args:
|
||||
token: Authentication token
|
||||
session_id: Session identifier for logging
|
||||
|
||||
Returns:
|
||||
DorisConnection wrapper
|
||||
"""
|
||||
pool, db_config = await self.get_pool_for_token(token)
|
||||
|
||||
try:
|
||||
connection = await asyncio.wait_for(
|
||||
pool.acquire(),
|
||||
timeout=self.connect_timeout
|
||||
)
|
||||
|
||||
self.logger.debug(f"Session {session_id}: Acquired connection from token pool "
|
||||
f"(user: {db_config['user']}@{db_config['host']})")
|
||||
|
||||
return DorisConnection(connection, session_id, self.security_manager)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Session {session_id}: Failed to acquire connection from token pool: {e}")
|
||||
raise
|
||||
|
||||
async def release_connection_for_token(self, token: str, connection: 'DorisConnection'):
|
||||
"""Release a connection back to the token's dedicated pool
|
||||
|
||||
Args:
|
||||
token: Authentication token
|
||||
connection: DorisConnection wrapper to release
|
||||
"""
|
||||
token_hash = self._get_token_hash(token)
|
||||
|
||||
if token_hash in self.token_pools:
|
||||
pool = self.token_pools[token_hash]
|
||||
if pool and not pool.closed:
|
||||
try:
|
||||
pool.release(connection.connection)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to release connection to token pool: {e}")
|
||||
|
||||
async def cleanup_token_pools(self, max_idle_time: int = 3600):
|
||||
"""Clean up idle token connection pools
|
||||
|
||||
Args:
|
||||
max_idle_time: Maximum idle time in seconds before closing a pool
|
||||
"""
|
||||
async with self._token_pools_lock:
|
||||
pools_to_remove = []
|
||||
|
||||
for token_hash, pool in self.token_pools.items():
|
||||
if pool and not pool.closed:
|
||||
# Check if pool is idle (no active connections)
|
||||
if pool.size == 0 and pool.freesize == 0:
|
||||
pools_to_remove.append(token_hash)
|
||||
elif pool and pool.closed:
|
||||
pools_to_remove.append(token_hash)
|
||||
|
||||
for token_hash in pools_to_remove:
|
||||
try:
|
||||
pool = self.token_pools.pop(token_hash, None)
|
||||
if pool and not pool.closed:
|
||||
pool.close()
|
||||
await pool.wait_closed()
|
||||
self.token_configs.pop(token_hash, None)
|
||||
self._token_pool_locks.pop(token_hash, None)
|
||||
self.logger.info(f"Cleaned up idle token pool (hash: {token_hash[:8]}...)")
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error cleaning up token pool: {e}")
|
||||
|
||||
async def close_all_token_pools(self):
|
||||
"""Close all token connection pools (for shutdown)"""
|
||||
# Use timeout to prevent blocking on lock acquisition during shutdown
|
||||
try:
|
||||
async with asyncio.timeout(5): # 5 second timeout for lock
|
||||
async with self._token_pools_lock:
|
||||
for token_hash, pool in list(self.token_pools.items()):
|
||||
try:
|
||||
if pool and not pool.closed:
|
||||
pool.close()
|
||||
# Use timeout for wait_closed to prevent hanging
|
||||
try:
|
||||
await asyncio.wait_for(pool.wait_closed(), timeout=2.0)
|
||||
except asyncio.TimeoutError:
|
||||
self.logger.warning(f"Timeout waiting for token pool to close (hash: {token_hash[:8]}...)")
|
||||
self.logger.info(f"Closed token pool (hash: {token_hash[:8]}...)")
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error closing token pool: {e}")
|
||||
|
||||
self.token_pools.clear()
|
||||
self.token_configs.clear()
|
||||
self._token_pool_locks.clear()
|
||||
except asyncio.TimeoutError:
|
||||
self.logger.warning("Timeout acquiring lock for token pool cleanup, forcing clear")
|
||||
# Force clear without lock
|
||||
self.token_pools.clear()
|
||||
self.token_configs.clear()
|
||||
self._token_pool_locks.clear()
|
||||
|
||||
async def configure_for_token(self, token: str) -> tuple[bool, str]:
|
||||
"""Configure connection manager for token with new priority logic
|
||||
|
||||
@@ -626,6 +922,213 @@ class DorisConnectionManager:
|
||||
self.logger.error(f"Failed to initialize connection pool: {e}")
|
||||
raise
|
||||
|
||||
async def initialize_for_stdio_mode(self, timeout: float = 30.0) -> None:
|
||||
"""
|
||||
Initialize connection pool for stdio mode with strict validation
|
||||
|
||||
stdio mode requires a working database connection because:
|
||||
- No HTTP authentication mechanism to support token-bound configs
|
||||
- All database operations depend on the global connection pool
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for connection establishment
|
||||
|
||||
Raises:
|
||||
RuntimeError: If configuration is invalid or connection fails
|
||||
"""
|
||||
try:
|
||||
# Validate that we have valid global configuration
|
||||
if not self._has_valid_global_config():
|
||||
error_msg = (
|
||||
"stdio mode requires valid global database configuration. "
|
||||
"Please set DORIS_HOST and DORIS_USER in environment variables or .env file. "
|
||||
f"Current config: host='{self.host}', user='{self.user}'"
|
||||
)
|
||||
self.logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
self.logger.info(f"stdio mode database config validated: {self.host}:{self.port}")
|
||||
|
||||
# Validate configuration format
|
||||
is_valid, error_message = self.validate_database_configuration()
|
||||
if not is_valid:
|
||||
error_msg = f"Database configuration validation failed: {error_message}"
|
||||
self.logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
# Test connectivity with timeout
|
||||
self.logger.info("Testing database connectivity for stdio mode...")
|
||||
if not await self._test_connectivity_with_timeout(timeout):
|
||||
error_msg = (
|
||||
f"Failed to connect to Doris database within {timeout} seconds. "
|
||||
f"Please check if Doris is running at {self.host}:{self.port} "
|
||||
f"and verify network connectivity."
|
||||
)
|
||||
self.logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
# Initialize the connection pool
|
||||
await self._create_connection_pool()
|
||||
|
||||
# Verify that we have a working connection pool
|
||||
if not self.pool:
|
||||
error_msg = "Database connection pool was not created successfully."
|
||||
self.logger.error(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
# Start background monitoring tasks
|
||||
self.pool_health_check_task = asyncio.create_task(self._pool_health_monitor())
|
||||
self.pool_cleanup_task = asyncio.create_task(self._pool_cleanup_monitor())
|
||||
|
||||
# Perform initial pool warmup
|
||||
await self._warmup_pool()
|
||||
|
||||
self.logger.info("Database connection established successfully for stdio mode")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"stdio mode database initialization failed: {e}")
|
||||
raise
|
||||
|
||||
async def initialize_for_http_mode(self) -> bool:
|
||||
"""
|
||||
Initialize connection pool for HTTP mode with graceful degradation
|
||||
|
||||
HTTP mode can work without global database configuration because:
|
||||
- Supports token-bound database configurations
|
||||
- Can handle authentication and use per-request database configs
|
||||
- Has fallback mechanisms for database operations
|
||||
|
||||
Returns:
|
||||
bool: True if global database pool was created, False if gracefully degraded
|
||||
"""
|
||||
try:
|
||||
# First validate configuration format if we have one
|
||||
if self._has_valid_global_config():
|
||||
is_valid, error_message = self.validate_database_configuration()
|
||||
if not is_valid:
|
||||
self.logger.warning(f"Global database configuration invalid: {error_message}")
|
||||
self.logger.info("HTTP mode will rely on token-bound database configurations")
|
||||
return False
|
||||
|
||||
# Try to establish global connection pool
|
||||
self.logger.info(f"Attempting to create global connection pool: {self.host}:{self.port}")
|
||||
|
||||
try:
|
||||
# Test connectivity with shorter timeout for HTTP mode
|
||||
if await self._test_connectivity_with_timeout(10.0):
|
||||
await self._create_connection_pool()
|
||||
|
||||
if self.pool:
|
||||
# Start background monitoring tasks
|
||||
self.pool_health_check_task = asyncio.create_task(self._pool_health_monitor())
|
||||
self.pool_cleanup_task = asyncio.create_task(self._pool_cleanup_monitor())
|
||||
|
||||
# Perform initial pool warmup
|
||||
await self._warmup_pool()
|
||||
|
||||
self.logger.info("Global database connection pool created successfully for HTTP mode")
|
||||
return True
|
||||
else:
|
||||
self.logger.warning("Global database connection test failed, will use token-bound configs")
|
||||
return False
|
||||
|
||||
except Exception as pool_error:
|
||||
self.logger.warning(f"Failed to create global connection pool: {pool_error}")
|
||||
self.logger.info("HTTP mode will rely on token-bound database configurations")
|
||||
return False
|
||||
else:
|
||||
self.logger.info("No valid global database config found, HTTP mode will use token-bound configurations")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"HTTP mode database initialization encountered error: {e}")
|
||||
self.logger.info("HTTP mode will rely on token-bound database configurations")
|
||||
return False
|
||||
|
||||
async def _test_connectivity_with_timeout(self, timeout: float) -> bool:
|
||||
"""
|
||||
Test database connectivity with timeout
|
||||
|
||||
Args:
|
||||
timeout: Maximum time to wait for connection test
|
||||
|
||||
Returns:
|
||||
bool: True if connection successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
await asyncio.wait_for(self._test_basic_connectivity(), timeout=timeout)
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
self.logger.error(f"Database connectivity test timed out after {timeout} seconds")
|
||||
return False
|
||||
except Exception as e:
|
||||
self.logger.error(f"Database connectivity test failed: {e}")
|
||||
return False
|
||||
|
||||
async def _test_basic_connectivity(self) -> None:
|
||||
"""
|
||||
Test basic database connectivity without connection pool
|
||||
|
||||
Raises:
|
||||
Exception: If connection fails
|
||||
"""
|
||||
import aiomysql
|
||||
|
||||
conn = None
|
||||
try:
|
||||
conn = await aiomysql.connect(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
user=self.user,
|
||||
password=self.password,
|
||||
db=self.database,
|
||||
charset=self.charset,
|
||||
connect_timeout=self.connect_timeout,
|
||||
autocommit=True
|
||||
)
|
||||
|
||||
async with conn.cursor() as cursor:
|
||||
await cursor.execute("SELECT 1")
|
||||
result = await cursor.fetchone()
|
||||
if not result or result[0] != 1:
|
||||
raise RuntimeError("Database connectivity test query failed")
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Database connectivity test failed: {e}")
|
||||
finally:
|
||||
if conn:
|
||||
conn.close()
|
||||
|
||||
async def _create_connection_pool(self) -> None:
|
||||
"""
|
||||
Create the connection pool
|
||||
|
||||
Raises:
|
||||
Exception: If pool creation fails
|
||||
"""
|
||||
self.pool = await aiomysql.create_pool(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
user=self.user,
|
||||
password=self.password,
|
||||
db=self.database,
|
||||
charset=self.charset,
|
||||
minsize=self.minsize,
|
||||
maxsize=self.maxsize,
|
||||
pool_recycle=self.pool_recycle,
|
||||
connect_timeout=self.connect_timeout,
|
||||
autocommit=True
|
||||
)
|
||||
|
||||
# Test pool health
|
||||
if not await self._test_pool_health():
|
||||
# Clean up the pool if health test fails
|
||||
if self.pool:
|
||||
self.pool.close()
|
||||
await self.pool.wait_closed()
|
||||
self.pool = None
|
||||
raise RuntimeError("Connection pool health check failed")
|
||||
|
||||
async def _test_pool_health(self) -> bool:
|
||||
"""Test connection pool health"""
|
||||
try:
|
||||
@@ -872,7 +1375,26 @@ class DorisConnectionManager:
|
||||
|
||||
Uses only semaphore to prevent too many concurrent acquisitions.
|
||||
If the connection is successfully obtained, it will be added to the connection pool cache.
|
||||
|
||||
🔧 FIX for token isolation: Now automatically checks for auth_context from ContextVar
|
||||
and uses token-specific connection pool if available.
|
||||
"""
|
||||
# 🔧 FIX: Check for auth_context from global ContextVar
|
||||
# This ensures all tools using get_connection respect token-bound database configuration
|
||||
auth_context = None
|
||||
try:
|
||||
from .security import mcp_auth_context_var
|
||||
auth_context = mcp_auth_context_var.get()
|
||||
except Exception as e:
|
||||
self.logger.debug(f"get_connection: Could not get auth_context: {e}")
|
||||
|
||||
if auth_context and hasattr(auth_context, 'token') and auth_context.token:
|
||||
# Use token-specific connection pool
|
||||
# SECURITY: Do NOT catch exceptions here - if token pool fails, don't fallback to global pool
|
||||
# This prevents privilege escalation
|
||||
self.logger.debug(f"get_connection: Using token-specific pool for session {session_id}")
|
||||
return await self.get_connection_for_token(auth_context.token, session_id)
|
||||
|
||||
cached_conn = self.session_cache.get(session_id)
|
||||
if cached_conn:
|
||||
return cached_conn
|
||||
@@ -1019,10 +1541,16 @@ class DorisConnectionManager:
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Close connection pool
|
||||
# 🔧 FIX: Close all per-token connection pools
|
||||
await self.close_all_token_pools()
|
||||
|
||||
# Close global connection pool with timeout
|
||||
if self.pool:
|
||||
self.pool.close()
|
||||
await self.pool.wait_closed()
|
||||
try:
|
||||
await asyncio.wait_for(self.pool.wait_closed(), timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
self.logger.warning("Timeout waiting for global pool to close")
|
||||
|
||||
self.logger.info("Connection manager closed successfully")
|
||||
|
||||
@@ -1051,10 +1579,42 @@ class DorisConnectionManager:
|
||||
async def execute_query(
|
||||
self, session_id: str, sql: str, params: tuple | None = None, auth_context=None
|
||||
) -> QueryResult:
|
||||
"""Execute query - Simplified Strategy with automatic connection management"""
|
||||
"""Execute query - Enhanced Strategy with per-token connection pool isolation
|
||||
|
||||
FIX for multi-tenant concurrency: Each token now uses its own dedicated connection pool
|
||||
to prevent configuration conflicts between concurrent requests from different tokens.
|
||||
"""
|
||||
connection = None
|
||||
token = None
|
||||
|
||||
try:
|
||||
# Always get fresh connection from pool
|
||||
# Check if we have a token for per-token pool isolation
|
||||
if auth_context and hasattr(auth_context, 'token') and auth_context.token:
|
||||
token = auth_context.token
|
||||
|
||||
try:
|
||||
# 🔧 FIX: Use dedicated connection pool for this token
|
||||
# This prevents concurrent requests from different tokens interfering
|
||||
connection = await self.get_connection_for_token(token, session_id)
|
||||
|
||||
# Get the config for logging
|
||||
token_hash = self._get_token_hash(token)
|
||||
if token_hash in self.token_configs:
|
||||
db_config = self.token_configs[token_hash]
|
||||
self.logger.info(f"Session {session_id}: Using dedicated pool for {db_config['user']}@{db_config['host']}")
|
||||
|
||||
except Exception as token_pool_error:
|
||||
# SECURITY: If token should have pool but creation fails, don't fallback
|
||||
# This prevents privilege escalation (using high-privilege default user)
|
||||
self.logger.error(f"Session {session_id}: Token pool error: {token_pool_error}")
|
||||
raise RuntimeError(
|
||||
f"Failed to get connection for authenticated token. "
|
||||
f"This is a security measure to prevent using default high-privilege credentials. "
|
||||
f"Error: {token_pool_error}"
|
||||
)
|
||||
else:
|
||||
# No token - use global pool (backward compatibility)
|
||||
self.logger.debug(f"Session {session_id}: No token, using global connection pool")
|
||||
connection = await self.get_connection(session_id)
|
||||
|
||||
# Execute query
|
||||
@@ -1066,8 +1626,11 @@ class DorisConnectionManager:
|
||||
self.logger.error(f"Query execution failed for session {session_id}: {e}")
|
||||
raise
|
||||
finally:
|
||||
# Always release connection back to pool
|
||||
# Always release connection back to the appropriate pool
|
||||
if connection:
|
||||
if token:
|
||||
await self.release_connection_for_token(token, connection)
|
||||
else:
|
||||
await self.release_connection(session_id, connection)
|
||||
|
||||
@asynccontextmanager
|
||||
|
||||
@@ -27,6 +27,13 @@ from collections import defaultdict, deque
|
||||
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
from .sql_security_utils import (
|
||||
SQLSecurityError,
|
||||
validate_identifier,
|
||||
quote_identifier,
|
||||
build_table_reference,
|
||||
get_auth_context
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -122,10 +129,19 @@ class DependencyAnalysisTools:
|
||||
async def _get_tables_metadata(self, connection, catalog_name: Optional[str], db_name: Optional[str], include_views: bool) -> List[Dict]:
|
||||
"""Get metadata for all tables and views"""
|
||||
try:
|
||||
# Build conditions for query
|
||||
# Build conditions for query with parameterized values
|
||||
where_conditions = []
|
||||
params = []
|
||||
|
||||
if db_name:
|
||||
where_conditions.append(f"table_schema = '{db_name}'")
|
||||
# SECURITY FIX: Validate identifier and use parameterized query
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid database name rejected: {e}")
|
||||
return []
|
||||
where_conditions.append("table_schema = %s")
|
||||
params.append(db_name)
|
||||
else:
|
||||
where_conditions.append("table_schema = DATABASE()")
|
||||
|
||||
@@ -148,9 +164,18 @@ class DependencyAnalysisTools:
|
||||
ORDER BY table_schema, table_name
|
||||
"""
|
||||
|
||||
result = await connection.execute(metadata_sql)
|
||||
# SECURITY FIX: Get auth_context and pass to execute for security validation
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(
|
||||
metadata_sql,
|
||||
params=tuple(params) if params else None,
|
||||
auth_context=auth_context
|
||||
)
|
||||
return result.data if result.data else []
|
||||
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Security validation failed in _get_tables_metadata: {str(e)}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get tables metadata: {str(e)}")
|
||||
return []
|
||||
@@ -186,17 +211,31 @@ class DependencyAnalysisTools:
|
||||
|
||||
async def _analyze_view_dependencies(self, connection, dependency_graph: Dict, tables_metadata: List[Dict]) -> None:
|
||||
"""Analyze view definitions to extract table dependencies"""
|
||||
# Get auth_context once for all operations in this method
|
||||
auth_context = get_auth_context()
|
||||
|
||||
try:
|
||||
for table in tables_metadata:
|
||||
if table["table_type"] == "VIEW":
|
||||
table_name = table["table_name"]
|
||||
schema_name = table.get("schema_name", "")
|
||||
|
||||
# Get view definition
|
||||
view_def_sql = f"SHOW CREATE VIEW {schema_name}.{table_name}" if schema_name else f"SHOW CREATE VIEW {table_name}"
|
||||
# SECURITY FIX: Validate identifiers before using in SQL
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
if schema_name:
|
||||
validate_identifier(schema_name, "schema name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected in view analysis: {e}")
|
||||
continue
|
||||
|
||||
# Build safe view reference using quoted identifiers
|
||||
view_ref = build_table_reference(table_name, schema_name) if schema_name else quote_identifier(table_name, "table name")
|
||||
view_def_sql = f"SHOW CREATE VIEW {view_ref}"
|
||||
|
||||
try:
|
||||
result = await connection.execute(view_def_sql)
|
||||
# SECURITY FIX: Pass auth_context to execute
|
||||
result = await connection.execute(view_def_sql, auth_context=auth_context)
|
||||
if result.data and len(result.data) > 0:
|
||||
# Extract view definition from result
|
||||
view_definition = ""
|
||||
@@ -235,6 +274,9 @@ class DependencyAnalysisTools:
|
||||
|
||||
async def _analyze_runtime_dependencies(self, connection, dependency_graph: Dict, analysis_depth: int) -> None:
|
||||
"""Analyze audit logs to discover runtime table dependencies"""
|
||||
# Get auth_context for security validation
|
||||
auth_context = get_auth_context()
|
||||
|
||||
try:
|
||||
# Get recent SQL statements from audit logs
|
||||
audit_sql = """
|
||||
@@ -252,7 +294,8 @@ class DependencyAnalysisTools:
|
||||
LIMIT 1000
|
||||
"""
|
||||
|
||||
result = await connection.execute(audit_sql)
|
||||
# SECURITY FIX: Pass auth_context to execute
|
||||
result = await connection.execute(audit_sql, auth_context=auth_context)
|
||||
|
||||
if result.data:
|
||||
for row in result.data:
|
||||
@@ -274,6 +317,9 @@ class DependencyAnalysisTools:
|
||||
|
||||
async def _analyze_foreign_key_dependencies(self, connection, dependency_graph: Dict, tables_metadata: List[Dict]) -> None:
|
||||
"""Analyze foreign key constraints for explicit dependencies"""
|
||||
# Get auth_context for security validation
|
||||
auth_context = get_auth_context()
|
||||
|
||||
try:
|
||||
# Get foreign key information
|
||||
fk_sql = """
|
||||
@@ -288,7 +334,8 @@ class DependencyAnalysisTools:
|
||||
WHERE REFERENCED_TABLE_NAME IS NOT NULL
|
||||
"""
|
||||
|
||||
result = await connection.execute(fk_sql)
|
||||
# SECURITY FIX: Pass auth_context to execute
|
||||
result = await connection.execute(fk_sql, auth_context=auth_context)
|
||||
|
||||
if result.data:
|
||||
for row in result.data:
|
||||
|
||||
@@ -28,6 +28,7 @@ from datetime import datetime
|
||||
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
from .sql_security_utils import get_auth_context
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -713,7 +714,8 @@ class DorisMonitoringTools:
|
||||
# Fallback to SHOW BACKENDS if no BE hosts configured
|
||||
logger.info("No BE hosts configured, using SHOW BACKENDS to discover BE nodes")
|
||||
connection = await self.connection_manager.get_connection("query")
|
||||
result = await connection.execute("SHOW BACKENDS")
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute("SHOW BACKENDS", auth_context=auth_context)
|
||||
|
||||
be_nodes = []
|
||||
for row in result.data:
|
||||
|
||||
@@ -27,6 +27,13 @@ from collections import defaultdict, Counter
|
||||
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
from .sql_security_utils import (
|
||||
SQLSecurityError,
|
||||
validate_identifier,
|
||||
quote_identifier,
|
||||
build_table_reference,
|
||||
get_auth_context
|
||||
)
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -229,7 +236,8 @@ class PerformanceAnalyticsTools:
|
||||
ORDER BY query_date
|
||||
"""
|
||||
|
||||
result = await connection.execute(query_volume_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(query_volume_sql, auth_context=auth_context)
|
||||
daily_data = result.data if result.data else []
|
||||
|
||||
if not daily_data:
|
||||
@@ -304,7 +312,8 @@ class PerformanceAnalyticsTools:
|
||||
ORDER BY activity_date
|
||||
"""
|
||||
|
||||
result = await connection.execute(user_activity_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(user_activity_sql, auth_context=auth_context)
|
||||
daily_data = result.data if result.data else []
|
||||
|
||||
if not daily_data:
|
||||
@@ -383,7 +392,8 @@ class PerformanceAnalyticsTools:
|
||||
LIMIT 5000
|
||||
"""
|
||||
|
||||
result = await connection.execute(slow_query_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(slow_query_sql, auth_context=auth_context)
|
||||
return result.data if result.data else []
|
||||
|
||||
except Exception as e:
|
||||
@@ -705,7 +715,8 @@ class PerformanceAnalyticsTools:
|
||||
ORDER BY size_mb DESC
|
||||
"""
|
||||
|
||||
db_result = await connection.execute(db_sizes_sql)
|
||||
auth_context = get_auth_context()
|
||||
db_result = await connection.execute(db_sizes_sql, auth_context=auth_context)
|
||||
|
||||
if not db_result.data:
|
||||
logger.warning("No database size information available")
|
||||
@@ -805,7 +816,16 @@ class PerformanceAnalyticsTools:
|
||||
async def _get_database_table_details_from_schema(self, connection, db_name: str) -> List[Dict]:
|
||||
"""Get table details for a specific database using information_schema"""
|
||||
try:
|
||||
table_details_sql = f"""
|
||||
# SECURITY FIX: Validate db_name and use parameterized query
|
||||
auth_context = get_auth_context()
|
||||
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid database name rejected: {e}")
|
||||
return []
|
||||
|
||||
table_details_sql = """
|
||||
SELECT
|
||||
TABLE_SCHEMA as schema_name,
|
||||
TABLE_NAME as table_name,
|
||||
@@ -814,13 +834,13 @@ class PerformanceAnalyticsTools:
|
||||
CREATE_TIME as create_time,
|
||||
UPDATE_TIME as update_time
|
||||
FROM information_schema.tables
|
||||
WHERE TABLE_SCHEMA = '{db_name}'
|
||||
WHERE TABLE_SCHEMA = %s
|
||||
AND TABLE_TYPE = 'BASE TABLE'
|
||||
AND (COALESCE(DATA_LENGTH, 0) + COALESCE(INDEX_LENGTH, 0)) > 0
|
||||
ORDER BY size_mb DESC
|
||||
"""
|
||||
|
||||
result = await connection.execute(table_details_sql)
|
||||
result = await connection.execute(table_details_sql, params=(db_name,), auth_context=auth_context)
|
||||
|
||||
if not result.data:
|
||||
logger.warning(f"No table details found for database {db_name}")
|
||||
@@ -867,6 +887,13 @@ class PerformanceAnalyticsTools:
|
||||
async def _get_database_table_details(self, connection, db_name: str) -> List[Dict]:
|
||||
"""Get table details for a specific database using session-consistent queries"""
|
||||
try:
|
||||
# SECURITY FIX: Validate db_name before using in SQL
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid database name rejected: {e}")
|
||||
return []
|
||||
|
||||
# Method 1: Try to use session-consistent approach with raw connection
|
||||
# This requires accessing the underlying connection to maintain session state
|
||||
|
||||
@@ -877,8 +904,9 @@ class PerformanceAnalyticsTools:
|
||||
# Use raw connection to maintain session state
|
||||
cursor = await raw_conn.cursor()
|
||||
try:
|
||||
# Execute USE and SHOW DATA in the same session
|
||||
await cursor.execute(f"USE {db_name}")
|
||||
# SECURITY FIX: Use quoted identifier for USE statement
|
||||
quoted_db = quote_identifier(db_name, "database name")
|
||||
await cursor.execute(f"USE {quoted_db}")
|
||||
await cursor.execute("SHOW DATA")
|
||||
|
||||
result = await cursor.fetchall()
|
||||
@@ -922,9 +950,19 @@ class PerformanceAnalyticsTools:
|
||||
async def _get_database_table_details_fallback(self, connection, db_name: str) -> List[Dict]:
|
||||
"""Fallback method to get table details using individual queries"""
|
||||
try:
|
||||
# Get all tables in the database
|
||||
tables_sql = f"SHOW TABLES FROM {db_name}"
|
||||
tables_result = await connection.execute(tables_sql)
|
||||
# SECURITY FIX: Validate db_name and get auth_context
|
||||
auth_context = get_auth_context()
|
||||
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid database name rejected: {e}")
|
||||
return []
|
||||
|
||||
# Get all tables in the database using quoted identifier
|
||||
quoted_db = quote_identifier(db_name, "database name")
|
||||
tables_sql = f"SHOW TABLES FROM {quoted_db}"
|
||||
tables_result = await connection.execute(tables_sql, auth_context=auth_context)
|
||||
|
||||
if not tables_result.data:
|
||||
return []
|
||||
@@ -934,9 +972,11 @@ class PerformanceAnalyticsTools:
|
||||
table_name = table_row.get(f"Tables_in_{db_name}", "") or table_row.get("table_name", "")
|
||||
if table_name:
|
||||
try:
|
||||
# Use SHOW DATA FROM db.table for each table
|
||||
data_sql = f"SHOW DATA FROM {db_name}.{table_name}"
|
||||
data_result = await connection.execute(data_sql)
|
||||
# SECURITY FIX: Validate table_name and use safe reference
|
||||
validate_identifier(table_name, "table name")
|
||||
safe_table_ref = build_table_reference(table_name, db_name)
|
||||
data_sql = f"SHOW DATA FROM {safe_table_ref}"
|
||||
data_result = await connection.execute(data_sql, auth_context=auth_context)
|
||||
|
||||
if data_result.data:
|
||||
for row in data_result.data:
|
||||
@@ -1036,6 +1076,7 @@ class PerformanceAnalyticsTools:
|
||||
async def _get_all_tables_info(self, connection) -> List[Dict]:
|
||||
"""Get basic information for all tables (fallback method)"""
|
||||
try:
|
||||
auth_context = get_auth_context()
|
||||
tables_sql = """
|
||||
SELECT
|
||||
table_schema,
|
||||
@@ -1053,7 +1094,7 @@ class PerformanceAnalyticsTools:
|
||||
ORDER BY (data_length + index_length) DESC
|
||||
"""
|
||||
|
||||
result = await connection.execute(tables_sql)
|
||||
result = await connection.execute(tables_sql, auth_context=auth_context)
|
||||
return result.data if result.data else []
|
||||
|
||||
except Exception as e:
|
||||
@@ -1120,23 +1161,37 @@ class PerformanceAnalyticsTools:
|
||||
async def _get_current_table_size(self, connection, full_table_name: str) -> Optional[Dict]:
|
||||
"""Get current table size"""
|
||||
try:
|
||||
# Try to query table size directly
|
||||
size_sql = f"""
|
||||
# SECURITY FIX: Get auth_context and use parameterized query
|
||||
auth_context = get_auth_context()
|
||||
|
||||
# Extract table name for parameterized query
|
||||
table_name_only = full_table_name.split('.')[-1] if '.' in full_table_name else full_table_name
|
||||
|
||||
# Validate identifiers
|
||||
try:
|
||||
validate_identifier(table_name_only, "table name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid table name rejected: {e}")
|
||||
return None
|
||||
|
||||
# Use parameterized query for safety
|
||||
size_sql = """
|
||||
SELECT
|
||||
COALESCE(ROUND((COALESCE(data_length, 0) + COALESCE(index_length, 0)) / 1024 / 1024, 2), 0) as size_mb,
|
||||
COALESCE(table_rows, 0) as `rows`
|
||||
FROM information_schema.tables
|
||||
WHERE CONCAT(table_schema, '.', table_name) = '{full_table_name}'
|
||||
OR table_name = '{full_table_name.split('.')[-1]}'
|
||||
WHERE CONCAT(table_schema, '.', table_name) = %s
|
||||
OR table_name = %s
|
||||
"""
|
||||
|
||||
result = await connection.execute(size_sql)
|
||||
result = await connection.execute(size_sql, params=(full_table_name, table_name_only), auth_context=auth_context)
|
||||
if result.data and result.data[0]:
|
||||
return result.data[0]
|
||||
|
||||
# If information_schema has no data, try COUNT query
|
||||
# full_table_name should already be validated by caller using build_table_reference
|
||||
count_sql = f"SELECT COUNT(*) as rows FROM {full_table_name}"
|
||||
count_result = await connection.execute(count_sql)
|
||||
count_result = await connection.execute(count_sql, auth_context=auth_context)
|
||||
if count_result.data:
|
||||
return {
|
||||
"size_mb": 0, # Cannot get exact size
|
||||
@@ -1145,6 +1200,9 @@ class PerformanceAnalyticsTools:
|
||||
|
||||
return None
|
||||
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Security validation failed for {full_table_name}: {str(e)}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get current size for {full_table_name}: {str(e)}")
|
||||
return None
|
||||
@@ -1154,8 +1212,19 @@ class PerformanceAnalyticsTools:
|
||||
) -> List[Dict]:
|
||||
"""Get historical growth data based on partitions"""
|
||||
try:
|
||||
# Query partition information
|
||||
partition_sql = f"""
|
||||
# SECURITY FIX: Validate identifiers and use parameterized query
|
||||
auth_context = get_auth_context()
|
||||
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
if schema_name:
|
||||
validate_identifier(schema_name, "schema name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return []
|
||||
|
||||
# Use parameterized query for safety
|
||||
partition_sql = """
|
||||
SELECT
|
||||
partition_name,
|
||||
partition_description,
|
||||
@@ -1163,15 +1232,19 @@ class PerformanceAnalyticsTools:
|
||||
data_length,
|
||||
create_time
|
||||
FROM information_schema.partitions
|
||||
WHERE table_schema = '{schema_name or ""}'
|
||||
AND table_name = '{table_name}'
|
||||
WHERE table_schema = %s
|
||||
AND table_name = %s
|
||||
AND partition_name IS NOT NULL
|
||||
AND create_time IS NOT NULL
|
||||
AND create_time >= DATE_SUB(NOW(), INTERVAL {days} DAY)
|
||||
AND create_time >= DATE_SUB(NOW(), INTERVAL %s DAY)
|
||||
ORDER BY create_time DESC
|
||||
"""
|
||||
|
||||
result = await connection.execute(partition_sql)
|
||||
result = await connection.execute(
|
||||
partition_sql,
|
||||
params=(schema_name or "", table_name, days),
|
||||
auth_context=auth_context
|
||||
)
|
||||
if not result.data:
|
||||
return []
|
||||
|
||||
@@ -1210,6 +1283,9 @@ class PerformanceAnalyticsTools:
|
||||
) -> List[Dict]:
|
||||
"""Get historical growth data based on timestamp fields"""
|
||||
try:
|
||||
# SECURITY FIX: Get auth_context
|
||||
auth_context = get_auth_context()
|
||||
|
||||
# Find possible timestamp fields
|
||||
timestamp_columns = await self._find_timestamp_columns(connection, table_name, schema_name)
|
||||
if not timestamp_columns:
|
||||
@@ -1218,20 +1294,29 @@ class PerformanceAnalyticsTools:
|
||||
# Use best timestamp field for analysis
|
||||
time_column = timestamp_columns[0]
|
||||
|
||||
# Aggregate data by date
|
||||
# SECURITY FIX: Validate time_column before using in SQL
|
||||
try:
|
||||
validate_identifier(time_column, "column name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid column name rejected: {e}")
|
||||
return []
|
||||
|
||||
quoted_time_column = quote_identifier(time_column, "column name")
|
||||
|
||||
# Aggregate data by date (full_table_name should be validated by caller)
|
||||
growth_sql = f"""
|
||||
SELECT
|
||||
DATE({time_column}) as date,
|
||||
DATE({quoted_time_column}) as date,
|
||||
COUNT(*) as daily_records,
|
||||
COUNT(*) / SUM(COUNT(*)) OVER() * 100 as percentage
|
||||
FROM {full_table_name}
|
||||
WHERE {time_column} >= DATE_SUB(NOW(), INTERVAL {days} DAY)
|
||||
AND {time_column} IS NOT NULL
|
||||
GROUP BY DATE({time_column})
|
||||
WHERE {quoted_time_column} >= DATE_SUB(NOW(), INTERVAL %s DAY)
|
||||
AND {quoted_time_column} IS NOT NULL
|
||||
GROUP BY DATE({quoted_time_column})
|
||||
ORDER BY date DESC
|
||||
"""
|
||||
|
||||
result = await connection.execute(growth_sql)
|
||||
result = await connection.execute(growth_sql, params=(days,), auth_context=auth_context)
|
||||
if not result.data:
|
||||
return []
|
||||
|
||||
@@ -1257,11 +1342,22 @@ class PerformanceAnalyticsTools:
|
||||
async def _find_timestamp_columns(self, connection, table_name: str, schema_name: str) -> List[str]:
|
||||
"""Find timestamp fields in table"""
|
||||
try:
|
||||
timestamp_sql = f"""
|
||||
# SECURITY FIX: Validate identifiers and use parameterized query
|
||||
auth_context = get_auth_context()
|
||||
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
if schema_name:
|
||||
validate_identifier(schema_name, "schema name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return []
|
||||
|
||||
timestamp_sql = """
|
||||
SELECT column_name, data_type
|
||||
FROM information_schema.columns
|
||||
WHERE table_schema = '{schema_name or ""}'
|
||||
AND table_name = '{table_name}'
|
||||
WHERE table_schema = %s
|
||||
AND table_name = %s
|
||||
AND (
|
||||
data_type IN ('datetime', 'timestamp', 'date')
|
||||
OR column_name REGEXP '(create|insert|update|modify).*time'
|
||||
@@ -1278,9 +1374,16 @@ class PerformanceAnalyticsTools:
|
||||
END
|
||||
"""
|
||||
|
||||
result = await connection.execute(timestamp_sql)
|
||||
result = await connection.execute(
|
||||
timestamp_sql,
|
||||
params=(schema_name or "", table_name),
|
||||
auth_context=auth_context
|
||||
)
|
||||
return [row["column_name"] for row in result.data] if result.data else []
|
||||
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Security validation failed: {str(e)}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to find timestamp columns: {str(e)}")
|
||||
return []
|
||||
@@ -1290,8 +1393,22 @@ class PerformanceAnalyticsTools:
|
||||
) -> List[Dict]:
|
||||
"""Estimate growth data based on audit logs"""
|
||||
try:
|
||||
# SECURITY FIX: Validate table_name and use parameterized query
|
||||
auth_context = get_auth_context()
|
||||
|
||||
try:
|
||||
validate_identifier(table_name.split(".")[-1], "table name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid table name rejected: {e}")
|
||||
return []
|
||||
|
||||
# Extract just the table name for LIKE pattern
|
||||
table_name_only = table_name.split(".")[-1]
|
||||
like_pattern_full = f"%{table_name}%"
|
||||
like_pattern_short = f"%{table_name_only}%"
|
||||
|
||||
# Analyze operation history for this table
|
||||
audit_sql = f"""
|
||||
audit_sql = """
|
||||
SELECT
|
||||
DATE(`time`) as operation_date,
|
||||
COUNT(*) as operation_count,
|
||||
@@ -1299,17 +1416,21 @@ class PerformanceAnalyticsTools:
|
||||
SUM(CASE WHEN stmt LIKE 'UPDATE%' THEN 1 ELSE 0 END) as update_count,
|
||||
SUM(CASE WHEN stmt LIKE 'DELETE%' THEN 1 ELSE 0 END) as delete_count
|
||||
FROM internal.__internal_schema.audit_log
|
||||
WHERE `time` >= DATE_SUB(NOW(), INTERVAL {days} DAY)
|
||||
WHERE `time` >= DATE_SUB(NOW(), INTERVAL %s DAY)
|
||||
AND stmt IS NOT NULL
|
||||
AND (
|
||||
stmt LIKE '%{table_name}%'
|
||||
OR stmt LIKE '%{table_name.split(".")[-1]}%'
|
||||
stmt LIKE %s
|
||||
OR stmt LIKE %s
|
||||
)
|
||||
GROUP BY DATE(`time`)
|
||||
ORDER BY operation_date DESC
|
||||
"""
|
||||
|
||||
result = await connection.execute(audit_sql)
|
||||
result = await connection.execute(
|
||||
audit_sql,
|
||||
params=(days, like_pattern_full, like_pattern_short),
|
||||
auth_context=auth_context
|
||||
)
|
||||
if not result.data:
|
||||
return []
|
||||
|
||||
|
||||
@@ -33,8 +33,11 @@ from datetime import datetime, timedelta, date
|
||||
from typing import Any, Dict
|
||||
from decimal import Decimal
|
||||
|
||||
import sqlparse
|
||||
|
||||
from .db import DorisConnectionManager, QueryResult
|
||||
from .logger import get_logger
|
||||
from .sql_security_utils import get_auth_context
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -467,6 +470,51 @@ class DorisQueryExecutor:
|
||||
f"Slow query detected: {execution_time:.2f}s (threshold: {self.slow_query_threshold}s)"
|
||||
)
|
||||
|
||||
async def execute_batch_sqls_for_mcp(
|
||||
self, sqls: list[str],
|
||||
timeout: int = 30,
|
||||
session_id: str = "mcp_session",
|
||||
user_id: str = "mcp_user",
|
||||
auth_context=None
|
||||
) -> dict[str, Any]:
|
||||
"""Execute multiple sqls in batch"""
|
||||
if not sqls:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "SQL query is required",
|
||||
"data": None
|
||||
}
|
||||
query_requests = [
|
||||
QueryRequest(
|
||||
sql=sql,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
timeout=timeout,
|
||||
cache_enabled=False
|
||||
)
|
||||
for sql in sqls
|
||||
]
|
||||
query_results = await self.execute_batch_queries(query_requests, auth_context)
|
||||
# Serialize data for JSON response
|
||||
results = [
|
||||
{
|
||||
"data": [self._serialize_row_data(data) for data in result.data],
|
||||
"row_count": result.row_count,
|
||||
"execution_time": result.execution_time,
|
||||
"metadata": {
|
||||
"columns": result.metadata.get("columns", []),
|
||||
"query": result.sql
|
||||
}
|
||||
}
|
||||
for result in query_results
|
||||
]
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"multiple_results": True,
|
||||
"results": results
|
||||
}
|
||||
|
||||
async def execute_batch_queries(
|
||||
self, query_requests: list[QueryRequest], auth_context=None
|
||||
) -> list[QueryResult]:
|
||||
@@ -484,20 +532,24 @@ class DorisQueryExecutor:
|
||||
self.execute_query(request, auth_context) for request in query_requests
|
||||
]
|
||||
|
||||
try:
|
||||
query_results = []
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Batch query execution failed: {e}")
|
||||
raise
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
self.logger.error(f"Batch query execution failed: {result}")
|
||||
raise result
|
||||
else:
|
||||
query_results.append(result)
|
||||
|
||||
return results
|
||||
return query_results
|
||||
|
||||
async def explain_query(self, sql: str, session_id: str) -> dict[str, Any]:
|
||||
"""Get query execution plan"""
|
||||
explain_sql = f"EXPLAIN {sql}"
|
||||
|
||||
connection = await self.connection_manager.get_connection(session_id)
|
||||
result = await connection.execute(explain_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(explain_sql, auth_context=auth_context)
|
||||
|
||||
return {
|
||||
"query": sql,
|
||||
@@ -546,9 +598,13 @@ class DorisQueryExecutor:
|
||||
limit: int = 1000,
|
||||
timeout: int = 30,
|
||||
session_id: str = "mcp_session",
|
||||
user_id: str = "mcp_user"
|
||||
user_id: str = "mcp_user",
|
||||
auth_context = None # FIX for Issue #62 Bug 1: Accept auth_context with token
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute SQL query for MCP interface - unified method"""
|
||||
"""Execute SQL query for MCP interface - unified method
|
||||
|
||||
FIX for Issue #62 Bug 1: Now accepts auth_context parameter to support token-bound database configuration
|
||||
"""
|
||||
max_retries = 2
|
||||
retry_count = 0
|
||||
|
||||
@@ -564,19 +620,31 @@ class DorisQueryExecutor:
|
||||
# Import required security modules
|
||||
from .security import DorisSecurityManager, AuthContext, SecurityLevel
|
||||
|
||||
# Create proper auth context with read-only permissions
|
||||
# FIX: Use provided auth_context if available (contains token for DB config)
|
||||
# Otherwise create default auth context for backward compatibility
|
||||
if auth_context is None:
|
||||
auth_context = AuthContext(
|
||||
user_id=user_id,
|
||||
roles=["read_only_user"], # Restrictive role for MCP interface
|
||||
permissions=["read_data"], # Only read permissions
|
||||
session_id=session_id,
|
||||
security_level=SecurityLevel.INTERNAL
|
||||
security_level=SecurityLevel.INTERNAL,
|
||||
token="" # No token in default context
|
||||
)
|
||||
else:
|
||||
# Use provided auth_context (may contain token for database configuration)
|
||||
self.logger.debug(f"Using provided auth_context with token: {bool(hasattr(auth_context, 'token') and auth_context.token)}")
|
||||
|
||||
# Perform SQL security validation if enabled
|
||||
if hasattr(self.connection_manager, 'config') and hasattr(self.connection_manager.config, 'security'):
|
||||
if self.connection_manager.config.security.enable_security_check:
|
||||
try:
|
||||
# 🔧 FIX: Use existing security_manager to avoid creating multiple TokenManager instances
|
||||
# Creating new DorisSecurityManager each time causes multiple hot reload monitors
|
||||
security_manager = getattr(self.connection_manager, 'security_manager', None)
|
||||
if not security_manager:
|
||||
# Fallback: create new one only if not available (should rarely happen)
|
||||
self.logger.warning("No existing security_manager, creating new instance")
|
||||
security_manager = DorisSecurityManager(self.connection_manager.config)
|
||||
validation_result = await security_manager.validate_sql_security(sql, auth_context)
|
||||
|
||||
@@ -623,6 +691,15 @@ class DorisQueryExecutor:
|
||||
sql = sql[:-1]
|
||||
sql = f"{sql} LIMIT {limit}"
|
||||
|
||||
all_statements = [
|
||||
s.strip()
|
||||
for s in sqlparse.split(sql)
|
||||
if s.strip()
|
||||
]
|
||||
if len(all_statements) > 1:
|
||||
return await self.execute_batch_sqls_for_mcp(sqls=all_statements, timeout=timeout,
|
||||
session_id=session_id, user_id=user_id,
|
||||
auth_context=auth_context)
|
||||
# Create query request
|
||||
query_request = QueryRequest(
|
||||
sql=sql,
|
||||
@@ -880,17 +957,20 @@ async def execute_sql_query(sql: str, connection_manager: DorisConnectionManager
|
||||
|
||||
This function now includes security validation to ensure safe query execution.
|
||||
All queries are validated against the configured security policies before execution.
|
||||
|
||||
FIX for Issue #62 Bug 1: Now supports auth_context parameter for token-bound database configuration
|
||||
FIX for Issue #58 Problem 2: Removed executor.close() to prevent ClosedResourceError in multi-worker mode
|
||||
"""
|
||||
try:
|
||||
# Create query executor with the connection manager's configuration
|
||||
executor = DorisQueryExecutor(connection_manager)
|
||||
|
||||
try:
|
||||
# Extract parameters from kwargs or use defaults
|
||||
limit = kwargs.get("limit", 1000)
|
||||
timeout = kwargs.get("timeout", 30)
|
||||
session_id = kwargs.get("session_id", "mcp_session")
|
||||
user_id = kwargs.get("user_id", "mcp_user")
|
||||
auth_context = kwargs.get("auth_context", None) # FIX: Extract auth_context
|
||||
|
||||
# The execute_sql_for_mcp method now includes security validation
|
||||
result = await executor.execute_sql_for_mcp(
|
||||
@@ -898,11 +978,17 @@ async def execute_sql_query(sql: str, connection_manager: DorisConnectionManager
|
||||
limit=limit,
|
||||
timeout=timeout,
|
||||
session_id=session_id,
|
||||
user_id=user_id
|
||||
user_id=user_id,
|
||||
auth_context=auth_context # FIX: Pass auth_context with token
|
||||
)
|
||||
|
||||
# FIX for Issue #58 Problem 2: Do NOT close executor here
|
||||
# In multi-worker mode, closing here causes ClosedResourceError
|
||||
# The executor's resources (cache, background tasks) will be managed
|
||||
# by the connection_manager lifecycle and Python's garbage collection
|
||||
# This prevents premature cleanup while MCP session manager is still processing
|
||||
|
||||
return result
|
||||
finally:
|
||||
await executor.close()
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
|
||||
@@ -32,6 +32,11 @@ from datetime import datetime, timedelta
|
||||
|
||||
# Import unified logging configuration
|
||||
from .logger import get_logger
|
||||
from .sql_security_utils import (
|
||||
SQLSecurityError,
|
||||
validate_identifier,
|
||||
quote_identifier
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logger = get_logger(__name__)
|
||||
@@ -431,6 +436,16 @@ class MetadataExtractor:
|
||||
logger.warning("Database name not specified")
|
||||
return {}
|
||||
|
||||
# SECURITY FIX: Validate identifiers to prevent SQL injection
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
validate_identifier(db_name, "database name")
|
||||
if effective_catalog:
|
||||
validate_identifier(effective_catalog, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected in get_table_schema: {e}")
|
||||
return {}
|
||||
|
||||
cache_key = f"schema_{effective_catalog or 'default'}_{db_name}_{table_name}"
|
||||
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
|
||||
return self.metadata_cache[cache_key]
|
||||
@@ -536,6 +551,16 @@ class MetadataExtractor:
|
||||
logger.warning("Database name not specified")
|
||||
return ""
|
||||
|
||||
# SECURITY FIX: Validate identifiers
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
validate_identifier(db_name, "database name")
|
||||
if effective_catalog:
|
||||
validate_identifier(effective_catalog, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return ""
|
||||
|
||||
cache_key = f"table_comment_{effective_catalog or 'default'}_{db_name}_{table_name}"
|
||||
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
|
||||
return self.metadata_cache[cache_key]
|
||||
@@ -587,6 +612,16 @@ class MetadataExtractor:
|
||||
logger.warning("Database name not specified")
|
||||
return {}
|
||||
|
||||
# SECURITY FIX: Validate identifiers
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
validate_identifier(db_name, "database name")
|
||||
if effective_catalog:
|
||||
validate_identifier(effective_catalog, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return {}
|
||||
|
||||
cache_key = f"column_comments_{effective_catalog or 'default'}_{db_name}_{table_name}"
|
||||
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
|
||||
return self.metadata_cache[cache_key]
|
||||
@@ -643,17 +678,30 @@ class MetadataExtractor:
|
||||
logger.error("Database name not specified")
|
||||
return []
|
||||
|
||||
# SECURITY FIX: Validate identifiers
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
validate_identifier(db_name, "database name")
|
||||
if effective_catalog:
|
||||
validate_identifier(effective_catalog, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return []
|
||||
|
||||
cache_key = f"indexes_{effective_catalog or 'default'}_{db_name}_{table_name}"
|
||||
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
|
||||
return self.metadata_cache[cache_key]
|
||||
|
||||
try:
|
||||
# Build query with catalog prefix if specified
|
||||
# Build query with catalog prefix if specified (identifiers already validated)
|
||||
safe_table = quote_identifier(table_name, "table name")
|
||||
safe_db = quote_identifier(db_name, "database name")
|
||||
if effective_catalog:
|
||||
query = f"SHOW INDEX FROM `{effective_catalog}`.`{db_name}`.`{table_name}`"
|
||||
safe_catalog = quote_identifier(effective_catalog, "catalog name")
|
||||
query = f"SHOW INDEX FROM {safe_catalog}.{safe_db}.{safe_table}"
|
||||
logger.info(f"Using three-part naming for index query: {query}")
|
||||
else:
|
||||
query = f"SHOW INDEX FROM `{db_name}`.`{table_name}`"
|
||||
query = f"SHOW INDEX FROM {safe_db}.{safe_table}"
|
||||
|
||||
try:
|
||||
# NOTE: Deprecated sync path retained for compatibility; use async variant instead.
|
||||
@@ -1146,8 +1194,17 @@ class MetadataExtractor:
|
||||
"""
|
||||
try:
|
||||
if self.connection_manager:
|
||||
# FIX: Get auth_context from global ContextVar for token-bound database configuration
|
||||
# This ensures all query methods use the correct user's connection pool
|
||||
auth_context = None
|
||||
try:
|
||||
from .security import mcp_auth_context_var
|
||||
auth_context = mcp_auth_context_var.get()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Use the injected connection manager directly (async)
|
||||
result = await self.connection_manager.execute_query(self._session_id, query, None)
|
||||
result = await self.connection_manager.execute_query(self._session_id, query, None, auth_context)
|
||||
|
||||
# Extract data from QueryResult
|
||||
if hasattr(result, 'data'):
|
||||
@@ -1188,12 +1245,28 @@ class MetadataExtractor:
|
||||
try:
|
||||
# Use async query method
|
||||
effective_catalog = catalog_name or self.catalog_name
|
||||
effective_db = db_name or self.db_name
|
||||
|
||||
# Build query statement
|
||||
# SECURITY FIX: Validate identifiers
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
if effective_db:
|
||||
validate_identifier(effective_db, "database name")
|
||||
if effective_catalog and effective_catalog != "internal":
|
||||
query = f"DESCRIBE `{effective_catalog}`.`{db_name or self.db_name}`.`{table_name}`"
|
||||
validate_identifier(effective_catalog, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return []
|
||||
|
||||
# Build query statement using safe identifiers
|
||||
safe_table = quote_identifier(table_name, "table name")
|
||||
safe_db = quote_identifier(effective_db, "database name") if effective_db else None
|
||||
|
||||
if effective_catalog and effective_catalog != "internal":
|
||||
safe_catalog = quote_identifier(effective_catalog, "catalog name")
|
||||
query = f"DESCRIBE {safe_catalog}.{safe_db}.{safe_table}"
|
||||
else:
|
||||
query = f"DESCRIBE `{db_name or self.db_name}`.`{table_name}`"
|
||||
query = f"DESCRIBE {safe_db}.{safe_table}"
|
||||
|
||||
# Execute async query
|
||||
result = await self._execute_query_async(query, db_name)
|
||||
@@ -1226,8 +1299,15 @@ class MetadataExtractor:
|
||||
try:
|
||||
effective_catalog = catalog_name or self.catalog_name
|
||||
|
||||
# SECURITY FIX: Validate catalog name if provided
|
||||
if effective_catalog and effective_catalog != "internal":
|
||||
query = f"SHOW DATABASES FROM `{effective_catalog}`"
|
||||
try:
|
||||
validate_identifier(effective_catalog, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid catalog name rejected: {e}")
|
||||
return []
|
||||
safe_catalog = quote_identifier(effective_catalog, "catalog name")
|
||||
query = f"SHOW DATABASES FROM {safe_catalog}"
|
||||
else:
|
||||
query = "SHOW DATABASES"
|
||||
|
||||
@@ -1257,10 +1337,23 @@ class MetadataExtractor:
|
||||
effective_catalog = catalog_name or self.catalog_name
|
||||
effective_db = db_name or self.db_name
|
||||
|
||||
# SECURITY FIX: Validate identifiers
|
||||
try:
|
||||
if effective_db:
|
||||
validate_identifier(effective_db, "database name")
|
||||
if effective_catalog and effective_catalog != "internal":
|
||||
query = f"SHOW TABLES FROM `{effective_catalog}`.`{effective_db}`"
|
||||
validate_identifier(effective_catalog, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return []
|
||||
|
||||
safe_db = quote_identifier(effective_db, "database name") if effective_db else None
|
||||
|
||||
if effective_catalog and effective_catalog != "internal":
|
||||
safe_catalog = quote_identifier(effective_catalog, "catalog name")
|
||||
query = f"SHOW TABLES FROM {safe_catalog}.{safe_db}"
|
||||
else:
|
||||
query = f"SHOW TABLES FROM `{effective_db}`"
|
||||
query = f"SHOW TABLES FROM {safe_db}"
|
||||
|
||||
result = await self._execute_query_async(query, effective_db)
|
||||
|
||||
@@ -1319,6 +1412,15 @@ class MetadataExtractor:
|
||||
effective_db = db_name or self.db_name
|
||||
effective_catalog = catalog_name or self.catalog_name
|
||||
|
||||
# SECURITY FIX: Validate identifiers
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
if effective_db:
|
||||
validate_identifier(effective_db, "database name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return ""
|
||||
|
||||
query = f"""
|
||||
SELECT
|
||||
TABLE_COMMENT
|
||||
@@ -1343,6 +1445,15 @@ class MetadataExtractor:
|
||||
effective_db = db_name or self.db_name
|
||||
effective_catalog = catalog_name or self.catalog_name
|
||||
|
||||
# SECURITY FIX: Validate identifiers
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
if effective_db:
|
||||
validate_identifier(effective_db, "database name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return {}
|
||||
|
||||
query = f"""
|
||||
SELECT
|
||||
COLUMN_NAME,
|
||||
@@ -1373,12 +1484,27 @@ class MetadataExtractor:
|
||||
effective_db = db_name or self.db_name
|
||||
effective_catalog = catalog_name or self.catalog_name
|
||||
|
||||
# Build query with catalog prefix if specified
|
||||
# SECURITY FIX: Validate identifiers
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
if effective_db:
|
||||
validate_identifier(effective_db, "database name")
|
||||
if effective_catalog:
|
||||
query = f"SHOW INDEX FROM `{effective_catalog}`.`{effective_db}`.`{table_name}`"
|
||||
validate_identifier(effective_catalog, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid identifier rejected: {e}")
|
||||
return []
|
||||
|
||||
# Build query with catalog prefix if specified (using safe identifiers)
|
||||
safe_table = quote_identifier(table_name, "table name")
|
||||
safe_db = quote_identifier(effective_db, "database name") if effective_db else None
|
||||
|
||||
if effective_catalog:
|
||||
safe_catalog = quote_identifier(effective_catalog, "catalog name")
|
||||
query = f"SHOW INDEX FROM {safe_catalog}.{safe_db}.{safe_table}"
|
||||
logger.info(f"Using three-part naming for async index query: {query}")
|
||||
else:
|
||||
query = f"SHOW INDEX FROM `{effective_db}`.`{table_name}`"
|
||||
query = f"SHOW INDEX FROM {safe_db}.{safe_table}"
|
||||
|
||||
rows = await self._execute_query_async(query, effective_db)
|
||||
indexes: List[Dict[str, Any]] = []
|
||||
@@ -1464,6 +1590,9 @@ class MetadataExtractor:
|
||||
"""
|
||||
Execute SQL query and return results, supports catalog federation queries
|
||||
Unified interface for MCP tools
|
||||
|
||||
FIX for Issue #62 Bug 1: Now retrieves auth_context from context variable to support token-bound database configuration
|
||||
FIX for Issue #62 Bug 3: Now uses db_name and catalog_name parameters to switch database context
|
||||
"""
|
||||
logger.info(f"Executing SQL query: {sql}, DB: {db_name}, Catalog: {catalog_name}, MaxRows: {max_rows}, Timeout: {timeout}")
|
||||
|
||||
@@ -1471,15 +1600,86 @@ class MetadataExtractor:
|
||||
if not sql:
|
||||
return self._format_response(success=False, error="No SQL statement provided", message="Please provide SQL statement to execute")
|
||||
|
||||
# FIX for Issue #62 Bug 3: Build context switching SQL if db_name or catalog_name is specified
|
||||
# SECURITY FIX: Validate catalog_name and db_name to prevent SQL injection
|
||||
final_sql = sql
|
||||
if catalog_name or db_name:
|
||||
context_statements = []
|
||||
|
||||
# Validate and sanitize catalog_name
|
||||
if catalog_name:
|
||||
try:
|
||||
validate_identifier(catalog_name, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid catalog name rejected: {e}")
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid catalog name: {catalog_name}",
|
||||
message="Catalog name contains invalid characters"
|
||||
)
|
||||
# Use quote_identifier to safely escape the catalog name
|
||||
safe_catalog = quote_identifier(catalog_name, "catalog name")
|
||||
context_statements.append(f"USE CATALOG {safe_catalog}")
|
||||
logger.debug(f"Switching to catalog: {catalog_name}")
|
||||
|
||||
# Validate and sanitize db_name
|
||||
if db_name:
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
logger.warning(f"Invalid database name rejected: {e}")
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid database name: {db_name}",
|
||||
message="Database name contains invalid characters"
|
||||
)
|
||||
# Use quote_identifier to safely escape the database name
|
||||
safe_db = quote_identifier(db_name, "database name")
|
||||
if catalog_name:
|
||||
safe_catalog = quote_identifier(catalog_name, "catalog name")
|
||||
context_statements.append(f"USE {safe_catalog}.{safe_db}")
|
||||
else:
|
||||
context_statements.append(f"USE {safe_db}")
|
||||
logger.debug(f"Switching to database: {db_name}")
|
||||
|
||||
# Combine context switching with original SQL
|
||||
if context_statements:
|
||||
# Remove trailing semicolon from context statements if present
|
||||
context_sql = "; ".join(context_statements)
|
||||
# Ensure original SQL doesn't start with semicolon
|
||||
sql_clean = sql.lstrip(";").strip()
|
||||
final_sql = f"{context_sql}; {sql_clean}"
|
||||
logger.debug(f"Modified SQL with context switching: {final_sql[:200]}...")
|
||||
|
||||
# FIX: Try to get auth_context from context variable (set by HTTP middleware)
|
||||
# This allows token-bound database configuration to work
|
||||
# CRITICAL: Use the global ContextVar from security.py to ensure same instance is used everywhere
|
||||
auth_context = None
|
||||
try:
|
||||
from .security import mcp_auth_context_var
|
||||
|
||||
# Get auth_context from the global context variable
|
||||
# This will be set by the HTTP request handler in main.py
|
||||
auth_context = mcp_auth_context_var.get()
|
||||
|
||||
if auth_context:
|
||||
logger.debug(f"Retrieved auth_context from context variable with token: {bool(hasattr(auth_context, 'token') and auth_context.token)}")
|
||||
else:
|
||||
logger.debug("No auth_context found in context variable, using default")
|
||||
except Exception as ctx_error:
|
||||
logger.debug(f"Could not retrieve auth_context from context variable: {ctx_error}")
|
||||
auth_context = None
|
||||
|
||||
# Import query executor
|
||||
from .query_executor import execute_sql_query
|
||||
|
||||
# Call execute_sql_query to execute query
|
||||
# Call execute_sql_query to execute query with auth_context
|
||||
exec_result = await execute_sql_query(
|
||||
sql=sql,
|
||||
sql=final_sql, # Use modified SQL with context switching
|
||||
connection_manager=self.connection_manager,
|
||||
limit=max_rows,
|
||||
timeout=timeout
|
||||
timeout=timeout,
|
||||
auth_context=auth_context # FIX: Pass auth_context with token
|
||||
)
|
||||
|
||||
return exec_result
|
||||
@@ -1500,6 +1700,36 @@ class MetadataExtractor:
|
||||
if not table_name:
|
||||
return self._format_response(success=False, error="Missing table_name parameter")
|
||||
|
||||
# SECURITY: Validate identifiers before processing
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid table name: {table_name}",
|
||||
message="Table name contains invalid characters"
|
||||
)
|
||||
|
||||
if db_name:
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid database name: {db_name}",
|
||||
message="Database name contains invalid characters"
|
||||
)
|
||||
|
||||
if catalog_name and catalog_name != "internal":
|
||||
try:
|
||||
validate_identifier(catalog_name, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid catalog name: {catalog_name}",
|
||||
message="Catalog name contains invalid characters"
|
||||
)
|
||||
|
||||
try:
|
||||
schema = await self.get_table_schema_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
|
||||
@@ -1523,6 +1753,27 @@ class MetadataExtractor:
|
||||
"""Get list of all table names in specified database - MCP interface"""
|
||||
logger.info(f"Getting database table list: DB: {db_name}, Catalog: {catalog_name}")
|
||||
|
||||
# SECURITY: Validate identifiers
|
||||
if db_name:
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid database name: {db_name}",
|
||||
message="Database name contains invalid characters"
|
||||
)
|
||||
|
||||
if catalog_name and catalog_name != "internal":
|
||||
try:
|
||||
validate_identifier(catalog_name, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid catalog name: {catalog_name}",
|
||||
message="Catalog name contains invalid characters"
|
||||
)
|
||||
|
||||
try:
|
||||
tables = await self.get_database_tables_async(db_name=db_name, catalog_name=catalog_name)
|
||||
return self._format_response(success=True, result=tables)
|
||||
@@ -1553,6 +1804,36 @@ class MetadataExtractor:
|
||||
if not table_name:
|
||||
return self._format_response(success=False, error="Missing table_name parameter")
|
||||
|
||||
# SECURITY: Validate identifiers
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid table name: {table_name}",
|
||||
message="Table name contains invalid characters"
|
||||
)
|
||||
|
||||
if db_name:
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid database name: {db_name}",
|
||||
message="Database name contains invalid characters"
|
||||
)
|
||||
|
||||
if catalog_name and catalog_name != "internal":
|
||||
try:
|
||||
validate_identifier(catalog_name, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid catalog name: {catalog_name}",
|
||||
message="Catalog name contains invalid characters"
|
||||
)
|
||||
|
||||
try:
|
||||
comment = await self.get_table_comment_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
return self._format_response(success=True, result=comment)
|
||||
@@ -1572,6 +1853,36 @@ class MetadataExtractor:
|
||||
if not table_name:
|
||||
return self._format_response(success=False, error="Missing table_name parameter")
|
||||
|
||||
# SECURITY: Validate identifiers
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid table name: {table_name}",
|
||||
message="Table name contains invalid characters"
|
||||
)
|
||||
|
||||
if db_name:
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid database name: {db_name}",
|
||||
message="Database name contains invalid characters"
|
||||
)
|
||||
|
||||
if catalog_name and catalog_name != "internal":
|
||||
try:
|
||||
validate_identifier(catalog_name, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid catalog name: {catalog_name}",
|
||||
message="Catalog name contains invalid characters"
|
||||
)
|
||||
|
||||
try:
|
||||
comments = await self.get_column_comments_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
return self._format_response(success=True, result=comments)
|
||||
@@ -1591,6 +1902,36 @@ class MetadataExtractor:
|
||||
if not table_name:
|
||||
return self._format_response(success=False, error="Missing table_name parameter")
|
||||
|
||||
# SECURITY: Validate identifiers
|
||||
try:
|
||||
validate_identifier(table_name, "table name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid table name: {table_name}",
|
||||
message="Table name contains invalid characters"
|
||||
)
|
||||
|
||||
if db_name:
|
||||
try:
|
||||
validate_identifier(db_name, "database name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid database name: {db_name}",
|
||||
message="Database name contains invalid characters"
|
||||
)
|
||||
|
||||
if catalog_name and catalog_name != "internal":
|
||||
try:
|
||||
validate_identifier(catalog_name, "catalog name")
|
||||
except SQLSecurityError as e:
|
||||
return self._format_response(
|
||||
success=False,
|
||||
error=f"Invalid catalog name: {catalog_name}",
|
||||
message="Catalog name contains invalid characters"
|
||||
)
|
||||
|
||||
try:
|
||||
indexes = await self.get_table_indexes_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
|
||||
return self._format_response(success=True, result=indexes)
|
||||
|
||||
@@ -22,6 +22,7 @@ Implements enterprise-level authentication, authorization, SQL security validati
|
||||
|
||||
import logging
|
||||
import re
|
||||
from contextvars import ContextVar
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
@@ -34,6 +35,10 @@ from sqlparse.tokens import Keyword, Name
|
||||
from .logger import get_logger
|
||||
from .config import DatabaseConfig
|
||||
|
||||
# Global ContextVar for auth_context - must be a single instance shared across all modules
|
||||
# This allows token-bound database configuration to work correctly in concurrent requests
|
||||
mcp_auth_context_var: ContextVar['AuthContext'] = ContextVar('mcp_auth_context', default=None)
|
||||
|
||||
|
||||
class SecurityLevel(Enum):
|
||||
"""Security level enumeration"""
|
||||
@@ -903,27 +908,47 @@ class SQLSecurityValidator:
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
try:
|
||||
# Parse SQL statement
|
||||
parsed = sqlparse.parse(sql)[0]
|
||||
# SECURITY FIX: Parse ALL SQL statements, not just the first one
|
||||
# This prevents bypassing security checks by injecting additional statements
|
||||
all_statements = sqlparse.parse(sql)
|
||||
|
||||
if not all_statements:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Empty or invalid SQL statement",
|
||||
risk_level="medium"
|
||||
)
|
||||
|
||||
# SECURITY FIX: Validate each statement individually
|
||||
for idx, parsed in enumerate(all_statements):
|
||||
# Skip empty statements (e.g., from trailing semicolons)
|
||||
if not parsed.tokens or str(parsed).strip() == '':
|
||||
continue
|
||||
|
||||
self.logger.debug(f"Validating SQL statement {idx + 1}/{len(all_statements)}: {str(parsed)[:100]}...")
|
||||
|
||||
# Check blocked operations first (more specific)
|
||||
keyword_result = await self._check_blocked_keywords(parsed)
|
||||
if not keyword_result.is_valid:
|
||||
keyword_result.error_message = f"Statement {idx + 1}: {keyword_result.error_message}"
|
||||
return keyword_result
|
||||
|
||||
# Check SQL injection risks
|
||||
injection_result = await self._check_sql_injection(sql, parsed)
|
||||
if not injection_result.is_valid:
|
||||
injection_result.error_message = f"Statement {idx + 1}: {injection_result.error_message}"
|
||||
return injection_result
|
||||
|
||||
# Check query complexity
|
||||
complexity_result = await self._check_query_complexity(parsed)
|
||||
if not complexity_result.is_valid:
|
||||
complexity_result.error_message = f"Statement {idx + 1}: {complexity_result.error_message}"
|
||||
return complexity_result
|
||||
|
||||
# Check table access permissions
|
||||
table_result = await self._check_table_access(parsed, auth_context)
|
||||
if not table_result.is_valid:
|
||||
table_result.error_message = f"Statement {idx + 1}: {table_result.error_message}"
|
||||
return table_result
|
||||
|
||||
return ValidationResult(is_valid=True)
|
||||
@@ -939,28 +964,69 @@ class SQLSecurityValidator:
|
||||
async def _check_sql_injection(
|
||||
self, sql: str, parsed: Statement
|
||||
) -> ValidationResult:
|
||||
"""Check SQL injection risks"""
|
||||
# Check common SQL injection patterns
|
||||
"""Check SQL injection risks with improved pattern detection
|
||||
|
||||
FIX for Issue #62 Bug 2: Improved patterns to reduce false positives
|
||||
Now better distinguishes between legitimate SQL (like BETWEEN...AND) and injection attempts
|
||||
"""
|
||||
# Improved injection patterns that are more specific and less prone to false positives
|
||||
injection_patterns = [
|
||||
r"(?i)(?<![A-Za-z0-9_])(union|select|insert|update|delete|drop|create|alter)(?![A-Za-z0-9_])\s+[\s\S]*?\s+(?<![A-Za-z0-9_])(union|select|insert|update|delete|drop|create|alter)(?![A-Za-z0-9_])",
|
||||
r"(\s|^)(or|and)\s+\d+\s*=\s*\d+",
|
||||
r"(\s|^)(or|and)\s+['\"].*['\"]",
|
||||
r";\s*(drop|delete|truncate|alter|create)",
|
||||
r"(exec|execute|sp_|xp_)",
|
||||
r"(script|javascript|vbscript)",
|
||||
r"(char|ascii|substring|concat)\s*\(",
|
||||
# Stacked queries with dangerous operations (true injection risk)
|
||||
r";\s*(DROP|DELETE|TRUNCATE|ALTER|CREATE|INSERT|UPDATE)\s+",
|
||||
|
||||
# UNION-based injection (but allow legitimate UNION queries)
|
||||
# Only flag if UNION is followed by suspicious patterns like SELECT with WHERE 1=1
|
||||
r"UNION\s+(ALL\s+)?SELECT\s+.*\s+(WHERE|AND|OR)\s+\d+\s*=\s*\d+",
|
||||
|
||||
# Boolean-based blind injection with comments (true injection pattern)
|
||||
r"(WHERE|AND|OR)\s+\d+\s*=\s*\d+\s*(--|#|/\*)",
|
||||
|
||||
# Quote-based injection attempts (but not in legitimate strings)
|
||||
r"(WHERE|AND|OR)\s+(['\"])[^\2]*\2\s*=\s*\2[^\2]*\2",
|
||||
|
||||
# Time-based blind injection
|
||||
r"(SLEEP|WAITFOR|BENCHMARK)\s*\(",
|
||||
|
||||
# System stored procedure injection
|
||||
r"(EXEC|EXECUTE|SP_|XP_)\s*\(",
|
||||
|
||||
# Script injection attempts
|
||||
r"<\s*(SCRIPT|JAVASCRIPT|VBSCRIPT)",
|
||||
]
|
||||
|
||||
sql_lower = sql.lower()
|
||||
# FIX: Don't flag legitimate SQL functions and keywords
|
||||
# These patterns are too broad and cause false positives:
|
||||
# - REMOVED: r"(char|ascii|substring|concat)\s*\(" - These are legitimate SQL functions
|
||||
# - REMOVED: r"(\s|^)(or|and)\s+\d+\s*=\s*\d+" - This flags BETWEEN...AND constructs
|
||||
# - REMOVED: r"(\s|^)(or|and)\s+['\"].*['\"]" - This is too broad
|
||||
|
||||
sql_upper = sql.upper()
|
||||
|
||||
# Special case: Allow BETWEEN...AND which is legitimate SQL
|
||||
# This prevents false positives like "WHERE dt BETWEEN '2025-01-01' AND '2025-01-31'"
|
||||
if "BETWEEN" in sql_upper and "AND" in sql_upper:
|
||||
# This is likely a BETWEEN clause, not injection
|
||||
# Check if AND appears in a BETWEEN context
|
||||
between_pattern = r"BETWEEN\s+[^\s]+\s+AND\s+[^\s]+"
|
||||
if re.search(between_pattern, sql_upper, re.IGNORECASE):
|
||||
# Remove BETWEEN clauses before checking other patterns
|
||||
sql_cleaned = re.sub(between_pattern, "BETWEEN_CLAUSE", sql_upper, flags=re.IGNORECASE)
|
||||
sql_to_check = sql_cleaned
|
||||
else:
|
||||
sql_to_check = sql_upper
|
||||
else:
|
||||
sql_to_check = sql_upper
|
||||
|
||||
for pattern in injection_patterns:
|
||||
if re.search(pattern, sql_lower, re.IGNORECASE):
|
||||
if re.search(pattern, sql_to_check, re.IGNORECASE):
|
||||
self.logger.warning(f"Potential SQL injection pattern detected: {pattern}")
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
error_message="Potential SQL injection risk detected",
|
||||
risk_level="high",
|
||||
)
|
||||
|
||||
# Check suspicious quotes and comments
|
||||
# Check suspicious quotes and comments (with improved detection)
|
||||
if self._has_suspicious_quotes_or_comments(sql):
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
@@ -971,20 +1037,68 @@ class SQLSecurityValidator:
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
def _has_suspicious_quotes_or_comments(self, sql: str) -> bool:
|
||||
"""Check suspicious quote and comment patterns"""
|
||||
# Check unmatched quotes
|
||||
single_quotes = sql.count("'")
|
||||
double_quotes = sql.count('"')
|
||||
"""Check suspicious quote and comment patterns with improved detection
|
||||
|
||||
FIX for Issue #62 Bug 2: Improved detection to reduce false positives
|
||||
Now distinguishes between legitimate comments/strings and injection attempts
|
||||
"""
|
||||
try:
|
||||
# Use sqlparse to parse the SQL and distinguish between code and comments/strings
|
||||
import sqlparse
|
||||
from sqlparse.tokens import Comment, String
|
||||
|
||||
# Parse the SQL
|
||||
parsed = sqlparse.parse(sql)
|
||||
if not parsed:
|
||||
# If parsing fails, be conservative
|
||||
return True
|
||||
|
||||
statement = parsed[0]
|
||||
|
||||
# Check for unmatched quotes ONLY in non-string tokens
|
||||
# This prevents false positives from legitimate string content
|
||||
non_string_content = []
|
||||
has_string_tokens = False
|
||||
|
||||
for token in statement.flatten():
|
||||
if token.ttype in (String.Single, String.Double):
|
||||
has_string_tokens = True
|
||||
# Skip string content - quotes inside strings are legitimate
|
||||
continue
|
||||
elif token.ttype in (Comment.Single, Comment.Multi):
|
||||
# Comments are generally OK, but check for suspicious injection patterns
|
||||
comment_value = str(token).lower()
|
||||
# Check if comment contains dangerous SQL keywords
|
||||
dangerous_in_comments = ['drop', 'delete', 'insert', 'update', 'union', 'exec', 'execute']
|
||||
if any(keyword in comment_value for keyword in dangerous_in_comments):
|
||||
self.logger.warning(f"Suspicious SQL keyword in comment: {token}")
|
||||
return True
|
||||
# Normal comments are OK
|
||||
continue
|
||||
else:
|
||||
# Accumulate non-string, non-comment content
|
||||
non_string_content.append(str(token))
|
||||
|
||||
# Check for unmatched quotes in non-string content
|
||||
non_string_text = ''.join(non_string_content)
|
||||
single_quotes = non_string_text.count("'")
|
||||
double_quotes = non_string_text.count('"')
|
||||
|
||||
# Only flag if there are unmatched quotes in actual SQL code (not in strings)
|
||||
if single_quotes % 2 != 0 or double_quotes % 2 != 0:
|
||||
return True
|
||||
|
||||
# Check SQL comments
|
||||
if "--" in sql or "/*" in sql:
|
||||
return True
|
||||
# FIX: Don't flag legitimate SQL comments
|
||||
# Comments are OK as long as they don't contain dangerous patterns (already checked above)
|
||||
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.logger.debug(f"SQL parsing error in quote/comment check: {e}")
|
||||
# On parsing error, fall back to conservative check
|
||||
# But be more lenient than before
|
||||
return False # Don't flag on parse errors to reduce false positives
|
||||
|
||||
async def _check_blocked_keywords(self, parsed: Statement) -> ValidationResult:
|
||||
"""Check blocked keywords"""
|
||||
blocked_operations = []
|
||||
@@ -1045,6 +1159,10 @@ class SQLSecurityValidator:
|
||||
self, parsed: Statement, auth_context: AuthContext
|
||||
) -> ValidationResult:
|
||||
"""Check table access permissions"""
|
||||
# If no auth_context, skip table access checks (rely on other security checks)
|
||||
if auth_context is None:
|
||||
return ValidationResult(is_valid=True)
|
||||
|
||||
# Extract table names from query
|
||||
tables = self._extract_table_names(parsed)
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from collections import Counter, defaultdict
|
||||
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
from .sql_security_utils import get_auth_context
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -192,7 +193,9 @@ class SecurityAnalyticsTools:
|
||||
LIMIT 10000
|
||||
"""
|
||||
|
||||
result = await connection.execute(audit_sql)
|
||||
# SECURITY FIX: Pass auth_context to execute
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(audit_sql, auth_context=auth_context)
|
||||
return result.data if result.data else []
|
||||
|
||||
except Exception as e:
|
||||
@@ -215,7 +218,8 @@ class SecurityAnalyticsTools:
|
||||
LIMIT 10000
|
||||
"""
|
||||
|
||||
result = await connection.execute(simple_audit_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(simple_audit_sql, auth_context=auth_context)
|
||||
return result.data if result.data else []
|
||||
|
||||
except Exception as e2:
|
||||
@@ -498,7 +502,8 @@ class SecurityAnalyticsTools:
|
||||
FROM mysql.user
|
||||
"""
|
||||
|
||||
result = await connection.execute(roles_sql)
|
||||
auth_context = get_auth_context()
|
||||
result = await connection.execute(roles_sql, auth_context=auth_context)
|
||||
|
||||
user_roles = defaultdict(list)
|
||||
if result.data:
|
||||
|
||||
301
doris_mcp_server/utils/sql_security_utils.py
Normal file
301
doris_mcp_server/utils/sql_security_utils.py
Normal file
@@ -0,0 +1,301 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
SQL Security Utilities Module
|
||||
|
||||
Provides SQL identifier validation, escaping, and safe query building utilities
|
||||
to prevent SQL injection attacks.
|
||||
"""
|
||||
|
||||
import re
|
||||
from contextvars import ContextVar
|
||||
from typing import Optional, Tuple, List, Any
|
||||
|
||||
from .logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Context variable for auth_context (set by HTTP middleware)
|
||||
auth_context_var: ContextVar = ContextVar('mcp_auth_context', default=None)
|
||||
|
||||
|
||||
class SQLSecurityError(Exception):
|
||||
"""Exception raised for SQL security validation failures"""
|
||||
pass
|
||||
|
||||
|
||||
class SQLSecurityUtils:
|
||||
"""
|
||||
SQL Security Utilities for preventing SQL injection attacks.
|
||||
|
||||
Provides:
|
||||
- Identifier validation (database names, table names, column names)
|
||||
- Safe identifier quoting with backticks
|
||||
- Safe table reference building
|
||||
- Auth context retrieval from context variables
|
||||
"""
|
||||
|
||||
# Valid SQL identifier pattern: letters, numbers, underscores
|
||||
# Must start with letter or underscore, not a number
|
||||
# Supports Unicode letters for international database/table names
|
||||
IDENTIFIER_PATTERN = re.compile(r'^[a-zA-Z_\u4e00-\u9fff][a-zA-Z0-9_\u4e00-\u9fff]*$')
|
||||
|
||||
# Maximum identifier length (MySQL/Doris standard)
|
||||
MAX_IDENTIFIER_LENGTH = 64
|
||||
|
||||
# SQL reserved keywords that should be quoted
|
||||
SQL_KEYWORDS = {
|
||||
'SELECT', 'FROM', 'WHERE', 'INSERT', 'UPDATE', 'DELETE', 'DROP',
|
||||
'CREATE', 'ALTER', 'TABLE', 'DATABASE', 'INDEX', 'VIEW', 'AND',
|
||||
'OR', 'NOT', 'NULL', 'TRUE', 'FALSE', 'IN', 'LIKE', 'BETWEEN',
|
||||
'JOIN', 'LEFT', 'RIGHT', 'INNER', 'OUTER', 'ON', 'AS', 'ORDER',
|
||||
'BY', 'GROUP', 'HAVING', 'LIMIT', 'OFFSET', 'UNION', 'ALL',
|
||||
'DISTINCT', 'INTO', 'VALUES', 'SET', 'DEFAULT', 'PRIMARY', 'KEY',
|
||||
'FOREIGN', 'REFERENCES', 'CHECK', 'UNIQUE', 'CONSTRAINT'
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def validate_identifier(cls, name: str, identifier_type: str = "identifier") -> str:
|
||||
"""
|
||||
Validate a SQL identifier (database name, table name, column name, etc.)
|
||||
|
||||
Args:
|
||||
name: The identifier to validate
|
||||
identifier_type: Type description for error messages (e.g., "database name", "table name")
|
||||
|
||||
Returns:
|
||||
The validated identifier (unchanged if valid)
|
||||
|
||||
Raises:
|
||||
SQLSecurityError: If the identifier is invalid
|
||||
"""
|
||||
if not name:
|
||||
raise SQLSecurityError(f"Empty {identifier_type} is not allowed")
|
||||
|
||||
if not isinstance(name, str):
|
||||
raise SQLSecurityError(f"Invalid {identifier_type}: must be a string, got {type(name).__name__}")
|
||||
|
||||
# Strip whitespace
|
||||
name = name.strip()
|
||||
|
||||
if not name:
|
||||
raise SQLSecurityError(f"Empty {identifier_type} is not allowed")
|
||||
|
||||
# Check length
|
||||
if len(name) > cls.MAX_IDENTIFIER_LENGTH:
|
||||
raise SQLSecurityError(
|
||||
f"Invalid {identifier_type}: '{name[:20]}...' exceeds maximum length of {cls.MAX_IDENTIFIER_LENGTH} characters"
|
||||
)
|
||||
|
||||
# Check for dangerous characters that could be SQL injection
|
||||
dangerous_chars = ["'", '"', ';', '--', '/*', '*/', '\\', '\x00']
|
||||
for char in dangerous_chars:
|
||||
if char in name:
|
||||
raise SQLSecurityError(
|
||||
f"Invalid {identifier_type}: '{name}' contains forbidden character '{char}'"
|
||||
)
|
||||
|
||||
# Validate pattern
|
||||
if not cls.IDENTIFIER_PATTERN.match(name):
|
||||
raise SQLSecurityError(
|
||||
f"Invalid {identifier_type}: '{name}' contains invalid characters. "
|
||||
f"Only letters, numbers, and underscores are allowed, and must start with a letter or underscore."
|
||||
)
|
||||
|
||||
logger.debug(f"Validated {identifier_type}: {name}")
|
||||
return name
|
||||
|
||||
@classmethod
|
||||
def quote_identifier(cls, name: str, identifier_type: str = "identifier") -> str:
|
||||
"""
|
||||
Safely quote a SQL identifier using backticks.
|
||||
|
||||
Args:
|
||||
name: The identifier to quote
|
||||
identifier_type: Type description for error messages
|
||||
|
||||
Returns:
|
||||
The quoted identifier (e.g., `table_name`)
|
||||
|
||||
Raises:
|
||||
SQLSecurityError: If the identifier is invalid
|
||||
"""
|
||||
# First validate the identifier
|
||||
validated_name = cls.validate_identifier(name, identifier_type)
|
||||
|
||||
# Escape any backticks within the name (double them)
|
||||
escaped_name = validated_name.replace('`', '``')
|
||||
|
||||
return f"`{escaped_name}`"
|
||||
|
||||
@classmethod
|
||||
def build_table_reference(
|
||||
cls,
|
||||
table_name: str,
|
||||
db_name: Optional[str] = None,
|
||||
catalog_name: Optional[str] = None,
|
||||
quote: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Build a safe, fully-qualified table reference.
|
||||
|
||||
Args:
|
||||
table_name: The table name (required)
|
||||
db_name: The database name (optional)
|
||||
catalog_name: The catalog name (optional)
|
||||
quote: Whether to quote identifiers with backticks (default: True)
|
||||
|
||||
Returns:
|
||||
A safe table reference string (e.g., `catalog`.`db`.`table`)
|
||||
|
||||
Raises:
|
||||
SQLSecurityError: If any identifier is invalid
|
||||
"""
|
||||
parts = []
|
||||
|
||||
if catalog_name:
|
||||
if quote:
|
||||
parts.append(cls.quote_identifier(catalog_name, "catalog name"))
|
||||
else:
|
||||
parts.append(cls.validate_identifier(catalog_name, "catalog name"))
|
||||
|
||||
if db_name:
|
||||
if quote:
|
||||
parts.append(cls.quote_identifier(db_name, "database name"))
|
||||
else:
|
||||
parts.append(cls.validate_identifier(db_name, "database name"))
|
||||
|
||||
if quote:
|
||||
parts.append(cls.quote_identifier(table_name, "table name"))
|
||||
else:
|
||||
parts.append(cls.validate_identifier(table_name, "table name"))
|
||||
|
||||
return '.'.join(parts)
|
||||
|
||||
@classmethod
|
||||
def build_column_reference(
|
||||
cls,
|
||||
column_name: str,
|
||||
table_name: Optional[str] = None,
|
||||
quote: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Build a safe column reference.
|
||||
|
||||
Args:
|
||||
column_name: The column name (required)
|
||||
table_name: The table name (optional, for qualified references)
|
||||
quote: Whether to quote identifiers with backticks (default: True)
|
||||
|
||||
Returns:
|
||||
A safe column reference string (e.g., `table`.`column`)
|
||||
|
||||
Raises:
|
||||
SQLSecurityError: If any identifier is invalid
|
||||
"""
|
||||
parts = []
|
||||
|
||||
if table_name:
|
||||
if quote:
|
||||
parts.append(cls.quote_identifier(table_name, "table name"))
|
||||
else:
|
||||
parts.append(cls.validate_identifier(table_name, "table name"))
|
||||
|
||||
if quote:
|
||||
parts.append(cls.quote_identifier(column_name, "column name"))
|
||||
else:
|
||||
parts.append(cls.validate_identifier(column_name, "column name"))
|
||||
|
||||
return '.'.join(parts)
|
||||
|
||||
@classmethod
|
||||
def validate_and_build_where_condition(
|
||||
cls,
|
||||
column_name: str,
|
||||
operator: str = "=",
|
||||
use_param: bool = True
|
||||
) -> Tuple[str, bool]:
|
||||
"""
|
||||
Build a safe WHERE condition for a column.
|
||||
|
||||
Args:
|
||||
column_name: The column name
|
||||
operator: The comparison operator (=, !=, <, >, <=, >=, LIKE, IN)
|
||||
use_param: Whether to use parameterized placeholder (%s)
|
||||
|
||||
Returns:
|
||||
Tuple of (condition_string, needs_param)
|
||||
e.g., ("`column` = %s", True) or ("`column` = DATABASE()", False)
|
||||
|
||||
Raises:
|
||||
SQLSecurityError: If column name is invalid or operator is not allowed
|
||||
"""
|
||||
# Validate column name
|
||||
quoted_column = cls.quote_identifier(column_name, "column name")
|
||||
|
||||
# Validate operator
|
||||
allowed_operators = {'=', '!=', '<>', '<', '>', '<=', '>=', 'LIKE', 'IN', 'IS'}
|
||||
if operator.upper() not in allowed_operators:
|
||||
raise SQLSecurityError(f"Invalid operator: '{operator}'. Allowed: {allowed_operators}")
|
||||
|
||||
if use_param:
|
||||
return f"{quoted_column} {operator} %s", True
|
||||
else:
|
||||
return f"{quoted_column} {operator}", False
|
||||
|
||||
@staticmethod
|
||||
def get_auth_context():
|
||||
"""
|
||||
Get auth_context from the context variable.
|
||||
|
||||
This retrieves the auth_context that was set by the HTTP middleware
|
||||
during request processing.
|
||||
|
||||
Returns:
|
||||
The auth_context object, or None if not available
|
||||
"""
|
||||
try:
|
||||
auth_context = auth_context_var.get()
|
||||
if auth_context:
|
||||
logger.debug(f"Retrieved auth_context from context variable")
|
||||
return auth_context
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not retrieve auth_context: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def set_auth_context(auth_context):
|
||||
"""
|
||||
Set auth_context in the context variable.
|
||||
|
||||
This is typically called by the HTTP middleware during request processing.
|
||||
|
||||
Args:
|
||||
auth_context: The auth_context object to set
|
||||
"""
|
||||
auth_context_var.set(auth_context)
|
||||
logger.debug("Set auth_context in context variable")
|
||||
|
||||
|
||||
# Convenience functions for direct use
|
||||
validate_identifier = SQLSecurityUtils.validate_identifier
|
||||
quote_identifier = SQLSecurityUtils.quote_identifier
|
||||
build_table_reference = SQLSecurityUtils.build_table_reference
|
||||
build_column_reference = SQLSecurityUtils.build_column_reference
|
||||
get_auth_context = SQLSecurityUtils.get_auth_context
|
||||
set_auth_context = SQLSecurityUtils.set_auth_context
|
||||
|
||||
@@ -20,7 +20,7 @@ build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "doris-mcp-server"
|
||||
version = "0.6.0"
|
||||
version = "0.6.1"
|
||||
description = "Enterprise-grade Model Context Protocol (MCP) server implementation for Apache Doris"
|
||||
authors = [
|
||||
{name = "Yijia Su", email = "freeoneplus@apache.org"}
|
||||
|
||||
@@ -64,9 +64,10 @@ else
|
||||
fi
|
||||
|
||||
# Set HTTP-specific environment variables
|
||||
# FIX for Issue #62 Bug 4: Use SERVER_PORT instead of MCP_PORT for consistency with code
|
||||
export MCP_TRANSPORT_TYPE="http"
|
||||
export MCP_HOST="${MCP_HOST:-0.0.0.0}"
|
||||
export MCP_PORT="${MCP_PORT:-3000}"
|
||||
export SERVER_PORT="${SERVER_PORT:-3000}" # Changed from MCP_PORT to SERVER_PORT
|
||||
export WORKERS="${WORKERS:-1}"
|
||||
export ALLOWED_ORIGINS="${ALLOWED_ORIGINS:-*}"
|
||||
export LOG_LEVEL="${LOG_LEVEL:-info}"
|
||||
@@ -77,15 +78,15 @@ export MCP_DEBUG_ADAPTER="true"
|
||||
export PYTHONPATH="$(pwd):$PYTHONPATH"
|
||||
|
||||
echo -e "${GREEN}Starting MCP server (Streamable HTTP mode)...${NC}"
|
||||
echo -e "${YELLOW}Service will run on http://${MCP_HOST}:${MCP_PORT}/mcp${NC}"
|
||||
echo -e "${YELLOW}Health Check: http://${MCP_HOST}:${MCP_PORT}/health${NC}"
|
||||
echo -e "${YELLOW}MCP Endpoint: http://${MCP_HOST}:${MCP_PORT}/mcp${NC}"
|
||||
echo -e "${YELLOW}Local access: http://localhost:${MCP_PORT}/mcp${NC}"
|
||||
echo -e "${YELLOW}Service will run on http://${MCP_HOST}:${SERVER_PORT}/mcp${NC}"
|
||||
echo -e "${YELLOW}Health Check: http://${MCP_HOST}:${SERVER_PORT}/health${NC}"
|
||||
echo -e "${YELLOW}MCP Endpoint: http://${MCP_HOST}:${SERVER_PORT}/mcp${NC}"
|
||||
echo -e "${YELLOW}Local access: http://localhost:${SERVER_PORT}/mcp${NC}"
|
||||
echo -e "${YELLOW}Workers: ${WORKERS}${NC}"
|
||||
echo -e "${YELLOW}Use Ctrl+C to stop the service${NC}"
|
||||
|
||||
# Start the server in HTTP mode (Streamable HTTP)
|
||||
python -m doris_mcp_server.main --transport http --host ${MCP_HOST} --port ${MCP_PORT} --workers ${WORKERS}
|
||||
python -m doris_mcp_server.main --transport http --host ${MCP_HOST} --port ${SERVER_PORT} --workers ${WORKERS}
|
||||
|
||||
# Check exit status
|
||||
if [ $? -ne 0 ]; then
|
||||
@@ -97,4 +98,4 @@ fi
|
||||
echo -e "${YELLOW}Tip: If the page displays abnormally, please clear your browser cache or use incognito mode${NC}"
|
||||
echo -e "${YELLOW}Chrome browser clear cache shortcut: Ctrl+Shift+Del (Windows) or Cmd+Shift+Del (Mac)${NC}"
|
||||
echo -e "${CYAN}For testing HTTP endpoints, you can use:${NC}"
|
||||
echo -e "${CYAN} curl -X POST http://localhost:${MCP_PORT}/mcp -H 'Content-Type: application/json' -d '{\"method\":\"tools/list\"}'${NC}"
|
||||
echo -e "${CYAN} curl -X POST http://localhost:${SERVER_PORT}/mcp -H 'Content-Type: application/json' -d '{\"method\":\"tools/list\"}'${NC}"
|
||||
367
test/security/test_sql_injection.py
Normal file
367
test/security/test_sql_injection.py
Normal file
@@ -0,0 +1,367 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
SQL Security Test Suite for Apache Doris MCP Server
|
||||
|
||||
Tests for:
|
||||
1. SQL injection prevention via identifier validation
|
||||
2. Multi-statement SQL parsing in security validator
|
||||
3. auth_context enforcement
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
|
||||
class TestSQLSecurityUtils:
|
||||
"""Test cases for sql_security_utils module"""
|
||||
|
||||
def test_validate_identifier_accepts_valid_names(self):
|
||||
"""Test that valid identifiers are accepted"""
|
||||
from doris_mcp_server.utils.sql_security_utils import validate_identifier
|
||||
|
||||
valid_names = [
|
||||
"users",
|
||||
"my_table",
|
||||
"Table123",
|
||||
"_private_table",
|
||||
"CamelCaseTable",
|
||||
"table_with_numbers_123",
|
||||
]
|
||||
|
||||
for name in valid_names:
|
||||
result = validate_identifier(name, "table")
|
||||
assert result == name, f"Valid identifier '{name}' should be accepted"
|
||||
|
||||
def test_validate_identifier_rejects_sql_injection(self):
|
||||
"""Test that SQL injection attempts are rejected"""
|
||||
from doris_mcp_server.utils.sql_security_utils import (
|
||||
validate_identifier,
|
||||
SQLSecurityError
|
||||
)
|
||||
|
||||
injection_attempts = [
|
||||
# Basic SQL injection
|
||||
"'; DROP TABLE users; --",
|
||||
"table' OR '1'='1",
|
||||
"table'; DELETE FROM users; --",
|
||||
|
||||
# Union-based injection
|
||||
"table' UNION SELECT * FROM passwords --",
|
||||
|
||||
# Comment injection
|
||||
"table/**/OR/**/1=1",
|
||||
"table--comment",
|
||||
|
||||
# Special characters
|
||||
"table`; DROP TABLE users;",
|
||||
'table"; DROP TABLE users;',
|
||||
"table\"; DELETE FROM",
|
||||
|
||||
# Backtick escape attempt
|
||||
"analytics`; SELECT * FROM sensitive_table;--",
|
||||
|
||||
# Whitespace injection
|
||||
"table name with spaces",
|
||||
"table\ttab",
|
||||
"table\nnewline",
|
||||
]
|
||||
|
||||
for injection in injection_attempts:
|
||||
with pytest.raises(SQLSecurityError):
|
||||
validate_identifier(injection, "table")
|
||||
|
||||
def test_validate_identifier_rejects_empty(self):
|
||||
"""Test that empty identifiers are rejected"""
|
||||
from doris_mcp_server.utils.sql_security_utils import (
|
||||
validate_identifier,
|
||||
SQLSecurityError
|
||||
)
|
||||
|
||||
with pytest.raises(SQLSecurityError):
|
||||
validate_identifier("", "table")
|
||||
|
||||
with pytest.raises(SQLSecurityError):
|
||||
validate_identifier(None, "table")
|
||||
|
||||
def test_validate_identifier_rejects_too_long(self):
|
||||
"""Test that identifiers exceeding max length are rejected"""
|
||||
from doris_mcp_server.utils.sql_security_utils import (
|
||||
validate_identifier,
|
||||
SQLSecurityError
|
||||
)
|
||||
|
||||
# Doris identifier max length is typically 64 characters
|
||||
long_name = "a" * 100
|
||||
with pytest.raises(SQLSecurityError):
|
||||
validate_identifier(long_name, "table")
|
||||
|
||||
def test_quote_identifier_adds_backticks(self):
|
||||
"""Test that quote_identifier properly escapes identifiers"""
|
||||
from doris_mcp_server.utils.sql_security_utils import quote_identifier
|
||||
|
||||
assert quote_identifier("my_table", "table") == "`my_table`"
|
||||
assert quote_identifier("users", "table") == "`users`"
|
||||
assert quote_identifier("Table123", "table") == "`Table123`"
|
||||
|
||||
def test_quote_identifier_validates_first(self):
|
||||
"""Test that quote_identifier validates before quoting"""
|
||||
from doris_mcp_server.utils.sql_security_utils import (
|
||||
quote_identifier,
|
||||
SQLSecurityError
|
||||
)
|
||||
|
||||
with pytest.raises(SQLSecurityError):
|
||||
quote_identifier("'; DROP TABLE users; --", "table")
|
||||
|
||||
|
||||
class TestSQLSecurityValidator:
|
||||
"""Test cases for SQLSecurityValidator multi-statement parsing"""
|
||||
|
||||
@pytest.fixture
|
||||
def dict_config(self):
|
||||
"""Create dictionary configuration"""
|
||||
return {
|
||||
"blocked_keywords": [
|
||||
"DROP", "CREATE", "ALTER", "TRUNCATE",
|
||||
"DELETE", "INSERT", "UPDATE",
|
||||
"GRANT", "REVOKE", "EXEC", "EXECUTE"
|
||||
],
|
||||
"max_query_complexity": 100,
|
||||
"enable_security_check": True
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth_context(self):
|
||||
"""Create mock auth context"""
|
||||
from doris_mcp_server.utils.security import AuthContext, SecurityLevel
|
||||
return AuthContext(
|
||||
user_id="test_user",
|
||||
roles=["user"],
|
||||
security_level=SecurityLevel.INTERNAL
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validates_all_statements(self, dict_config, mock_auth_context):
|
||||
"""Test that validator checks ALL SQL statements, not just the first"""
|
||||
from doris_mcp_server.utils.security import SQLSecurityValidator
|
||||
|
||||
validator = SQLSecurityValidator(dict_config)
|
||||
|
||||
# Multi-statement with injection in second statement
|
||||
# This should be BLOCKED
|
||||
malicious_sql = "SELECT 1; DROP TABLE users; SELECT 2"
|
||||
|
||||
result = await validator.validate(malicious_sql, mock_auth_context)
|
||||
|
||||
assert not result.is_valid, "Multi-statement injection should be blocked"
|
||||
# Check for either DROP keyword detection or SQL injection detection
|
||||
error_upper = result.error_message.upper()
|
||||
assert ("DROP" in error_upper or
|
||||
"INJECTION" in error_upper or
|
||||
"BLOCKED" in error_upper), f"Expected DROP/injection/blocked in: {result.error_message}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocks_hidden_dangerous_statement(self, dict_config, mock_auth_context):
|
||||
"""Test that dangerous statements hidden after safe ones are blocked"""
|
||||
from doris_mcp_server.utils.security import SQLSecurityValidator
|
||||
|
||||
validator = SQLSecurityValidator(dict_config)
|
||||
|
||||
# Safe statement followed by dangerous one
|
||||
malicious_sql = """
|
||||
SELECT * FROM users WHERE id = 1;
|
||||
DELETE FROM audit_log;
|
||||
SELECT 1;
|
||||
"""
|
||||
|
||||
result = await validator.validate(malicious_sql, mock_auth_context)
|
||||
|
||||
assert not result.is_valid, "Hidden DELETE statement should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allows_safe_multi_statement(self, dict_config, mock_auth_context):
|
||||
"""Test that multiple safe SELECT statements are allowed"""
|
||||
from doris_mcp_server.utils.security import SQLSecurityValidator
|
||||
|
||||
validator = SQLSecurityValidator(dict_config)
|
||||
|
||||
safe_sql = """
|
||||
SELECT * FROM users;
|
||||
SELECT COUNT(*) FROM orders;
|
||||
SELECT id, name FROM products;
|
||||
"""
|
||||
|
||||
result = await validator.validate(safe_sql, mock_auth_context)
|
||||
|
||||
assert result.is_valid, f"Multiple safe SELECT statements should be allowed, got: {result.error_message}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_context_switch_injection_blocked(self, dict_config, mock_auth_context):
|
||||
"""Test that context switch SQL injection is blocked"""
|
||||
from doris_mcp_server.utils.security import SQLSecurityValidator
|
||||
|
||||
validator = SQLSecurityValidator(dict_config)
|
||||
|
||||
# Simulating the exec_query_for_mcp attack vector
|
||||
injected_sql = """
|
||||
USE `analytics`; SELECT * FROM sensitive_table;-- `;
|
||||
SELECT * FROM public_table;
|
||||
"""
|
||||
|
||||
result = await validator.validate(injected_sql, mock_auth_context)
|
||||
|
||||
# The validator should process all statements
|
||||
# Even if USE is allowed, subsequent unauthorized access should be caught
|
||||
# by table access checks (if configured)
|
||||
|
||||
|
||||
class TestExecQueryForMCP:
|
||||
"""Test cases for exec_query_for_mcp function"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_malicious_db_name(self):
|
||||
"""Test that malicious db_name is rejected"""
|
||||
from doris_mcp_server.utils.sql_security_utils import (
|
||||
validate_identifier,
|
||||
SQLSecurityError
|
||||
)
|
||||
|
||||
# The attack vector from security report
|
||||
malicious_db_name = "analytics`; SELECT * FROM sensitive_table;--"
|
||||
|
||||
with pytest.raises(SQLSecurityError):
|
||||
validate_identifier(malicious_db_name, "database name")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_malicious_catalog_name(self):
|
||||
"""Test that malicious catalog_name is rejected"""
|
||||
from doris_mcp_server.utils.sql_security_utils import (
|
||||
validate_identifier,
|
||||
SQLSecurityError
|
||||
)
|
||||
|
||||
malicious_catalog_name = "internal'; DROP DATABASE production;--"
|
||||
|
||||
with pytest.raises(SQLSecurityError):
|
||||
validate_identifier(malicious_catalog_name, "catalog name")
|
||||
|
||||
|
||||
class TestDependencyAnalysisTools:
|
||||
"""Test cases for dependency_analysis_tools security fixes"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tables_metadata_rejects_injection(self):
|
||||
"""Test that _get_tables_metadata rejects SQL injection"""
|
||||
from doris_mcp_server.utils.sql_security_utils import (
|
||||
validate_identifier,
|
||||
SQLSecurityError
|
||||
)
|
||||
|
||||
# The attack vector from security report
|
||||
injection_db_name = "test_db' OR '1'='1' --"
|
||||
|
||||
with pytest.raises(SQLSecurityError):
|
||||
validate_identifier(injection_db_name, "database name")
|
||||
|
||||
|
||||
class TestAuthContextEnforcement:
|
||||
"""Test cases for auth_context enforcement"""
|
||||
|
||||
def test_execute_requires_auth_context_for_security(self):
|
||||
"""Test that security checks require auth_context"""
|
||||
# This test documents the expected behavior:
|
||||
# When auth_context is None, security checks are skipped
|
||||
# When auth_context is provided, security checks are performed
|
||||
|
||||
# The fix ensures all execute() calls pass auth_context
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_auth_context_returns_context(self):
|
||||
"""Test that get_auth_context retrieves context from ContextVar"""
|
||||
from doris_mcp_server.utils.sql_security_utils import get_auth_context
|
||||
|
||||
# When no context is set, should return None
|
||||
result = get_auth_context()
|
||||
# This is expected - context is set by HTTP middleware
|
||||
assert result is None or hasattr(result, 'user_id')
|
||||
|
||||
|
||||
class TestIntegrationScenarios:
|
||||
"""Integration test scenarios for security fixes"""
|
||||
|
||||
def test_attack_scenario_1_permission_bypass(self):
|
||||
"""
|
||||
Attack Scenario 1: Permission Bypass in Multi-Tenant Environment
|
||||
|
||||
Expected: User can only query their own database (db_name="tenant_a_db")
|
||||
Attack: Inject "tenant_a_db' OR '1'='1' --" to query ALL databases
|
||||
Result: Should be BLOCKED by validate_identifier()
|
||||
"""
|
||||
from doris_mcp_server.utils.sql_security_utils import (
|
||||
validate_identifier,
|
||||
SQLSecurityError
|
||||
)
|
||||
|
||||
with pytest.raises(SQLSecurityError):
|
||||
validate_identifier("tenant_a_db' OR '1'='1' --", "database name")
|
||||
|
||||
def test_attack_scenario_2_union_injection(self):
|
||||
"""
|
||||
Attack Scenario 2: UNION-based Information Disclosure
|
||||
|
||||
Attack: Inject UNION SELECT to extract sensitive data
|
||||
Result: Should be BLOCKED
|
||||
"""
|
||||
from doris_mcp_server.utils.sql_security_utils import (
|
||||
validate_identifier,
|
||||
SQLSecurityError
|
||||
)
|
||||
|
||||
with pytest.raises(SQLSecurityError):
|
||||
validate_identifier(
|
||||
"test' UNION SELECT password FROM users --",
|
||||
"database name"
|
||||
)
|
||||
|
||||
def test_attack_scenario_3_backtick_escape(self):
|
||||
"""
|
||||
Attack Scenario 3: Backtick Escape Attempt
|
||||
|
||||
Attack: Use backticks to break out of quoted identifier
|
||||
Result: Should be BLOCKED
|
||||
"""
|
||||
from doris_mcp_server.utils.sql_security_utils import (
|
||||
validate_identifier,
|
||||
SQLSecurityError
|
||||
)
|
||||
|
||||
with pytest.raises(SQLSecurityError):
|
||||
validate_identifier(
|
||||
"analytics`; SELECT * FROM sensitive_table;--",
|
||||
"database name"
|
||||
)
|
||||
|
||||
|
||||
# Run tests with: pytest tests/test_sql_security.py -v
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
|
||||
871
test/security/test_sql_injection_api.py
Normal file
871
test/security/test_sql_injection_api.py
Normal file
@@ -0,0 +1,871 @@
|
||||
#!/usr/bin/env python3
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
SQL Injection API Integration Tests
|
||||
|
||||
This module tests SQL injection prevention through the MCP HTTP API.
|
||||
It sends malicious payloads and verifies they are properly blocked.
|
||||
|
||||
Prerequisites:
|
||||
- MCP server running on localhost:3000
|
||||
- Run with: pytest test/security/test_sql_injection_api.py -v
|
||||
|
||||
Usage:
|
||||
# Start server first
|
||||
bash start_server.sh
|
||||
|
||||
# Run tests
|
||||
pytest test/security/test_sql_injection_api.py -v --no-cov
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import httpx
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# Server configuration
|
||||
MCP_BASE_URL = "http://localhost:3000"
|
||||
MCP_ENDPOINT = f"{MCP_BASE_URL}/mcp"
|
||||
HEALTH_ENDPOINT = f"{MCP_BASE_URL}/health"
|
||||
TIMEOUT = 30.0
|
||||
|
||||
|
||||
class MCPClient:
|
||||
"""Simple MCP HTTP client for testing"""
|
||||
|
||||
def __init__(self, base_url: str = MCP_BASE_URL):
|
||||
self.base_url = base_url
|
||||
self.mcp_endpoint = f"{base_url}/mcp"
|
||||
self.session_id: Optional[str] = None
|
||||
self.request_id = 0
|
||||
self.client = httpx.AsyncClient(timeout=TIMEOUT)
|
||||
|
||||
async def close(self):
|
||||
await self.client.aclose()
|
||||
|
||||
def _next_id(self) -> int:
|
||||
self.request_id += 1
|
||||
return self.request_id
|
||||
|
||||
async def initialize(self) -> dict:
|
||||
"""Initialize MCP session"""
|
||||
response = await self.client.post(
|
||||
self.mcp_endpoint,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream"
|
||||
},
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {},
|
||||
"clientInfo": {
|
||||
"name": "sql-injection-test",
|
||||
"version": "1.0.0"
|
||||
}
|
||||
},
|
||||
"id": self._next_id()
|
||||
}
|
||||
)
|
||||
|
||||
# Extract session ID from response header
|
||||
self.session_id = response.headers.get("mcp-session-id")
|
||||
return self._parse_response(response.text)
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: dict) -> dict:
|
||||
"""Call an MCP tool"""
|
||||
if not self.session_id:
|
||||
await self.initialize()
|
||||
|
||||
response = await self.client.post(
|
||||
self.mcp_endpoint,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json, text/event-stream",
|
||||
"mcp-session-id": self.session_id
|
||||
},
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "tools/call",
|
||||
"params": {
|
||||
"name": tool_name,
|
||||
"arguments": arguments
|
||||
},
|
||||
"id": self._next_id()
|
||||
}
|
||||
)
|
||||
|
||||
return self._parse_response(response.text)
|
||||
|
||||
def _parse_response(self, text: str) -> dict:
|
||||
"""Parse JSON response"""
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
# Try SSE format
|
||||
lines = text.strip().split("\n")
|
||||
for line in lines:
|
||||
if line.startswith("data: "):
|
||||
try:
|
||||
return json.loads(line[6:])
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
return {"raw": text}
|
||||
|
||||
|
||||
def print_result(test_name: str, payload: dict, result: dict):
|
||||
"""Print test result in a readable format"""
|
||||
print(f"\n{'='*60}")
|
||||
print(f"TEST: {test_name}")
|
||||
print(f"{'='*60}")
|
||||
print(f"PAYLOAD: {json.dumps(payload, ensure_ascii=False)}")
|
||||
print(f"{'-'*60}")
|
||||
|
||||
# Extract inner result content
|
||||
if "result" in result and "content" in result.get("result", {}):
|
||||
for item in result["result"]["content"]:
|
||||
if item.get("type") == "text":
|
||||
try:
|
||||
inner = json.loads(item["text"])
|
||||
print("RESPONSE:")
|
||||
print(f" success: {inner.get('success')}")
|
||||
if inner.get('error'):
|
||||
print(f" error: {inner.get('error')}")
|
||||
if inner.get('error_type'):
|
||||
print(f" error_type: {inner.get('error_type')}")
|
||||
if inner.get('risk_level'):
|
||||
print(f" risk_level: {inner.get('risk_level')}")
|
||||
if inner.get('message'):
|
||||
print(f" message: {inner.get('message')}")
|
||||
if inner.get('data') is not None and inner.get('success'):
|
||||
data_str = json.dumps(inner.get('data'), ensure_ascii=False)
|
||||
if len(data_str) > 200:
|
||||
data_str = data_str[:200] + "..."
|
||||
print(f" data: {data_str}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
print(f"RESPONSE (raw): {item.get('text', '')[:500]}")
|
||||
elif "error" in result:
|
||||
print(f"RESPONSE ERROR: {result['error']}")
|
||||
else:
|
||||
print(f"RESPONSE (raw): {json.dumps(result, ensure_ascii=False)[:500]}")
|
||||
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
|
||||
class TestSQLInjectionAPI:
|
||||
"""Test SQL injection prevention through MCP API"""
|
||||
|
||||
@pytest.fixture
|
||||
async def mcp_client(self):
|
||||
"""Create MCP client instance"""
|
||||
client = MCPClient()
|
||||
yield client
|
||||
await client.close()
|
||||
|
||||
@pytest.fixture
|
||||
def is_server_running(self):
|
||||
"""Check if MCP server is running"""
|
||||
import httpx
|
||||
try:
|
||||
response = httpx.get(HEALTH_ENDPOINT, timeout=5.0)
|
||||
return response.status_code == 200
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_server_health(self):
|
||||
"""Test that MCP server is running and healthy"""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(HEALTH_ENDPOINT)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "healthy"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_query_with_drop_injection(self, mcp_client):
|
||||
"""Test exec_query rejects DROP TABLE injection"""
|
||||
# Classic SQL injection: append DROP TABLE
|
||||
payload = {"sql": "SELECT * FROM users; DROP TABLE users; --"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("DROP TABLE Injection", payload, result)
|
||||
|
||||
# Should return error, not execute the DROP
|
||||
assert self._is_blocked_or_error(result), \
|
||||
f"DROP TABLE injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_query_with_union_injection(self, mcp_client):
|
||||
"""Test exec_query blocks UNION-based injection attempts"""
|
||||
# UNION injection to extract data from other tables
|
||||
payload = {"sql": "SELECT id FROM users UNION SELECT password FROM admin_users"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("UNION Injection", payload, result)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_query_with_delete_injection(self, mcp_client):
|
||||
"""Test exec_query rejects DELETE injection"""
|
||||
payload = {"sql": "SELECT 1; DELETE FROM users WHERE 1=1; SELECT 2"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("DELETE Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
f"DELETE injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_query_with_update_injection(self, mcp_client):
|
||||
"""Test exec_query rejects UPDATE injection"""
|
||||
payload = {"sql": "SELECT 1; UPDATE users SET role='admin' WHERE id=1; --"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("UPDATE Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
f"UPDATE injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_query_db_name_injection(self, mcp_client):
|
||||
"""Test exec_query rejects SQL injection via db_name parameter"""
|
||||
# Attack vector: inject SQL via db_name parameter
|
||||
payload = {"sql": "SELECT 1", "db_name": "test'; DROP TABLE users; --"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("db_name Parameter Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
f"db_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_query_catalog_name_injection(self, mcp_client):
|
||||
"""Test exec_query rejects SQL injection via catalog_name parameter"""
|
||||
# Attack vector: inject SQL via catalog_name parameter
|
||||
payload = {"sql": "SELECT 1", "catalog_name": "internal`; SELECT * FROM mysql.user; --"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("catalog_name Parameter Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
f"catalog_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_schema_injection(self, mcp_client):
|
||||
"""Test get_table_schema rejects SQL injection via table_name"""
|
||||
# Attack vector: inject SQL via table_name parameter
|
||||
payload = {"table_name": "users'; DROP TABLE users; --"}
|
||||
result = await mcp_client.call_tool("get_table_schema", payload)
|
||||
print_result("table_name Injection (get_table_schema)", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
f"table_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_schema_db_injection(self, mcp_client):
|
||||
"""Test get_table_schema rejects SQL injection via db_name"""
|
||||
payload = {"table_name": "users", "db_name": "test' OR '1'='1"}
|
||||
result = await mcp_client.call_tool("get_table_schema", payload)
|
||||
print_result("db_name Injection (get_table_schema)", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
f"db_name injection in get_table_schema should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_dependencies_injection(self, mcp_client):
|
||||
"""Test analyze_dependencies rejects SQL injection"""
|
||||
# This was the original vulnerability reported
|
||||
payload = {"table_name": "users", "db_name": "test_db' OR '1'='1' --"}
|
||||
result = await mcp_client.call_tool("analyze_dependencies", payload)
|
||||
print_result("analyze_dependencies Injection (Original Report)", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
f"analyze_dependencies db_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stacked_queries_injection(self, mcp_client):
|
||||
"""Test that stacked queries (multiple statements) are blocked"""
|
||||
# Multiple statements injection
|
||||
payload = {"sql": "SELECT * FROM users WHERE id = 1; INSERT INTO audit_log VALUES (NULL, 'hacked', NOW()); SELECT 1;"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("Stacked Queries (INSERT) Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
f"Stacked queries with INSERT should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_comment_based_injection(self, mcp_client):
|
||||
"""Test that comment-based injection is blocked"""
|
||||
# Using comments to bypass filters
|
||||
payload = {"sql": "SELECT * FROM users WHERE id = 1/**/OR/**/1=1"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("Comment-based Injection", payload, result)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hex_encoded_injection(self, mcp_client):
|
||||
"""Test that hex-encoded injection attempts are handled"""
|
||||
# Hex-encoded 'DROP' attempt
|
||||
payload = {"sql": "SELECT 0x44524F50205441424C4520757365727320"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("Hex Encoded Injection", payload, result)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backtick_escape_injection(self, mcp_client):
|
||||
"""Test backtick escape injection is blocked"""
|
||||
# Attempt to escape backtick quoting
|
||||
payload = {"sql": "SELECT 1", "db_name": "analytics`; SELECT * FROM sensitive_table;--"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("Backtick Escape Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
f"Backtick escape injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_query_succeeds(self, mcp_client):
|
||||
"""Test that valid queries still work"""
|
||||
# Simple valid query should work
|
||||
payload = {"sql": "SELECT 1 AS test_value"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("Valid Query (should succeed)", payload, result)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_show_databases(self, mcp_client):
|
||||
"""Test that SHOW DATABASES works"""
|
||||
payload = {"sql": "SHOW DATABASES"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("SHOW DATABASES (should succeed)", payload, result)
|
||||
|
||||
def _is_blocked_or_error(self, result: dict) -> bool:
|
||||
"""Check if result indicates blocked or error"""
|
||||
if not result:
|
||||
return True
|
||||
|
||||
# Check for JSON-RPC error
|
||||
if "error" in result:
|
||||
return True
|
||||
|
||||
# Check for error in result content
|
||||
if "result" in result:
|
||||
result_content = result.get("result", {})
|
||||
if isinstance(result_content, dict):
|
||||
# Check for isError flag
|
||||
if result_content.get("isError"):
|
||||
return True
|
||||
# Check content array for error messages
|
||||
content = result_content.get("content", [])
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
text = item.get("text", "")
|
||||
# Parse the JSON text content
|
||||
try:
|
||||
text_data = json.loads(text)
|
||||
# Check for success: false
|
||||
if text_data.get("success") is False:
|
||||
return True
|
||||
# Check for error field
|
||||
if text_data.get("error"):
|
||||
return True
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
# Check text for security keywords
|
||||
if any(keyword in text.lower() for keyword in [
|
||||
"error", "blocked", "invalid", "security",
|
||||
"injection", "denied", "forbidden", "not allowed",
|
||||
"security_violation", "risk_level"
|
||||
]):
|
||||
return True
|
||||
|
||||
# Check raw text response
|
||||
raw = result.get("raw", "")
|
||||
if isinstance(raw, str) and any(keyword in raw.lower() for keyword in [
|
||||
"error", "blocked", "invalid", "security"
|
||||
]):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
class TestIdentifierInjectionAPI:
|
||||
"""Test identifier-based SQL injection prevention"""
|
||||
|
||||
@pytest.fixture
|
||||
async def mcp_client(self):
|
||||
"""Create MCP client instance"""
|
||||
client = MCPClient()
|
||||
yield client
|
||||
await client.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_name_with_semicolon(self, mcp_client):
|
||||
"""Test table name containing semicolon is rejected"""
|
||||
payload = {"table_name": "users; DROP TABLE users"}
|
||||
result = await mcp_client.call_tool("get_table_schema", payload)
|
||||
print_result("Table Name with Semicolon", payload, result)
|
||||
|
||||
# Should be blocked by identifier validation
|
||||
assert self._contains_error_indicator(result), \
|
||||
f"Table name with semicolon should be rejected"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_name_with_quotes(self, mcp_client):
|
||||
"""Test table name containing quotes is rejected"""
|
||||
payload = {"table_name": "users' OR '1'='1"}
|
||||
result = await mcp_client.call_tool("get_table_schema", payload)
|
||||
print_result("Table Name with Quotes", payload, result)
|
||||
|
||||
assert self._contains_error_indicator(result), \
|
||||
f"Table name with quotes should be rejected"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_db_name_with_special_chars(self, mcp_client):
|
||||
"""Test database name with special characters is rejected"""
|
||||
special_chars = [
|
||||
"test;db",
|
||||
"test'db",
|
||||
"test\"db",
|
||||
"test`db",
|
||||
"test--db",
|
||||
"test/*db*/",
|
||||
]
|
||||
|
||||
for db_name in special_chars:
|
||||
payload = {"table_name": "users", "db_name": db_name}
|
||||
result = await mcp_client.call_tool("get_table_schema", payload)
|
||||
print_result(f"Special Char in db_name: {db_name}", payload, result)
|
||||
|
||||
assert self._contains_error_indicator(result), \
|
||||
f"db_name '{db_name}' should be rejected"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_identifiers_accepted(self, mcp_client):
|
||||
"""Test that valid identifiers are accepted"""
|
||||
valid_names = [
|
||||
"users",
|
||||
"my_table",
|
||||
"Table123",
|
||||
"_internal_table",
|
||||
]
|
||||
|
||||
for table_name in valid_names:
|
||||
payload = {"table_name": table_name}
|
||||
result = await mcp_client.call_tool("get_table_schema", payload)
|
||||
print_result(f"Valid Identifier: {table_name}", payload, result)
|
||||
|
||||
def _contains_error_indicator(self, result: dict) -> bool:
|
||||
"""Check if result contains error indicators"""
|
||||
if not result:
|
||||
return True
|
||||
|
||||
# Check for JSON-RPC error
|
||||
if "error" in result:
|
||||
return True
|
||||
|
||||
# Check result content
|
||||
result_str = json.dumps(result).lower()
|
||||
error_keywords = [
|
||||
"error", "invalid", "illegal", "blocked",
|
||||
"security", "injection", "denied", "forbidden"
|
||||
]
|
||||
|
||||
return any(keyword in result_str for keyword in error_keywords)
|
||||
|
||||
|
||||
class TestMultiStatementInjectionAPI:
|
||||
"""Test multi-statement SQL injection prevention"""
|
||||
|
||||
@pytest.fixture
|
||||
async def mcp_client(self):
|
||||
"""Create MCP client instance"""
|
||||
client = MCPClient()
|
||||
yield client
|
||||
await client.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hidden_drop_after_select(self, mcp_client):
|
||||
"""Test DROP hidden after legitimate SELECT is blocked"""
|
||||
payload = {"sql": "SELECT id, name FROM users WHERE status = 'active'; DROP TABLE audit_log; SELECT 1;"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("Hidden DROP after SELECT", payload, result)
|
||||
|
||||
assert self._is_dangerous_blocked(result), \
|
||||
f"Hidden DROP statement should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hidden_truncate_after_select(self, mcp_client):
|
||||
"""Test TRUNCATE hidden after SELECT is blocked"""
|
||||
payload = {"sql": "SELECT 1; TRUNCATE TABLE users"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("Hidden TRUNCATE after SELECT", payload, result)
|
||||
|
||||
assert self._is_dangerous_blocked(result), \
|
||||
f"Hidden TRUNCATE should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hidden_grant_after_select(self, mcp_client):
|
||||
"""Test GRANT hidden after SELECT is blocked"""
|
||||
payload = {"sql": "SELECT 1; GRANT ALL ON *.* TO 'hacker'@'%'"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("Hidden GRANT after SELECT", payload, result)
|
||||
|
||||
assert self._is_dangerous_blocked(result), \
|
||||
f"Hidden GRANT should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_safe_selects_allowed(self, mcp_client):
|
||||
"""Test that multiple SELECT statements may be allowed"""
|
||||
payload = {"sql": "SELECT 1; SELECT 2; SELECT 3;"}
|
||||
result = await mcp_client.call_tool("exec_query", payload)
|
||||
print_result("Multiple Safe SELECTs", payload, result)
|
||||
|
||||
def _is_dangerous_blocked(self, result: dict) -> bool:
|
||||
"""Check if dangerous operation was blocked"""
|
||||
if not result:
|
||||
return True
|
||||
|
||||
# Check for error
|
||||
if "error" in result:
|
||||
return True
|
||||
|
||||
# Check result content for blocking indicators
|
||||
result_str = json.dumps(result).lower()
|
||||
block_indicators = [
|
||||
"drop", "truncate", "grant", "revoke",
|
||||
"blocked", "denied", "forbidden", "not allowed",
|
||||
"security", "error"
|
||||
]
|
||||
|
||||
return any(indicator in result_str for indicator in block_indicators)
|
||||
|
||||
|
||||
class TestADBCQueryInjectionAPI:
|
||||
"""Test ADBC query SQL injection prevention"""
|
||||
|
||||
@pytest.fixture
|
||||
async def mcp_client(self):
|
||||
"""Create MCP client instance"""
|
||||
client = MCPClient()
|
||||
yield client
|
||||
await client.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_adbc_query_drop_injection(self, mcp_client):
|
||||
"""Test exec_adbc_query rejects DROP TABLE injection"""
|
||||
payload = {"sql": "SELECT * FROM users; DROP TABLE users; --"}
|
||||
result = await mcp_client.call_tool("exec_adbc_query", payload)
|
||||
print_result("ADBC DROP TABLE Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"ADBC DROP TABLE injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_adbc_query_delete_injection(self, mcp_client):
|
||||
"""Test exec_adbc_query rejects DELETE injection"""
|
||||
payload = {"sql": "SELECT 1; DELETE FROM users; --"}
|
||||
result = await mcp_client.call_tool("exec_adbc_query", payload)
|
||||
print_result("ADBC DELETE Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"ADBC DELETE injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exec_adbc_query_valid(self, mcp_client):
|
||||
"""Test exec_adbc_query allows valid queries"""
|
||||
payload = {"sql": "SELECT 1 AS test"}
|
||||
result = await mcp_client.call_tool("exec_adbc_query", payload)
|
||||
print_result("ADBC Valid Query", payload, result)
|
||||
|
||||
def _is_blocked_or_error(self, result: dict) -> bool:
|
||||
"""Check if result indicates blocked or error"""
|
||||
if not result:
|
||||
return True
|
||||
if "error" in result:
|
||||
return True
|
||||
result_str = json.dumps(result).lower()
|
||||
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
|
||||
|
||||
|
||||
class TestMetadataToolsInjectionAPI:
|
||||
"""Test metadata tools SQL injection prevention"""
|
||||
|
||||
@pytest.fixture
|
||||
async def mcp_client(self):
|
||||
"""Create MCP client instance"""
|
||||
client = MCPClient()
|
||||
yield client
|
||||
await client.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_db_table_list_db_injection(self, mcp_client):
|
||||
"""Test get_db_table_list rejects db_name injection"""
|
||||
payload = {"db_name": "test'; DROP TABLE users; --"}
|
||||
result = await mcp_client.call_tool("get_db_table_list", payload)
|
||||
print_result("get_db_table_list db_name Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"db_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_db_table_list_catalog_injection(self, mcp_client):
|
||||
"""Test get_db_table_list rejects catalog_name injection"""
|
||||
payload = {"catalog_name": "internal`; SELECT * FROM mysql.user; --"}
|
||||
result = await mcp_client.call_tool("get_db_table_list", payload)
|
||||
print_result("get_db_table_list catalog_name Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"catalog_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_comment_injection(self, mcp_client):
|
||||
"""Test get_table_comment rejects table_name injection"""
|
||||
payload = {"table_name": "users'; DROP TABLE users; --"}
|
||||
result = await mcp_client.call_tool("get_table_comment", payload)
|
||||
print_result("get_table_comment table_name Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"table_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_column_comments_injection(self, mcp_client):
|
||||
"""Test get_table_column_comments rejects injection"""
|
||||
payload = {"table_name": "users'; DROP TABLE users; --", "db_name": "test"}
|
||||
result = await mcp_client.call_tool("get_table_column_comments", payload)
|
||||
print_result("get_table_column_comments Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"table_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_indexes_injection(self, mcp_client):
|
||||
"""Test get_table_indexes rejects table_name injection"""
|
||||
payload = {"table_name": "users; DROP TABLE users", "db_name": "test"}
|
||||
result = await mcp_client.call_tool("get_table_indexes", payload)
|
||||
print_result("get_table_indexes Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"table_name injection should be blocked"
|
||||
|
||||
def _is_blocked_or_error(self, result: dict) -> bool:
|
||||
"""Check if result indicates blocked or error"""
|
||||
if not result:
|
||||
return True
|
||||
if "error" in result:
|
||||
return True
|
||||
result_str = json.dumps(result).lower()
|
||||
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
|
||||
|
||||
|
||||
class TestAnalyticsToolsInjectionAPI:
|
||||
"""Test analytics tools SQL injection prevention"""
|
||||
|
||||
@pytest.fixture
|
||||
async def mcp_client(self):
|
||||
"""Create MCP client instance"""
|
||||
client = MCPClient()
|
||||
yield client
|
||||
await client.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_columns_table_injection(self, mcp_client):
|
||||
"""Test analyze_columns rejects table_name injection"""
|
||||
payload = {"table_name": "users'; DROP TABLE users; --"}
|
||||
result = await mcp_client.call_tool("analyze_columns", payload)
|
||||
print_result("analyze_columns table_name Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"table_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_columns_db_injection(self, mcp_client):
|
||||
"""Test analyze_columns rejects db_name injection"""
|
||||
payload = {"table_name": "users", "db_name": "test' OR '1'='1"}
|
||||
result = await mcp_client.call_tool("analyze_columns", payload)
|
||||
print_result("analyze_columns db_name Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"db_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_basic_info_injection(self, mcp_client):
|
||||
"""Test get_table_basic_info rejects injection"""
|
||||
payload = {"table_name": "users; DROP TABLE audit_log"}
|
||||
result = await mcp_client.call_tool("get_table_basic_info", payload)
|
||||
print_result("get_table_basic_info Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"table_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_table_storage_injection(self, mcp_client):
|
||||
"""Test analyze_table_storage rejects injection"""
|
||||
payload = {"table_name": "users`; SELECT * FROM sensitive; --"}
|
||||
result = await mcp_client.call_tool("analyze_table_storage", payload)
|
||||
print_result("analyze_table_storage Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"table_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_sql_explain_injection(self, mcp_client):
|
||||
"""Test get_sql_explain rejects SQL injection"""
|
||||
payload = {"sql": "SELECT 1; DROP TABLE users; --"}
|
||||
result = await mcp_client.call_tool("get_sql_explain", payload)
|
||||
print_result("get_sql_explain SQL Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"SQL injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_sql_profile_injection(self, mcp_client):
|
||||
"""Test get_sql_profile rejects SQL injection"""
|
||||
payload = {"sql": "SELECT 1; DELETE FROM audit_log; --"}
|
||||
result = await mcp_client.call_tool("get_sql_profile", payload)
|
||||
print_result("get_sql_profile SQL Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"SQL injection should be blocked"
|
||||
|
||||
def _is_blocked_or_error(self, result: dict) -> bool:
|
||||
"""Check if result indicates blocked or error"""
|
||||
if not result:
|
||||
return True
|
||||
if "error" in result:
|
||||
return True
|
||||
result_str = json.dumps(result).lower()
|
||||
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
|
||||
|
||||
|
||||
class TestGovernanceToolsInjectionAPI:
|
||||
"""Test data governance tools SQL injection prevention"""
|
||||
|
||||
@pytest.fixture
|
||||
async def mcp_client(self):
|
||||
"""Create MCP client instance"""
|
||||
client = MCPClient()
|
||||
yield client
|
||||
await client.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trace_column_lineage_table_injection(self, mcp_client):
|
||||
"""Test trace_column_lineage rejects table_name injection"""
|
||||
payload = {"table_name": "users'; DROP TABLE users; --", "column_name": "id"}
|
||||
result = await mcp_client.call_tool("trace_column_lineage", payload)
|
||||
print_result("trace_column_lineage table_name Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"table_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trace_column_lineage_column_injection(self, mcp_client):
|
||||
"""Test trace_column_lineage rejects column_name injection"""
|
||||
payload = {"table_name": "users", "column_name": "id; DROP TABLE users"}
|
||||
result = await mcp_client.call_tool("trace_column_lineage", payload)
|
||||
print_result("trace_column_lineage column_name Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"column_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_monitor_data_freshness_injection(self, mcp_client):
|
||||
"""Test monitor_data_freshness rejects table_name injection"""
|
||||
payload = {"table_name": "users`; SELECT * FROM passwords; --"}
|
||||
result = await mcp_client.call_tool("monitor_data_freshness", payload)
|
||||
print_result("monitor_data_freshness Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"table_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_data_access_patterns_injection(self, mcp_client):
|
||||
"""Test analyze_data_access_patterns rejects injection"""
|
||||
payload = {"table_name": "users' UNION SELECT password FROM admin --"}
|
||||
result = await mcp_client.call_tool("analyze_data_access_patterns", payload)
|
||||
print_result("analyze_data_access_patterns Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"table_name injection should be blocked"
|
||||
|
||||
def _is_blocked_or_error(self, result: dict) -> bool:
|
||||
"""Check if result indicates blocked or error"""
|
||||
if not result:
|
||||
return True
|
||||
if "error" in result:
|
||||
return True
|
||||
result_str = json.dumps(result).lower()
|
||||
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
|
||||
|
||||
|
||||
class TestPerformanceToolsInjectionAPI:
|
||||
"""Test performance analytics tools SQL injection prevention"""
|
||||
|
||||
@pytest.fixture
|
||||
async def mcp_client(self):
|
||||
"""Create MCP client instance"""
|
||||
client = MCPClient()
|
||||
yield client
|
||||
await client.close()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_slow_queries_db_injection(self, mcp_client):
|
||||
"""Test analyze_slow_queries_topn rejects db_name injection"""
|
||||
payload = {"db_name": "test'; DROP TABLE audit_log; --"}
|
||||
result = await mcp_client.call_tool("analyze_slow_queries_topn", payload)
|
||||
print_result("analyze_slow_queries_topn db_name Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"db_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyze_resource_growth_db_injection(self, mcp_client):
|
||||
"""Test analyze_resource_growth_curves rejects db_name injection"""
|
||||
payload = {"db_name": "test`; GRANT ALL ON *.* TO 'hacker'; --"}
|
||||
result = await mcp_client.call_tool("analyze_resource_growth_curves", payload)
|
||||
print_result("analyze_resource_growth_curves db_name Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"db_name injection should be blocked"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_data_size_injection(self, mcp_client):
|
||||
"""Test get_table_data_size rejects table_name injection"""
|
||||
payload = {"table_name": "users; TRUNCATE TABLE logs"}
|
||||
result = await mcp_client.call_tool("get_table_data_size", payload)
|
||||
print_result("get_table_data_size Injection", payload, result)
|
||||
|
||||
assert self._is_blocked_or_error(result), \
|
||||
"table_name injection should be blocked"
|
||||
|
||||
def _is_blocked_or_error(self, result: dict) -> bool:
|
||||
"""Check if result indicates blocked or error"""
|
||||
if not result:
|
||||
return True
|
||||
if "error" in result:
|
||||
return True
|
||||
result_str = json.dumps(result).lower()
|
||||
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
|
||||
|
||||
|
||||
# Pytest configuration for async tests
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
"""Create event loop for async tests"""
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short", "-x"])
|
||||
|
||||
@@ -201,3 +201,73 @@ class TestDorisQueryExecutor:
|
||||
if result["success"]:
|
||||
assert "data" in result
|
||||
assert "row_count" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_multi_sql_statements(self, query_executor):
|
||||
"""Test execution of multiple SQL statements"""
|
||||
from doris_mcp_server.utils.query_executor import QueryResult
|
||||
|
||||
# Disable security check for this test
|
||||
query_executor.connection_manager.config.security.enable_security_check = False
|
||||
|
||||
with patch.object(query_executor, 'execute_query') as mock_execute:
|
||||
# Mock results for three SQL statements
|
||||
mock_execute.side_effect = [
|
||||
QueryResult(
|
||||
data=[{"id": 1, "name": "张三"}],
|
||||
row_count=1,
|
||||
execution_time=0.1,
|
||||
sql="SELECT id, name FROM users WHERE id = 1",
|
||||
metadata={"columns": ["id", "name"]}
|
||||
),
|
||||
QueryResult(
|
||||
data=[{"id": 2, "name": "李四"}],
|
||||
row_count=1,
|
||||
execution_time=0.12,
|
||||
sql="SELECT id, name FROM users WHERE id = 2",
|
||||
metadata={"columns": ["id", "name"]}
|
||||
),
|
||||
QueryResult(
|
||||
data=[{"count": 100}],
|
||||
row_count=1,
|
||||
execution_time=0.08,
|
||||
sql="SELECT COUNT(*) as count FROM users",
|
||||
metadata={"columns": ["count"]}
|
||||
)
|
||||
]
|
||||
|
||||
# Execute multiple SQL statements separated by semicolons
|
||||
multi_sql = """
|
||||
SELECT id, name FROM users WHERE id = 1;
|
||||
SELECT id, name FROM users WHERE id = 2;
|
||||
SELECT COUNT(*) as count FROM users;
|
||||
"""
|
||||
|
||||
result = await query_executor.execute_sql_for_mcp(multi_sql)
|
||||
|
||||
# Verify the result structure for multiple statements
|
||||
assert result["success"] is True
|
||||
assert result["multiple_results"] is True
|
||||
assert "results" in result
|
||||
assert len(result["results"]) == 3
|
||||
|
||||
# Verify first query result
|
||||
assert result["results"][0]["data"] == [{"id": 1, "name": "张三"}]
|
||||
assert result["results"][0]["row_count"] == 1
|
||||
assert result["results"][0]["metadata"]["columns"] == ["id", "name"]
|
||||
assert result["results"][0]["metadata"]["query"] == "SELECT id, name FROM users WHERE id = 1"
|
||||
|
||||
# Verify second query result
|
||||
assert result["results"][1]["data"] == [{"id": 2, "name": "李四"}]
|
||||
assert result["results"][1]["row_count"] == 1
|
||||
assert result["results"][1]["metadata"]["columns"] == ["id", "name"]
|
||||
assert result["results"][1]["metadata"]["query"] == "SELECT id, name FROM users WHERE id = 2"
|
||||
|
||||
# Verify third query result
|
||||
assert result["results"][2]["data"] == [{"count": 100}]
|
||||
assert result["results"][2]["row_count"] == 1
|
||||
assert result["results"][2]["metadata"]["columns"] == ["count"]
|
||||
assert result["results"][2]["metadata"]["query"] == "SELECT COUNT(*) as count FROM users"
|
||||
|
||||
# Verify execute_query was called three times
|
||||
assert mock_execute.call_count == 3
|
||||
|
||||
Reference in New Issue
Block a user