init doris mcp 0.2.0
This commit is contained in:
1
doris_mcp_server/utils/__init__.py
Normal file
1
doris_mcp_server/utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Mark directory as a package
|
||||
100
doris_mcp_server/utils/db.py
Normal file
100
doris_mcp_server/utils/db.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import os
|
||||
import json
|
||||
import pymysql
|
||||
import pandas as pd
|
||||
from typing import Dict, List, Optional, Any
|
||||
from dotenv import load_dotenv
|
||||
import re
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Database configuration
|
||||
DB_CONFIG = {
|
||||
"host": os.getenv("DB_HOST", "localhost"),
|
||||
"port": int(os.getenv("DB_PORT", "9030")),
|
||||
"user": os.getenv("DB_USER", "root"),
|
||||
"password": os.getenv("DB_PASSWORD", ""),
|
||||
"database": os.getenv("DB_DATABASE", ""),
|
||||
"charset": "utf8mb4",
|
||||
"cursorclass": pymysql.cursors.DictCursor
|
||||
}
|
||||
|
||||
def get_db_connection(db_name: Optional[str] = None):
|
||||
"""
|
||||
Get database connection
|
||||
|
||||
Args:
|
||||
db_name: Specify the database name to connect to, use default config if None
|
||||
|
||||
Returns:
|
||||
Database connection
|
||||
"""
|
||||
if db_name:
|
||||
# Use default config but override database name
|
||||
config = DB_CONFIG.copy()
|
||||
config["database"] = db_name
|
||||
return pymysql.connect(**config)
|
||||
else:
|
||||
# Use default config
|
||||
return pymysql.connect(**DB_CONFIG)
|
||||
|
||||
def get_db_name() -> str:
|
||||
"""Get the currently configured default database name"""
|
||||
return DB_CONFIG["database"] or os.getenv("DB_DATABASE", "")
|
||||
|
||||
def execute_query(sql, db_name: Optional[str] = None):
|
||||
"""
|
||||
Execute SQL query and return results
|
||||
|
||||
Args:
|
||||
sql: SQL query statement
|
||||
db_name: Specify the database name to connect to, use default config if None
|
||||
|
||||
Returns:
|
||||
Query results
|
||||
"""
|
||||
conn = get_db_connection(db_name)
|
||||
try:
|
||||
with conn.cursor() as cursor:
|
||||
# Set connection character set to utf8 before executing query
|
||||
cursor.execute("SET NAMES utf8")
|
||||
|
||||
# Execute the actual query
|
||||
cursor.execute(sql)
|
||||
result = cursor.fetchall()
|
||||
return result
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def execute_query_df(sql, db_name: Optional[str] = None):
|
||||
"""
|
||||
Execute SQL query and return pandas DataFrame
|
||||
|
||||
Args:
|
||||
sql: SQL query statement
|
||||
db_name: Specify the database name to connect to, use default config if None
|
||||
|
||||
Returns:
|
||||
pandas DataFrame
|
||||
"""
|
||||
conn = get_db_connection(db_name)
|
||||
try:
|
||||
# Use a temporary cursor to execute the query and get results
|
||||
with conn.cursor() as cursor:
|
||||
# Set connection character set to utf8 before executing query
|
||||
cursor.execute("SET NAMES utf8")
|
||||
|
||||
# Execute the actual query
|
||||
cursor.execute(sql)
|
||||
result = cursor.fetchall()
|
||||
|
||||
# If no results, return empty DataFrame
|
||||
if not result:
|
||||
return pd.DataFrame()
|
||||
|
||||
# Manually convert dict results to DataFrame
|
||||
df = pd.DataFrame(result)
|
||||
return df
|
||||
finally:
|
||||
conn.close()
|
||||
226
doris_mcp_server/utils/logger.py
Normal file
226
doris_mcp_server/utils/logger.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""
|
||||
Unified Logging Configuration Module
|
||||
|
||||
Provides unified logging configuration, including:
|
||||
- General logs: Record all program execution information
|
||||
- Audit logs: Record JSON data for key operations and processing results
|
||||
- Error logs: Specifically record program exceptions and errors
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
import logging.handlers
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
from datetime import datetime
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Load environment variables
|
||||
load_dotenv(override=True)
|
||||
|
||||
# Get project root directory
|
||||
PROJECT_ROOT = Path(__file__).parents[2].absolute()
|
||||
|
||||
# Get log configuration from environment variables
|
||||
LOG_DIR = os.getenv("LOG_DIR", str(PROJECT_ROOT / "logs"))
|
||||
LOG_PREFIX = os.getenv("LOG_PREFIX", "doris_mcp")
|
||||
LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper()
|
||||
LOG_MAX_DAYS = int(os.getenv("LOG_MAX_DAYS", "30"))
|
||||
# Whether to output logs to the console (should be disabled when running as a service)
|
||||
CONSOLE_LOGGING = os.getenv("CONSOLE_LOGGING", "false").lower() == "true"
|
||||
# Whether stdio transport mode is being used
|
||||
STDIO_MODE = os.getenv("MCP_TRANSPORT_TYPE", "").lower() == "stdio"
|
||||
|
||||
def purge_old_logs():
|
||||
"""Clean up expired log files"""
|
||||
# --- Only perform cleanup in non-Stdio mode ---
|
||||
if STDIO_MODE:
|
||||
return
|
||||
try:
|
||||
now = datetime.now()
|
||||
log_dir = Path(LOG_DIR)
|
||||
# Check if directory exists and is readable/writable
|
||||
if not log_dir.is_dir() or not os.access(LOG_DIR, os.W_OK):
|
||||
if not STDIO_MODE: # Avoid printing to stdout in stdio mode
|
||||
print(f"Warning: Log directory {LOG_DIR} not accessible, skipping log purge.", file=sys.stderr)
|
||||
return
|
||||
|
||||
for log_file in log_dir.glob(f"{LOG_PREFIX}*.20*"):
|
||||
# Parse date
|
||||
file_name = log_file.name
|
||||
date_str = None
|
||||
|
||||
# Try to find the date part
|
||||
parts = file_name.split('.')
|
||||
for part in parts:
|
||||
if part.startswith('20') and len(part) == 8: # 20YYMMDD format
|
||||
date_str = part
|
||||
break
|
||||
|
||||
if date_str:
|
||||
try:
|
||||
file_date = datetime.strptime(date_str, '%Y%m%d')
|
||||
days_old = (now - file_date).days
|
||||
|
||||
if days_old > LOG_MAX_DAYS:
|
||||
os.remove(log_file)
|
||||
if not STDIO_MODE:
|
||||
print(f"Deleted expired log file: {log_file}")
|
||||
except (ValueError, OSError) as e:
|
||||
if not STDIO_MODE:
|
||||
print(f"Error processing log file {file_name}: {e}", file=sys.stderr)
|
||||
except Exception as e:
|
||||
if not STDIO_MODE:
|
||||
print(f"Error cleaning up logs: {e}", file=sys.stderr)
|
||||
|
||||
# Force disable console log output if in stdio mode
|
||||
if STDIO_MODE:
|
||||
CONSOLE_LOGGING = False
|
||||
|
||||
# --- Only create log directory and clean old logs in non-Stdio mode ---
|
||||
if not STDIO_MODE:
|
||||
try:
|
||||
os.makedirs(LOG_DIR, exist_ok=True)
|
||||
# Clean up expired logs on startup (also moved here, as it only handles file logs)
|
||||
purge_old_logs()
|
||||
except OSError as e:
|
||||
# If directory creation fails (e.g., permission issue), print warning but continue to avoid startup failure
|
||||
print(f"Warning: Failed to create log directory {LOG_DIR} or purge logs: {e}", file=sys.stderr)
|
||||
|
||||
# Log file paths (definition still needed, but files might not be created/used)
|
||||
LOG_FILE = os.path.join(LOG_DIR, f"{LOG_PREFIX}.log")
|
||||
AUDIT_LOG_FILE = os.path.join(LOG_DIR, f"{LOG_PREFIX}.audit")
|
||||
ERROR_LOG_FILE = os.path.join(LOG_DIR, f"{LOG_PREFIX}.error")
|
||||
|
||||
# Log level mapping
|
||||
LOG_LEVELS = {
|
||||
"DEBUG": logging.DEBUG,
|
||||
"INFO": logging.INFO,
|
||||
"WARNING": logging.WARNING,
|
||||
"ERROR": logging.ERROR,
|
||||
"CRITICAL": logging.CRITICAL
|
||||
}
|
||||
|
||||
# Log format
|
||||
LOG_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
AUDIT_FORMAT = '%(asctime)s - %(name)s - %(message)s'
|
||||
ERROR_FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(pathname)s:%(lineno)d - %(message)s'
|
||||
|
||||
# Dedicated audit log level
|
||||
AUDIT = 25 # Level between INFO and WARNING
|
||||
logging.addLevelName(AUDIT, "AUDIT")
|
||||
|
||||
# Logger object cache
|
||||
_loggers: Dict[str, logging.Logger] = {}
|
||||
|
||||
# Handler type mapping, used to ensure no duplicates are added
|
||||
_handler_types = {
|
||||
'console': logging.StreamHandler,
|
||||
'file': logging.handlers.TimedRotatingFileHandler,
|
||||
'audit': logging.handlers.TimedRotatingFileHandler,
|
||||
'error': logging.handlers.TimedRotatingFileHandler
|
||||
}
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""
|
||||
Get a logger with the specified name
|
||||
|
||||
Args:
|
||||
name: Logger name
|
||||
|
||||
Returns:
|
||||
logging.Logger: Configured logger
|
||||
"""
|
||||
if name in _loggers:
|
||||
return _loggers[name]
|
||||
|
||||
# Create logger
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(LOG_LEVELS.get(LOG_LEVEL, logging.INFO))
|
||||
|
||||
# Avoid duplicate logs caused by propagation
|
||||
logger.propagate = False
|
||||
|
||||
# Check if handlers already exist to avoid duplicates
|
||||
handler_types = set(type(h) for h in logger.handlers)
|
||||
|
||||
# Add audit log method
|
||||
def audit(self, message, *args, **kwargs):
|
||||
self.log(AUDIT, message, *args, **kwargs)
|
||||
|
||||
logger.audit = audit.__get__(logger)
|
||||
|
||||
# General log handler - output to console (only if enabled)
|
||||
if CONSOLE_LOGGING and _handler_types['console'] not in handler_types:
|
||||
# Use stderr instead of stdout to avoid conflicts with MCP communication
|
||||
console_handler = logging.StreamHandler(sys.stderr)
|
||||
console_handler.setFormatter(logging.Formatter(LOG_FORMAT))
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
# --- Only add file handlers in non-Stdio mode ---
|
||||
if not STDIO_MODE:
|
||||
# General log handler - daily rotating file
|
||||
if _handler_types['file'] not in handler_types:
|
||||
try: # Add try-except block
|
||||
file_handler = logging.handlers.TimedRotatingFileHandler(
|
||||
LOG_FILE,
|
||||
when='midnight',
|
||||
interval=1,
|
||||
backupCount=LOG_MAX_DAYS,
|
||||
encoding='utf-8'
|
||||
)
|
||||
file_handler.setFormatter(logging.Formatter(LOG_FORMAT))
|
||||
file_handler.suffix = "%Y%m%d"
|
||||
logger.addHandler(file_handler)
|
||||
except OSError as e:
|
||||
print(f"Warning: Failed to add file log handler for {LOG_FILE}: {e}", file=sys.stderr)
|
||||
|
||||
# Audit log handler - only logs AUDIT level
|
||||
if _handler_types['audit'] not in handler_types:
|
||||
try: # Add try-except block
|
||||
audit_handler = logging.handlers.TimedRotatingFileHandler(
|
||||
AUDIT_LOG_FILE,
|
||||
when='midnight',
|
||||
interval=1,
|
||||
backupCount=LOG_MAX_DAYS,
|
||||
encoding='utf-8'
|
||||
)
|
||||
audit_handler.setFormatter(logging.Formatter(AUDIT_FORMAT))
|
||||
audit_handler.suffix = "%Y%m%d"
|
||||
audit_handler.setLevel(AUDIT)
|
||||
audit_handler.addFilter(lambda record: record.levelno == AUDIT)
|
||||
logger.addHandler(audit_handler)
|
||||
except OSError as e:
|
||||
print(f"Warning: Failed to add audit log handler for {AUDIT_LOG_FILE}: {e}", file=sys.stderr)
|
||||
|
||||
# Error log handler - only logs ERROR level and above
|
||||
if _handler_types['error'] not in handler_types:
|
||||
try: # Add try-except block
|
||||
error_handler = logging.handlers.TimedRotatingFileHandler(
|
||||
ERROR_LOG_FILE,
|
||||
when='midnight',
|
||||
interval=1,
|
||||
backupCount=LOG_MAX_DAYS,
|
||||
encoding='utf-8'
|
||||
)
|
||||
error_handler.setFormatter(logging.Formatter(ERROR_FORMAT))
|
||||
error_handler.suffix = "%Y%m%d"
|
||||
error_handler.setLevel(logging.ERROR)
|
||||
logger.addHandler(error_handler)
|
||||
except OSError as e:
|
||||
print(f"Warning: Failed to add error log handler for {ERROR_LOG_FILE}: {e}", file=sys.stderr)
|
||||
|
||||
# Cache logger
|
||||
_loggers[name] = logger
|
||||
|
||||
return logger
|
||||
|
||||
# Default logger
|
||||
logger = get_logger('doris_mcp')
|
||||
|
||||
# Audit logger - for recording processing results, business operations, etc.
|
||||
audit_logger = get_logger('audit')
|
||||
|
||||
# Call to clean logs moved after directory creation, and added non-stdio check
|
||||
1013
doris_mcp_server/utils/schema_extractor.py
Normal file
1013
doris_mcp_server/utils/schema_extractor.py
Normal file
File diff suppressed because it is too large
Load Diff
349
doris_mcp_server/utils/sql_executor_tools.py
Normal file
349
doris_mcp_server/utils/sql_executor_tools.py
Normal file
@@ -0,0 +1,349 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
SQL Execution Tool
|
||||
|
||||
Responsible for executing SQL queries and handling results
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import traceback
|
||||
import time
|
||||
from typing import Dict, Any
|
||||
import re
|
||||
import datetime
|
||||
from decimal import Decimal
|
||||
|
||||
# Get logger
|
||||
logger = logging.getLogger("doris-mcp.sql-executor")
|
||||
|
||||
# Add environment variable control for whether to perform SQL security checks
|
||||
ENABLE_SQL_SECURITY_CHECK = os.environ.get('ENABLE_SQL_SECURITY_CHECK', 'true').lower() == 'true'
|
||||
|
||||
async def execute_sql_query(ctx) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute SQL query and return results
|
||||
|
||||
Args:
|
||||
ctx: Context object or dictionary containing request parameters
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Execution result
|
||||
"""
|
||||
try:
|
||||
# Support the case where the passed argument is a dictionary
|
||||
if isinstance(ctx, dict) and 'params' in ctx:
|
||||
params = ctx['params']
|
||||
else:
|
||||
params = ctx.params
|
||||
|
||||
sql = params.get("sql")
|
||||
db_name = params.get("db_name", os.getenv("DB_DATABASE", ""))
|
||||
max_rows = params.get("max_rows", 1000) # Maximum number of rows to return
|
||||
timeout = params.get("timeout", 30) # Timeout in seconds
|
||||
|
||||
if not sql:
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps({
|
||||
"success": False,
|
||||
"error": "Missing SQL parameter",
|
||||
"message": "Please provide the SQL query to execute"
|
||||
}, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# First check SQL security
|
||||
security_result = await _check_sql_security(sql)
|
||||
if not security_result.get("is_safe", False):
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps({
|
||||
"success": False,
|
||||
"error": "SQL security check failed",
|
||||
"message": "Query contains unsafe operations and cannot be executed",
|
||||
"security_issues": security_result.get("security_issues", [])
|
||||
}, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Import database connection tool
|
||||
from doris_mcp_server.utils.db import execute_query
|
||||
|
||||
if not sql:
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps({
|
||||
"success": False,
|
||||
"error": "Missing SQL parameter",
|
||||
"message": "Please provide the SQL query to execute"
|
||||
}, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Ensure SELECT statements include a LIMIT clause
|
||||
sql_lower = sql.lower().strip()
|
||||
if sql_lower.startswith("select") and "limit" not in sql_lower:
|
||||
sql = sql.rstrip(";") + f" LIMIT {max_rows};"
|
||||
|
||||
# Start timer
|
||||
start_time = time.time()
|
||||
|
||||
# Execute query
|
||||
try:
|
||||
result = execute_query(sql, db_name)
|
||||
|
||||
# Calculate execution time
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
# Build return result
|
||||
if isinstance(result, list):
|
||||
# Handle list of query results
|
||||
row_count = len(result)
|
||||
|
||||
# Extract column names
|
||||
if hasattr(result[0], "_fields"):
|
||||
# If it's a named tuple
|
||||
columns = list(result[0]._fields)
|
||||
else:
|
||||
# Otherwise, assume it's a dictionary
|
||||
columns = list(result[0].keys()) if isinstance(result[0], dict) else []
|
||||
|
||||
# Convert results to serializable format
|
||||
data = []
|
||||
for row in result:
|
||||
row_dict = {}
|
||||
if hasattr(row, "_asdict"):
|
||||
# If it's a named tuple
|
||||
row_dict = row._asdict()
|
||||
elif isinstance(row, dict):
|
||||
# If it's a dictionary
|
||||
row_dict = row
|
||||
else:
|
||||
# If it's a list or tuple
|
||||
row_dict = dict(zip(columns, row)) if columns else row
|
||||
|
||||
# Handle special types to make them JSON serializable
|
||||
serialized_row = _serialize_row_data(row_dict)
|
||||
data.append(serialized_row)
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps({
|
||||
"success": True,
|
||||
"sql": sql,
|
||||
"row_count": row_count,
|
||||
"columns": columns,
|
||||
"data": data[:max_rows], # Limit returned rows
|
||||
"execution_time": execution_time,
|
||||
"truncated": row_count > max_rows
|
||||
}, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
else:
|
||||
# Handle other types of results
|
||||
other_response = {
|
||||
"success": True,
|
||||
"sql": sql,
|
||||
"result": str(result),
|
||||
"execution_time": execution_time
|
||||
}
|
||||
other_response = _serialize_row_data(other_response)
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(other_response, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
except Exception as db_error:
|
||||
error_message = str(db_error)
|
||||
|
||||
# Try to get more detailed error information
|
||||
error_details = {}
|
||||
if "timeout" in error_message.lower():
|
||||
error_details["type"] = "timeout"
|
||||
error_details["suggestion"] = "Query timed out, please optimize SQL or increase timeout"
|
||||
elif "syntax" in error_message.lower():
|
||||
error_details["type"] = "syntax"
|
||||
error_details["suggestion"] = "SQL syntax error, please check syntax"
|
||||
elif "not found" in error_message.lower() or "doesn't exist" in error_message.lower():
|
||||
error_details["type"] = "not_found"
|
||||
error_details["suggestion"] = "Table or column not found, please check table and column names"
|
||||
else:
|
||||
error_details["type"] = "unknown"
|
||||
error_details["suggestion"] = "Please check the SQL statement and try simplifying the query"
|
||||
|
||||
# Create error response
|
||||
error_response = {
|
||||
"success": False,
|
||||
"error": error_message,
|
||||
"error_details": error_details,
|
||||
"sql": sql,
|
||||
"db_name": db_name
|
||||
}
|
||||
|
||||
# Ensure error response is also serializable
|
||||
error_response = _serialize_row_data(error_response)
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(error_response, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to execute SQL query: {str(e)}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
error_response = {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": "Error occurred while executing SQL query"
|
||||
}
|
||||
|
||||
# Ensure error response is also serializable
|
||||
error_response = _serialize_row_data(error_response)
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(error_response, ensure_ascii=False)
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Helper function
|
||||
async def _check_sql_security(sql: str) -> Dict[str, Any]:
|
||||
"""Check SQL security"""
|
||||
# If environment variable is set to disable security check, return safe immediately
|
||||
if not ENABLE_SQL_SECURITY_CHECK:
|
||||
return {
|
||||
"is_safe": True,
|
||||
"security_issues": []
|
||||
}
|
||||
|
||||
# Check if SQL contains dangerous operations
|
||||
sql_lower = sql.lower()
|
||||
|
||||
# Check if it's a read-only query type
|
||||
is_read_only = sql_lower.strip().startswith(("select ", "show ", "desc ", "describe ", "explain "))
|
||||
|
||||
# Define list of dangerous operations (checked for both read-only and non-read-only queries)
|
||||
dangerous_operations = [
|
||||
(r'\bdelete\b', "DELETE operation"),
|
||||
(r'\bdrop\b', "DROP TABLE/DATABASE operation"),
|
||||
(r'\btruncate\b', "TRUNCATE TABLE operation"),
|
||||
(r'\bupdate\b', "UPDATE operation"),
|
||||
(r'\binsert\b', "INSERT operation"),
|
||||
(r'\balter\b', "ALTER TABLE structure operation"),
|
||||
(r'\bcreate\b', "CREATE TABLE/DATABASE operation"),
|
||||
(r'\bgrant\b', "GRANT operation"),
|
||||
(r'\brevoke\b', "REVOKE permission operation"),
|
||||
(r'\bexec\b', "EXECUTE stored procedure"),
|
||||
(r'\bxp_', "Extended stored procedure, potential security risk"),
|
||||
(r'\bshutdown\b', "SHUTDOWN database operation"),
|
||||
(r'\bunion\s+all\s+select\b', "UNION statement, potential SQL injection"),
|
||||
(r'\bunion\s+select\b', "UNION statement, potential SQL injection"),
|
||||
(r'\binto\s+outfile\b', "Write to file operation"),
|
||||
(r'\bload_file\b', "Load file operation")
|
||||
]
|
||||
|
||||
# Dangerous operations checked only for non-read-only queries
|
||||
non_readonly_operations = []
|
||||
if not is_read_only:
|
||||
non_readonly_operations = [
|
||||
(r'--', "SQL comment, potential SQL injection"),
|
||||
(r'/\*', "SQL block comment, potential SQL injection")
|
||||
]
|
||||
|
||||
# Check if dangerous operations are included
|
||||
security_issues = []
|
||||
|
||||
# Check dangerous operations applicable to all queries
|
||||
for operation, description in dangerous_operations:
|
||||
if re.search(operation, sql_lower):
|
||||
# For specific keywords in read-only queries, differentiate if used as independent operations
|
||||
if is_read_only and operation in [r'\bcreate\b', r'\bdrop\b', r'\bdelete\b', r'\binsert\b', r'\bupdate\b', r'\balter\b']:
|
||||
# Check if used as DDL/DML keyword, e.g., CREATE TABLE, DROP DATABASE
|
||||
pattern = operation + r'\s+(?:table|database|view|index|procedure|function|trigger|event)'
|
||||
if re.search(pattern, sql_lower):
|
||||
security_issues.append({
|
||||
"operation": operation.replace(r'\b', '').replace(r'\s+', ' '),
|
||||
"description": description,
|
||||
"severity": "High"
|
||||
})
|
||||
else:
|
||||
security_issues.append({
|
||||
"operation": operation.replace(r'\b', '').replace(r'\s+', ' '),
|
||||
"description": description,
|
||||
"severity": "High"
|
||||
})
|
||||
|
||||
# Check dangerous operations specific to non-read-only queries
|
||||
for operation, description in non_readonly_operations:
|
||||
if re.search(operation, sql_lower):
|
||||
security_issues.append({
|
||||
"operation": operation.replace(r'\b', '').replace(r'\s+', ' '),
|
||||
"description": description,
|
||||
"severity": "Medium"
|
||||
})
|
||||
|
||||
return {
|
||||
"is_safe": len(security_issues) == 0,
|
||||
"security_issues": security_issues
|
||||
}
|
||||
|
||||
def _serialize_row_data(row_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert special types in row data (like date, time, Decimal) to JSON serializable format
|
||||
|
||||
Args:
|
||||
row_data: Row data dictionary
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Processed serializable dictionary
|
||||
"""
|
||||
serialized_data = {}
|
||||
for key, value in row_data.items():
|
||||
if value is None:
|
||||
serialized_data[key] = None
|
||||
elif isinstance(value, (datetime.date, datetime.datetime)):
|
||||
# Convert date and time types to ISO format string
|
||||
serialized_data[key] = value.isoformat()
|
||||
elif isinstance(value, Decimal):
|
||||
# Convert Decimal type to float
|
||||
serialized_data[key] = float(value)
|
||||
elif isinstance(value, (list, tuple)):
|
||||
# Recursively process elements in list or tuple
|
||||
serialized_data[key] = [
|
||||
_serialize_row_data(item) if isinstance(item, dict) else item
|
||||
for item in value
|
||||
]
|
||||
elif isinstance(value, dict):
|
||||
# Recursively process nested dictionaries
|
||||
serialized_data[key] = _serialize_row_data(value)
|
||||
else:
|
||||
serialized_data[key] = value
|
||||
return serialized_data
|
||||
Reference in New Issue
Block a user