6 Commits

Author SHA1 Message Date
The Apache Software Foundation
a6f893628b Set up default protection ruleset for default and release branches 2026-05-15 15:28:08 -05:00
bingquanzhao
81305ffbf9 [fix]fix token auth (#69)
* fix tocken auth

* Further fixes to the token overwriting issue and restoration of hot reloading of tokens.json.
2025-12-24 20:39:16 +08:00
zzzzwc
43143f0b30 feat: add batch SQL execution support for MCP (#70)
* feat: add batch SQL execution support for MCP

- Add sql field to QueryResult to track executed query
- Implement execute_batch_sqls_for_mcp for executing multiple SQL
- Use sqlparse to split and execute multiple SQL in single request
- Improve error handling in execute_batch_queries
- Return multiple results format when batch queries are detected

* test: add multi-SQL statements test for query executor
2025-12-24 12:45:29 +08:00
bingquanzhao
e58361e04b fix some security issues (#68) 2025-12-10 09:11:03 +08:00
Yijia Su
a125a2f5f8 [fix]Fixed five known issues, including token authentication and multi-worker operation. (#63)
* 0.6.1Version

* fix 0.5.1 schema async bug

* fix security bug

* fix security bug

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add a controllable MCP Server DB Pool permission authentication system, connect it with the Doris permission system, and provide it to enterprise-level applications concurrently with the multi-Worker mode.

* Add Tokens Management

* change version

* fix stdio start bug

* fix stdio start bug

* fix stdio start bug
2025-11-04 14:45:38 +08:00
Yijia Su
2613912df3 [Performance]Optimize Stdio and Streamable HTTP startup solutions (#60)
* 0.5.1 Version

* fix 0.5.1 schema async bug

* fix security bug

* fix security bug

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add a controllable MCP Server DB Pool permission authentication system, connect it with the Doris permission system, and provide it to enterprise-level applications concurrently with the multi-Worker mode.

* Add Tokens Management

* change version

* fix stdio start bug

* fix stdio start bug
2025-09-23 12:21:30 +08:00
24 changed files with 3535 additions and 323 deletions

View File

@@ -26,13 +26,26 @@ github:
- mcp - mcp
- ai - ai
enabled_merge_buttons: enabled_merge_buttons:
squash: true squash: true
merge: false merge: false
rebase: false rebase: false
features: features:
issues: true issues: true
projects: true projects: true
rulesets:
- name: "Default Branch Protection"
type: branch
branches:
includes:
- "~DEFAULT_BRANCH"
- "release/*"
- "rel/*"
excludes: []
bypass_teams:
- root
restrict_deletion: true
restrict_force_push: true
notifications: notifications:
issues: commits@doris.apache.org issues: commits@doris.apache.org
commits: commits@doris.apache.org commits: commits@doris.apache.org
pullrequests: commits@doris.apache.org pullrequests: commits@doris.apache.org

View File

@@ -432,9 +432,9 @@ class DorisServer:
await self.security_manager.initialize() await self.security_manager.initialize()
self.logger.info("Security manager initialization completed") self.logger.info("Security manager initialization completed")
# Ensure connection manager is initialized # For stdio mode, we must establish a working database connection
await self.connection_manager.initialize() # Use the dedicated stdio mode initialization method
self.logger.info("Connection manager initialization completed") await self.connection_manager.initialize_for_stdio_mode()
# Start stdio server - using compatible import approach # Start stdio server - using compatible import approach
try: try:
@@ -502,8 +502,12 @@ class DorisServer:
await self.security_manager.initialize() await self.security_manager.initialize()
self.logger.info("Security manager initialization completed") self.logger.info("Security manager initialization completed")
# Ensure connection manager is initialized # For HTTP mode, try to initialize global connection pool with graceful degradation
await self.connection_manager.initialize() global_pool_created = await self.connection_manager.initialize_for_http_mode()
if global_pool_created:
self.logger.info("Global database connection pool available for HTTP mode")
else:
self.logger.info("HTTP mode running without global database pool, will use token-bound configurations")
# Use Starlette and StreamableHTTPSessionManager according to official example # Use Starlette and StreamableHTTPSessionManager according to official example
import uvicorn import uvicorn
@@ -630,14 +634,24 @@ class DorisServer:
try: try:
# Extract authentication information # Extract authentication information
auth_info = await self._extract_auth_info_from_scope(scope, headers) auth_info = await self._extract_auth_info_from_scope(scope, headers)
# Authenticate the request # Authenticate the request
auth_context = await self.security_manager.authenticate_request(auth_info) auth_context = await self.security_manager.authenticate_request(auth_info)
self.logger.info(f"MCP request authenticated: token_id={auth_context.token_id}, client_ip={auth_context.client_ip}") self.logger.info(f"MCP request authenticated: token_id={auth_context.token_id}, client_ip={auth_context.client_ip}")
# Store auth context in scope for potential use by tools/resources # Store auth context in scope for potential use by tools/resources
scope["auth_context"] = auth_context scope["auth_context"] = auth_context
# FIX for Issue #62 Bug 1: Set auth_context in context variable
# This allows tools to access token information for token-bound database configuration
# CRITICAL: Use the global ContextVar from security.py to ensure same instance is used everywhere
try:
from .utils.security import mcp_auth_context_var
mcp_auth_context_var.set(auth_context)
self.logger.debug(f"Set auth_context in context variable with token: {bool(hasattr(auth_context, 'token') and auth_context.token)}")
except Exception as ctx_error:
self.logger.warning(f"Failed to set auth_context in context variable: {ctx_error}")
except Exception as auth_error: except Exception as auth_error:
self.logger.error(f"MCP authentication failed: {auth_error}") self.logger.error(f"MCP authentication failed: {auth_error}")
# Return 401 Unauthorized # Return 401 Unauthorized

View File

@@ -31,6 +31,7 @@ from mcp.types import (
) )
from ..utils.db import DorisConnectionManager from ..utils.db import DorisConnectionManager
from ..utils.sql_security_utils import get_auth_context
class PromptTemplate: class PromptTemplate:
@@ -422,7 +423,8 @@ Please generate accurate and efficient SQL queries based on the above requiremen
AND table_type = 'BASE TABLE' AND table_type = 'BASE TABLE'
""" """
db_result = await connection.execute(db_info_sql) auth_context = get_auth_context()
db_result = await connection.execute(db_info_sql, auth_context=auth_context)
db_info = db_result.data[0] if db_result.data else {} db_info = db_result.data[0] if db_result.data else {}
# Get main table list # Get main table list
@@ -438,7 +440,7 @@ Please generate accurate and efficient SQL queries based on the above requiremen
LIMIT 10 LIMIT 10
""" """
tables_result = await connection.execute(tables_sql) tables_result = await connection.execute(tables_sql, auth_context=auth_context)
context = f"""Current database statistics: context = f"""Current database statistics:
- Total number of tables: {db_info.get("table_count", 0)} - Total number of tables: {db_info.get("table_count", 0)}

View File

@@ -26,6 +26,7 @@ from typing import Any
from mcp.types import Resource from mcp.types import Resource
from ..utils.db import DorisConnectionManager from ..utils.db import DorisConnectionManager
from ..utils.sql_security_utils import get_auth_context
class TableMetadata: class TableMetadata:
@@ -169,7 +170,8 @@ class DorisResourcesManager:
ORDER BY table_name ORDER BY table_name
""" """
result = await connection.execute(tables_query) auth_context = get_auth_context()
result = await connection.execute(tables_query, auth_context=auth_context)
tables = [] tables = []
for row in result.data: for row in result.data:
@@ -204,7 +206,8 @@ class DorisResourcesManager:
ORDER BY ordinal_position ORDER BY ordinal_position
""" """
result = await connection.execute(columns_query, (table_name,)) auth_context = get_auth_context()
result = await connection.execute(columns_query, params=(table_name,), auth_context=auth_context)
return [dict(row) for row in result.data] return [dict(row) for row in result.data]
async def _get_view_metadata(self) -> list[ViewMetadata]: async def _get_view_metadata(self) -> list[ViewMetadata]:
@@ -226,7 +229,8 @@ class DorisResourcesManager:
ORDER BY table_name ORDER BY table_name
""" """
result = await connection.execute(views_query) auth_context = get_auth_context()
result = await connection.execute(views_query, auth_context=auth_context)
views = [] views = []
for row in result.data: for row in result.data:
@@ -257,7 +261,8 @@ class DorisResourcesManager:
AND table_name = %s AND table_name = %s
""" """
table_result = await connection.execute(table_info_query, (table_name,)) auth_context = get_auth_context()
table_result = await connection.execute(table_info_query, params=(table_name,), auth_context=auth_context)
if not table_result.data: if not table_result.data:
raise ValueError(f"Table {table_name} does not exist") raise ValueError(f"Table {table_name} does not exist")
@@ -295,7 +300,8 @@ class DorisResourcesManager:
ORDER BY index_name, seq_in_index ORDER BY index_name, seq_in_index
""" """
result = await connection.execute(indexes_query, (table_name,)) auth_context = get_auth_context()
result = await connection.execute(indexes_query, params=(table_name,), auth_context=auth_context)
return [dict(row) for row in result.data] return [dict(row) for row in result.data]
async def _get_view_definition(self, view_name: str) -> str: async def _get_view_definition(self, view_name: str) -> str:
@@ -312,7 +318,8 @@ class DorisResourcesManager:
AND table_name = %s AND table_name = %s
""" """
result = await connection.execute(view_query, (view_name,)) auth_context = get_auth_context()
result = await connection.execute(view_query, params=(view_name,), auth_context=auth_context)
if not result.data: if not result.data:
raise ValueError(f"View {view_name} does not exist") raise ValueError(f"View {view_name} does not exist")
@@ -340,7 +347,8 @@ class DorisResourcesManager:
AND table_type = 'BASE TABLE' AND table_type = 'BASE TABLE'
""" """
table_result = await connection.execute(table_stats_query) auth_context = get_auth_context()
table_result = await connection.execute(table_stats_query, auth_context=auth_context)
table_stats = table_result.data[0] if table_result.data else {} table_stats = table_result.data[0] if table_result.data else {}
# Get view statistics # Get view statistics
@@ -350,7 +358,7 @@ class DorisResourcesManager:
WHERE table_schema = DATABASE() WHERE table_schema = DATABASE()
""" """
view_result = await connection.execute(view_stats_query) view_result = await connection.execute(view_stats_query, auth_context=auth_context)
view_stats = view_result.data[0] if view_result.data else {} view_stats = view_result.data[0] if view_result.data else {}
stats_info = { stats_info = {

View File

@@ -28,6 +28,7 @@ from typing import Any, Dict, List, Optional
from ..utils.logger import get_logger from ..utils.logger import get_logger
from ..utils.db import DorisConnectionManager from ..utils.db import DorisConnectionManager
from ..utils.sql_security_utils import get_auth_context
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -277,7 +278,8 @@ class DorisADBCQueryTools:
# Get BE nodes via SHOW BACKENDS # Get BE nodes via SHOW BACKENDS
logger.info("No BE hosts configured, getting BE node information via SHOW BACKENDS") logger.info("No BE hosts configured, getting BE node information via SHOW BACKENDS")
connection = await self.connection_manager.get_connection("query") connection = await self.connection_manager.get_connection("query")
result = await connection.execute("SHOW BACKENDS") auth_context = get_auth_context()
result = await connection.execute("SHOW BACKENDS", auth_context=auth_context)
be_hosts = [] be_hosts = []
for row in result.data: for row in result.data:
@@ -383,6 +385,20 @@ class DorisADBCQueryTools:
"error_type": "no_connection" "error_type": "no_connection"
} }
# SECURITY FIX: Perform SQL security validation before executing
auth_context = get_auth_context()
if self.connection_manager.security_manager:
# Always perform security validation, even without auth_context
# Use a default context for basic SQL security checks
validation_result = await self.connection_manager.security_manager.validate_sql_security(sql, auth_context)
if not validation_result.is_valid:
return {
"success": False,
"error": f"SQL security validation failed: {validation_result.error_message}",
"error_type": "security_violation",
"risk_level": validation_result.risk_level
}
cursor = self.adbc_client.cursor() cursor = self.adbc_client.cursor()
start_time = time.time() start_time = time.time()

View File

@@ -29,6 +29,13 @@ from pathlib import Path
from .db import DorisConnectionManager from .db import DorisConnectionManager
from .logger import get_logger from .logger import get_logger
from .sql_security_utils import (
SQLSecurityError,
validate_identifier,
quote_identifier,
build_table_reference,
get_auth_context
)
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -46,10 +53,17 @@ class TableAnalyzer:
sample_size: int = 10 sample_size: int = 10
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Get table summary information""" """Get table summary information"""
# SECURITY FIX: Validate table_name and get auth_context
try:
validate_identifier(table_name, "table name")
except SQLSecurityError as e:
raise ValueError(f"Invalid table name: {e}")
auth_context = get_auth_context()
connection = await self.connection_manager.get_connection("query") connection = await self.connection_manager.get_connection("query")
# Get table basic information # Get table basic information using parameterized query
table_info_sql = f""" table_info_sql = """
SELECT SELECT
table_name, table_name,
table_comment, table_comment,
@@ -58,17 +72,17 @@ class TableAnalyzer:
engine engine
FROM information_schema.tables FROM information_schema.tables
WHERE table_schema = DATABASE() WHERE table_schema = DATABASE()
AND table_name = '{table_name}' AND table_name = %s
""" """
table_info_result = await connection.execute(table_info_sql) table_info_result = await connection.execute(table_info_sql, params=(table_name,), auth_context=auth_context)
if not table_info_result.data: if not table_info_result.data:
raise ValueError(f"Table {table_name} does not exist") raise ValueError(f"Table {table_name} does not exist")
table_info = table_info_result.data[0] table_info = table_info_result.data[0]
# Get column information # Get column information using parameterized query
columns_sql = f""" columns_sql = """
SELECT SELECT
column_name, column_name,
data_type, data_type,
@@ -76,11 +90,11 @@ class TableAnalyzer:
column_comment column_comment
FROM information_schema.columns FROM information_schema.columns
WHERE table_schema = DATABASE() WHERE table_schema = DATABASE()
AND table_name = '{table_name}' AND table_name = %s
ORDER BY ordinal_position ORDER BY ordinal_position
""" """
columns_result = await connection.execute(columns_sql) columns_result = await connection.execute(columns_sql, params=(table_name,), auth_context=auth_context)
summary = { summary = {
"table_name": table_info["table_name"], "table_name": table_info["table_name"],
@@ -92,10 +106,11 @@ class TableAnalyzer:
"columns": columns_result.data, "columns": columns_result.data,
} }
# Get sample data # Get sample data using quoted identifier
if include_sample and sample_size > 0: if include_sample and sample_size > 0:
sample_sql = f"SELECT * FROM {table_name} LIMIT {sample_size}" quoted_table = quote_identifier(table_name, "table name")
sample_result = await connection.execute(sample_sql) sample_sql = f"SELECT * FROM {quoted_table} LIMIT {sample_size}"
sample_result = await connection.execute(sample_sql, auth_context=auth_context)
summary["sample_data"] = sample_result.data summary["sample_data"] = sample_result.data
return summary return summary
@@ -120,7 +135,8 @@ class TableAnalyzer:
FROM {table_name} FROM {table_name}
""" """
basic_result = await connection.execute(basic_stats_sql) auth_context = get_auth_context()
basic_result = await connection.execute(basic_stats_sql, auth_context=auth_context)
if not basic_result.data: if not basic_result.data:
return { return {
"success": False, "success": False,
@@ -144,7 +160,7 @@ class TableAnalyzer:
LIMIT 20 LIMIT 20
""" """
distribution_result = await connection.execute(distribution_sql) distribution_result = await connection.execute(distribution_sql, auth_context=auth_context)
analysis["value_distribution"] = distribution_result.data analysis["value_distribution"] = distribution_result.data
if analysis_type == "detailed": if analysis_type == "detailed":
@@ -159,7 +175,7 @@ class TableAnalyzer:
WHERE {column_name} IS NOT NULL WHERE {column_name} IS NOT NULL
""" """
numeric_result = await connection.execute(numeric_stats_sql) numeric_result = await connection.execute(numeric_stats_sql, auth_context=auth_context)
if numeric_result.data: if numeric_result.data:
analysis.update(numeric_result.data[0]) analysis.update(numeric_result.data[0])
except Exception: except Exception:
@@ -196,7 +212,8 @@ class TableAnalyzer:
AND table_name = '{table_name}' AND table_name = '{table_name}'
""" """
table_result = await connection.execute(table_info_sql) auth_context = get_auth_context()
table_result = await connection.execute(table_info_sql, auth_context=auth_context)
if not table_result.data: if not table_result.data:
raise ValueError(f"Table {table_name} does not exist") raise ValueError(f"Table {table_name} does not exist")
@@ -211,7 +228,7 @@ class TableAnalyzer:
AND table_name != %s AND table_name != %s
""" """
all_tables_result = await connection.execute(all_tables_sql, (table_name,)) all_tables_result = await connection.execute(all_tables_sql, params=(table_name,), auth_context=auth_context)
return { return {
"center_table": table_result.data[0], "center_table": table_result.data[0],
@@ -291,7 +308,8 @@ class PerformanceMonitor:
LIMIT 20 LIMIT 20
""" """
tables_result = await connection.execute(tables_sql) auth_context = get_auth_context()
tables_result = await connection.execute(tables_sql, auth_context=auth_context)
stats = { stats = {
"metric_type": "tables", "metric_type": "tables",
"time_range": time_range, "time_range": time_range,
@@ -379,9 +397,23 @@ class SQLAnalyzer:
logger.info(f"Generating SQL explain for query ID: {query_id}") logger.info(f"Generating SQL explain for query ID: {query_id}")
# 🔧 FIX: Get auth_context for token-bound database configuration
auth_context = None
try:
from .security import mcp_auth_context_var
auth_context = mcp_auth_context_var.get()
except Exception:
pass
# Switch database if specified # Switch database if specified
# SECURITY FIX: Validate and quote db_name
if db_name: if db_name:
await self.connection_manager.execute_query("explain_session", f"USE {db_name}") try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
return {"success": False, "error": f"Invalid database name: {e}"}
safe_db = quote_identifier(db_name, "database name")
await self.connection_manager.execute_query("explain_session", f"USE {safe_db}", None, auth_context)
# Construct EXPLAIN query # Construct EXPLAIN query
explain_type = "EXPLAIN VERBOSE" if verbose else "EXPLAIN" explain_type = "EXPLAIN VERBOSE" if verbose else "EXPLAIN"
@@ -390,7 +422,7 @@ class SQLAnalyzer:
logger.info(f"Executing explain query: {explain_sql}") logger.info(f"Executing explain query: {explain_sql}")
# Execute explain query # Execute explain query
result = await self.connection_manager.execute_query("explain_session", explain_sql) result = await self.connection_manager.execute_query("explain_session", explain_sql, None, auth_context)
# Format explain output # Format explain output
explain_content = [] explain_content = []
@@ -515,24 +547,36 @@ class SQLAnalyzer:
try: try:
# Switch to specified database/catalog if provided # Switch to specified database/catalog if provided
# SECURITY FIX: Validate identifiers before using in SQL
if catalog_name: if catalog_name:
await connection.execute(f"SWITCH `{catalog_name}`") try:
validate_identifier(catalog_name, "catalog name")
except SQLSecurityError as e:
return {"success": False, "error": f"Invalid catalog name: {e}"}
safe_catalog = quote_identifier(catalog_name, "catalog name")
auth_context = get_auth_context()
await connection.execute(f"SWITCH {safe_catalog}", auth_context=auth_context)
if db_name: if db_name:
await connection.execute(f"USE `{db_name}`") try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
return {"success": False, "error": f"Invalid database name: {e}"}
safe_db = quote_identifier(db_name, "database name")
await connection.execute(f"USE {safe_db}", auth_context=auth_context)
# Set trace ID for the session using session variable # Set trace ID for the session using session variable
# According to official docs: set session_context="trace_id:your_trace_id" # According to official docs: set session_context="trace_id:your_trace_id"
await connection.execute(f'set session_context="trace_id:{trace_id}"') await connection.execute(f'set session_context="trace_id:{trace_id}"', auth_context=auth_context)
logger.info(f"Set trace ID: {trace_id}") logger.info(f"Set trace ID: {trace_id}")
# Enable profile # Enable profile
await connection.execute(f'set enable_profile=true') await connection.execute(f'set enable_profile=true', auth_context=auth_context)
logger.info(f"Enabled profile") logger.info(f"Enabled profile")
# Execute the SQL statement # Execute the SQL statement
logger.info(f"Executing SQL with trace ID: {sql}") logger.info(f"Executing SQL with trace ID: {sql}")
start_time = time.time() start_time = time.time()
sql_result = await connection.execute(sql) sql_result = await connection.execute(sql, auth_context=auth_context)
execution_time = time.time() - start_time execution_time = time.time() - start_time
logger.info(f"SQL execution completed in {execution_time:.3f}s") logger.info(f"SQL execution completed in {execution_time:.3f}s")

View File

@@ -26,6 +26,13 @@ from typing import Any, Dict, List, Optional, Union
from .db import DorisConnectionManager from .db import DorisConnectionManager
from .logger import get_logger from .logger import get_logger
from .sql_security_utils import (
SQLSecurityError,
validate_identifier,
quote_identifier,
build_table_reference,
get_auth_context
)
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -43,24 +50,30 @@ class DataExplorationTools:
def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str: def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str:
"""Build full table name with catalog and database using three-part naming convention""" """Build full table name with catalog and database using three-part naming convention"""
# Default catalog for internal tables # SECURITY FIX: Use build_table_reference for safe identifier handling
effective_catalog = catalog_name if catalog_name else "internal" effective_catalog = catalog_name if catalog_name else "internal"
if db_name: if db_name:
return f"{effective_catalog}.{db_name}.{table_name}" return build_table_reference(table_name, db_name, effective_catalog)
else: else:
# If no db_name provided, need to determine the current database return build_table_reference(table_name, catalog_name=effective_catalog)
return f"{effective_catalog}.{table_name}"
async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]: async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]:
"""Get basic table information including row count""" """Get basic table information including row count"""
try: try:
# SECURITY FIX: Get auth_context for security validation
# table_name should already be validated by _build_full_table_name
auth_context = get_auth_context()
count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}" count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}"
result = await connection.execute(count_sql) result = await connection.execute(count_sql, auth_context=auth_context)
if result.data: if result.data:
return {"row_count": result.data[0]["row_count"]} return {"row_count": result.data[0]["row_count"]}
return None return None
except SQLSecurityError as e:
logger.warning(f"Security validation failed for table {table_name}: {str(e)}")
return {"row_count": 0}
except Exception as e: except Exception as e:
logger.warning(f"Failed to get basic info for table {table_name}: {str(e)}") logger.warning(f"Failed to get basic info for table {table_name}: {str(e)}")
return {"row_count": 0} return {"row_count": 0}
@@ -68,10 +81,24 @@ class DataExplorationTools:
async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]: async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]:
"""Get detailed column information""" """Get detailed column information"""
try: try:
where_conditions = [f"table_name = '{table_name}'"] # SECURITY FIX: Validate identifiers and use parameterized query
auth_context = get_auth_context()
try:
validate_identifier(table_name, "table name")
if db_name:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return []
# Build parameterized query
params = [table_name]
where_conditions = ["table_name = %s"]
if db_name: if db_name:
where_conditions.append(f"table_schema = '{db_name}'") where_conditions.append("table_schema = %s")
params.append(db_name)
else: else:
where_conditions.append("table_schema = DATABASE()") where_conditions.append("table_schema = DATABASE()")
@@ -87,9 +114,12 @@ class DataExplorationTools:
ORDER BY ordinal_position ORDER BY ordinal_position
""" """
result = await connection.execute(columns_sql) result = await connection.execute(columns_sql, params=tuple(params), auth_context=auth_context)
return result.data if result.data else [] return result.data if result.data else []
except SQLSecurityError as e:
logger.warning(f"Security validation failed: {str(e)}")
return []
except Exception as e: except Exception as e:
logger.warning(f"Failed to get columns info for table {table_name}: {str(e)}") logger.warning(f"Failed to get columns info for table {table_name}: {str(e)}")
return [] return []
@@ -177,7 +207,8 @@ class DataExplorationTools:
WHERE {col_name} IS NOT NULL WHERE {col_name} IS NOT NULL
""" """
stats_result = await connection.execute(stats_sql) auth_context = get_auth_context()
stats_result = await connection.execute(stats_sql, auth_context=auth_context)
if stats_result.data and stats_result.data[0]["count"] > 0: if stats_result.data and stats_result.data[0]["count"] > 0:
stats = stats_result.data[0] stats = stats_result.data[0]
@@ -229,7 +260,8 @@ class DataExplorationTools:
WHERE {col_name} IS NOT NULL WHERE {col_name} IS NOT NULL
""" """
result = await connection.execute(percentile_sql) auth_context = get_auth_context()
result = await connection.execute(percentile_sql, auth_context=auth_context)
if result.data: if result.data:
data = result.data[0] data = result.data[0]
@@ -268,7 +300,8 @@ class DataExplorationTools:
WHERE {col_name} IS NOT NULL WHERE {col_name} IS NOT NULL
""" """
result = await connection.execute(outlier_sql) auth_context = get_auth_context()
result = await connection.execute(outlier_sql, auth_context=auth_context)
if result.data: if result.data:
data = result.data[0] data = result.data[0]
@@ -359,7 +392,8 @@ class DataExplorationTools:
{sampling_info.get('sample_query_suffix', '')} {sampling_info.get('sample_query_suffix', '')}
""" """
cardinality_result = await connection.execute(cardinality_sql) auth_context = get_auth_context()
cardinality_result = await connection.execute(cardinality_sql, auth_context=auth_context)
if cardinality_result.data: if cardinality_result.data:
cardinality_data = cardinality_result.data[0] cardinality_data = cardinality_result.data[0]
@@ -408,7 +442,8 @@ class DataExplorationTools:
LIMIT 20 LIMIT 20
""" """
result = await connection.execute(distribution_sql) auth_context = get_auth_context()
result = await connection.execute(distribution_sql, auth_context=auth_context)
if result.data: if result.data:
distribution = [] distribution = []
@@ -458,7 +493,8 @@ class DataExplorationTools:
WHERE {col_name} IS NOT NULL WHERE {col_name} IS NOT NULL
""" """
range_result = await connection.execute(range_sql) auth_context = get_auth_context()
range_result = await connection.execute(range_sql, auth_context=auth_context)
if range_result.data and range_result.data[0]["non_null_count"] > 0: if range_result.data and range_result.data[0]["non_null_count"] > 0:
range_data = range_result.data[0] range_data = range_result.data[0]
@@ -539,7 +575,8 @@ class DataExplorationTools:
ORDER BY day_of_week ORDER BY day_of_week
""" """
weekly_result = await connection.execute(weekly_pattern_sql) auth_context = get_auth_context()
weekly_result = await connection.execute(weekly_pattern_sql, auth_context=auth_context)
weekly_pattern = [] weekly_pattern = []
if weekly_result.data: if weekly_result.data:
@@ -561,7 +598,7 @@ class DataExplorationTools:
LIMIT 12 LIMIT 12
""" """
monthly_result = await connection.execute(monthly_trend_sql) monthly_result = await connection.execute(monthly_trend_sql, auth_context=auth_context)
monthly_trend = "stable" # Simplified trend analysis monthly_trend = "stable" # Simplified trend analysis
if monthly_result.data and len(monthly_result.data) > 3: if monthly_result.data and len(monthly_result.data) > 3:
@@ -646,7 +683,8 @@ class DataExplorationTools:
FROM {table_expr} FROM {table_expr}
""" """
result = await connection.execute(null_sql) auth_context = get_auth_context()
result = await connection.execute(null_sql, auth_context=auth_context)
if result.data: if result.data:
data = result.data[0] data = result.data[0]
total_count = data["total_count"] total_count = data["total_count"]

View File

@@ -26,6 +26,13 @@ from typing import Any, Dict, List, Optional
from .db import DorisConnectionManager from .db import DorisConnectionManager
from .logger import get_logger from .logger import get_logger
from .sql_security_utils import (
SQLSecurityError,
validate_identifier,
quote_identifier,
build_table_reference,
get_auth_context
)
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -216,26 +223,34 @@ class DataGovernanceTools:
# ==================== Private Helper Methods ==================== # ==================== Private Helper Methods ====================
def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str: def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str:
"""Build full table name - use three-level naming convention""" """Build full table name - use three-level naming convention with security validation"""
# SECURITY FIX: Use build_table_reference for safe identifier handling
# Default catalog is internal for internal tables # Default catalog is internal for internal tables
effective_catalog = catalog_name if catalog_name else "internal" effective_catalog = catalog_name if catalog_name else "internal"
if db_name: if db_name:
return f"{effective_catalog}.{db_name}.{table_name}" return build_table_reference(table_name, db_name, effective_catalog)
else: else:
# If db_name is not provided, need to determine current database # If db_name is not provided, need to determine current database
return f"{effective_catalog}.{table_name}" return build_table_reference(table_name, catalog_name=effective_catalog)
async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]: async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]:
"""Get table basic information""" """Get table basic information"""
try: try:
# SECURITY FIX: Get auth_context for security validation
# table_name should already be validated by _build_full_table_name
auth_context = get_auth_context()
# Try to get table row count # Try to get table row count
count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}" count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}"
result = await connection.execute(count_sql) result = await connection.execute(count_sql, auth_context=auth_context)
if result.data: if result.data:
return {"row_count": result.data[0]["row_count"]} return {"row_count": result.data[0]["row_count"]}
return None return None
except SQLSecurityError as e:
logger.warning(f"Security validation failed for table {table_name}: {str(e)}")
return {"row_count": 0}
except Exception as e: except Exception as e:
logger.warning(f"Failed to get basic info for table {table_name}: {str(e)}") logger.warning(f"Failed to get basic info for table {table_name}: {str(e)}")
return {"row_count": 0} return {"row_count": 0}
@@ -243,11 +258,24 @@ class DataGovernanceTools:
async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]: async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]:
"""Get table column information""" """Get table column information"""
try: try:
# Build query conditions # SECURITY FIX: Validate identifiers and use parameterized query
where_conditions = [f"table_name = '{table_name}'"] auth_context = get_auth_context()
try:
validate_identifier(table_name, "table name")
if db_name:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return []
# Build parameterized query conditions
params = [table_name]
where_conditions = ["table_name = %s"]
if db_name: if db_name:
where_conditions.append(f"table_schema = '{db_name}'") where_conditions.append("table_schema = %s")
params.append(db_name)
else: else:
where_conditions.append("table_schema = DATABASE()") where_conditions.append("table_schema = DATABASE()")
@@ -263,30 +291,49 @@ class DataGovernanceTools:
ORDER BY ordinal_position ORDER BY ordinal_position
""" """
result = await connection.execute(columns_sql) result = await connection.execute(columns_sql, params=tuple(params), auth_context=auth_context)
return result.data if result.data else [] return result.data if result.data else []
except SQLSecurityError as e:
logger.warning(f"Security validation failed: {str(e)}")
return []
except Exception as e: except Exception as e:
logger.warning(f"Failed to get columns info for table {table_name}: {str(e)}") logger.warning(f"Failed to get columns info for table {table_name}: {str(e)}")
return [] return []
async def _analyze_column_completeness(self, connection, table_name: str, columns_info: List[Dict]) -> Dict[str, Any]: async def _analyze_column_completeness(self, connection, table_name: str, columns_info: List[Dict]) -> Dict[str, Any]:
"""Analyze column completeness""" """Analyze column completeness"""
# SECURITY FIX: Get auth_context for security validation
auth_context = get_auth_context()
column_completeness = {} column_completeness = {}
for column in columns_info: for column in columns_info:
column_name = column["column_name"] column_name = column["column_name"]
try: try:
# SECURITY FIX: Validate column name before using in SQL
try:
validate_identifier(column_name, "column name")
except SQLSecurityError as e:
logger.warning(f"Invalid column name rejected: {e}")
column_completeness[column_name] = {
"error": f"Invalid column name: {e}",
"completeness_score": 0.0
}
continue
# Use quoted identifier for column name
quoted_column = quote_identifier(column_name, "column name")
# Calculate null value statistics # Calculate null value statistics
null_sql = f""" null_sql = f"""
SELECT SELECT
COUNT(*) as total_count, COUNT(*) as total_count,
COUNT({column_name}) as non_null_count, COUNT({quoted_column}) as non_null_count,
COUNT(*) - COUNT({column_name}) as null_count COUNT(*) - COUNT({quoted_column}) as null_count
FROM {table_name} FROM {table_name}
""" """
result = await connection.execute(null_sql) result = await connection.execute(null_sql, auth_context=auth_context)
if result.data: if result.data:
stats = result.data[0] stats = result.data[0]
total_count = stats["total_count"] total_count = stats["total_count"]
@@ -304,6 +351,12 @@ class DataGovernanceTools:
"completeness_score": round(completeness_score, 4) "completeness_score": round(completeness_score, 4)
} }
except SQLSecurityError as e:
logger.warning(f"Security validation failed for column {column_name}: {str(e)}")
column_completeness[column_name] = {
"error": str(e),
"completeness_score": 0.0
}
except Exception as e: except Exception as e:
logger.warning(f"Failed to analyze completeness for column {column_name}: {str(e)}") logger.warning(f"Failed to analyze completeness for column {column_name}: {str(e)}")
column_completeness[column_name] = { column_completeness[column_name] = {
@@ -333,7 +386,8 @@ class DataGovernanceTools:
FROM {table_name} FROM {table_name}
""" """
result = await connection.execute(compliance_sql) auth_context = get_auth_context()
result = await connection.execute(compliance_sql, auth_context=auth_context)
if result.data: if result.data:
stats = result.data[0] stats = result.data[0]
pass_count = stats["pass_count"] or 0 pass_count = stats["pass_count"] or 0
@@ -378,7 +432,8 @@ class DataGovernanceTools:
) t ) t
""" """
result = await connection.execute(duplicate_sql) auth_context = get_auth_context()
result = await connection.execute(duplicate_sql, auth_context=auth_context)
if result.data and result.data[0]["duplicate_count"] > 0: if result.data and result.data[0]["duplicate_count"] > 0:
issues.append({ issues.append({
"type": "duplicate_primary_keys", "type": "duplicate_primary_keys",
@@ -456,10 +511,21 @@ class DataGovernanceTools:
async def _verify_column_exists(self, connection, table_name: str, column_name: str) -> bool: async def _verify_column_exists(self, connection, table_name: str, column_name: str) -> bool:
"""Verify if column exists""" """Verify if column exists"""
try: try:
# Simple verification method: try to query the column # SECURITY FIX: Validate and quote column name
verify_sql = f"SELECT {column_name} FROM {table_name} LIMIT 1" try:
await connection.execute(verify_sql) validate_identifier(column_name, "column name")
except SQLSecurityError as e:
logger.warning(f"Invalid column name rejected: {e}")
return False
safe_column = quote_identifier(column_name, "column name")
# table_name is already safe (from _build_full_table_name)
verify_sql = f"SELECT {safe_column} FROM {table_name} LIMIT 1"
auth_context = get_auth_context()
await connection.execute(verify_sql, auth_context=auth_context)
return True return True
except SQLSecurityError:
return False
except Exception: except Exception:
return False return False
@@ -469,21 +535,34 @@ class DataGovernanceTools:
source_chain = [] source_chain = []
try: try:
# SECURITY FIX: Validate table name and use parameterized-like approach
table_name_part = table_name.split('.')[-1]
try:
validate_identifier(table_name_part, "table name")
except SQLSecurityError as e:
logger.warning(f"Invalid table name rejected: {e}")
return []
# Escape special characters for LIKE pattern
safe_pattern = table_name_part.replace('%', r'\%').replace('_', r'\_')
like_pattern = f"%{safe_pattern}%"
# Try to find related INSERT/CREATE TABLE AS SELECT statements from audit logs (one year range) # Try to find related INSERT/CREATE TABLE AS SELECT statements from audit logs (one year range)
auth_context = get_auth_context()
audit_sql = """ audit_sql = """
SELECT SELECT
stmt as sql_statement, stmt as sql_statement,
`time` as execution_time, `time` as execution_time,
`user` as user_name `user` as user_name
FROM internal.__internal_schema.audit_log FROM internal.__internal_schema.audit_log
WHERE stmt LIKE '%{}%' WHERE stmt LIKE %s
AND (stmt LIKE '%INSERT%' OR stmt LIKE '%CREATE%' OR stmt LIKE '%SELECT%') AND (stmt LIKE '%INSERT%' OR stmt LIKE '%CREATE%' OR stmt LIKE '%SELECT%')
AND `time` >= DATE_SUB(NOW(), INTERVAL 1 YEAR) AND `time` >= DATE_SUB(NOW(), INTERVAL 1 YEAR)
ORDER BY `time` DESC ORDER BY `time` DESC
LIMIT 50 LIMIT 50
""".format(table_name.split('.')[-1]) # Use the last part of table name """
result = await connection.execute(audit_sql) result = await connection.execute(audit_sql, params=(like_pattern,), auth_context=auth_context)
if result.data: if result.data:
for i, log_entry in enumerate(result.data[:depth]): for i, log_entry in enumerate(result.data[:depth]):
@@ -556,19 +635,33 @@ class DataGovernanceTools:
downstream_usage = [] downstream_usage = []
try: try:
# SECURITY FIX: Validate inputs and use parameterized-like approach
table_name_part = table_name.split('.')[-1]
try:
validate_identifier(table_name_part, "table name")
validate_identifier(column_name, "column name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return []
# Escape special characters for LIKE pattern
safe_table_pattern = f"%{table_name_part.replace('%', r'\\%').replace('_', r'\\_')}%"
safe_column_pattern = f"%{column_name.replace('%', r'\\%').replace('_', r'\\_')}%"
# Find other tables that might use this field (through audit logs, one year range) # Find other tables that might use this field (through audit logs, one year range)
auth_context = get_auth_context()
usage_sql = """ usage_sql = """
SELECT DISTINCT SELECT DISTINCT
stmt as sql_statement stmt as sql_statement
FROM internal.__internal_schema.audit_log FROM internal.__internal_schema.audit_log
WHERE stmt LIKE '%{}%' WHERE stmt LIKE %s
AND stmt LIKE '%{}%' AND stmt LIKE %s
AND stmt LIKE '%SELECT%' AND stmt LIKE '%SELECT%'
AND `time` >= DATE_SUB(NOW(), INTERVAL 1 YEAR) AND `time` >= DATE_SUB(NOW(), INTERVAL 1 YEAR)
LIMIT 20 LIMIT 20
""".format(table_name.split('.')[-1], column_name) """
result = await connection.execute(usage_sql) result = await connection.execute(usage_sql, params=(safe_table_pattern, safe_column_pattern), auth_context=auth_context)
if result.data: if result.data:
for entry in result.data: for entry in result.data:
@@ -634,14 +727,20 @@ class DataGovernanceTools:
async def _get_all_tables(self, connection, catalog_name: Optional[str], db_name: Optional[str]) -> List[str]: async def _get_all_tables(self, connection, catalog_name: Optional[str], db_name: Optional[str]) -> List[str]:
"""Get list of all tables""" """Get list of all tables"""
try: try:
where_conditions = [] auth_context = get_auth_context()
params = []
# SECURITY FIX: Use parameterized query
if db_name: if db_name:
where_conditions.append(f"table_schema = '{db_name}'") try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
logger.warning(f"Invalid database name rejected: {e}")
return []
where_clause = "table_schema = %s"
params.append(db_name)
else: else:
where_conditions.append("table_schema = DATABASE()") where_clause = "table_schema = DATABASE()"
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
tables_sql = f""" tables_sql = f"""
SELECT table_name SELECT table_name
@@ -651,7 +750,7 @@ class DataGovernanceTools:
ORDER BY table_name ORDER BY table_name
""" """
result = await connection.execute(tables_sql) result = await connection.execute(tables_sql, params=tuple(params) if params else None, auth_context=auth_context)
return [row["table_name"] for row in result.data] if result.data else [] return [row["table_name"] for row in result.data] if result.data else []
except Exception as e: except Exception as e:
@@ -728,15 +827,23 @@ class DataGovernanceTools:
async def _get_freshness_from_partition_info(self, connection, table_name: str) -> Optional[Dict]: async def _get_freshness_from_partition_info(self, connection, table_name: str) -> Optional[Dict]:
"""Get freshness from partition information""" """Get freshness from partition information"""
try: try:
# Query partition information (if table has partitions) # SECURITY FIX: Validate and use parameterized query
partition_sql = f""" table_name_part = table_name.split('.')[-1]
try:
validate_identifier(table_name_part, "table name")
except SQLSecurityError as e:
logger.warning(f"Invalid table name rejected: {e}")
return None
auth_context = get_auth_context()
partition_sql = """
SELECT MAX(CREATE_TIME) as last_update SELECT MAX(CREATE_TIME) as last_update
FROM information_schema.partitions FROM information_schema.partitions
WHERE table_name = '{table_name.split('.')[-1]}' WHERE table_name = %s
AND CREATE_TIME IS NOT NULL AND CREATE_TIME IS NOT NULL
""" """
result = await connection.execute(partition_sql) result = await connection.execute(partition_sql, params=(table_name_part,), auth_context=auth_context)
if result.data and result.data[0]["last_update"]: if result.data and result.data[0]["last_update"]:
return { return {
"last_update": result.data[0]["last_update"], "last_update": result.data[0]["last_update"],
@@ -759,7 +866,8 @@ class DataGovernanceTools:
FROM {table_name} FROM {table_name}
""" """
result = await connection.execute(max_time_sql) auth_context = get_auth_context()
result = await connection.execute(max_time_sql, auth_context=auth_context)
if result.data and result.data[0]["last_update"]: if result.data and result.data[0]["last_update"]:
return { return {
"last_update": result.data[0]["last_update"], "last_update": result.data[0]["last_update"],
@@ -773,15 +881,23 @@ class DataGovernanceTools:
async def _get_freshness_from_table_metadata(self, connection, table_name: str) -> Optional[Dict]: async def _get_freshness_from_table_metadata(self, connection, table_name: str) -> Optional[Dict]:
"""Get freshness from table metadata""" """Get freshness from table metadata"""
try: try:
# Query table's update time # SECURITY FIX: Validate and use parameterized query
metadata_sql = f""" table_name_part = table_name.split('.')[-1]
try:
validate_identifier(table_name_part, "table name")
except SQLSecurityError as e:
logger.warning(f"Invalid table name rejected: {e}")
return None
auth_context = get_auth_context()
metadata_sql = """
SELECT UPDATE_TIME as last_update SELECT UPDATE_TIME as last_update
FROM information_schema.tables FROM information_schema.tables
WHERE table_name = '{table_name.split('.')[-1]}' WHERE table_name = %s
AND UPDATE_TIME IS NOT NULL AND UPDATE_TIME IS NOT NULL
""" """
result = await connection.execute(metadata_sql) result = await connection.execute(metadata_sql, params=(table_name_part,), auth_context=auth_context)
if result.data and result.data[0]["last_update"]: if result.data and result.data[0]["last_update"]:
return { return {
"last_update": result.data[0]["last_update"], "last_update": result.data[0]["last_update"],
@@ -795,10 +911,19 @@ class DataGovernanceTools:
async def _find_timestamp_columns(self, connection, table_name: str) -> List[str]: async def _find_timestamp_columns(self, connection, table_name: str) -> List[str]:
"""Find possible timestamp fields""" """Find possible timestamp fields"""
try: try:
timestamp_sql = f""" # SECURITY FIX: Validate and use parameterized query
table_name_part = table_name.split('.')[-1]
try:
validate_identifier(table_name_part, "table name")
except SQLSecurityError as e:
logger.warning(f"Invalid table name rejected: {e}")
return []
auth_context = get_auth_context()
timestamp_sql = """
SELECT column_name SELECT column_name
FROM information_schema.columns FROM information_schema.columns
WHERE table_name = '{table_name.split('.')[-1]}' WHERE table_name = %s
AND ( AND (
data_type IN ('datetime', 'timestamp', 'date') data_type IN ('datetime', 'timestamp', 'date')
OR column_name LIKE '%time%' OR column_name LIKE '%time%'
@@ -815,7 +940,7 @@ class DataGovernanceTools:
END END
""" """
result = await connection.execute(timestamp_sql) result = await connection.execute(timestamp_sql, params=(table_name_part,), auth_context=auth_context)
return [row["column_name"] for row in result.data] if result.data else [] return [row["column_name"] for row in result.data] if result.data else []
except Exception: except Exception:

View File

@@ -31,6 +31,12 @@ from collections import Counter, defaultdict
from .db import DorisConnectionManager from .db import DorisConnectionManager
from .logger import get_logger from .logger import get_logger
from .config import DorisConfig from .config import DorisConfig
from .sql_security_utils import (
SQLSecurityError,
validate_identifier,
build_table_reference,
get_auth_context
)
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -299,23 +305,26 @@ class DataQualityTools:
# =========================================== # ===========================================
def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str: def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str:
"""Build full table name""" """Build full table name with security validation"""
if catalog_name and db_name: # SECURITY FIX: Use build_table_reference for safe identifier handling
return f"{catalog_name}.{db_name}.{table_name}" return build_table_reference(table_name, db_name, catalog_name)
elif db_name:
return f"{db_name}.{table_name}"
else:
return table_name
async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]: async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]:
"""Get basic table information""" """Get basic table information"""
try: try:
# SECURITY FIX: table_name should already be validated by _build_full_table_name
# But we add auth_context for security validation
auth_context = get_auth_context()
# Try to get row count # Try to get row count
count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}" count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}"
result = await connection.execute(count_sql) result = await connection.execute(count_sql, auth_context=auth_context)
if result.data: if result.data:
return {"row_count": result.data[0]["row_count"]} return {"row_count": result.data[0]["row_count"]}
return None return None
except SQLSecurityError as e:
logger.warning(f"Security validation failed: {str(e)}")
return None
except Exception as e: except Exception as e:
logger.warning(f"Failed to get table basic info: {str(e)}") logger.warning(f"Failed to get table basic info: {str(e)}")
return None return None
@@ -323,9 +332,13 @@ class DataQualityTools:
async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]: async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]:
"""Get table column information""" """Get table column information"""
try: try:
# Build DESCRIBE query # SECURITY FIX: Build safe table reference and pass auth_context
describe_sql = f"DESCRIBE {self._build_full_table_name(table_name, catalog_name, db_name)}" auth_context = get_auth_context()
result = await connection.execute(describe_sql)
# Build DESCRIBE query with safe table reference
safe_table_ref = self._build_full_table_name(table_name, catalog_name, db_name)
describe_sql = f"DESCRIBE {safe_table_ref}"
result = await connection.execute(describe_sql, auth_context=auth_context)
columns_info = [] columns_info = []
if result.data: if result.data:
@@ -339,6 +352,9 @@ class DataQualityTools:
}) })
return columns_info return columns_info
except SQLSecurityError as e:
logger.warning(f"Security validation failed: {str(e)}")
return []
except Exception as e: except Exception as e:
logger.warning(f"Failed to get table columns info: {str(e)}") logger.warning(f"Failed to get table columns info: {str(e)}")
return [] return []
@@ -346,7 +362,32 @@ class DataQualityTools:
async def _get_table_partitions(self, connection, table_name: str, db_name: Optional[str] = None) -> List[Dict]: async def _get_table_partitions(self, connection, table_name: str, db_name: Optional[str] = None) -> List[Dict]:
"""Get table partition information""" """Get table partition information"""
try: try:
# Query partition information # SECURITY FIX: Validate identifiers and use parameterized query
auth_context = get_auth_context()
# Validate table_name
try:
validate_identifier(table_name, "table name")
if db_name:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return []
# Build parameterized query
params = []
where_conditions = []
if db_name:
where_conditions.append("TABLE_SCHEMA = %s")
params.append(db_name)
else:
where_conditions.append("TABLE_SCHEMA = ''")
where_conditions.append("TABLE_NAME = %s")
params.append(table_name)
where_conditions.append("PARTITION_NAME IS NOT NULL")
partition_sql = f""" partition_sql = f"""
SELECT SELECT
PARTITION_NAME, PARTITION_NAME,
@@ -355,12 +396,10 @@ class DataQualityTools:
DATA_LENGTH, DATA_LENGTH,
INDEX_LENGTH INDEX_LENGTH
FROM information_schema.PARTITIONS FROM information_schema.PARTITIONS
WHERE TABLE_SCHEMA = '{db_name or ""}' WHERE {' AND '.join(where_conditions)}
AND TABLE_NAME = '{table_name}'
AND PARTITION_NAME IS NOT NULL
""" """
result = await connection.execute(partition_sql) result = await connection.execute(partition_sql, params=tuple(params), auth_context=auth_context)
partitions = [] partitions = []
if result.data: if result.data:
for row in result.data: for row in result.data:
@@ -373,6 +412,9 @@ class DataQualityTools:
}) })
return partitions return partitions
except SQLSecurityError as e:
logger.warning(f"Security validation failed: {str(e)}")
return []
except Exception as e: except Exception as e:
logger.warning(f"Failed to get table partitions: {str(e)}") logger.warning(f"Failed to get table partitions: {str(e)}")
return [] return []
@@ -417,7 +459,8 @@ class DataQualityTools:
if db_name if db_name
else f"SHOW CREATE TABLE {table_name}" else f"SHOW CREATE TABLE {table_name}"
) )
result = await connection.execute(query) auth_context = get_auth_context()
result = await connection.execute(query, auth_context=auth_context)
if result.data: if result.data:
return result.data[0].get("Create Table") return result.data[0].get("Create Table")
return None return None
@@ -428,8 +471,16 @@ class DataQualityTools:
async def _get_table_size_info(self, connection, table_name: str) -> Dict[str, Any]: async def _get_table_size_info(self, connection, table_name: str) -> Dict[str, Any]:
"""Get table size information""" """Get table size information"""
try: try:
# Query table size information # SECURITY FIX: Validate and use parameterized query
size_sql = f""" table_name_part = table_name.split('.')[-1]
try:
validate_identifier(table_name_part, "table name")
except SQLSecurityError as e:
logger.warning(f"Invalid table name rejected: {e}")
return {"engine": "Unknown", "estimated_rows": 0, "data_length": 0, "index_length": 0, "total_size": 0}
auth_context = get_auth_context()
size_sql = """
SELECT SELECT
table_name, table_name,
engine, engine,
@@ -438,10 +489,10 @@ class DataQualityTools:
index_length, index_length,
(data_length + index_length) as total_size (data_length + index_length) as total_size
FROM information_schema.tables FROM information_schema.tables
WHERE table_name = '{table_name.split('.')[-1]}' WHERE table_name = %s
""" """
result = await connection.execute(size_sql) result = await connection.execute(size_sql, params=(table_name_part,), auth_context=auth_context)
if result.data and result.data[0]: if result.data and result.data[0]:
row = result.data[0] row = result.data[0]
return { return {
@@ -582,7 +633,8 @@ class DataQualityTools:
batch_sql = f"SELECT {', '.join(select_clauses)} FROM {table_expr}" batch_sql = f"SELECT {', '.join(select_clauses)} FROM {table_expr}"
result = await connection.execute(batch_sql) auth_context = get_auth_context()
result = await connection.execute(batch_sql, auth_context=auth_context)
if not result.data: if not result.data:
return {"error": "No data returned from batch completeness query"} return {"error": "No data returned from batch completeness query"}
@@ -664,7 +716,8 @@ class DataQualityTools:
batch_sql = f"SELECT {', '.join(select_clauses)} FROM {table_expr}" batch_sql = f"SELECT {', '.join(select_clauses)} FROM {table_expr}"
result = await connection.execute(batch_sql) auth_context = get_auth_context()
result = await connection.execute(batch_sql, auth_context=auth_context)
if not result.data: if not result.data:
return {} return {}
@@ -705,7 +758,8 @@ class DataQualityTools:
LIMIT 10 LIMIT 10
""" """
result = await connection.execute(freq_sql) auth_context = get_auth_context()
result = await connection.execute(freq_sql, auth_context=auth_context)
frequencies = result.data if result.data else [] frequencies = result.data if result.data else []
categorical_results[col_name] = { categorical_results[col_name] = {
@@ -738,7 +792,8 @@ class DataQualityTools:
batch_sql = f"SELECT {', '.join(select_clauses)} FROM {table_expr}" batch_sql = f"SELECT {', '.join(select_clauses)} FROM {table_expr}"
result = await connection.execute(batch_sql) auth_context = get_auth_context()
result = await connection.execute(batch_sql, auth_context=auth_context)
if not result.data: if not result.data:
return {} return {}
@@ -780,7 +835,8 @@ class DataQualityTools:
FROM {table_expr} FROM {table_expr}
""" """
result = await connection.execute(completeness_sql) auth_context = get_auth_context()
result = await connection.execute(completeness_sql, auth_context=auth_context)
if result.data: if result.data:
stats = result.data[0] stats = result.data[0]
total_count = stats["total_count"] total_count = stats["total_count"]
@@ -906,7 +962,8 @@ class DataQualityTools:
WHERE {col_name} IS NOT NULL WHERE {col_name} IS NOT NULL
""" """
result = await connection.execute(stats_sql) auth_context = get_auth_context()
result = await connection.execute(stats_sql, auth_context=auth_context)
if result.data and result.data[0]["non_null_count"] > 0: if result.data and result.data[0]["non_null_count"] > 0:
stats = result.data[0] stats = result.data[0]
numeric_analysis[col_name] = { numeric_analysis[col_name] = {
@@ -945,7 +1002,8 @@ class DataQualityTools:
WHERE {col_name} IS NOT NULL WHERE {col_name} IS NOT NULL
""" """
cardinality_result = await connection.execute(cardinality_sql) auth_context = get_auth_context()
cardinality_result = await connection.execute(cardinality_sql, auth_context=auth_context)
if cardinality_result.data: if cardinality_result.data:
stats = cardinality_result.data[0] stats = cardinality_result.data[0]
@@ -969,7 +1027,7 @@ class DataQualityTools:
LIMIT 10 LIMIT 10
""" """
top_values_result = await connection.execute(top_values_sql) top_values_result = await connection.execute(top_values_sql, auth_context=auth_context)
if top_values_result.data: if top_values_result.data:
categorical_analysis[col_name]["top_values"] = [ categorical_analysis[col_name]["top_values"] = [
{"value": row[col_name], "count": row["count"]} {"value": row[col_name], "count": row["count"]}
@@ -998,7 +1056,8 @@ class DataQualityTools:
WHERE {col_name} IS NOT NULL WHERE {col_name} IS NOT NULL
""" """
result = await connection.execute(stats_sql) auth_context = get_auth_context()
result = await connection.execute(stats_sql, auth_context=auth_context)
if result.data and result.data[0]["non_null_count"] > 0: if result.data and result.data[0]["non_null_count"] > 0:
stats = result.data[0] stats = result.data[0]
temporal_analysis[col_name] = { temporal_analysis[col_name] = {

View File

@@ -59,6 +59,7 @@ class QueryResult:
metadata: dict[str, Any] metadata: dict[str, Any]
execution_time: float execution_time: float
row_count: int row_count: int
sql: str
class DorisConnection: class DorisConnection:
@@ -95,12 +96,14 @@ class DorisConnection:
await cursor.execute(sql, params) await cursor.execute(sql, params)
# Check if it's a query statement (statement that returns result set) # Check if it's a query statement (statement that returns result set)
# FIX for Issue #62 Bug 5: Added WITH support for Common Table Expressions (CTE)
sql_upper = sql.strip().upper() sql_upper = sql.strip().upper()
if (sql_upper.startswith("SELECT") or if (sql_upper.startswith("SELECT") or
sql_upper.startswith("SHOW") or sql_upper.startswith("SHOW") or
sql_upper.startswith("DESCRIBE") or sql_upper.startswith("DESCRIBE") or
sql_upper.startswith("DESC") or sql_upper.startswith("DESC") or
sql_upper.startswith("EXPLAIN")): sql_upper.startswith("EXPLAIN") or
sql_upper.startswith("WITH")): # FIX: Support CTE queries
data = await cursor.fetchall() data = await cursor.fetchall()
row_count = len(data) row_count = len(data)
else: else:
@@ -130,6 +133,7 @@ class DorisConnection:
metadata=metadata, metadata=metadata,
execution_time=execution_time, execution_time=execution_time,
row_count=row_count, row_count=row_count,
sql=sql
) )
except Exception as e: except Exception as e:
@@ -250,7 +254,23 @@ class DorisConnectionManager:
self.logger = get_logger(__name__) self.logger = get_logger(__name__)
self.security_manager = security_manager self.security_manager = security_manager
self.token_manager = token_manager # Token manager for token-bound DB config self.token_manager = token_manager # Token manager for token-bound DB config
self.session_cache = DorisSessionCache(self)
# 🔧 FIX for multi-tenant concurrency: Per-token connection pool isolation
# Each token gets its own connection pool to prevent configuration conflicts
self.token_pools: Dict[str, Pool] = {} # token_hash -> pool
self.token_configs: Dict[str, dict] = {} # token_hash -> db_config
self._token_pool_locks: Dict[str, asyncio.Lock] = {} # token_hash -> lock
self._token_pools_lock = asyncio.Lock() # Lock for managing token_pools dict
# FIX for Issue #58 Problem 1: Disable session caching to prevent connection sharing
# Session caching causes multiple threads to share the same MySQL connection,
# leading to race conditions and deadlocks in multi-threaded environments
# By disabling caching, each request gets a fresh connection from the pool
self.session_cache = DorisSessionCache(
self,
cache_system_session=False, # Disabled to prevent multi-thread issues
cache_user_session=False # Disabled to prevent multi-thread issues
)
# Store original database config for fallback # Store original database config for fallback
self.original_db_config = { self.original_db_config = {
@@ -263,6 +283,7 @@ class DorisConnectionManager:
} }
# Current active database config (may be overridden by token-bound config) # Current active database config (may be overridden by token-bound config)
# NOTE: This is kept for backward compatibility with non-token requests
self.active_db_config = self.original_db_config.copy() self.active_db_config = self.original_db_config.copy()
# Connection pool state management # Connection pool state management
@@ -346,6 +367,281 @@ class DorisConnectionManager:
self.logger.error(f"Error finding available token: {e}") self.logger.error(f"Error finding available token: {e}")
return "" return ""
def _get_token_hash(self, token: str) -> str:
"""Get hash of token for use as dictionary key"""
import hashlib
return hashlib.sha256(token.encode()).hexdigest()[:16]
def _get_current_token_db_config(self, token: str) -> dict | None:
"""Get current database config for token from TokenManager
This is used to check if config has changed for hot reload support.
"""
if not self.token_manager:
return None
token_db_config = self.token_manager.get_database_config_by_token(token)
if token_db_config:
return {
'host': token_db_config.host,
'port': token_db_config.port,
'user': token_db_config.user,
'password': token_db_config.password,
'database': token_db_config.database,
'charset': token_db_config.charset
}
return None
def _config_changed(self, old_config: dict, new_config: dict) -> bool:
"""Check if database configuration has changed"""
if old_config is None or new_config is None:
return old_config != new_config
# Compare key fields
for key in ['host', 'port', 'user', 'password', 'database']:
if old_config.get(key) != new_config.get(key):
return True
return False
async def get_pool_for_token(self, token: str) -> tuple[Pool, dict]:
"""Get or create a dedicated connection pool for a specific token
This method implements per-token connection pool isolation to prevent
concurrent requests from different tokens interfering with each other.
🔧 FIX: Supports hot reload - if tokens.json config changes,
the old pool is closed and a new one is created automatically.
Args:
token: Authentication token
Returns:
(pool, db_config): The dedicated pool and its configuration
Raises:
RuntimeError: If no valid database configuration is available
"""
token_hash = self._get_token_hash(token)
# Fast path: pool already exists
if token_hash in self.token_pools:
pool = self.token_pools[token_hash]
cached_config = self.token_configs.get(token_hash)
# 🔧 FIX: Check if config has changed (hot reload support)
current_config = self._get_current_token_db_config(token)
if current_config and cached_config and self._config_changed(cached_config, current_config):
self.logger.info(f"Token config changed (hash: {token_hash[:8]}...), recreating pool...")
# Config changed, need to recreate pool
async with self._token_pools_lock:
# Close old pool
old_pool = self.token_pools.pop(token_hash, None)
if old_pool and not old_pool.closed:
try:
old_pool.close()
await asyncio.wait_for(old_pool.wait_closed(), timeout=2.0)
except Exception as e:
self.logger.warning(f"Error closing old pool during hot reload: {e}")
self.token_configs.pop(token_hash, None)
# Continue to slow path to create new pool
elif pool and not pool.closed:
return pool, cached_config
# Slow path: need to create pool (with lock to prevent race conditions)
async with self._token_pools_lock:
# Double-check after acquiring lock
if token_hash in self.token_pools:
pool = self.token_pools[token_hash]
if pool and not pool.closed:
return pool, self.token_configs[token_hash]
# Get database config for this token
db_config = None
config_source = "unknown"
if self.token_manager:
token_db_config = self.token_manager.get_database_config_by_token(token)
if token_db_config:
db_config = {
'host': token_db_config.host,
'port': token_db_config.port,
'user': token_db_config.user,
'password': token_db_config.password,
'database': token_db_config.database,
'charset': token_db_config.charset
}
config_source = "token-bound"
# Fallback to global config if token has no specific config
if not db_config or self._is_config_empty(db_config.get('host')) or self._is_config_empty(db_config.get('user')):
if self._has_valid_global_config():
db_config = self.original_db_config.copy()
config_source = "global-env"
else:
raise RuntimeError(
f"No valid database configuration available for token. "
f"Please configure database in tokens.json or .env file."
)
# Create dedicated pool for this token
self.logger.info(f"Creating dedicated connection pool for token (hash: {token_hash[:8]}...) "
f"using {config_source} config: {db_config['user']}@{db_config['host']}:{db_config['port']}")
pool = await self._create_pool_with_config(db_config)
# Store pool and config
self.token_pools[token_hash] = pool
self.token_configs[token_hash] = db_config
# Create lock for this token if not exists
if token_hash not in self._token_pool_locks:
self._token_pool_locks[token_hash] = asyncio.Lock()
return pool, db_config
async def _create_pool_with_config(self, db_config: dict) -> Pool:
"""Create a connection pool with specified configuration
Args:
db_config: Database configuration dictionary
Returns:
Created connection pool
"""
# Convert charset to aiomysql compatible format
charset_map = {"UTF8": "utf8", "UTF8MB4": "utf8mb4"}
charset = charset_map.get(db_config['charset'].upper(), db_config['charset'].lower())
self.logger.debug(f"Creating pool for {db_config['user']}@{db_config['host']}:{db_config['port']}/{db_config['database']}")
try:
pool = await asyncio.wait_for(
aiomysql.create_pool(
host=db_config['host'],
port=db_config['port'],
user=db_config['user'],
password=db_config['password'],
db=db_config['database'],
charset=charset,
minsize=0, # Don't pre-create connections
maxsize=self.maxsize,
connect_timeout=self.connect_timeout,
autocommit=True,
pool_recycle=self.pool_recycle
),
timeout=self.connect_timeout + 5 # Give extra time for pool creation
)
self.logger.info(f"Successfully created pool for {db_config['user']}@{db_config['host']}:{db_config['port']}")
return pool
except asyncio.TimeoutError:
self.logger.error(f"Timeout creating pool for {db_config['user']}@{db_config['host']}:{db_config['port']}")
raise RuntimeError(f"Timeout creating connection pool for {db_config['user']}@{db_config['host']}:{db_config['port']}")
except Exception as e:
self.logger.error(f"Failed to create pool for {db_config['user']}@{db_config['host']}:{db_config['port']}: {type(e).__name__}: {e}")
raise
async def get_connection_for_token(self, token: str, session_id: str) -> 'DorisConnection':
"""Get a connection from the token's dedicated pool
Args:
token: Authentication token
session_id: Session identifier for logging
Returns:
DorisConnection wrapper
"""
pool, db_config = await self.get_pool_for_token(token)
try:
connection = await asyncio.wait_for(
pool.acquire(),
timeout=self.connect_timeout
)
self.logger.debug(f"Session {session_id}: Acquired connection from token pool "
f"(user: {db_config['user']}@{db_config['host']})")
return DorisConnection(connection, session_id, self.security_manager)
except Exception as e:
self.logger.error(f"Session {session_id}: Failed to acquire connection from token pool: {e}")
raise
async def release_connection_for_token(self, token: str, connection: 'DorisConnection'):
"""Release a connection back to the token's dedicated pool
Args:
token: Authentication token
connection: DorisConnection wrapper to release
"""
token_hash = self._get_token_hash(token)
if token_hash in self.token_pools:
pool = self.token_pools[token_hash]
if pool and not pool.closed:
try:
pool.release(connection.connection)
except Exception as e:
self.logger.warning(f"Failed to release connection to token pool: {e}")
async def cleanup_token_pools(self, max_idle_time: int = 3600):
"""Clean up idle token connection pools
Args:
max_idle_time: Maximum idle time in seconds before closing a pool
"""
async with self._token_pools_lock:
pools_to_remove = []
for token_hash, pool in self.token_pools.items():
if pool and not pool.closed:
# Check if pool is idle (no active connections)
if pool.size == 0 and pool.freesize == 0:
pools_to_remove.append(token_hash)
elif pool and pool.closed:
pools_to_remove.append(token_hash)
for token_hash in pools_to_remove:
try:
pool = self.token_pools.pop(token_hash, None)
if pool and not pool.closed:
pool.close()
await pool.wait_closed()
self.token_configs.pop(token_hash, None)
self._token_pool_locks.pop(token_hash, None)
self.logger.info(f"Cleaned up idle token pool (hash: {token_hash[:8]}...)")
except Exception as e:
self.logger.warning(f"Error cleaning up token pool: {e}")
async def close_all_token_pools(self):
"""Close all token connection pools (for shutdown)"""
# Use timeout to prevent blocking on lock acquisition during shutdown
try:
async with asyncio.timeout(5): # 5 second timeout for lock
async with self._token_pools_lock:
for token_hash, pool in list(self.token_pools.items()):
try:
if pool and not pool.closed:
pool.close()
# Use timeout for wait_closed to prevent hanging
try:
await asyncio.wait_for(pool.wait_closed(), timeout=2.0)
except asyncio.TimeoutError:
self.logger.warning(f"Timeout waiting for token pool to close (hash: {token_hash[:8]}...)")
self.logger.info(f"Closed token pool (hash: {token_hash[:8]}...)")
except Exception as e:
self.logger.warning(f"Error closing token pool: {e}")
self.token_pools.clear()
self.token_configs.clear()
self._token_pool_locks.clear()
except asyncio.TimeoutError:
self.logger.warning("Timeout acquiring lock for token pool cleanup, forcing clear")
# Force clear without lock
self.token_pools.clear()
self.token_configs.clear()
self._token_pool_locks.clear()
async def configure_for_token(self, token: str) -> tuple[bool, str]: async def configure_for_token(self, token: str) -> tuple[bool, str]:
"""Configure connection manager for token with new priority logic """Configure connection manager for token with new priority logic
@@ -626,6 +922,213 @@ class DorisConnectionManager:
self.logger.error(f"Failed to initialize connection pool: {e}") self.logger.error(f"Failed to initialize connection pool: {e}")
raise raise
async def initialize_for_stdio_mode(self, timeout: float = 30.0) -> None:
"""
Initialize connection pool for stdio mode with strict validation
stdio mode requires a working database connection because:
- No HTTP authentication mechanism to support token-bound configs
- All database operations depend on the global connection pool
Args:
timeout: Maximum time to wait for connection establishment
Raises:
RuntimeError: If configuration is invalid or connection fails
"""
try:
# Validate that we have valid global configuration
if not self._has_valid_global_config():
error_msg = (
"stdio mode requires valid global database configuration. "
"Please set DORIS_HOST and DORIS_USER in environment variables or .env file. "
f"Current config: host='{self.host}', user='{self.user}'"
)
self.logger.error(error_msg)
raise RuntimeError(error_msg)
self.logger.info(f"stdio mode database config validated: {self.host}:{self.port}")
# Validate configuration format
is_valid, error_message = self.validate_database_configuration()
if not is_valid:
error_msg = f"Database configuration validation failed: {error_message}"
self.logger.error(error_msg)
raise RuntimeError(error_msg)
# Test connectivity with timeout
self.logger.info("Testing database connectivity for stdio mode...")
if not await self._test_connectivity_with_timeout(timeout):
error_msg = (
f"Failed to connect to Doris database within {timeout} seconds. "
f"Please check if Doris is running at {self.host}:{self.port} "
f"and verify network connectivity."
)
self.logger.error(error_msg)
raise RuntimeError(error_msg)
# Initialize the connection pool
await self._create_connection_pool()
# Verify that we have a working connection pool
if not self.pool:
error_msg = "Database connection pool was not created successfully."
self.logger.error(error_msg)
raise RuntimeError(error_msg)
# Start background monitoring tasks
self.pool_health_check_task = asyncio.create_task(self._pool_health_monitor())
self.pool_cleanup_task = asyncio.create_task(self._pool_cleanup_monitor())
# Perform initial pool warmup
await self._warmup_pool()
self.logger.info("Database connection established successfully for stdio mode")
except Exception as e:
self.logger.error(f"stdio mode database initialization failed: {e}")
raise
async def initialize_for_http_mode(self) -> bool:
"""
Initialize connection pool for HTTP mode with graceful degradation
HTTP mode can work without global database configuration because:
- Supports token-bound database configurations
- Can handle authentication and use per-request database configs
- Has fallback mechanisms for database operations
Returns:
bool: True if global database pool was created, False if gracefully degraded
"""
try:
# First validate configuration format if we have one
if self._has_valid_global_config():
is_valid, error_message = self.validate_database_configuration()
if not is_valid:
self.logger.warning(f"Global database configuration invalid: {error_message}")
self.logger.info("HTTP mode will rely on token-bound database configurations")
return False
# Try to establish global connection pool
self.logger.info(f"Attempting to create global connection pool: {self.host}:{self.port}")
try:
# Test connectivity with shorter timeout for HTTP mode
if await self._test_connectivity_with_timeout(10.0):
await self._create_connection_pool()
if self.pool:
# Start background monitoring tasks
self.pool_health_check_task = asyncio.create_task(self._pool_health_monitor())
self.pool_cleanup_task = asyncio.create_task(self._pool_cleanup_monitor())
# Perform initial pool warmup
await self._warmup_pool()
self.logger.info("Global database connection pool created successfully for HTTP mode")
return True
else:
self.logger.warning("Global database connection test failed, will use token-bound configs")
return False
except Exception as pool_error:
self.logger.warning(f"Failed to create global connection pool: {pool_error}")
self.logger.info("HTTP mode will rely on token-bound database configurations")
return False
else:
self.logger.info("No valid global database config found, HTTP mode will use token-bound configurations")
return False
except Exception as e:
self.logger.warning(f"HTTP mode database initialization encountered error: {e}")
self.logger.info("HTTP mode will rely on token-bound database configurations")
return False
async def _test_connectivity_with_timeout(self, timeout: float) -> bool:
"""
Test database connectivity with timeout
Args:
timeout: Maximum time to wait for connection test
Returns:
bool: True if connection successful, False otherwise
"""
try:
await asyncio.wait_for(self._test_basic_connectivity(), timeout=timeout)
return True
except asyncio.TimeoutError:
self.logger.error(f"Database connectivity test timed out after {timeout} seconds")
return False
except Exception as e:
self.logger.error(f"Database connectivity test failed: {e}")
return False
async def _test_basic_connectivity(self) -> None:
"""
Test basic database connectivity without connection pool
Raises:
Exception: If connection fails
"""
import aiomysql
conn = None
try:
conn = await aiomysql.connect(
host=self.host,
port=self.port,
user=self.user,
password=self.password,
db=self.database,
charset=self.charset,
connect_timeout=self.connect_timeout,
autocommit=True
)
async with conn.cursor() as cursor:
await cursor.execute("SELECT 1")
result = await cursor.fetchone()
if not result or result[0] != 1:
raise RuntimeError("Database connectivity test query failed")
except Exception as e:
raise RuntimeError(f"Database connectivity test failed: {e}")
finally:
if conn:
conn.close()
async def _create_connection_pool(self) -> None:
"""
Create the connection pool
Raises:
Exception: If pool creation fails
"""
self.pool = await aiomysql.create_pool(
host=self.host,
port=self.port,
user=self.user,
password=self.password,
db=self.database,
charset=self.charset,
minsize=self.minsize,
maxsize=self.maxsize,
pool_recycle=self.pool_recycle,
connect_timeout=self.connect_timeout,
autocommit=True
)
# Test pool health
if not await self._test_pool_health():
# Clean up the pool if health test fails
if self.pool:
self.pool.close()
await self.pool.wait_closed()
self.pool = None
raise RuntimeError("Connection pool health check failed")
async def _test_pool_health(self) -> bool: async def _test_pool_health(self) -> bool:
"""Test connection pool health""" """Test connection pool health"""
try: try:
@@ -872,7 +1375,26 @@ class DorisConnectionManager:
Uses only semaphore to prevent too many concurrent acquisitions. Uses only semaphore to prevent too many concurrent acquisitions.
If the connection is successfully obtained, it will be added to the connection pool cache. If the connection is successfully obtained, it will be added to the connection pool cache.
🔧 FIX for token isolation: Now automatically checks for auth_context from ContextVar
and uses token-specific connection pool if available.
""" """
# 🔧 FIX: Check for auth_context from global ContextVar
# This ensures all tools using get_connection respect token-bound database configuration
auth_context = None
try:
from .security import mcp_auth_context_var
auth_context = mcp_auth_context_var.get()
except Exception as e:
self.logger.debug(f"get_connection: Could not get auth_context: {e}")
if auth_context and hasattr(auth_context, 'token') and auth_context.token:
# Use token-specific connection pool
# SECURITY: Do NOT catch exceptions here - if token pool fails, don't fallback to global pool
# This prevents privilege escalation
self.logger.debug(f"get_connection: Using token-specific pool for session {session_id}")
return await self.get_connection_for_token(auth_context.token, session_id)
cached_conn = self.session_cache.get(session_id) cached_conn = self.session_cache.get(session_id)
if cached_conn: if cached_conn:
return cached_conn return cached_conn
@@ -1019,10 +1541,16 @@ class DorisConnectionManager:
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
# Close connection pool # 🔧 FIX: Close all per-token connection pools
await self.close_all_token_pools()
# Close global connection pool with timeout
if self.pool: if self.pool:
self.pool.close() self.pool.close()
await self.pool.wait_closed() try:
await asyncio.wait_for(self.pool.wait_closed(), timeout=5.0)
except asyncio.TimeoutError:
self.logger.warning("Timeout waiting for global pool to close")
self.logger.info("Connection manager closed successfully") self.logger.info("Connection manager closed successfully")
@@ -1051,24 +1579,59 @@ class DorisConnectionManager:
async def execute_query( async def execute_query(
self, session_id: str, sql: str, params: tuple | None = None, auth_context=None self, session_id: str, sql: str, params: tuple | None = None, auth_context=None
) -> QueryResult: ) -> QueryResult:
"""Execute query - Simplified Strategy with automatic connection management""" """Execute query - Enhanced Strategy with per-token connection pool isolation
FIX for multi-tenant concurrency: Each token now uses its own dedicated connection pool
to prevent configuration conflicts between concurrent requests from different tokens.
"""
connection = None connection = None
token = None
try: try:
# Always get fresh connection from pool # Check if we have a token for per-token pool isolation
connection = await self.get_connection(session_id) if auth_context and hasattr(auth_context, 'token') and auth_context.token:
token = auth_context.token
try:
# 🔧 FIX: Use dedicated connection pool for this token
# This prevents concurrent requests from different tokens interfering
connection = await self.get_connection_for_token(token, session_id)
# Get the config for logging
token_hash = self._get_token_hash(token)
if token_hash in self.token_configs:
db_config = self.token_configs[token_hash]
self.logger.info(f"Session {session_id}: Using dedicated pool for {db_config['user']}@{db_config['host']}")
except Exception as token_pool_error:
# SECURITY: If token should have pool but creation fails, don't fallback
# This prevents privilege escalation (using high-privilege default user)
self.logger.error(f"Session {session_id}: Token pool error: {token_pool_error}")
raise RuntimeError(
f"Failed to get connection for authenticated token. "
f"This is a security measure to prevent using default high-privilege credentials. "
f"Error: {token_pool_error}"
)
else:
# No token - use global pool (backward compatibility)
self.logger.debug(f"Session {session_id}: No token, using global connection pool")
connection = await self.get_connection(session_id)
# Execute query # Execute query
result = await connection.execute(sql, params, auth_context) result = await connection.execute(sql, params, auth_context)
return result return result
except Exception as e: except Exception as e:
self.logger.error(f"Query execution failed for session {session_id}: {e}") self.logger.error(f"Query execution failed for session {session_id}: {e}")
raise raise
finally: finally:
# Always release connection back to pool # Always release connection back to the appropriate pool
if connection: if connection:
await self.release_connection(session_id, connection) if token:
await self.release_connection_for_token(token, connection)
else:
await self.release_connection(session_id, connection)
@asynccontextmanager @asynccontextmanager
async def get_connection_context(self, session_id: str): async def get_connection_context(self, session_id: str):

View File

@@ -27,6 +27,13 @@ from collections import defaultdict, deque
from .db import DorisConnectionManager from .db import DorisConnectionManager
from .logger import get_logger from .logger import get_logger
from .sql_security_utils import (
SQLSecurityError,
validate_identifier,
quote_identifier,
build_table_reference,
get_auth_context
)
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -122,10 +129,19 @@ class DependencyAnalysisTools:
async def _get_tables_metadata(self, connection, catalog_name: Optional[str], db_name: Optional[str], include_views: bool) -> List[Dict]: async def _get_tables_metadata(self, connection, catalog_name: Optional[str], db_name: Optional[str], include_views: bool) -> List[Dict]:
"""Get metadata for all tables and views""" """Get metadata for all tables and views"""
try: try:
# Build conditions for query # Build conditions for query with parameterized values
where_conditions = [] where_conditions = []
params = []
if db_name: if db_name:
where_conditions.append(f"table_schema = '{db_name}'") # SECURITY FIX: Validate identifier and use parameterized query
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
logger.warning(f"Invalid database name rejected: {e}")
return []
where_conditions.append("table_schema = %s")
params.append(db_name)
else: else:
where_conditions.append("table_schema = DATABASE()") where_conditions.append("table_schema = DATABASE()")
@@ -148,9 +164,18 @@ class DependencyAnalysisTools:
ORDER BY table_schema, table_name ORDER BY table_schema, table_name
""" """
result = await connection.execute(metadata_sql) # SECURITY FIX: Get auth_context and pass to execute for security validation
auth_context = get_auth_context()
result = await connection.execute(
metadata_sql,
params=tuple(params) if params else None,
auth_context=auth_context
)
return result.data if result.data else [] return result.data if result.data else []
except SQLSecurityError as e:
logger.warning(f"Security validation failed in _get_tables_metadata: {str(e)}")
return []
except Exception as e: except Exception as e:
logger.warning(f"Failed to get tables metadata: {str(e)}") logger.warning(f"Failed to get tables metadata: {str(e)}")
return [] return []
@@ -186,17 +211,31 @@ class DependencyAnalysisTools:
async def _analyze_view_dependencies(self, connection, dependency_graph: Dict, tables_metadata: List[Dict]) -> None: async def _analyze_view_dependencies(self, connection, dependency_graph: Dict, tables_metadata: List[Dict]) -> None:
"""Analyze view definitions to extract table dependencies""" """Analyze view definitions to extract table dependencies"""
# Get auth_context once for all operations in this method
auth_context = get_auth_context()
try: try:
for table in tables_metadata: for table in tables_metadata:
if table["table_type"] == "VIEW": if table["table_type"] == "VIEW":
table_name = table["table_name"] table_name = table["table_name"]
schema_name = table.get("schema_name", "") schema_name = table.get("schema_name", "")
# Get view definition # SECURITY FIX: Validate identifiers before using in SQL
view_def_sql = f"SHOW CREATE VIEW {schema_name}.{table_name}" if schema_name else f"SHOW CREATE VIEW {table_name}" try:
validate_identifier(table_name, "table name")
if schema_name:
validate_identifier(schema_name, "schema name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected in view analysis: {e}")
continue
# Build safe view reference using quoted identifiers
view_ref = build_table_reference(table_name, schema_name) if schema_name else quote_identifier(table_name, "table name")
view_def_sql = f"SHOW CREATE VIEW {view_ref}"
try: try:
result = await connection.execute(view_def_sql) # SECURITY FIX: Pass auth_context to execute
result = await connection.execute(view_def_sql, auth_context=auth_context)
if result.data and len(result.data) > 0: if result.data and len(result.data) > 0:
# Extract view definition from result # Extract view definition from result
view_definition = "" view_definition = ""
@@ -235,6 +274,9 @@ class DependencyAnalysisTools:
async def _analyze_runtime_dependencies(self, connection, dependency_graph: Dict, analysis_depth: int) -> None: async def _analyze_runtime_dependencies(self, connection, dependency_graph: Dict, analysis_depth: int) -> None:
"""Analyze audit logs to discover runtime table dependencies""" """Analyze audit logs to discover runtime table dependencies"""
# Get auth_context for security validation
auth_context = get_auth_context()
try: try:
# Get recent SQL statements from audit logs # Get recent SQL statements from audit logs
audit_sql = """ audit_sql = """
@@ -252,7 +294,8 @@ class DependencyAnalysisTools:
LIMIT 1000 LIMIT 1000
""" """
result = await connection.execute(audit_sql) # SECURITY FIX: Pass auth_context to execute
result = await connection.execute(audit_sql, auth_context=auth_context)
if result.data: if result.data:
for row in result.data: for row in result.data:
@@ -274,6 +317,9 @@ class DependencyAnalysisTools:
async def _analyze_foreign_key_dependencies(self, connection, dependency_graph: Dict, tables_metadata: List[Dict]) -> None: async def _analyze_foreign_key_dependencies(self, connection, dependency_graph: Dict, tables_metadata: List[Dict]) -> None:
"""Analyze foreign key constraints for explicit dependencies""" """Analyze foreign key constraints for explicit dependencies"""
# Get auth_context for security validation
auth_context = get_auth_context()
try: try:
# Get foreign key information # Get foreign key information
fk_sql = """ fk_sql = """
@@ -288,7 +334,8 @@ class DependencyAnalysisTools:
WHERE REFERENCED_TABLE_NAME IS NOT NULL WHERE REFERENCED_TABLE_NAME IS NOT NULL
""" """
result = await connection.execute(fk_sql) # SECURITY FIX: Pass auth_context to execute
result = await connection.execute(fk_sql, auth_context=auth_context)
if result.data: if result.data:
for row in result.data: for row in result.data:

View File

@@ -28,6 +28,7 @@ from datetime import datetime
from .db import DorisConnectionManager from .db import DorisConnectionManager
from .logger import get_logger from .logger import get_logger
from .sql_security_utils import get_auth_context
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -713,7 +714,8 @@ class DorisMonitoringTools:
# Fallback to SHOW BACKENDS if no BE hosts configured # Fallback to SHOW BACKENDS if no BE hosts configured
logger.info("No BE hosts configured, using SHOW BACKENDS to discover BE nodes") logger.info("No BE hosts configured, using SHOW BACKENDS to discover BE nodes")
connection = await self.connection_manager.get_connection("query") connection = await self.connection_manager.get_connection("query")
result = await connection.execute("SHOW BACKENDS") auth_context = get_auth_context()
result = await connection.execute("SHOW BACKENDS", auth_context=auth_context)
be_nodes = [] be_nodes = []
for row in result.data: for row in result.data:

View File

@@ -27,6 +27,13 @@ from collections import defaultdict, Counter
from .db import DorisConnectionManager from .db import DorisConnectionManager
from .logger import get_logger from .logger import get_logger
from .sql_security_utils import (
SQLSecurityError,
validate_identifier,
quote_identifier,
build_table_reference,
get_auth_context
)
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -229,7 +236,8 @@ class PerformanceAnalyticsTools:
ORDER BY query_date ORDER BY query_date
""" """
result = await connection.execute(query_volume_sql) auth_context = get_auth_context()
result = await connection.execute(query_volume_sql, auth_context=auth_context)
daily_data = result.data if result.data else [] daily_data = result.data if result.data else []
if not daily_data: if not daily_data:
@@ -304,7 +312,8 @@ class PerformanceAnalyticsTools:
ORDER BY activity_date ORDER BY activity_date
""" """
result = await connection.execute(user_activity_sql) auth_context = get_auth_context()
result = await connection.execute(user_activity_sql, auth_context=auth_context)
daily_data = result.data if result.data else [] daily_data = result.data if result.data else []
if not daily_data: if not daily_data:
@@ -383,7 +392,8 @@ class PerformanceAnalyticsTools:
LIMIT 5000 LIMIT 5000
""" """
result = await connection.execute(slow_query_sql) auth_context = get_auth_context()
result = await connection.execute(slow_query_sql, auth_context=auth_context)
return result.data if result.data else [] return result.data if result.data else []
except Exception as e: except Exception as e:
@@ -705,7 +715,8 @@ class PerformanceAnalyticsTools:
ORDER BY size_mb DESC ORDER BY size_mb DESC
""" """
db_result = await connection.execute(db_sizes_sql) auth_context = get_auth_context()
db_result = await connection.execute(db_sizes_sql, auth_context=auth_context)
if not db_result.data: if not db_result.data:
logger.warning("No database size information available") logger.warning("No database size information available")
@@ -805,7 +816,16 @@ class PerformanceAnalyticsTools:
async def _get_database_table_details_from_schema(self, connection, db_name: str) -> List[Dict]: async def _get_database_table_details_from_schema(self, connection, db_name: str) -> List[Dict]:
"""Get table details for a specific database using information_schema""" """Get table details for a specific database using information_schema"""
try: try:
table_details_sql = f""" # SECURITY FIX: Validate db_name and use parameterized query
auth_context = get_auth_context()
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
logger.warning(f"Invalid database name rejected: {e}")
return []
table_details_sql = """
SELECT SELECT
TABLE_SCHEMA as schema_name, TABLE_SCHEMA as schema_name,
TABLE_NAME as table_name, TABLE_NAME as table_name,
@@ -814,13 +834,13 @@ class PerformanceAnalyticsTools:
CREATE_TIME as create_time, CREATE_TIME as create_time,
UPDATE_TIME as update_time UPDATE_TIME as update_time
FROM information_schema.tables FROM information_schema.tables
WHERE TABLE_SCHEMA = '{db_name}' WHERE TABLE_SCHEMA = %s
AND TABLE_TYPE = 'BASE TABLE' AND TABLE_TYPE = 'BASE TABLE'
AND (COALESCE(DATA_LENGTH, 0) + COALESCE(INDEX_LENGTH, 0)) > 0 AND (COALESCE(DATA_LENGTH, 0) + COALESCE(INDEX_LENGTH, 0)) > 0
ORDER BY size_mb DESC ORDER BY size_mb DESC
""" """
result = await connection.execute(table_details_sql) result = await connection.execute(table_details_sql, params=(db_name,), auth_context=auth_context)
if not result.data: if not result.data:
logger.warning(f"No table details found for database {db_name}") logger.warning(f"No table details found for database {db_name}")
@@ -867,6 +887,13 @@ class PerformanceAnalyticsTools:
async def _get_database_table_details(self, connection, db_name: str) -> List[Dict]: async def _get_database_table_details(self, connection, db_name: str) -> List[Dict]:
"""Get table details for a specific database using session-consistent queries""" """Get table details for a specific database using session-consistent queries"""
try: try:
# SECURITY FIX: Validate db_name before using in SQL
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
logger.warning(f"Invalid database name rejected: {e}")
return []
# Method 1: Try to use session-consistent approach with raw connection # Method 1: Try to use session-consistent approach with raw connection
# This requires accessing the underlying connection to maintain session state # This requires accessing the underlying connection to maintain session state
@@ -877,8 +904,9 @@ class PerformanceAnalyticsTools:
# Use raw connection to maintain session state # Use raw connection to maintain session state
cursor = await raw_conn.cursor() cursor = await raw_conn.cursor()
try: try:
# Execute USE and SHOW DATA in the same session # SECURITY FIX: Use quoted identifier for USE statement
await cursor.execute(f"USE {db_name}") quoted_db = quote_identifier(db_name, "database name")
await cursor.execute(f"USE {quoted_db}")
await cursor.execute("SHOW DATA") await cursor.execute("SHOW DATA")
result = await cursor.fetchall() result = await cursor.fetchall()
@@ -922,9 +950,19 @@ class PerformanceAnalyticsTools:
async def _get_database_table_details_fallback(self, connection, db_name: str) -> List[Dict]: async def _get_database_table_details_fallback(self, connection, db_name: str) -> List[Dict]:
"""Fallback method to get table details using individual queries""" """Fallback method to get table details using individual queries"""
try: try:
# Get all tables in the database # SECURITY FIX: Validate db_name and get auth_context
tables_sql = f"SHOW TABLES FROM {db_name}" auth_context = get_auth_context()
tables_result = await connection.execute(tables_sql)
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
logger.warning(f"Invalid database name rejected: {e}")
return []
# Get all tables in the database using quoted identifier
quoted_db = quote_identifier(db_name, "database name")
tables_sql = f"SHOW TABLES FROM {quoted_db}"
tables_result = await connection.execute(tables_sql, auth_context=auth_context)
if not tables_result.data: if not tables_result.data:
return [] return []
@@ -934,9 +972,11 @@ class PerformanceAnalyticsTools:
table_name = table_row.get(f"Tables_in_{db_name}", "") or table_row.get("table_name", "") table_name = table_row.get(f"Tables_in_{db_name}", "") or table_row.get("table_name", "")
if table_name: if table_name:
try: try:
# Use SHOW DATA FROM db.table for each table # SECURITY FIX: Validate table_name and use safe reference
data_sql = f"SHOW DATA FROM {db_name}.{table_name}" validate_identifier(table_name, "table name")
data_result = await connection.execute(data_sql) safe_table_ref = build_table_reference(table_name, db_name)
data_sql = f"SHOW DATA FROM {safe_table_ref}"
data_result = await connection.execute(data_sql, auth_context=auth_context)
if data_result.data: if data_result.data:
for row in data_result.data: for row in data_result.data:
@@ -1036,6 +1076,7 @@ class PerformanceAnalyticsTools:
async def _get_all_tables_info(self, connection) -> List[Dict]: async def _get_all_tables_info(self, connection) -> List[Dict]:
"""Get basic information for all tables (fallback method)""" """Get basic information for all tables (fallback method)"""
try: try:
auth_context = get_auth_context()
tables_sql = """ tables_sql = """
SELECT SELECT
table_schema, table_schema,
@@ -1053,7 +1094,7 @@ class PerformanceAnalyticsTools:
ORDER BY (data_length + index_length) DESC ORDER BY (data_length + index_length) DESC
""" """
result = await connection.execute(tables_sql) result = await connection.execute(tables_sql, auth_context=auth_context)
return result.data if result.data else [] return result.data if result.data else []
except Exception as e: except Exception as e:
@@ -1120,23 +1161,37 @@ class PerformanceAnalyticsTools:
async def _get_current_table_size(self, connection, full_table_name: str) -> Optional[Dict]: async def _get_current_table_size(self, connection, full_table_name: str) -> Optional[Dict]:
"""Get current table size""" """Get current table size"""
try: try:
# Try to query table size directly # SECURITY FIX: Get auth_context and use parameterized query
size_sql = f""" auth_context = get_auth_context()
# Extract table name for parameterized query
table_name_only = full_table_name.split('.')[-1] if '.' in full_table_name else full_table_name
# Validate identifiers
try:
validate_identifier(table_name_only, "table name")
except SQLSecurityError as e:
logger.warning(f"Invalid table name rejected: {e}")
return None
# Use parameterized query for safety
size_sql = """
SELECT SELECT
COALESCE(ROUND((COALESCE(data_length, 0) + COALESCE(index_length, 0)) / 1024 / 1024, 2), 0) as size_mb, COALESCE(ROUND((COALESCE(data_length, 0) + COALESCE(index_length, 0)) / 1024 / 1024, 2), 0) as size_mb,
COALESCE(table_rows, 0) as `rows` COALESCE(table_rows, 0) as `rows`
FROM information_schema.tables FROM information_schema.tables
WHERE CONCAT(table_schema, '.', table_name) = '{full_table_name}' WHERE CONCAT(table_schema, '.', table_name) = %s
OR table_name = '{full_table_name.split('.')[-1]}' OR table_name = %s
""" """
result = await connection.execute(size_sql) result = await connection.execute(size_sql, params=(full_table_name, table_name_only), auth_context=auth_context)
if result.data and result.data[0]: if result.data and result.data[0]:
return result.data[0] return result.data[0]
# If information_schema has no data, try COUNT query # If information_schema has no data, try COUNT query
# full_table_name should already be validated by caller using build_table_reference
count_sql = f"SELECT COUNT(*) as rows FROM {full_table_name}" count_sql = f"SELECT COUNT(*) as rows FROM {full_table_name}"
count_result = await connection.execute(count_sql) count_result = await connection.execute(count_sql, auth_context=auth_context)
if count_result.data: if count_result.data:
return { return {
"size_mb": 0, # Cannot get exact size "size_mb": 0, # Cannot get exact size
@@ -1145,6 +1200,9 @@ class PerformanceAnalyticsTools:
return None return None
except SQLSecurityError as e:
logger.warning(f"Security validation failed for {full_table_name}: {str(e)}")
return None
except Exception as e: except Exception as e:
logger.warning(f"Failed to get current size for {full_table_name}: {str(e)}") logger.warning(f"Failed to get current size for {full_table_name}: {str(e)}")
return None return None
@@ -1154,8 +1212,19 @@ class PerformanceAnalyticsTools:
) -> List[Dict]: ) -> List[Dict]:
"""Get historical growth data based on partitions""" """Get historical growth data based on partitions"""
try: try:
# Query partition information # SECURITY FIX: Validate identifiers and use parameterized query
partition_sql = f""" auth_context = get_auth_context()
try:
validate_identifier(table_name, "table name")
if schema_name:
validate_identifier(schema_name, "schema name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return []
# Use parameterized query for safety
partition_sql = """
SELECT SELECT
partition_name, partition_name,
partition_description, partition_description,
@@ -1163,15 +1232,19 @@ class PerformanceAnalyticsTools:
data_length, data_length,
create_time create_time
FROM information_schema.partitions FROM information_schema.partitions
WHERE table_schema = '{schema_name or ""}' WHERE table_schema = %s
AND table_name = '{table_name}' AND table_name = %s
AND partition_name IS NOT NULL AND partition_name IS NOT NULL
AND create_time IS NOT NULL AND create_time IS NOT NULL
AND create_time >= DATE_SUB(NOW(), INTERVAL {days} DAY) AND create_time >= DATE_SUB(NOW(), INTERVAL %s DAY)
ORDER BY create_time DESC ORDER BY create_time DESC
""" """
result = await connection.execute(partition_sql) result = await connection.execute(
partition_sql,
params=(schema_name or "", table_name, days),
auth_context=auth_context
)
if not result.data: if not result.data:
return [] return []
@@ -1210,6 +1283,9 @@ class PerformanceAnalyticsTools:
) -> List[Dict]: ) -> List[Dict]:
"""Get historical growth data based on timestamp fields""" """Get historical growth data based on timestamp fields"""
try: try:
# SECURITY FIX: Get auth_context
auth_context = get_auth_context()
# Find possible timestamp fields # Find possible timestamp fields
timestamp_columns = await self._find_timestamp_columns(connection, table_name, schema_name) timestamp_columns = await self._find_timestamp_columns(connection, table_name, schema_name)
if not timestamp_columns: if not timestamp_columns:
@@ -1218,20 +1294,29 @@ class PerformanceAnalyticsTools:
# Use best timestamp field for analysis # Use best timestamp field for analysis
time_column = timestamp_columns[0] time_column = timestamp_columns[0]
# Aggregate data by date # SECURITY FIX: Validate time_column before using in SQL
try:
validate_identifier(time_column, "column name")
except SQLSecurityError as e:
logger.warning(f"Invalid column name rejected: {e}")
return []
quoted_time_column = quote_identifier(time_column, "column name")
# Aggregate data by date (full_table_name should be validated by caller)
growth_sql = f""" growth_sql = f"""
SELECT SELECT
DATE({time_column}) as date, DATE({quoted_time_column}) as date,
COUNT(*) as daily_records, COUNT(*) as daily_records,
COUNT(*) / SUM(COUNT(*)) OVER() * 100 as percentage COUNT(*) / SUM(COUNT(*)) OVER() * 100 as percentage
FROM {full_table_name} FROM {full_table_name}
WHERE {time_column} >= DATE_SUB(NOW(), INTERVAL {days} DAY) WHERE {quoted_time_column} >= DATE_SUB(NOW(), INTERVAL %s DAY)
AND {time_column} IS NOT NULL AND {quoted_time_column} IS NOT NULL
GROUP BY DATE({time_column}) GROUP BY DATE({quoted_time_column})
ORDER BY date DESC ORDER BY date DESC
""" """
result = await connection.execute(growth_sql) result = await connection.execute(growth_sql, params=(days,), auth_context=auth_context)
if not result.data: if not result.data:
return [] return []
@@ -1257,11 +1342,22 @@ class PerformanceAnalyticsTools:
async def _find_timestamp_columns(self, connection, table_name: str, schema_name: str) -> List[str]: async def _find_timestamp_columns(self, connection, table_name: str, schema_name: str) -> List[str]:
"""Find timestamp fields in table""" """Find timestamp fields in table"""
try: try:
timestamp_sql = f""" # SECURITY FIX: Validate identifiers and use parameterized query
auth_context = get_auth_context()
try:
validate_identifier(table_name, "table name")
if schema_name:
validate_identifier(schema_name, "schema name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return []
timestamp_sql = """
SELECT column_name, data_type SELECT column_name, data_type
FROM information_schema.columns FROM information_schema.columns
WHERE table_schema = '{schema_name or ""}' WHERE table_schema = %s
AND table_name = '{table_name}' AND table_name = %s
AND ( AND (
data_type IN ('datetime', 'timestamp', 'date') data_type IN ('datetime', 'timestamp', 'date')
OR column_name REGEXP '(create|insert|update|modify).*time' OR column_name REGEXP '(create|insert|update|modify).*time'
@@ -1278,9 +1374,16 @@ class PerformanceAnalyticsTools:
END END
""" """
result = await connection.execute(timestamp_sql) result = await connection.execute(
timestamp_sql,
params=(schema_name or "", table_name),
auth_context=auth_context
)
return [row["column_name"] for row in result.data] if result.data else [] return [row["column_name"] for row in result.data] if result.data else []
except SQLSecurityError as e:
logger.warning(f"Security validation failed: {str(e)}")
return []
except Exception as e: except Exception as e:
logger.warning(f"Failed to find timestamp columns: {str(e)}") logger.warning(f"Failed to find timestamp columns: {str(e)}")
return [] return []
@@ -1290,8 +1393,22 @@ class PerformanceAnalyticsTools:
) -> List[Dict]: ) -> List[Dict]:
"""Estimate growth data based on audit logs""" """Estimate growth data based on audit logs"""
try: try:
# SECURITY FIX: Validate table_name and use parameterized query
auth_context = get_auth_context()
try:
validate_identifier(table_name.split(".")[-1], "table name")
except SQLSecurityError as e:
logger.warning(f"Invalid table name rejected: {e}")
return []
# Extract just the table name for LIKE pattern
table_name_only = table_name.split(".")[-1]
like_pattern_full = f"%{table_name}%"
like_pattern_short = f"%{table_name_only}%"
# Analyze operation history for this table # Analyze operation history for this table
audit_sql = f""" audit_sql = """
SELECT SELECT
DATE(`time`) as operation_date, DATE(`time`) as operation_date,
COUNT(*) as operation_count, COUNT(*) as operation_count,
@@ -1299,17 +1416,21 @@ class PerformanceAnalyticsTools:
SUM(CASE WHEN stmt LIKE 'UPDATE%' THEN 1 ELSE 0 END) as update_count, SUM(CASE WHEN stmt LIKE 'UPDATE%' THEN 1 ELSE 0 END) as update_count,
SUM(CASE WHEN stmt LIKE 'DELETE%' THEN 1 ELSE 0 END) as delete_count SUM(CASE WHEN stmt LIKE 'DELETE%' THEN 1 ELSE 0 END) as delete_count
FROM internal.__internal_schema.audit_log FROM internal.__internal_schema.audit_log
WHERE `time` >= DATE_SUB(NOW(), INTERVAL {days} DAY) WHERE `time` >= DATE_SUB(NOW(), INTERVAL %s DAY)
AND stmt IS NOT NULL AND stmt IS NOT NULL
AND ( AND (
stmt LIKE '%{table_name}%' stmt LIKE %s
OR stmt LIKE '%{table_name.split(".")[-1]}%' OR stmt LIKE %s
) )
GROUP BY DATE(`time`) GROUP BY DATE(`time`)
ORDER BY operation_date DESC ORDER BY operation_date DESC
""" """
result = await connection.execute(audit_sql) result = await connection.execute(
audit_sql,
params=(days, like_pattern_full, like_pattern_short),
auth_context=auth_context
)
if not result.data: if not result.data:
return [] return []

View File

@@ -33,8 +33,11 @@ from datetime import datetime, timedelta, date
from typing import Any, Dict from typing import Any, Dict
from decimal import Decimal from decimal import Decimal
import sqlparse
from .db import DorisConnectionManager, QueryResult from .db import DorisConnectionManager, QueryResult
from .logger import get_logger from .logger import get_logger
from .sql_security_utils import get_auth_context
@dataclass @dataclass
@@ -467,6 +470,51 @@ class DorisQueryExecutor:
f"Slow query detected: {execution_time:.2f}s (threshold: {self.slow_query_threshold}s)" f"Slow query detected: {execution_time:.2f}s (threshold: {self.slow_query_threshold}s)"
) )
async def execute_batch_sqls_for_mcp(
self, sqls: list[str],
timeout: int = 30,
session_id: str = "mcp_session",
user_id: str = "mcp_user",
auth_context=None
) -> dict[str, Any]:
"""Execute multiple sqls in batch"""
if not sqls:
return {
"success": False,
"error": "SQL query is required",
"data": None
}
query_requests = [
QueryRequest(
sql=sql,
session_id=session_id,
user_id=user_id,
timeout=timeout,
cache_enabled=False
)
for sql in sqls
]
query_results = await self.execute_batch_queries(query_requests, auth_context)
# Serialize data for JSON response
results = [
{
"data": [self._serialize_row_data(data) for data in result.data],
"row_count": result.row_count,
"execution_time": result.execution_time,
"metadata": {
"columns": result.metadata.get("columns", []),
"query": result.sql
}
}
for result in query_results
]
return {
"success": True,
"multiple_results": True,
"results": results
}
async def execute_batch_queries( async def execute_batch_queries(
self, query_requests: list[QueryRequest], auth_context=None self, query_requests: list[QueryRequest], auth_context=None
) -> list[QueryResult]: ) -> list[QueryResult]:
@@ -484,20 +532,24 @@ class DorisQueryExecutor:
self.execute_query(request, auth_context) for request in query_requests self.execute_query(request, auth_context) for request in query_requests
] ]
try: query_results = []
results = await asyncio.gather(*tasks, return_exceptions=True) results = await asyncio.gather(*tasks, return_exceptions=True)
except Exception as e: for result in results:
self.logger.error(f"Batch query execution failed: {e}") if isinstance(result, Exception):
raise self.logger.error(f"Batch query execution failed: {result}")
raise result
else:
query_results.append(result)
return results return query_results
async def explain_query(self, sql: str, session_id: str) -> dict[str, Any]: async def explain_query(self, sql: str, session_id: str) -> dict[str, Any]:
"""Get query execution plan""" """Get query execution plan"""
explain_sql = f"EXPLAIN {sql}" explain_sql = f"EXPLAIN {sql}"
connection = await self.connection_manager.get_connection(session_id) connection = await self.connection_manager.get_connection(session_id)
result = await connection.execute(explain_sql) auth_context = get_auth_context()
result = await connection.execute(explain_sql, auth_context=auth_context)
return { return {
"query": sql, "query": sql,
@@ -541,17 +593,21 @@ class DorisQueryExecutor:
await self.query_cache.clear_all() await self.query_cache.clear_all()
async def execute_sql_for_mcp( async def execute_sql_for_mcp(
self, self,
sql: str, sql: str,
limit: int = 1000, limit: int = 1000,
timeout: int = 30, timeout: int = 30,
session_id: str = "mcp_session", session_id: str = "mcp_session",
user_id: str = "mcp_user" user_id: str = "mcp_user",
auth_context = None # FIX for Issue #62 Bug 1: Accept auth_context with token
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Execute SQL query for MCP interface - unified method""" """Execute SQL query for MCP interface - unified method
FIX for Issue #62 Bug 1: Now accepts auth_context parameter to support token-bound database configuration
"""
max_retries = 2 max_retries = 2
retry_count = 0 retry_count = 0
while retry_count <= max_retries: while retry_count <= max_retries:
try: try:
if not sql: if not sql:
@@ -564,22 +620,34 @@ class DorisQueryExecutor:
# Import required security modules # Import required security modules
from .security import DorisSecurityManager, AuthContext, SecurityLevel from .security import DorisSecurityManager, AuthContext, SecurityLevel
# Create proper auth context with read-only permissions # FIX: Use provided auth_context if available (contains token for DB config)
auth_context = AuthContext( # Otherwise create default auth context for backward compatibility
user_id=user_id, if auth_context is None:
roles=["read_only_user"], # Restrictive role for MCP interface auth_context = AuthContext(
permissions=["read_data"], # Only read permissions user_id=user_id,
session_id=session_id, roles=["read_only_user"], # Restrictive role for MCP interface
security_level=SecurityLevel.INTERNAL permissions=["read_data"], # Only read permissions
) session_id=session_id,
security_level=SecurityLevel.INTERNAL,
token="" # No token in default context
)
else:
# Use provided auth_context (may contain token for database configuration)
self.logger.debug(f"Using provided auth_context with token: {bool(hasattr(auth_context, 'token') and auth_context.token)}")
# Perform SQL security validation if enabled # Perform SQL security validation if enabled
if hasattr(self.connection_manager, 'config') and hasattr(self.connection_manager.config, 'security'): if hasattr(self.connection_manager, 'config') and hasattr(self.connection_manager.config, 'security'):
if self.connection_manager.config.security.enable_security_check: if self.connection_manager.config.security.enable_security_check:
try: try:
security_manager = DorisSecurityManager(self.connection_manager.config) # 🔧 FIX: Use existing security_manager to avoid creating multiple TokenManager instances
# Creating new DorisSecurityManager each time causes multiple hot reload monitors
security_manager = getattr(self.connection_manager, 'security_manager', None)
if not security_manager:
# Fallback: create new one only if not available (should rarely happen)
self.logger.warning("No existing security_manager, creating new instance")
security_manager = DorisSecurityManager(self.connection_manager.config)
validation_result = await security_manager.validate_sql_security(sql, auth_context) validation_result = await security_manager.validate_sql_security(sql, auth_context)
if not validation_result.is_valid: if not validation_result.is_valid:
self.logger.warning(f"SQL security validation failed for query: {sql[:100]}...") self.logger.warning(f"SQL security validation failed for query: {sql[:100]}...")
return { return {
@@ -623,6 +691,15 @@ class DorisQueryExecutor:
sql = sql[:-1] sql = sql[:-1]
sql = f"{sql} LIMIT {limit}" sql = f"{sql} LIMIT {limit}"
all_statements = [
s.strip()
for s in sqlparse.split(sql)
if s.strip()
]
if len(all_statements) > 1:
return await self.execute_batch_sqls_for_mcp(sqls=all_statements, timeout=timeout,
session_id=session_id, user_id=user_id,
auth_context=auth_context)
# Create query request # Create query request
query_request = QueryRequest( query_request = QueryRequest(
sql=sql, sql=sql,
@@ -877,33 +954,42 @@ class QueryPerformanceMonitor:
# Unified convenience function for MCP integration # Unified convenience function for MCP integration
async def execute_sql_query(sql: str, connection_manager: DorisConnectionManager, **kwargs) -> Dict[str, Any]: async def execute_sql_query(sql: str, connection_manager: DorisConnectionManager, **kwargs) -> Dict[str, Any]:
"""Execute SQL query - unified convenience function for MCP tools """Execute SQL query - unified convenience function for MCP tools
This function now includes security validation to ensure safe query execution. This function now includes security validation to ensure safe query execution.
All queries are validated against the configured security policies before execution. All queries are validated against the configured security policies before execution.
FIX for Issue #62 Bug 1: Now supports auth_context parameter for token-bound database configuration
FIX for Issue #58 Problem 2: Removed executor.close() to prevent ClosedResourceError in multi-worker mode
""" """
try: try:
# Create query executor with the connection manager's configuration # Create query executor with the connection manager's configuration
executor = DorisQueryExecutor(connection_manager) executor = DorisQueryExecutor(connection_manager)
try: # Extract parameters from kwargs or use defaults
# Extract parameters from kwargs or use defaults limit = kwargs.get("limit", 1000)
limit = kwargs.get("limit", 1000) timeout = kwargs.get("timeout", 30)
timeout = kwargs.get("timeout", 30) session_id = kwargs.get("session_id", "mcp_session")
session_id = kwargs.get("session_id", "mcp_session") user_id = kwargs.get("user_id", "mcp_user")
user_id = kwargs.get("user_id", "mcp_user") auth_context = kwargs.get("auth_context", None) # FIX: Extract auth_context
# The execute_sql_for_mcp method now includes security validation # The execute_sql_for_mcp method now includes security validation
result = await executor.execute_sql_for_mcp( result = await executor.execute_sql_for_mcp(
sql=sql, sql=sql,
limit=limit, limit=limit,
timeout=timeout, timeout=timeout,
session_id=session_id, session_id=session_id,
user_id=user_id user_id=user_id,
) auth_context=auth_context # FIX: Pass auth_context with token
return result )
finally:
await executor.close() # FIX for Issue #58 Problem 2: Do NOT close executor here
# In multi-worker mode, closing here causes ClosedResourceError
# The executor's resources (cache, background tasks) will be managed
# by the connection_manager lifecycle and Python's garbage collection
# This prevents premature cleanup while MCP session manager is still processing
return result
except Exception as e: except Exception as e:
return { return {
"success": False, "success": False,

View File

@@ -32,6 +32,11 @@ from datetime import datetime, timedelta
# Import unified logging configuration # Import unified logging configuration
from .logger import get_logger from .logger import get_logger
from .sql_security_utils import (
SQLSecurityError,
validate_identifier,
quote_identifier
)
# Configure logging # Configure logging
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -431,6 +436,16 @@ class MetadataExtractor:
logger.warning("Database name not specified") logger.warning("Database name not specified")
return {} return {}
# SECURITY FIX: Validate identifiers to prevent SQL injection
try:
validate_identifier(table_name, "table name")
validate_identifier(db_name, "database name")
if effective_catalog:
validate_identifier(effective_catalog, "catalog name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected in get_table_schema: {e}")
return {}
cache_key = f"schema_{effective_catalog or 'default'}_{db_name}_{table_name}" cache_key = f"schema_{effective_catalog or 'default'}_{db_name}_{table_name}"
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl: if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
return self.metadata_cache[cache_key] return self.metadata_cache[cache_key]
@@ -536,6 +551,16 @@ class MetadataExtractor:
logger.warning("Database name not specified") logger.warning("Database name not specified")
return "" return ""
# SECURITY FIX: Validate identifiers
try:
validate_identifier(table_name, "table name")
validate_identifier(db_name, "database name")
if effective_catalog:
validate_identifier(effective_catalog, "catalog name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return ""
cache_key = f"table_comment_{effective_catalog or 'default'}_{db_name}_{table_name}" cache_key = f"table_comment_{effective_catalog or 'default'}_{db_name}_{table_name}"
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl: if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
return self.metadata_cache[cache_key] return self.metadata_cache[cache_key]
@@ -587,6 +612,16 @@ class MetadataExtractor:
logger.warning("Database name not specified") logger.warning("Database name not specified")
return {} return {}
# SECURITY FIX: Validate identifiers
try:
validate_identifier(table_name, "table name")
validate_identifier(db_name, "database name")
if effective_catalog:
validate_identifier(effective_catalog, "catalog name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return {}
cache_key = f"column_comments_{effective_catalog or 'default'}_{db_name}_{table_name}" cache_key = f"column_comments_{effective_catalog or 'default'}_{db_name}_{table_name}"
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl: if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
return self.metadata_cache[cache_key] return self.metadata_cache[cache_key]
@@ -643,17 +678,30 @@ class MetadataExtractor:
logger.error("Database name not specified") logger.error("Database name not specified")
return [] return []
# SECURITY FIX: Validate identifiers
try:
validate_identifier(table_name, "table name")
validate_identifier(db_name, "database name")
if effective_catalog:
validate_identifier(effective_catalog, "catalog name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return []
cache_key = f"indexes_{effective_catalog or 'default'}_{db_name}_{table_name}" cache_key = f"indexes_{effective_catalog or 'default'}_{db_name}_{table_name}"
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl: if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
return self.metadata_cache[cache_key] return self.metadata_cache[cache_key]
try: try:
# Build query with catalog prefix if specified # Build query with catalog prefix if specified (identifiers already validated)
safe_table = quote_identifier(table_name, "table name")
safe_db = quote_identifier(db_name, "database name")
if effective_catalog: if effective_catalog:
query = f"SHOW INDEX FROM `{effective_catalog}`.`{db_name}`.`{table_name}`" safe_catalog = quote_identifier(effective_catalog, "catalog name")
query = f"SHOW INDEX FROM {safe_catalog}.{safe_db}.{safe_table}"
logger.info(f"Using three-part naming for index query: {query}") logger.info(f"Using three-part naming for index query: {query}")
else: else:
query = f"SHOW INDEX FROM `{db_name}`.`{table_name}`" query = f"SHOW INDEX FROM {safe_db}.{safe_table}"
try: try:
# NOTE: Deprecated sync path retained for compatibility; use async variant instead. # NOTE: Deprecated sync path retained for compatibility; use async variant instead.
@@ -1146,8 +1194,17 @@ class MetadataExtractor:
""" """
try: try:
if self.connection_manager: if self.connection_manager:
# FIX: Get auth_context from global ContextVar for token-bound database configuration
# This ensures all query methods use the correct user's connection pool
auth_context = None
try:
from .security import mcp_auth_context_var
auth_context = mcp_auth_context_var.get()
except Exception:
pass
# Use the injected connection manager directly (async) # Use the injected connection manager directly (async)
result = await self.connection_manager.execute_query(self._session_id, query, None) result = await self.connection_manager.execute_query(self._session_id, query, None, auth_context)
# Extract data from QueryResult # Extract data from QueryResult
if hasattr(result, 'data'): if hasattr(result, 'data'):
@@ -1188,12 +1245,28 @@ class MetadataExtractor:
try: try:
# Use async query method # Use async query method
effective_catalog = catalog_name or self.catalog_name effective_catalog = catalog_name or self.catalog_name
effective_db = db_name or self.db_name
# SECURITY FIX: Validate identifiers
try:
validate_identifier(table_name, "table name")
if effective_db:
validate_identifier(effective_db, "database name")
if effective_catalog and effective_catalog != "internal":
validate_identifier(effective_catalog, "catalog name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return []
# Build query statement using safe identifiers
safe_table = quote_identifier(table_name, "table name")
safe_db = quote_identifier(effective_db, "database name") if effective_db else None
# Build query statement
if effective_catalog and effective_catalog != "internal": if effective_catalog and effective_catalog != "internal":
query = f"DESCRIBE `{effective_catalog}`.`{db_name or self.db_name}`.`{table_name}`" safe_catalog = quote_identifier(effective_catalog, "catalog name")
query = f"DESCRIBE {safe_catalog}.{safe_db}.{safe_table}"
else: else:
query = f"DESCRIBE `{db_name or self.db_name}`.`{table_name}`" query = f"DESCRIBE {safe_db}.{safe_table}"
# Execute async query # Execute async query
result = await self._execute_query_async(query, db_name) result = await self._execute_query_async(query, db_name)
@@ -1226,8 +1299,15 @@ class MetadataExtractor:
try: try:
effective_catalog = catalog_name or self.catalog_name effective_catalog = catalog_name or self.catalog_name
# SECURITY FIX: Validate catalog name if provided
if effective_catalog and effective_catalog != "internal": if effective_catalog and effective_catalog != "internal":
query = f"SHOW DATABASES FROM `{effective_catalog}`" try:
validate_identifier(effective_catalog, "catalog name")
except SQLSecurityError as e:
logger.warning(f"Invalid catalog name rejected: {e}")
return []
safe_catalog = quote_identifier(effective_catalog, "catalog name")
query = f"SHOW DATABASES FROM {safe_catalog}"
else: else:
query = "SHOW DATABASES" query = "SHOW DATABASES"
@@ -1257,10 +1337,23 @@ class MetadataExtractor:
effective_catalog = catalog_name or self.catalog_name effective_catalog = catalog_name or self.catalog_name
effective_db = db_name or self.db_name effective_db = db_name or self.db_name
# SECURITY FIX: Validate identifiers
try:
if effective_db:
validate_identifier(effective_db, "database name")
if effective_catalog and effective_catalog != "internal":
validate_identifier(effective_catalog, "catalog name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return []
safe_db = quote_identifier(effective_db, "database name") if effective_db else None
if effective_catalog and effective_catalog != "internal": if effective_catalog and effective_catalog != "internal":
query = f"SHOW TABLES FROM `{effective_catalog}`.`{effective_db}`" safe_catalog = quote_identifier(effective_catalog, "catalog name")
query = f"SHOW TABLES FROM {safe_catalog}.{safe_db}"
else: else:
query = f"SHOW TABLES FROM `{effective_db}`" query = f"SHOW TABLES FROM {safe_db}"
result = await self._execute_query_async(query, effective_db) result = await self._execute_query_async(query, effective_db)
@@ -1319,6 +1412,15 @@ class MetadataExtractor:
effective_db = db_name or self.db_name effective_db = db_name or self.db_name
effective_catalog = catalog_name or self.catalog_name effective_catalog = catalog_name or self.catalog_name
# SECURITY FIX: Validate identifiers
try:
validate_identifier(table_name, "table name")
if effective_db:
validate_identifier(effective_db, "database name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return ""
query = f""" query = f"""
SELECT SELECT
TABLE_COMMENT TABLE_COMMENT
@@ -1343,6 +1445,15 @@ class MetadataExtractor:
effective_db = db_name or self.db_name effective_db = db_name or self.db_name
effective_catalog = catalog_name or self.catalog_name effective_catalog = catalog_name or self.catalog_name
# SECURITY FIX: Validate identifiers
try:
validate_identifier(table_name, "table name")
if effective_db:
validate_identifier(effective_db, "database name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return {}
query = f""" query = f"""
SELECT SELECT
COLUMN_NAME, COLUMN_NAME,
@@ -1373,12 +1484,27 @@ class MetadataExtractor:
effective_db = db_name or self.db_name effective_db = db_name or self.db_name
effective_catalog = catalog_name or self.catalog_name effective_catalog = catalog_name or self.catalog_name
# Build query with catalog prefix if specified # SECURITY FIX: Validate identifiers
try:
validate_identifier(table_name, "table name")
if effective_db:
validate_identifier(effective_db, "database name")
if effective_catalog:
validate_identifier(effective_catalog, "catalog name")
except SQLSecurityError as e:
logger.warning(f"Invalid identifier rejected: {e}")
return []
# Build query with catalog prefix if specified (using safe identifiers)
safe_table = quote_identifier(table_name, "table name")
safe_db = quote_identifier(effective_db, "database name") if effective_db else None
if effective_catalog: if effective_catalog:
query = f"SHOW INDEX FROM `{effective_catalog}`.`{effective_db}`.`{table_name}`" safe_catalog = quote_identifier(effective_catalog, "catalog name")
query = f"SHOW INDEX FROM {safe_catalog}.{safe_db}.{safe_table}"
logger.info(f"Using three-part naming for async index query: {query}") logger.info(f"Using three-part naming for async index query: {query}")
else: else:
query = f"SHOW INDEX FROM `{effective_db}`.`{table_name}`" query = f"SHOW INDEX FROM {safe_db}.{safe_table}"
rows = await self._execute_query_async(query, effective_db) rows = await self._execute_query_async(query, effective_db)
indexes: List[Dict[str, Any]] = [] indexes: List[Dict[str, Any]] = []
@@ -1454,32 +1580,106 @@ class MetadataExtractor:
return response_data return response_data
async def exec_query_for_mcp( async def exec_query_for_mcp(
self, self,
sql: str, sql: str,
db_name: str = None, db_name: str = None,
catalog_name: str = None, catalog_name: str = None,
max_rows: int = 100, max_rows: int = 100,
timeout: int = 30 timeout: int = 30
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Execute SQL query and return results, supports catalog federation queries Execute SQL query and return results, supports catalog federation queries
Unified interface for MCP tools Unified interface for MCP tools
FIX for Issue #62 Bug 1: Now retrieves auth_context from context variable to support token-bound database configuration
FIX for Issue #62 Bug 3: Now uses db_name and catalog_name parameters to switch database context
""" """
logger.info(f"Executing SQL query: {sql}, DB: {db_name}, Catalog: {catalog_name}, MaxRows: {max_rows}, Timeout: {timeout}") logger.info(f"Executing SQL query: {sql}, DB: {db_name}, Catalog: {catalog_name}, MaxRows: {max_rows}, Timeout: {timeout}")
try: try:
if not sql: if not sql:
return self._format_response(success=False, error="No SQL statement provided", message="Please provide SQL statement to execute") return self._format_response(success=False, error="No SQL statement provided", message="Please provide SQL statement to execute")
# FIX for Issue #62 Bug 3: Build context switching SQL if db_name or catalog_name is specified
# SECURITY FIX: Validate catalog_name and db_name to prevent SQL injection
final_sql = sql
if catalog_name or db_name:
context_statements = []
# Validate and sanitize catalog_name
if catalog_name:
try:
validate_identifier(catalog_name, "catalog name")
except SQLSecurityError as e:
logger.warning(f"Invalid catalog name rejected: {e}")
return self._format_response(
success=False,
error=f"Invalid catalog name: {catalog_name}",
message="Catalog name contains invalid characters"
)
# Use quote_identifier to safely escape the catalog name
safe_catalog = quote_identifier(catalog_name, "catalog name")
context_statements.append(f"USE CATALOG {safe_catalog}")
logger.debug(f"Switching to catalog: {catalog_name}")
# Validate and sanitize db_name
if db_name:
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
logger.warning(f"Invalid database name rejected: {e}")
return self._format_response(
success=False,
error=f"Invalid database name: {db_name}",
message="Database name contains invalid characters"
)
# Use quote_identifier to safely escape the database name
safe_db = quote_identifier(db_name, "database name")
if catalog_name:
safe_catalog = quote_identifier(catalog_name, "catalog name")
context_statements.append(f"USE {safe_catalog}.{safe_db}")
else:
context_statements.append(f"USE {safe_db}")
logger.debug(f"Switching to database: {db_name}")
# Combine context switching with original SQL
if context_statements:
# Remove trailing semicolon from context statements if present
context_sql = "; ".join(context_statements)
# Ensure original SQL doesn't start with semicolon
sql_clean = sql.lstrip(";").strip()
final_sql = f"{context_sql}; {sql_clean}"
logger.debug(f"Modified SQL with context switching: {final_sql[:200]}...")
# FIX: Try to get auth_context from context variable (set by HTTP middleware)
# This allows token-bound database configuration to work
# CRITICAL: Use the global ContextVar from security.py to ensure same instance is used everywhere
auth_context = None
try:
from .security import mcp_auth_context_var
# Get auth_context from the global context variable
# This will be set by the HTTP request handler in main.py
auth_context = mcp_auth_context_var.get()
if auth_context:
logger.debug(f"Retrieved auth_context from context variable with token: {bool(hasattr(auth_context, 'token') and auth_context.token)}")
else:
logger.debug("No auth_context found in context variable, using default")
except Exception as ctx_error:
logger.debug(f"Could not retrieve auth_context from context variable: {ctx_error}")
auth_context = None
# Import query executor # Import query executor
from .query_executor import execute_sql_query from .query_executor import execute_sql_query
# Call execute_sql_query to execute query # Call execute_sql_query to execute query with auth_context
exec_result = await execute_sql_query( exec_result = await execute_sql_query(
sql=sql, sql=final_sql, # Use modified SQL with context switching
connection_manager=self.connection_manager, connection_manager=self.connection_manager,
limit=max_rows, limit=max_rows,
timeout=timeout timeout=timeout,
auth_context=auth_context # FIX: Pass auth_context with token
) )
return exec_result return exec_result
@@ -1500,6 +1700,36 @@ class MetadataExtractor:
if not table_name: if not table_name:
return self._format_response(success=False, error="Missing table_name parameter") return self._format_response(success=False, error="Missing table_name parameter")
# SECURITY: Validate identifiers before processing
try:
validate_identifier(table_name, "table name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid table name: {table_name}",
message="Table name contains invalid characters"
)
if db_name:
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid database name: {db_name}",
message="Database name contains invalid characters"
)
if catalog_name and catalog_name != "internal":
try:
validate_identifier(catalog_name, "catalog name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid catalog name: {catalog_name}",
message="Catalog name contains invalid characters"
)
try: try:
schema = await self.get_table_schema_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name) schema = await self.get_table_schema_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
@@ -1523,6 +1753,27 @@ class MetadataExtractor:
"""Get list of all table names in specified database - MCP interface""" """Get list of all table names in specified database - MCP interface"""
logger.info(f"Getting database table list: DB: {db_name}, Catalog: {catalog_name}") logger.info(f"Getting database table list: DB: {db_name}, Catalog: {catalog_name}")
# SECURITY: Validate identifiers
if db_name:
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid database name: {db_name}",
message="Database name contains invalid characters"
)
if catalog_name and catalog_name != "internal":
try:
validate_identifier(catalog_name, "catalog name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid catalog name: {catalog_name}",
message="Catalog name contains invalid characters"
)
try: try:
tables = await self.get_database_tables_async(db_name=db_name, catalog_name=catalog_name) tables = await self.get_database_tables_async(db_name=db_name, catalog_name=catalog_name)
return self._format_response(success=True, result=tables) return self._format_response(success=True, result=tables)
@@ -1553,6 +1804,36 @@ class MetadataExtractor:
if not table_name: if not table_name:
return self._format_response(success=False, error="Missing table_name parameter") return self._format_response(success=False, error="Missing table_name parameter")
# SECURITY: Validate identifiers
try:
validate_identifier(table_name, "table name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid table name: {table_name}",
message="Table name contains invalid characters"
)
if db_name:
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid database name: {db_name}",
message="Database name contains invalid characters"
)
if catalog_name and catalog_name != "internal":
try:
validate_identifier(catalog_name, "catalog name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid catalog name: {catalog_name}",
message="Catalog name contains invalid characters"
)
try: try:
comment = await self.get_table_comment_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name) comment = await self.get_table_comment_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
return self._format_response(success=True, result=comment) return self._format_response(success=True, result=comment)
@@ -1572,6 +1853,36 @@ class MetadataExtractor:
if not table_name: if not table_name:
return self._format_response(success=False, error="Missing table_name parameter") return self._format_response(success=False, error="Missing table_name parameter")
# SECURITY: Validate identifiers
try:
validate_identifier(table_name, "table name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid table name: {table_name}",
message="Table name contains invalid characters"
)
if db_name:
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid database name: {db_name}",
message="Database name contains invalid characters"
)
if catalog_name and catalog_name != "internal":
try:
validate_identifier(catalog_name, "catalog name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid catalog name: {catalog_name}",
message="Catalog name contains invalid characters"
)
try: try:
comments = await self.get_column_comments_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name) comments = await self.get_column_comments_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
return self._format_response(success=True, result=comments) return self._format_response(success=True, result=comments)
@@ -1591,6 +1902,36 @@ class MetadataExtractor:
if not table_name: if not table_name:
return self._format_response(success=False, error="Missing table_name parameter") return self._format_response(success=False, error="Missing table_name parameter")
# SECURITY: Validate identifiers
try:
validate_identifier(table_name, "table name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid table name: {table_name}",
message="Table name contains invalid characters"
)
if db_name:
try:
validate_identifier(db_name, "database name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid database name: {db_name}",
message="Database name contains invalid characters"
)
if catalog_name and catalog_name != "internal":
try:
validate_identifier(catalog_name, "catalog name")
except SQLSecurityError as e:
return self._format_response(
success=False,
error=f"Invalid catalog name: {catalog_name}",
message="Catalog name contains invalid characters"
)
try: try:
indexes = await self.get_table_indexes_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name) indexes = await self.get_table_indexes_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name)
return self._format_response(success=True, result=indexes) return self._format_response(success=True, result=indexes)

View File

@@ -22,6 +22,7 @@ Implements enterprise-level authentication, authorization, SQL security validati
import logging import logging
import re import re
from contextvars import ContextVar
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
@@ -34,6 +35,10 @@ from sqlparse.tokens import Keyword, Name
from .logger import get_logger from .logger import get_logger
from .config import DatabaseConfig from .config import DatabaseConfig
# Global ContextVar for auth_context - must be a single instance shared across all modules
# This allows token-bound database configuration to work correctly in concurrent requests
mcp_auth_context_var: ContextVar['AuthContext'] = ContextVar('mcp_auth_context', default=None)
class SecurityLevel(Enum): class SecurityLevel(Enum):
"""Security level enumeration""" """Security level enumeration"""
@@ -901,30 +906,50 @@ class SQLSecurityValidator:
if not self.enable_security_check: if not self.enable_security_check:
self.logger.debug("SQL security check is disabled, allowing all queries") self.logger.debug("SQL security check is disabled, allowing all queries")
return ValidationResult(is_valid=True) return ValidationResult(is_valid=True)
try: try:
# Parse SQL statement # SECURITY FIX: Parse ALL SQL statements, not just the first one
parsed = sqlparse.parse(sql)[0] # This prevents bypassing security checks by injecting additional statements
all_statements = sqlparse.parse(sql)
# Check blocked operations first (more specific) if not all_statements:
keyword_result = await self._check_blocked_keywords(parsed) return ValidationResult(
if not keyword_result.is_valid: is_valid=False,
return keyword_result error_message="Empty or invalid SQL statement",
risk_level="medium"
)
# Check SQL injection risks # SECURITY FIX: Validate each statement individually
injection_result = await self._check_sql_injection(sql, parsed) for idx, parsed in enumerate(all_statements):
if not injection_result.is_valid: # Skip empty statements (e.g., from trailing semicolons)
return injection_result if not parsed.tokens or str(parsed).strip() == '':
continue
# Check query complexity self.logger.debug(f"Validating SQL statement {idx + 1}/{len(all_statements)}: {str(parsed)[:100]}...")
complexity_result = await self._check_query_complexity(parsed)
if not complexity_result.is_valid:
return complexity_result
# Check table access permissions # Check blocked operations first (more specific)
table_result = await self._check_table_access(parsed, auth_context) keyword_result = await self._check_blocked_keywords(parsed)
if not table_result.is_valid: if not keyword_result.is_valid:
return table_result keyword_result.error_message = f"Statement {idx + 1}: {keyword_result.error_message}"
return keyword_result
# Check SQL injection risks
injection_result = await self._check_sql_injection(sql, parsed)
if not injection_result.is_valid:
injection_result.error_message = f"Statement {idx + 1}: {injection_result.error_message}"
return injection_result
# Check query complexity
complexity_result = await self._check_query_complexity(parsed)
if not complexity_result.is_valid:
complexity_result.error_message = f"Statement {idx + 1}: {complexity_result.error_message}"
return complexity_result
# Check table access permissions
table_result = await self._check_table_access(parsed, auth_context)
if not table_result.is_valid:
table_result.error_message = f"Statement {idx + 1}: {table_result.error_message}"
return table_result
return ValidationResult(is_valid=True) return ValidationResult(is_valid=True)
@@ -939,28 +964,69 @@ class SQLSecurityValidator:
async def _check_sql_injection( async def _check_sql_injection(
self, sql: str, parsed: Statement self, sql: str, parsed: Statement
) -> ValidationResult: ) -> ValidationResult:
"""Check SQL injection risks""" """Check SQL injection risks with improved pattern detection
# Check common SQL injection patterns
FIX for Issue #62 Bug 2: Improved patterns to reduce false positives
Now better distinguishes between legitimate SQL (like BETWEEN...AND) and injection attempts
"""
# Improved injection patterns that are more specific and less prone to false positives
injection_patterns = [ injection_patterns = [
r"(?i)(?<![A-Za-z0-9_])(union|select|insert|update|delete|drop|create|alter)(?![A-Za-z0-9_])\s+[\s\S]*?\s+(?<![A-Za-z0-9_])(union|select|insert|update|delete|drop|create|alter)(?![A-Za-z0-9_])", # Stacked queries with dangerous operations (true injection risk)
r"(\s|^)(or|and)\s+\d+\s*=\s*\d+", r";\s*(DROP|DELETE|TRUNCATE|ALTER|CREATE|INSERT|UPDATE)\s+",
r"(\s|^)(or|and)\s+['\"].*['\"]",
r";\s*(drop|delete|truncate|alter|create)", # UNION-based injection (but allow legitimate UNION queries)
r"(exec|execute|sp_|xp_)", # Only flag if UNION is followed by suspicious patterns like SELECT with WHERE 1=1
r"(script|javascript|vbscript)", r"UNION\s+(ALL\s+)?SELECT\s+.*\s+(WHERE|AND|OR)\s+\d+\s*=\s*\d+",
r"(char|ascii|substring|concat)\s*\(",
# Boolean-based blind injection with comments (true injection pattern)
r"(WHERE|AND|OR)\s+\d+\s*=\s*\d+\s*(--|#|/\*)",
# Quote-based injection attempts (but not in legitimate strings)
r"(WHERE|AND|OR)\s+(['\"])[^\2]*\2\s*=\s*\2[^\2]*\2",
# Time-based blind injection
r"(SLEEP|WAITFOR|BENCHMARK)\s*\(",
# System stored procedure injection
r"(EXEC|EXECUTE|SP_|XP_)\s*\(",
# Script injection attempts
r"<\s*(SCRIPT|JAVASCRIPT|VBSCRIPT)",
] ]
sql_lower = sql.lower() # FIX: Don't flag legitimate SQL functions and keywords
# These patterns are too broad and cause false positives:
# - REMOVED: r"(char|ascii|substring|concat)\s*\(" - These are legitimate SQL functions
# - REMOVED: r"(\s|^)(or|and)\s+\d+\s*=\s*\d+" - This flags BETWEEN...AND constructs
# - REMOVED: r"(\s|^)(or|and)\s+['\"].*['\"]" - This is too broad
sql_upper = sql.upper()
# Special case: Allow BETWEEN...AND which is legitimate SQL
# This prevents false positives like "WHERE dt BETWEEN '2025-01-01' AND '2025-01-31'"
if "BETWEEN" in sql_upper and "AND" in sql_upper:
# This is likely a BETWEEN clause, not injection
# Check if AND appears in a BETWEEN context
between_pattern = r"BETWEEN\s+[^\s]+\s+AND\s+[^\s]+"
if re.search(between_pattern, sql_upper, re.IGNORECASE):
# Remove BETWEEN clauses before checking other patterns
sql_cleaned = re.sub(between_pattern, "BETWEEN_CLAUSE", sql_upper, flags=re.IGNORECASE)
sql_to_check = sql_cleaned
else:
sql_to_check = sql_upper
else:
sql_to_check = sql_upper
for pattern in injection_patterns: for pattern in injection_patterns:
if re.search(pattern, sql_lower, re.IGNORECASE): if re.search(pattern, sql_to_check, re.IGNORECASE):
self.logger.warning(f"Potential SQL injection pattern detected: {pattern}")
return ValidationResult( return ValidationResult(
is_valid=False, is_valid=False,
error_message="Potential SQL injection risk detected", error_message="Potential SQL injection risk detected",
risk_level="high", risk_level="high",
) )
# Check suspicious quotes and comments # Check suspicious quotes and comments (with improved detection)
if self._has_suspicious_quotes_or_comments(sql): if self._has_suspicious_quotes_or_comments(sql):
return ValidationResult( return ValidationResult(
is_valid=False, is_valid=False,
@@ -971,19 +1037,67 @@ class SQLSecurityValidator:
return ValidationResult(is_valid=True) return ValidationResult(is_valid=True)
def _has_suspicious_quotes_or_comments(self, sql: str) -> bool: def _has_suspicious_quotes_or_comments(self, sql: str) -> bool:
"""Check suspicious quote and comment patterns""" """Check suspicious quote and comment patterns with improved detection
# Check unmatched quotes
single_quotes = sql.count("'")
double_quotes = sql.count('"')
if single_quotes % 2 != 0 or double_quotes % 2 != 0: FIX for Issue #62 Bug 2: Improved detection to reduce false positives
return True Now distinguishes between legitimate comments/strings and injection attempts
"""
try:
# Use sqlparse to parse the SQL and distinguish between code and comments/strings
import sqlparse
from sqlparse.tokens import Comment, String
# Check SQL comments # Parse the SQL
if "--" in sql or "/*" in sql: parsed = sqlparse.parse(sql)
return True if not parsed:
# If parsing fails, be conservative
return True
return False statement = parsed[0]
# Check for unmatched quotes ONLY in non-string tokens
# This prevents false positives from legitimate string content
non_string_content = []
has_string_tokens = False
for token in statement.flatten():
if token.ttype in (String.Single, String.Double):
has_string_tokens = True
# Skip string content - quotes inside strings are legitimate
continue
elif token.ttype in (Comment.Single, Comment.Multi):
# Comments are generally OK, but check for suspicious injection patterns
comment_value = str(token).lower()
# Check if comment contains dangerous SQL keywords
dangerous_in_comments = ['drop', 'delete', 'insert', 'update', 'union', 'exec', 'execute']
if any(keyword in comment_value for keyword in dangerous_in_comments):
self.logger.warning(f"Suspicious SQL keyword in comment: {token}")
return True
# Normal comments are OK
continue
else:
# Accumulate non-string, non-comment content
non_string_content.append(str(token))
# Check for unmatched quotes in non-string content
non_string_text = ''.join(non_string_content)
single_quotes = non_string_text.count("'")
double_quotes = non_string_text.count('"')
# Only flag if there are unmatched quotes in actual SQL code (not in strings)
if single_quotes % 2 != 0 or double_quotes % 2 != 0:
return True
# FIX: Don't flag legitimate SQL comments
# Comments are OK as long as they don't contain dangerous patterns (already checked above)
return False
except Exception as e:
self.logger.debug(f"SQL parsing error in quote/comment check: {e}")
# On parsing error, fall back to conservative check
# But be more lenient than before
return False # Don't flag on parse errors to reduce false positives
async def _check_blocked_keywords(self, parsed: Statement) -> ValidationResult: async def _check_blocked_keywords(self, parsed: Statement) -> ValidationResult:
"""Check blocked keywords""" """Check blocked keywords"""
@@ -1045,6 +1159,10 @@ class SQLSecurityValidator:
self, parsed: Statement, auth_context: AuthContext self, parsed: Statement, auth_context: AuthContext
) -> ValidationResult: ) -> ValidationResult:
"""Check table access permissions""" """Check table access permissions"""
# If no auth_context, skip table access checks (rely on other security checks)
if auth_context is None:
return ValidationResult(is_valid=True)
# Extract table names from query # Extract table names from query
tables = self._extract_table_names(parsed) tables = self._extract_table_names(parsed)

View File

@@ -26,6 +26,7 @@ from collections import Counter, defaultdict
from .db import DorisConnectionManager from .db import DorisConnectionManager
from .logger import get_logger from .logger import get_logger
from .sql_security_utils import get_auth_context
logger = get_logger(__name__) logger = get_logger(__name__)
@@ -192,7 +193,9 @@ class SecurityAnalyticsTools:
LIMIT 10000 LIMIT 10000
""" """
result = await connection.execute(audit_sql) # SECURITY FIX: Pass auth_context to execute
auth_context = get_auth_context()
result = await connection.execute(audit_sql, auth_context=auth_context)
return result.data if result.data else [] return result.data if result.data else []
except Exception as e: except Exception as e:
@@ -215,7 +218,8 @@ class SecurityAnalyticsTools:
LIMIT 10000 LIMIT 10000
""" """
result = await connection.execute(simple_audit_sql) auth_context = get_auth_context()
result = await connection.execute(simple_audit_sql, auth_context=auth_context)
return result.data if result.data else [] return result.data if result.data else []
except Exception as e2: except Exception as e2:
@@ -498,7 +502,8 @@ class SecurityAnalyticsTools:
FROM mysql.user FROM mysql.user
""" """
result = await connection.execute(roles_sql) auth_context = get_auth_context()
result = await connection.execute(roles_sql, auth_context=auth_context)
user_roles = defaultdict(list) user_roles = defaultdict(list)
if result.data: if result.data:

View File

@@ -0,0 +1,301 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
SQL Security Utilities Module
Provides SQL identifier validation, escaping, and safe query building utilities
to prevent SQL injection attacks.
"""
import re
from contextvars import ContextVar
from typing import Optional, Tuple, List, Any
from .logger import get_logger
logger = get_logger(__name__)
# Context variable for auth_context (set by HTTP middleware)
auth_context_var: ContextVar = ContextVar('mcp_auth_context', default=None)
class SQLSecurityError(Exception):
"""Exception raised for SQL security validation failures"""
pass
class SQLSecurityUtils:
"""
SQL Security Utilities for preventing SQL injection attacks.
Provides:
- Identifier validation (database names, table names, column names)
- Safe identifier quoting with backticks
- Safe table reference building
- Auth context retrieval from context variables
"""
# Valid SQL identifier pattern: letters, numbers, underscores
# Must start with letter or underscore, not a number
# Supports Unicode letters for international database/table names
IDENTIFIER_PATTERN = re.compile(r'^[a-zA-Z_\u4e00-\u9fff][a-zA-Z0-9_\u4e00-\u9fff]*$')
# Maximum identifier length (MySQL/Doris standard)
MAX_IDENTIFIER_LENGTH = 64
# SQL reserved keywords that should be quoted
SQL_KEYWORDS = {
'SELECT', 'FROM', 'WHERE', 'INSERT', 'UPDATE', 'DELETE', 'DROP',
'CREATE', 'ALTER', 'TABLE', 'DATABASE', 'INDEX', 'VIEW', 'AND',
'OR', 'NOT', 'NULL', 'TRUE', 'FALSE', 'IN', 'LIKE', 'BETWEEN',
'JOIN', 'LEFT', 'RIGHT', 'INNER', 'OUTER', 'ON', 'AS', 'ORDER',
'BY', 'GROUP', 'HAVING', 'LIMIT', 'OFFSET', 'UNION', 'ALL',
'DISTINCT', 'INTO', 'VALUES', 'SET', 'DEFAULT', 'PRIMARY', 'KEY',
'FOREIGN', 'REFERENCES', 'CHECK', 'UNIQUE', 'CONSTRAINT'
}
@classmethod
def validate_identifier(cls, name: str, identifier_type: str = "identifier") -> str:
"""
Validate a SQL identifier (database name, table name, column name, etc.)
Args:
name: The identifier to validate
identifier_type: Type description for error messages (e.g., "database name", "table name")
Returns:
The validated identifier (unchanged if valid)
Raises:
SQLSecurityError: If the identifier is invalid
"""
if not name:
raise SQLSecurityError(f"Empty {identifier_type} is not allowed")
if not isinstance(name, str):
raise SQLSecurityError(f"Invalid {identifier_type}: must be a string, got {type(name).__name__}")
# Strip whitespace
name = name.strip()
if not name:
raise SQLSecurityError(f"Empty {identifier_type} is not allowed")
# Check length
if len(name) > cls.MAX_IDENTIFIER_LENGTH:
raise SQLSecurityError(
f"Invalid {identifier_type}: '{name[:20]}...' exceeds maximum length of {cls.MAX_IDENTIFIER_LENGTH} characters"
)
# Check for dangerous characters that could be SQL injection
dangerous_chars = ["'", '"', ';', '--', '/*', '*/', '\\', '\x00']
for char in dangerous_chars:
if char in name:
raise SQLSecurityError(
f"Invalid {identifier_type}: '{name}' contains forbidden character '{char}'"
)
# Validate pattern
if not cls.IDENTIFIER_PATTERN.match(name):
raise SQLSecurityError(
f"Invalid {identifier_type}: '{name}' contains invalid characters. "
f"Only letters, numbers, and underscores are allowed, and must start with a letter or underscore."
)
logger.debug(f"Validated {identifier_type}: {name}")
return name
@classmethod
def quote_identifier(cls, name: str, identifier_type: str = "identifier") -> str:
"""
Safely quote a SQL identifier using backticks.
Args:
name: The identifier to quote
identifier_type: Type description for error messages
Returns:
The quoted identifier (e.g., `table_name`)
Raises:
SQLSecurityError: If the identifier is invalid
"""
# First validate the identifier
validated_name = cls.validate_identifier(name, identifier_type)
# Escape any backticks within the name (double them)
escaped_name = validated_name.replace('`', '``')
return f"`{escaped_name}`"
@classmethod
def build_table_reference(
cls,
table_name: str,
db_name: Optional[str] = None,
catalog_name: Optional[str] = None,
quote: bool = True
) -> str:
"""
Build a safe, fully-qualified table reference.
Args:
table_name: The table name (required)
db_name: The database name (optional)
catalog_name: The catalog name (optional)
quote: Whether to quote identifiers with backticks (default: True)
Returns:
A safe table reference string (e.g., `catalog`.`db`.`table`)
Raises:
SQLSecurityError: If any identifier is invalid
"""
parts = []
if catalog_name:
if quote:
parts.append(cls.quote_identifier(catalog_name, "catalog name"))
else:
parts.append(cls.validate_identifier(catalog_name, "catalog name"))
if db_name:
if quote:
parts.append(cls.quote_identifier(db_name, "database name"))
else:
parts.append(cls.validate_identifier(db_name, "database name"))
if quote:
parts.append(cls.quote_identifier(table_name, "table name"))
else:
parts.append(cls.validate_identifier(table_name, "table name"))
return '.'.join(parts)
@classmethod
def build_column_reference(
cls,
column_name: str,
table_name: Optional[str] = None,
quote: bool = True
) -> str:
"""
Build a safe column reference.
Args:
column_name: The column name (required)
table_name: The table name (optional, for qualified references)
quote: Whether to quote identifiers with backticks (default: True)
Returns:
A safe column reference string (e.g., `table`.`column`)
Raises:
SQLSecurityError: If any identifier is invalid
"""
parts = []
if table_name:
if quote:
parts.append(cls.quote_identifier(table_name, "table name"))
else:
parts.append(cls.validate_identifier(table_name, "table name"))
if quote:
parts.append(cls.quote_identifier(column_name, "column name"))
else:
parts.append(cls.validate_identifier(column_name, "column name"))
return '.'.join(parts)
@classmethod
def validate_and_build_where_condition(
cls,
column_name: str,
operator: str = "=",
use_param: bool = True
) -> Tuple[str, bool]:
"""
Build a safe WHERE condition for a column.
Args:
column_name: The column name
operator: The comparison operator (=, !=, <, >, <=, >=, LIKE, IN)
use_param: Whether to use parameterized placeholder (%s)
Returns:
Tuple of (condition_string, needs_param)
e.g., ("`column` = %s", True) or ("`column` = DATABASE()", False)
Raises:
SQLSecurityError: If column name is invalid or operator is not allowed
"""
# Validate column name
quoted_column = cls.quote_identifier(column_name, "column name")
# Validate operator
allowed_operators = {'=', '!=', '<>', '<', '>', '<=', '>=', 'LIKE', 'IN', 'IS'}
if operator.upper() not in allowed_operators:
raise SQLSecurityError(f"Invalid operator: '{operator}'. Allowed: {allowed_operators}")
if use_param:
return f"{quoted_column} {operator} %s", True
else:
return f"{quoted_column} {operator}", False
@staticmethod
def get_auth_context():
"""
Get auth_context from the context variable.
This retrieves the auth_context that was set by the HTTP middleware
during request processing.
Returns:
The auth_context object, or None if not available
"""
try:
auth_context = auth_context_var.get()
if auth_context:
logger.debug(f"Retrieved auth_context from context variable")
return auth_context
except Exception as e:
logger.debug(f"Could not retrieve auth_context: {e}")
return None
@staticmethod
def set_auth_context(auth_context):
"""
Set auth_context in the context variable.
This is typically called by the HTTP middleware during request processing.
Args:
auth_context: The auth_context object to set
"""
auth_context_var.set(auth_context)
logger.debug("Set auth_context in context variable")
# Convenience functions for direct use
validate_identifier = SQLSecurityUtils.validate_identifier
quote_identifier = SQLSecurityUtils.quote_identifier
build_table_reference = SQLSecurityUtils.build_table_reference
build_column_reference = SQLSecurityUtils.build_column_reference
get_auth_context = SQLSecurityUtils.get_auth_context
set_auth_context = SQLSecurityUtils.set_auth_context

View File

@@ -20,7 +20,7 @@ build-backend = "hatchling.build"
[project] [project]
name = "doris-mcp-server" name = "doris-mcp-server"
version = "0.6.0" version = "0.6.1"
description = "Enterprise-grade Model Context Protocol (MCP) server implementation for Apache Doris" description = "Enterprise-grade Model Context Protocol (MCP) server implementation for Apache Doris"
authors = [ authors = [
{name = "Yijia Su", email = "freeoneplus@apache.org"} {name = "Yijia Su", email = "freeoneplus@apache.org"}

View File

@@ -64,9 +64,10 @@ else
fi fi
# Set HTTP-specific environment variables # Set HTTP-specific environment variables
# FIX for Issue #62 Bug 4: Use SERVER_PORT instead of MCP_PORT for consistency with code
export MCP_TRANSPORT_TYPE="http" export MCP_TRANSPORT_TYPE="http"
export MCP_HOST="${MCP_HOST:-0.0.0.0}" export MCP_HOST="${MCP_HOST:-0.0.0.0}"
export MCP_PORT="${MCP_PORT:-3000}" export SERVER_PORT="${SERVER_PORT:-3000}" # Changed from MCP_PORT to SERVER_PORT
export WORKERS="${WORKERS:-1}" export WORKERS="${WORKERS:-1}"
export ALLOWED_ORIGINS="${ALLOWED_ORIGINS:-*}" export ALLOWED_ORIGINS="${ALLOWED_ORIGINS:-*}"
export LOG_LEVEL="${LOG_LEVEL:-info}" export LOG_LEVEL="${LOG_LEVEL:-info}"
@@ -77,15 +78,15 @@ export MCP_DEBUG_ADAPTER="true"
export PYTHONPATH="$(pwd):$PYTHONPATH" export PYTHONPATH="$(pwd):$PYTHONPATH"
echo -e "${GREEN}Starting MCP server (Streamable HTTP mode)...${NC}" echo -e "${GREEN}Starting MCP server (Streamable HTTP mode)...${NC}"
echo -e "${YELLOW}Service will run on http://${MCP_HOST}:${MCP_PORT}/mcp${NC}" echo -e "${YELLOW}Service will run on http://${MCP_HOST}:${SERVER_PORT}/mcp${NC}"
echo -e "${YELLOW}Health Check: http://${MCP_HOST}:${MCP_PORT}/health${NC}" echo -e "${YELLOW}Health Check: http://${MCP_HOST}:${SERVER_PORT}/health${NC}"
echo -e "${YELLOW}MCP Endpoint: http://${MCP_HOST}:${MCP_PORT}/mcp${NC}" echo -e "${YELLOW}MCP Endpoint: http://${MCP_HOST}:${SERVER_PORT}/mcp${NC}"
echo -e "${YELLOW}Local access: http://localhost:${MCP_PORT}/mcp${NC}" echo -e "${YELLOW}Local access: http://localhost:${SERVER_PORT}/mcp${NC}"
echo -e "${YELLOW}Workers: ${WORKERS}${NC}" echo -e "${YELLOW}Workers: ${WORKERS}${NC}"
echo -e "${YELLOW}Use Ctrl+C to stop the service${NC}" echo -e "${YELLOW}Use Ctrl+C to stop the service${NC}"
# Start the server in HTTP mode (Streamable HTTP) # Start the server in HTTP mode (Streamable HTTP)
python -m doris_mcp_server.main --transport http --host ${MCP_HOST} --port ${MCP_PORT} --workers ${WORKERS} python -m doris_mcp_server.main --transport http --host ${MCP_HOST} --port ${SERVER_PORT} --workers ${WORKERS}
# Check exit status # Check exit status
if [ $? -ne 0 ]; then if [ $? -ne 0 ]; then
@@ -97,4 +98,4 @@ fi
echo -e "${YELLOW}Tip: If the page displays abnormally, please clear your browser cache or use incognito mode${NC}" echo -e "${YELLOW}Tip: If the page displays abnormally, please clear your browser cache or use incognito mode${NC}"
echo -e "${YELLOW}Chrome browser clear cache shortcut: Ctrl+Shift+Del (Windows) or Cmd+Shift+Del (Mac)${NC}" echo -e "${YELLOW}Chrome browser clear cache shortcut: Ctrl+Shift+Del (Windows) or Cmd+Shift+Del (Mac)${NC}"
echo -e "${CYAN}For testing HTTP endpoints, you can use:${NC}" echo -e "${CYAN}For testing HTTP endpoints, you can use:${NC}"
echo -e "${CYAN} curl -X POST http://localhost:${MCP_PORT}/mcp -H 'Content-Type: application/json' -d '{\"method\":\"tools/list\"}'${NC}" echo -e "${CYAN} curl -X POST http://localhost:${SERVER_PORT}/mcp -H 'Content-Type: application/json' -d '{\"method\":\"tools/list\"}'${NC}"

View File

@@ -0,0 +1,367 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
SQL Security Test Suite for Apache Doris MCP Server
Tests for:
1. SQL injection prevention via identifier validation
2. Multi-statement SQL parsing in security validator
3. auth_context enforcement
"""
import pytest
import asyncio
from unittest.mock import MagicMock, AsyncMock, patch
class TestSQLSecurityUtils:
"""Test cases for sql_security_utils module"""
def test_validate_identifier_accepts_valid_names(self):
"""Test that valid identifiers are accepted"""
from doris_mcp_server.utils.sql_security_utils import validate_identifier
valid_names = [
"users",
"my_table",
"Table123",
"_private_table",
"CamelCaseTable",
"table_with_numbers_123",
]
for name in valid_names:
result = validate_identifier(name, "table")
assert result == name, f"Valid identifier '{name}' should be accepted"
def test_validate_identifier_rejects_sql_injection(self):
"""Test that SQL injection attempts are rejected"""
from doris_mcp_server.utils.sql_security_utils import (
validate_identifier,
SQLSecurityError
)
injection_attempts = [
# Basic SQL injection
"'; DROP TABLE users; --",
"table' OR '1'='1",
"table'; DELETE FROM users; --",
# Union-based injection
"table' UNION SELECT * FROM passwords --",
# Comment injection
"table/**/OR/**/1=1",
"table--comment",
# Special characters
"table`; DROP TABLE users;",
'table"; DROP TABLE users;',
"table\"; DELETE FROM",
# Backtick escape attempt
"analytics`; SELECT * FROM sensitive_table;--",
# Whitespace injection
"table name with spaces",
"table\ttab",
"table\nnewline",
]
for injection in injection_attempts:
with pytest.raises(SQLSecurityError):
validate_identifier(injection, "table")
def test_validate_identifier_rejects_empty(self):
"""Test that empty identifiers are rejected"""
from doris_mcp_server.utils.sql_security_utils import (
validate_identifier,
SQLSecurityError
)
with pytest.raises(SQLSecurityError):
validate_identifier("", "table")
with pytest.raises(SQLSecurityError):
validate_identifier(None, "table")
def test_validate_identifier_rejects_too_long(self):
"""Test that identifiers exceeding max length are rejected"""
from doris_mcp_server.utils.sql_security_utils import (
validate_identifier,
SQLSecurityError
)
# Doris identifier max length is typically 64 characters
long_name = "a" * 100
with pytest.raises(SQLSecurityError):
validate_identifier(long_name, "table")
def test_quote_identifier_adds_backticks(self):
"""Test that quote_identifier properly escapes identifiers"""
from doris_mcp_server.utils.sql_security_utils import quote_identifier
assert quote_identifier("my_table", "table") == "`my_table`"
assert quote_identifier("users", "table") == "`users`"
assert quote_identifier("Table123", "table") == "`Table123`"
def test_quote_identifier_validates_first(self):
"""Test that quote_identifier validates before quoting"""
from doris_mcp_server.utils.sql_security_utils import (
quote_identifier,
SQLSecurityError
)
with pytest.raises(SQLSecurityError):
quote_identifier("'; DROP TABLE users; --", "table")
class TestSQLSecurityValidator:
"""Test cases for SQLSecurityValidator multi-statement parsing"""
@pytest.fixture
def dict_config(self):
"""Create dictionary configuration"""
return {
"blocked_keywords": [
"DROP", "CREATE", "ALTER", "TRUNCATE",
"DELETE", "INSERT", "UPDATE",
"GRANT", "REVOKE", "EXEC", "EXECUTE"
],
"max_query_complexity": 100,
"enable_security_check": True
}
@pytest.fixture
def mock_auth_context(self):
"""Create mock auth context"""
from doris_mcp_server.utils.security import AuthContext, SecurityLevel
return AuthContext(
user_id="test_user",
roles=["user"],
security_level=SecurityLevel.INTERNAL
)
@pytest.mark.asyncio
async def test_validates_all_statements(self, dict_config, mock_auth_context):
"""Test that validator checks ALL SQL statements, not just the first"""
from doris_mcp_server.utils.security import SQLSecurityValidator
validator = SQLSecurityValidator(dict_config)
# Multi-statement with injection in second statement
# This should be BLOCKED
malicious_sql = "SELECT 1; DROP TABLE users; SELECT 2"
result = await validator.validate(malicious_sql, mock_auth_context)
assert not result.is_valid, "Multi-statement injection should be blocked"
# Check for either DROP keyword detection or SQL injection detection
error_upper = result.error_message.upper()
assert ("DROP" in error_upper or
"INJECTION" in error_upper or
"BLOCKED" in error_upper), f"Expected DROP/injection/blocked in: {result.error_message}"
@pytest.mark.asyncio
async def test_blocks_hidden_dangerous_statement(self, dict_config, mock_auth_context):
"""Test that dangerous statements hidden after safe ones are blocked"""
from doris_mcp_server.utils.security import SQLSecurityValidator
validator = SQLSecurityValidator(dict_config)
# Safe statement followed by dangerous one
malicious_sql = """
SELECT * FROM users WHERE id = 1;
DELETE FROM audit_log;
SELECT 1;
"""
result = await validator.validate(malicious_sql, mock_auth_context)
assert not result.is_valid, "Hidden DELETE statement should be blocked"
@pytest.mark.asyncio
async def test_allows_safe_multi_statement(self, dict_config, mock_auth_context):
"""Test that multiple safe SELECT statements are allowed"""
from doris_mcp_server.utils.security import SQLSecurityValidator
validator = SQLSecurityValidator(dict_config)
safe_sql = """
SELECT * FROM users;
SELECT COUNT(*) FROM orders;
SELECT id, name FROM products;
"""
result = await validator.validate(safe_sql, mock_auth_context)
assert result.is_valid, f"Multiple safe SELECT statements should be allowed, got: {result.error_message}"
@pytest.mark.asyncio
async def test_context_switch_injection_blocked(self, dict_config, mock_auth_context):
"""Test that context switch SQL injection is blocked"""
from doris_mcp_server.utils.security import SQLSecurityValidator
validator = SQLSecurityValidator(dict_config)
# Simulating the exec_query_for_mcp attack vector
injected_sql = """
USE `analytics`; SELECT * FROM sensitive_table;-- `;
SELECT * FROM public_table;
"""
result = await validator.validate(injected_sql, mock_auth_context)
# The validator should process all statements
# Even if USE is allowed, subsequent unauthorized access should be caught
# by table access checks (if configured)
class TestExecQueryForMCP:
"""Test cases for exec_query_for_mcp function"""
@pytest.mark.asyncio
async def test_rejects_malicious_db_name(self):
"""Test that malicious db_name is rejected"""
from doris_mcp_server.utils.sql_security_utils import (
validate_identifier,
SQLSecurityError
)
# The attack vector from security report
malicious_db_name = "analytics`; SELECT * FROM sensitive_table;--"
with pytest.raises(SQLSecurityError):
validate_identifier(malicious_db_name, "database name")
@pytest.mark.asyncio
async def test_rejects_malicious_catalog_name(self):
"""Test that malicious catalog_name is rejected"""
from doris_mcp_server.utils.sql_security_utils import (
validate_identifier,
SQLSecurityError
)
malicious_catalog_name = "internal'; DROP DATABASE production;--"
with pytest.raises(SQLSecurityError):
validate_identifier(malicious_catalog_name, "catalog name")
class TestDependencyAnalysisTools:
"""Test cases for dependency_analysis_tools security fixes"""
@pytest.mark.asyncio
async def test_get_tables_metadata_rejects_injection(self):
"""Test that _get_tables_metadata rejects SQL injection"""
from doris_mcp_server.utils.sql_security_utils import (
validate_identifier,
SQLSecurityError
)
# The attack vector from security report
injection_db_name = "test_db' OR '1'='1' --"
with pytest.raises(SQLSecurityError):
validate_identifier(injection_db_name, "database name")
class TestAuthContextEnforcement:
"""Test cases for auth_context enforcement"""
def test_execute_requires_auth_context_for_security(self):
"""Test that security checks require auth_context"""
# This test documents the expected behavior:
# When auth_context is None, security checks are skipped
# When auth_context is provided, security checks are performed
# The fix ensures all execute() calls pass auth_context
pass
@pytest.mark.asyncio
async def test_get_auth_context_returns_context(self):
"""Test that get_auth_context retrieves context from ContextVar"""
from doris_mcp_server.utils.sql_security_utils import get_auth_context
# When no context is set, should return None
result = get_auth_context()
# This is expected - context is set by HTTP middleware
assert result is None or hasattr(result, 'user_id')
class TestIntegrationScenarios:
"""Integration test scenarios for security fixes"""
def test_attack_scenario_1_permission_bypass(self):
"""
Attack Scenario 1: Permission Bypass in Multi-Tenant Environment
Expected: User can only query their own database (db_name="tenant_a_db")
Attack: Inject "tenant_a_db' OR '1'='1' --" to query ALL databases
Result: Should be BLOCKED by validate_identifier()
"""
from doris_mcp_server.utils.sql_security_utils import (
validate_identifier,
SQLSecurityError
)
with pytest.raises(SQLSecurityError):
validate_identifier("tenant_a_db' OR '1'='1' --", "database name")
def test_attack_scenario_2_union_injection(self):
"""
Attack Scenario 2: UNION-based Information Disclosure
Attack: Inject UNION SELECT to extract sensitive data
Result: Should be BLOCKED
"""
from doris_mcp_server.utils.sql_security_utils import (
validate_identifier,
SQLSecurityError
)
with pytest.raises(SQLSecurityError):
validate_identifier(
"test' UNION SELECT password FROM users --",
"database name"
)
def test_attack_scenario_3_backtick_escape(self):
"""
Attack Scenario 3: Backtick Escape Attempt
Attack: Use backticks to break out of quoted identifier
Result: Should be BLOCKED
"""
from doris_mcp_server.utils.sql_security_utils import (
validate_identifier,
SQLSecurityError
)
with pytest.raises(SQLSecurityError):
validate_identifier(
"analytics`; SELECT * FROM sensitive_table;--",
"database name"
)
# Run tests with: pytest tests/test_sql_security.py -v
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])

View File

@@ -0,0 +1,871 @@
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
SQL Injection API Integration Tests
This module tests SQL injection prevention through the MCP HTTP API.
It sends malicious payloads and verifies they are properly blocked.
Prerequisites:
- MCP server running on localhost:3000
- Run with: pytest test/security/test_sql_injection_api.py -v
Usage:
# Start server first
bash start_server.sh
# Run tests
pytest test/security/test_sql_injection_api.py -v --no-cov
"""
import pytest
import httpx
import json
import asyncio
from typing import Optional
# Server configuration
MCP_BASE_URL = "http://localhost:3000"
MCP_ENDPOINT = f"{MCP_BASE_URL}/mcp"
HEALTH_ENDPOINT = f"{MCP_BASE_URL}/health"
TIMEOUT = 30.0
class MCPClient:
"""Simple MCP HTTP client for testing"""
def __init__(self, base_url: str = MCP_BASE_URL):
self.base_url = base_url
self.mcp_endpoint = f"{base_url}/mcp"
self.session_id: Optional[str] = None
self.request_id = 0
self.client = httpx.AsyncClient(timeout=TIMEOUT)
async def close(self):
await self.client.aclose()
def _next_id(self) -> int:
self.request_id += 1
return self.request_id
async def initialize(self) -> dict:
"""Initialize MCP session"""
response = await self.client.post(
self.mcp_endpoint,
headers={
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream"
},
json={
"jsonrpc": "2.0",
"method": "initialize",
"params": {
"protocolVersion": "2024-11-05",
"capabilities": {},
"clientInfo": {
"name": "sql-injection-test",
"version": "1.0.0"
}
},
"id": self._next_id()
}
)
# Extract session ID from response header
self.session_id = response.headers.get("mcp-session-id")
return self._parse_response(response.text)
async def call_tool(self, tool_name: str, arguments: dict) -> dict:
"""Call an MCP tool"""
if not self.session_id:
await self.initialize()
response = await self.client.post(
self.mcp_endpoint,
headers={
"Content-Type": "application/json",
"Accept": "application/json, text/event-stream",
"mcp-session-id": self.session_id
},
json={
"jsonrpc": "2.0",
"method": "tools/call",
"params": {
"name": tool_name,
"arguments": arguments
},
"id": self._next_id()
}
)
return self._parse_response(response.text)
def _parse_response(self, text: str) -> dict:
"""Parse JSON response"""
try:
return json.loads(text)
except json.JSONDecodeError:
# Try SSE format
lines = text.strip().split("\n")
for line in lines:
if line.startswith("data: "):
try:
return json.loads(line[6:])
except json.JSONDecodeError:
continue
return {"raw": text}
def print_result(test_name: str, payload: dict, result: dict):
"""Print test result in a readable format"""
print(f"\n{'='*60}")
print(f"TEST: {test_name}")
print(f"{'='*60}")
print(f"PAYLOAD: {json.dumps(payload, ensure_ascii=False)}")
print(f"{'-'*60}")
# Extract inner result content
if "result" in result and "content" in result.get("result", {}):
for item in result["result"]["content"]:
if item.get("type") == "text":
try:
inner = json.loads(item["text"])
print("RESPONSE:")
print(f" success: {inner.get('success')}")
if inner.get('error'):
print(f" error: {inner.get('error')}")
if inner.get('error_type'):
print(f" error_type: {inner.get('error_type')}")
if inner.get('risk_level'):
print(f" risk_level: {inner.get('risk_level')}")
if inner.get('message'):
print(f" message: {inner.get('message')}")
if inner.get('data') is not None and inner.get('success'):
data_str = json.dumps(inner.get('data'), ensure_ascii=False)
if len(data_str) > 200:
data_str = data_str[:200] + "..."
print(f" data: {data_str}")
except (json.JSONDecodeError, TypeError):
print(f"RESPONSE (raw): {item.get('text', '')[:500]}")
elif "error" in result:
print(f"RESPONSE ERROR: {result['error']}")
else:
print(f"RESPONSE (raw): {json.dumps(result, ensure_ascii=False)[:500]}")
print(f"{'='*60}\n")
class TestSQLInjectionAPI:
"""Test SQL injection prevention through MCP API"""
@pytest.fixture
async def mcp_client(self):
"""Create MCP client instance"""
client = MCPClient()
yield client
await client.close()
@pytest.fixture
def is_server_running(self):
"""Check if MCP server is running"""
import httpx
try:
response = httpx.get(HEALTH_ENDPOINT, timeout=5.0)
return response.status_code == 200
except Exception:
return False
@pytest.mark.asyncio
async def test_server_health(self):
"""Test that MCP server is running and healthy"""
async with httpx.AsyncClient() as client:
response = await client.get(HEALTH_ENDPOINT)
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"
@pytest.mark.asyncio
async def test_exec_query_with_drop_injection(self, mcp_client):
"""Test exec_query rejects DROP TABLE injection"""
# Classic SQL injection: append DROP TABLE
payload = {"sql": "SELECT * FROM users; DROP TABLE users; --"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("DROP TABLE Injection", payload, result)
# Should return error, not execute the DROP
assert self._is_blocked_or_error(result), \
f"DROP TABLE injection should be blocked"
@pytest.mark.asyncio
async def test_exec_query_with_union_injection(self, mcp_client):
"""Test exec_query blocks UNION-based injection attempts"""
# UNION injection to extract data from other tables
payload = {"sql": "SELECT id FROM users UNION SELECT password FROM admin_users"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("UNION Injection", payload, result)
@pytest.mark.asyncio
async def test_exec_query_with_delete_injection(self, mcp_client):
"""Test exec_query rejects DELETE injection"""
payload = {"sql": "SELECT 1; DELETE FROM users WHERE 1=1; SELECT 2"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("DELETE Injection", payload, result)
assert self._is_blocked_or_error(result), \
f"DELETE injection should be blocked"
@pytest.mark.asyncio
async def test_exec_query_with_update_injection(self, mcp_client):
"""Test exec_query rejects UPDATE injection"""
payload = {"sql": "SELECT 1; UPDATE users SET role='admin' WHERE id=1; --"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("UPDATE Injection", payload, result)
assert self._is_blocked_or_error(result), \
f"UPDATE injection should be blocked"
@pytest.mark.asyncio
async def test_exec_query_db_name_injection(self, mcp_client):
"""Test exec_query rejects SQL injection via db_name parameter"""
# Attack vector: inject SQL via db_name parameter
payload = {"sql": "SELECT 1", "db_name": "test'; DROP TABLE users; --"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("db_name Parameter Injection", payload, result)
assert self._is_blocked_or_error(result), \
f"db_name injection should be blocked"
@pytest.mark.asyncio
async def test_exec_query_catalog_name_injection(self, mcp_client):
"""Test exec_query rejects SQL injection via catalog_name parameter"""
# Attack vector: inject SQL via catalog_name parameter
payload = {"sql": "SELECT 1", "catalog_name": "internal`; SELECT * FROM mysql.user; --"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("catalog_name Parameter Injection", payload, result)
assert self._is_blocked_or_error(result), \
f"catalog_name injection should be blocked"
@pytest.mark.asyncio
async def test_get_table_schema_injection(self, mcp_client):
"""Test get_table_schema rejects SQL injection via table_name"""
# Attack vector: inject SQL via table_name parameter
payload = {"table_name": "users'; DROP TABLE users; --"}
result = await mcp_client.call_tool("get_table_schema", payload)
print_result("table_name Injection (get_table_schema)", payload, result)
assert self._is_blocked_or_error(result), \
f"table_name injection should be blocked"
@pytest.mark.asyncio
async def test_get_table_schema_db_injection(self, mcp_client):
"""Test get_table_schema rejects SQL injection via db_name"""
payload = {"table_name": "users", "db_name": "test' OR '1'='1"}
result = await mcp_client.call_tool("get_table_schema", payload)
print_result("db_name Injection (get_table_schema)", payload, result)
assert self._is_blocked_or_error(result), \
f"db_name injection in get_table_schema should be blocked"
@pytest.mark.asyncio
async def test_analyze_dependencies_injection(self, mcp_client):
"""Test analyze_dependencies rejects SQL injection"""
# This was the original vulnerability reported
payload = {"table_name": "users", "db_name": "test_db' OR '1'='1' --"}
result = await mcp_client.call_tool("analyze_dependencies", payload)
print_result("analyze_dependencies Injection (Original Report)", payload, result)
assert self._is_blocked_or_error(result), \
f"analyze_dependencies db_name injection should be blocked"
@pytest.mark.asyncio
async def test_stacked_queries_injection(self, mcp_client):
"""Test that stacked queries (multiple statements) are blocked"""
# Multiple statements injection
payload = {"sql": "SELECT * FROM users WHERE id = 1; INSERT INTO audit_log VALUES (NULL, 'hacked', NOW()); SELECT 1;"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("Stacked Queries (INSERT) Injection", payload, result)
assert self._is_blocked_or_error(result), \
f"Stacked queries with INSERT should be blocked"
@pytest.mark.asyncio
async def test_comment_based_injection(self, mcp_client):
"""Test that comment-based injection is blocked"""
# Using comments to bypass filters
payload = {"sql": "SELECT * FROM users WHERE id = 1/**/OR/**/1=1"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("Comment-based Injection", payload, result)
@pytest.mark.asyncio
async def test_hex_encoded_injection(self, mcp_client):
"""Test that hex-encoded injection attempts are handled"""
# Hex-encoded 'DROP' attempt
payload = {"sql": "SELECT 0x44524F50205441424C4520757365727320"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("Hex Encoded Injection", payload, result)
@pytest.mark.asyncio
async def test_backtick_escape_injection(self, mcp_client):
"""Test backtick escape injection is blocked"""
# Attempt to escape backtick quoting
payload = {"sql": "SELECT 1", "db_name": "analytics`; SELECT * FROM sensitive_table;--"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("Backtick Escape Injection", payload, result)
assert self._is_blocked_or_error(result), \
f"Backtick escape injection should be blocked"
@pytest.mark.asyncio
async def test_valid_query_succeeds(self, mcp_client):
"""Test that valid queries still work"""
# Simple valid query should work
payload = {"sql": "SELECT 1 AS test_value"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("Valid Query (should succeed)", payload, result)
@pytest.mark.asyncio
async def test_valid_show_databases(self, mcp_client):
"""Test that SHOW DATABASES works"""
payload = {"sql": "SHOW DATABASES"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("SHOW DATABASES (should succeed)", payload, result)
def _is_blocked_or_error(self, result: dict) -> bool:
"""Check if result indicates blocked or error"""
if not result:
return True
# Check for JSON-RPC error
if "error" in result:
return True
# Check for error in result content
if "result" in result:
result_content = result.get("result", {})
if isinstance(result_content, dict):
# Check for isError flag
if result_content.get("isError"):
return True
# Check content array for error messages
content = result_content.get("content", [])
for item in content:
if isinstance(item, dict):
text = item.get("text", "")
# Parse the JSON text content
try:
text_data = json.loads(text)
# Check for success: false
if text_data.get("success") is False:
return True
# Check for error field
if text_data.get("error"):
return True
except (json.JSONDecodeError, TypeError):
pass
# Check text for security keywords
if any(keyword in text.lower() for keyword in [
"error", "blocked", "invalid", "security",
"injection", "denied", "forbidden", "not allowed",
"security_violation", "risk_level"
]):
return True
# Check raw text response
raw = result.get("raw", "")
if isinstance(raw, str) and any(keyword in raw.lower() for keyword in [
"error", "blocked", "invalid", "security"
]):
return True
return False
class TestIdentifierInjectionAPI:
"""Test identifier-based SQL injection prevention"""
@pytest.fixture
async def mcp_client(self):
"""Create MCP client instance"""
client = MCPClient()
yield client
await client.close()
@pytest.mark.asyncio
async def test_table_name_with_semicolon(self, mcp_client):
"""Test table name containing semicolon is rejected"""
payload = {"table_name": "users; DROP TABLE users"}
result = await mcp_client.call_tool("get_table_schema", payload)
print_result("Table Name with Semicolon", payload, result)
# Should be blocked by identifier validation
assert self._contains_error_indicator(result), \
f"Table name with semicolon should be rejected"
@pytest.mark.asyncio
async def test_table_name_with_quotes(self, mcp_client):
"""Test table name containing quotes is rejected"""
payload = {"table_name": "users' OR '1'='1"}
result = await mcp_client.call_tool("get_table_schema", payload)
print_result("Table Name with Quotes", payload, result)
assert self._contains_error_indicator(result), \
f"Table name with quotes should be rejected"
@pytest.mark.asyncio
async def test_db_name_with_special_chars(self, mcp_client):
"""Test database name with special characters is rejected"""
special_chars = [
"test;db",
"test'db",
"test\"db",
"test`db",
"test--db",
"test/*db*/",
]
for db_name in special_chars:
payload = {"table_name": "users", "db_name": db_name}
result = await mcp_client.call_tool("get_table_schema", payload)
print_result(f"Special Char in db_name: {db_name}", payload, result)
assert self._contains_error_indicator(result), \
f"db_name '{db_name}' should be rejected"
@pytest.mark.asyncio
async def test_valid_identifiers_accepted(self, mcp_client):
"""Test that valid identifiers are accepted"""
valid_names = [
"users",
"my_table",
"Table123",
"_internal_table",
]
for table_name in valid_names:
payload = {"table_name": table_name}
result = await mcp_client.call_tool("get_table_schema", payload)
print_result(f"Valid Identifier: {table_name}", payload, result)
def _contains_error_indicator(self, result: dict) -> bool:
"""Check if result contains error indicators"""
if not result:
return True
# Check for JSON-RPC error
if "error" in result:
return True
# Check result content
result_str = json.dumps(result).lower()
error_keywords = [
"error", "invalid", "illegal", "blocked",
"security", "injection", "denied", "forbidden"
]
return any(keyword in result_str for keyword in error_keywords)
class TestMultiStatementInjectionAPI:
"""Test multi-statement SQL injection prevention"""
@pytest.fixture
async def mcp_client(self):
"""Create MCP client instance"""
client = MCPClient()
yield client
await client.close()
@pytest.mark.asyncio
async def test_hidden_drop_after_select(self, mcp_client):
"""Test DROP hidden after legitimate SELECT is blocked"""
payload = {"sql": "SELECT id, name FROM users WHERE status = 'active'; DROP TABLE audit_log; SELECT 1;"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("Hidden DROP after SELECT", payload, result)
assert self._is_dangerous_blocked(result), \
f"Hidden DROP statement should be blocked"
@pytest.mark.asyncio
async def test_hidden_truncate_after_select(self, mcp_client):
"""Test TRUNCATE hidden after SELECT is blocked"""
payload = {"sql": "SELECT 1; TRUNCATE TABLE users"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("Hidden TRUNCATE after SELECT", payload, result)
assert self._is_dangerous_blocked(result), \
f"Hidden TRUNCATE should be blocked"
@pytest.mark.asyncio
async def test_hidden_grant_after_select(self, mcp_client):
"""Test GRANT hidden after SELECT is blocked"""
payload = {"sql": "SELECT 1; GRANT ALL ON *.* TO 'hacker'@'%'"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("Hidden GRANT after SELECT", payload, result)
assert self._is_dangerous_blocked(result), \
f"Hidden GRANT should be blocked"
@pytest.mark.asyncio
async def test_multiple_safe_selects_allowed(self, mcp_client):
"""Test that multiple SELECT statements may be allowed"""
payload = {"sql": "SELECT 1; SELECT 2; SELECT 3;"}
result = await mcp_client.call_tool("exec_query", payload)
print_result("Multiple Safe SELECTs", payload, result)
def _is_dangerous_blocked(self, result: dict) -> bool:
"""Check if dangerous operation was blocked"""
if not result:
return True
# Check for error
if "error" in result:
return True
# Check result content for blocking indicators
result_str = json.dumps(result).lower()
block_indicators = [
"drop", "truncate", "grant", "revoke",
"blocked", "denied", "forbidden", "not allowed",
"security", "error"
]
return any(indicator in result_str for indicator in block_indicators)
class TestADBCQueryInjectionAPI:
"""Test ADBC query SQL injection prevention"""
@pytest.fixture
async def mcp_client(self):
"""Create MCP client instance"""
client = MCPClient()
yield client
await client.close()
@pytest.mark.asyncio
async def test_exec_adbc_query_drop_injection(self, mcp_client):
"""Test exec_adbc_query rejects DROP TABLE injection"""
payload = {"sql": "SELECT * FROM users; DROP TABLE users; --"}
result = await mcp_client.call_tool("exec_adbc_query", payload)
print_result("ADBC DROP TABLE Injection", payload, result)
assert self._is_blocked_or_error(result), \
"ADBC DROP TABLE injection should be blocked"
@pytest.mark.asyncio
async def test_exec_adbc_query_delete_injection(self, mcp_client):
"""Test exec_adbc_query rejects DELETE injection"""
payload = {"sql": "SELECT 1; DELETE FROM users; --"}
result = await mcp_client.call_tool("exec_adbc_query", payload)
print_result("ADBC DELETE Injection", payload, result)
assert self._is_blocked_or_error(result), \
"ADBC DELETE injection should be blocked"
@pytest.mark.asyncio
async def test_exec_adbc_query_valid(self, mcp_client):
"""Test exec_adbc_query allows valid queries"""
payload = {"sql": "SELECT 1 AS test"}
result = await mcp_client.call_tool("exec_adbc_query", payload)
print_result("ADBC Valid Query", payload, result)
def _is_blocked_or_error(self, result: dict) -> bool:
"""Check if result indicates blocked or error"""
if not result:
return True
if "error" in result:
return True
result_str = json.dumps(result).lower()
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
class TestMetadataToolsInjectionAPI:
"""Test metadata tools SQL injection prevention"""
@pytest.fixture
async def mcp_client(self):
"""Create MCP client instance"""
client = MCPClient()
yield client
await client.close()
@pytest.mark.asyncio
async def test_get_db_table_list_db_injection(self, mcp_client):
"""Test get_db_table_list rejects db_name injection"""
payload = {"db_name": "test'; DROP TABLE users; --"}
result = await mcp_client.call_tool("get_db_table_list", payload)
print_result("get_db_table_list db_name Injection", payload, result)
assert self._is_blocked_or_error(result), \
"db_name injection should be blocked"
@pytest.mark.asyncio
async def test_get_db_table_list_catalog_injection(self, mcp_client):
"""Test get_db_table_list rejects catalog_name injection"""
payload = {"catalog_name": "internal`; SELECT * FROM mysql.user; --"}
result = await mcp_client.call_tool("get_db_table_list", payload)
print_result("get_db_table_list catalog_name Injection", payload, result)
assert self._is_blocked_or_error(result), \
"catalog_name injection should be blocked"
@pytest.mark.asyncio
async def test_get_table_comment_injection(self, mcp_client):
"""Test get_table_comment rejects table_name injection"""
payload = {"table_name": "users'; DROP TABLE users; --"}
result = await mcp_client.call_tool("get_table_comment", payload)
print_result("get_table_comment table_name Injection", payload, result)
assert self._is_blocked_or_error(result), \
"table_name injection should be blocked"
@pytest.mark.asyncio
async def test_get_table_column_comments_injection(self, mcp_client):
"""Test get_table_column_comments rejects injection"""
payload = {"table_name": "users'; DROP TABLE users; --", "db_name": "test"}
result = await mcp_client.call_tool("get_table_column_comments", payload)
print_result("get_table_column_comments Injection", payload, result)
assert self._is_blocked_or_error(result), \
"table_name injection should be blocked"
@pytest.mark.asyncio
async def test_get_table_indexes_injection(self, mcp_client):
"""Test get_table_indexes rejects table_name injection"""
payload = {"table_name": "users; DROP TABLE users", "db_name": "test"}
result = await mcp_client.call_tool("get_table_indexes", payload)
print_result("get_table_indexes Injection", payload, result)
assert self._is_blocked_or_error(result), \
"table_name injection should be blocked"
def _is_blocked_or_error(self, result: dict) -> bool:
"""Check if result indicates blocked or error"""
if not result:
return True
if "error" in result:
return True
result_str = json.dumps(result).lower()
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
class TestAnalyticsToolsInjectionAPI:
"""Test analytics tools SQL injection prevention"""
@pytest.fixture
async def mcp_client(self):
"""Create MCP client instance"""
client = MCPClient()
yield client
await client.close()
@pytest.mark.asyncio
async def test_analyze_columns_table_injection(self, mcp_client):
"""Test analyze_columns rejects table_name injection"""
payload = {"table_name": "users'; DROP TABLE users; --"}
result = await mcp_client.call_tool("analyze_columns", payload)
print_result("analyze_columns table_name Injection", payload, result)
assert self._is_blocked_or_error(result), \
"table_name injection should be blocked"
@pytest.mark.asyncio
async def test_analyze_columns_db_injection(self, mcp_client):
"""Test analyze_columns rejects db_name injection"""
payload = {"table_name": "users", "db_name": "test' OR '1'='1"}
result = await mcp_client.call_tool("analyze_columns", payload)
print_result("analyze_columns db_name Injection", payload, result)
assert self._is_blocked_or_error(result), \
"db_name injection should be blocked"
@pytest.mark.asyncio
async def test_get_table_basic_info_injection(self, mcp_client):
"""Test get_table_basic_info rejects injection"""
payload = {"table_name": "users; DROP TABLE audit_log"}
result = await mcp_client.call_tool("get_table_basic_info", payload)
print_result("get_table_basic_info Injection", payload, result)
assert self._is_blocked_or_error(result), \
"table_name injection should be blocked"
@pytest.mark.asyncio
async def test_analyze_table_storage_injection(self, mcp_client):
"""Test analyze_table_storage rejects injection"""
payload = {"table_name": "users`; SELECT * FROM sensitive; --"}
result = await mcp_client.call_tool("analyze_table_storage", payload)
print_result("analyze_table_storage Injection", payload, result)
assert self._is_blocked_or_error(result), \
"table_name injection should be blocked"
@pytest.mark.asyncio
async def test_get_sql_explain_injection(self, mcp_client):
"""Test get_sql_explain rejects SQL injection"""
payload = {"sql": "SELECT 1; DROP TABLE users; --"}
result = await mcp_client.call_tool("get_sql_explain", payload)
print_result("get_sql_explain SQL Injection", payload, result)
assert self._is_blocked_or_error(result), \
"SQL injection should be blocked"
@pytest.mark.asyncio
async def test_get_sql_profile_injection(self, mcp_client):
"""Test get_sql_profile rejects SQL injection"""
payload = {"sql": "SELECT 1; DELETE FROM audit_log; --"}
result = await mcp_client.call_tool("get_sql_profile", payload)
print_result("get_sql_profile SQL Injection", payload, result)
assert self._is_blocked_or_error(result), \
"SQL injection should be blocked"
def _is_blocked_or_error(self, result: dict) -> bool:
"""Check if result indicates blocked or error"""
if not result:
return True
if "error" in result:
return True
result_str = json.dumps(result).lower()
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
class TestGovernanceToolsInjectionAPI:
"""Test data governance tools SQL injection prevention"""
@pytest.fixture
async def mcp_client(self):
"""Create MCP client instance"""
client = MCPClient()
yield client
await client.close()
@pytest.mark.asyncio
async def test_trace_column_lineage_table_injection(self, mcp_client):
"""Test trace_column_lineage rejects table_name injection"""
payload = {"table_name": "users'; DROP TABLE users; --", "column_name": "id"}
result = await mcp_client.call_tool("trace_column_lineage", payload)
print_result("trace_column_lineage table_name Injection", payload, result)
assert self._is_blocked_or_error(result), \
"table_name injection should be blocked"
@pytest.mark.asyncio
async def test_trace_column_lineage_column_injection(self, mcp_client):
"""Test trace_column_lineage rejects column_name injection"""
payload = {"table_name": "users", "column_name": "id; DROP TABLE users"}
result = await mcp_client.call_tool("trace_column_lineage", payload)
print_result("trace_column_lineage column_name Injection", payload, result)
assert self._is_blocked_or_error(result), \
"column_name injection should be blocked"
@pytest.mark.asyncio
async def test_monitor_data_freshness_injection(self, mcp_client):
"""Test monitor_data_freshness rejects table_name injection"""
payload = {"table_name": "users`; SELECT * FROM passwords; --"}
result = await mcp_client.call_tool("monitor_data_freshness", payload)
print_result("monitor_data_freshness Injection", payload, result)
assert self._is_blocked_or_error(result), \
"table_name injection should be blocked"
@pytest.mark.asyncio
async def test_analyze_data_access_patterns_injection(self, mcp_client):
"""Test analyze_data_access_patterns rejects injection"""
payload = {"table_name": "users' UNION SELECT password FROM admin --"}
result = await mcp_client.call_tool("analyze_data_access_patterns", payload)
print_result("analyze_data_access_patterns Injection", payload, result)
assert self._is_blocked_or_error(result), \
"table_name injection should be blocked"
def _is_blocked_or_error(self, result: dict) -> bool:
"""Check if result indicates blocked or error"""
if not result:
return True
if "error" in result:
return True
result_str = json.dumps(result).lower()
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
class TestPerformanceToolsInjectionAPI:
"""Test performance analytics tools SQL injection prevention"""
@pytest.fixture
async def mcp_client(self):
"""Create MCP client instance"""
client = MCPClient()
yield client
await client.close()
@pytest.mark.asyncio
async def test_analyze_slow_queries_db_injection(self, mcp_client):
"""Test analyze_slow_queries_topn rejects db_name injection"""
payload = {"db_name": "test'; DROP TABLE audit_log; --"}
result = await mcp_client.call_tool("analyze_slow_queries_topn", payload)
print_result("analyze_slow_queries_topn db_name Injection", payload, result)
assert self._is_blocked_or_error(result), \
"db_name injection should be blocked"
@pytest.mark.asyncio
async def test_analyze_resource_growth_db_injection(self, mcp_client):
"""Test analyze_resource_growth_curves rejects db_name injection"""
payload = {"db_name": "test`; GRANT ALL ON *.* TO 'hacker'; --"}
result = await mcp_client.call_tool("analyze_resource_growth_curves", payload)
print_result("analyze_resource_growth_curves db_name Injection", payload, result)
assert self._is_blocked_or_error(result), \
"db_name injection should be blocked"
@pytest.mark.asyncio
async def test_get_table_data_size_injection(self, mcp_client):
"""Test get_table_data_size rejects table_name injection"""
payload = {"table_name": "users; TRUNCATE TABLE logs"}
result = await mcp_client.call_tool("get_table_data_size", payload)
print_result("get_table_data_size Injection", payload, result)
assert self._is_blocked_or_error(result), \
"table_name injection should be blocked"
def _is_blocked_or_error(self, result: dict) -> bool:
"""Check if result indicates blocked or error"""
if not result:
return True
if "error" in result:
return True
result_str = json.dumps(result).lower()
return any(kw in result_str for kw in ["error", "blocked", "invalid", "security", "injection"])
# Pytest configuration for async tests
@pytest.fixture(scope="session")
def event_loop():
"""Create event loop for async tests"""
loop = asyncio.new_event_loop()
yield loop
loop.close()
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short", "-x"])

View File

@@ -201,3 +201,73 @@ class TestDorisQueryExecutor:
if result["success"]: if result["success"]:
assert "data" in result assert "data" in result
assert "row_count" in result assert "row_count" in result
@pytest.mark.asyncio
async def test_execute_multi_sql_statements(self, query_executor):
"""Test execution of multiple SQL statements"""
from doris_mcp_server.utils.query_executor import QueryResult
# Disable security check for this test
query_executor.connection_manager.config.security.enable_security_check = False
with patch.object(query_executor, 'execute_query') as mock_execute:
# Mock results for three SQL statements
mock_execute.side_effect = [
QueryResult(
data=[{"id": 1, "name": "张三"}],
row_count=1,
execution_time=0.1,
sql="SELECT id, name FROM users WHERE id = 1",
metadata={"columns": ["id", "name"]}
),
QueryResult(
data=[{"id": 2, "name": "李四"}],
row_count=1,
execution_time=0.12,
sql="SELECT id, name FROM users WHERE id = 2",
metadata={"columns": ["id", "name"]}
),
QueryResult(
data=[{"count": 100}],
row_count=1,
execution_time=0.08,
sql="SELECT COUNT(*) as count FROM users",
metadata={"columns": ["count"]}
)
]
# Execute multiple SQL statements separated by semicolons
multi_sql = """
SELECT id, name FROM users WHERE id = 1;
SELECT id, name FROM users WHERE id = 2;
SELECT COUNT(*) as count FROM users;
"""
result = await query_executor.execute_sql_for_mcp(multi_sql)
# Verify the result structure for multiple statements
assert result["success"] is True
assert result["multiple_results"] is True
assert "results" in result
assert len(result["results"]) == 3
# Verify first query result
assert result["results"][0]["data"] == [{"id": 1, "name": "张三"}]
assert result["results"][0]["row_count"] == 1
assert result["results"][0]["metadata"]["columns"] == ["id", "name"]
assert result["results"][0]["metadata"]["query"] == "SELECT id, name FROM users WHERE id = 1"
# Verify second query result
assert result["results"][1]["data"] == [{"id": 2, "name": "李四"}]
assert result["results"][1]["row_count"] == 1
assert result["results"][1]["metadata"]["columns"] == ["id", "name"]
assert result["results"][1]["metadata"]["query"] == "SELECT id, name FROM users WHERE id = 2"
# Verify third query result
assert result["results"][2]["data"] == [{"count": 100}]
assert result["results"][2]["row_count"] == 1
assert result["results"][2]["metadata"]["columns"] == ["count"]
assert result["results"][2]["metadata"]["query"] == "SELECT COUNT(*) as count FROM users"
# Verify execute_query was called three times
assert mock_execute.call_count == 3

2
uv.lock generated
View File

@@ -562,7 +562,7 @@ wheels = [
[[package]] [[package]]
name = "doris-mcp-server" name = "doris-mcp-server"
version = "0.5.1" version = "0.6.1"
source = { editable = "." } source = { editable = "." }
dependencies = [ dependencies = [
{ name = "adbc-driver-flightsql" }, { name = "adbc-driver-flightsql" },