[Feature]Add 9 New Tools (#23)

release 0.5.0
This commit is contained in:
Yijia Su
2025-07-11 12:03:13 +08:00
committed by GitHub
parent d12dfbd014
commit 54572d0861
15 changed files with 7297 additions and 215 deletions

View File

@@ -0,0 +1,526 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Apache Doris ADBC Query Tools
High-performance data querying using Apache Arrow Flight SQL protocol
"""
import os
import socket
import time
from datetime import datetime
from typing import Any, Dict, List, Optional
from ..utils.logger import get_logger
from ..utils.db import DorisConnectionManager
logger = get_logger(__name__)
def _convert_numpy_types(obj):
"""Convert numpy types to native Python types for JSON serialization"""
try:
# Import numpy only when needed
import numpy as np
import pandas as pd
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.bool_):
return bool(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, (pd.Timestamp, pd.NaT.__class__)):
return str(obj)
elif pd.isna(obj):
return None
else:
return obj
except ImportError:
# If numpy/pandas not available, return as-is
return obj
def _convert_dataframe_to_json_serializable(df):
"""Convert DataFrame to JSON serializable format"""
try:
import pandas as pd
import numpy as np
# Convert DataFrame to records
records = df.to_dict('records')
# Convert each record's values
converted_records = []
for record in records:
converted_record = {}
for key, value in record.items():
converted_record[key] = _convert_numpy_types(value)
converted_records.append(converted_record)
return converted_records
except ImportError:
# Fallback to basic dict conversion
return df.to_dict('records')
class DorisADBCQueryTools:
"""ADBC Query Tools for high-performance data transfer using Arrow Flight SQL"""
def __init__(self, connection_manager: DorisConnectionManager):
self.connection_manager = connection_manager
self.adbc_client = None
self.flight_sql_module = None
self.adbc_manager_module = None
async def exec_adbc_query(
self,
sql: str,
max_rows: int | None = None,
timeout: int | None = None,
return_format: str | None = None
) -> Dict[str, Any]:
"""
Execute SQL query using ADBC (Arrow Flight SQL) protocol
Args:
sql: SQL statement to execute
max_rows: Maximum number of rows to return (uses config default if None)
timeout: Query timeout in seconds (uses config default if None)
return_format: Format for returned data ("arrow", "pandas", "dict", uses config default if None)
Returns:
Query results in specified format with metadata
"""
try:
start_time = time.time()
# Use configuration defaults if parameters not specified
adbc_config = self.connection_manager.config.adbc
max_rows = max_rows if max_rows is not None else adbc_config.default_max_rows
timeout = timeout if timeout is not None else adbc_config.default_timeout
return_format = return_format if return_format is not None else adbc_config.default_return_format
# Step 1: Check environment variables and port availability
port_check_result = await self._check_arrow_flight_ports()
if not port_check_result["success"]:
return port_check_result
# Step 2: Import required ADBC modules
import_result = await self._import_adbc_modules()
if not import_result["success"]:
return import_result
# Step 3: Create ADBC connection
connection_result = await self._create_adbc_connection()
if not connection_result["success"]:
return connection_result
# Step 4: Execute query using ADBC
query_result = await self._execute_query_with_adbc(
sql, max_rows, timeout, return_format
)
execution_time = time.time() - start_time
if query_result["success"]:
query_result["execution_time"] = round(execution_time, 3)
query_result["protocol"] = "ADBC_Arrow_Flight_SQL"
query_result["timestamp"] = datetime.now().isoformat()
return query_result
except Exception as e:
logger.error(f"ADBC query execution failed: {str(e)}")
return {
"success": False,
"error": f"ADBC query execution failed: {str(e)}",
"error_type": "execution_error",
"timestamp": datetime.now().isoformat()
}
async def _check_arrow_flight_ports(self) -> Dict[str, Any]:
"""Check Arrow Flight SQL port configuration and availability"""
try:
# Check environment variables
fe_port = os.getenv("FE_ARROW_FLIGHT_SQL_PORT")
be_port = os.getenv("BE_ARROW_FLIGHT_SQL_PORT")
if not fe_port:
return {
"success": False,
"error": "Missing environment variable FE_ARROW_FLIGHT_SQL_PORT, please configure Arrow Flight SQL FE port in .env file",
"error_type": "missing_fe_port_config"
}
if not be_port:
return {
"success": False,
"error": "Missing environment variable BE_ARROW_FLIGHT_SQL_PORT, please configure Arrow Flight SQL BE port in .env file",
"error_type": "missing_be_port_config"
}
# Convert to integer and validate
try:
fe_port = int(fe_port)
be_port = int(be_port)
except ValueError:
return {
"success": False,
"error": "Invalid Arrow Flight SQL port configuration, please ensure FE_ARROW_FLIGHT_SQL_PORT and BE_ARROW_FLIGHT_SQL_PORT are valid numbers",
"error_type": "invalid_port_format"
}
# Get host address
db_config = self.connection_manager.config.database
fe_host = db_config.host
# Check FE Arrow Flight SQL port availability
fe_available = self._check_port_connectivity(fe_host, fe_port)
if not fe_available:
return {
"success": False,
"error": f"Cannot connect to FE Arrow Flight SQL port {fe_host}:{fe_port}, please check if service is running",
"error_type": "fe_port_unavailable",
"fe_host": fe_host,
"fe_port": fe_port
}
# Get BE host list
be_hosts = await self._get_be_hosts()
if not be_hosts:
return {
"success": False,
"error": "Cannot get BE node information, please check cluster status",
"error_type": "no_be_hosts"
}
# Check at least one BE Arrow Flight SQL port availability
be_available_count = 0
be_check_results = []
for be_host in be_hosts[:3]: # Check first 3 BE nodes
be_available = self._check_port_connectivity(be_host, be_port)
be_check_results.append({
"host": be_host,
"port": be_port,
"available": be_available
})
if be_available:
be_available_count += 1
if be_available_count == 0:
return {
"success": False,
"error": f"Cannot connect to any BE Arrow Flight SQL port (port: {be_port}), please check if BE services are running",
"error_type": "no_be_ports_available",
"be_check_results": be_check_results
}
return {
"success": True,
"fe_host": fe_host,
"fe_port": fe_port,
"be_port": be_port,
"be_hosts": be_hosts,
"be_available_count": be_available_count,
"be_check_results": be_check_results
}
except Exception as e:
logger.error(f"Arrow Flight port check failed: {str(e)}")
return {
"success": False,
"error": f"Arrow Flight port check failed: {str(e)}",
"error_type": "port_check_error"
}
def _check_port_connectivity(self, host: str, port: int, timeout: int | None = None) -> bool:
"""Check port connectivity"""
try:
# Use config timeout if not specified
if timeout is None:
timeout = self.connection_manager.config.adbc.connection_timeout
with socket.create_connection((host, port), timeout=timeout):
return True
except (socket.timeout, socket.error, OSError):
return False
async def _get_be_hosts(self) -> List[str]:
"""Get BE host list"""
try:
db_config = self.connection_manager.config.database
# Use configured BE hosts first
if db_config.be_hosts:
logger.info(f"Using configured BE hosts: {db_config.be_hosts}")
return db_config.be_hosts
# Get BE nodes via SHOW BACKENDS
logger.info("No BE hosts configured, getting BE node information via SHOW BACKENDS")
connection = await self.connection_manager.get_connection("query")
result = await connection.execute("SHOW BACKENDS")
be_hosts = []
for row in result.data:
host = row.get("Host")
alive = row.get("Alive", "").lower()
if host and alive == "true":
be_hosts.append(host)
logger.info(f"Got {len(be_hosts)} active BE nodes from SHOW BACKENDS")
return be_hosts
except Exception as e:
logger.error(f"Failed to get BE hosts: {str(e)}")
return []
async def _import_adbc_modules(self) -> Dict[str, Any]:
"""Import ADBC related modules"""
try:
# Import ADBC Driver Manager
try:
import adbc_driver_manager
self.adbc_manager_module = adbc_driver_manager
except ImportError:
return {
"success": False,
"error": "Missing adbc_driver_manager module, please install: pip install adbc_driver_manager",
"error_type": "missing_adbc_manager"
}
# Import ADBC Flight SQL Driver
try:
import adbc_driver_flightsql.dbapi as flight_sql
self.flight_sql_module = flight_sql
except ImportError:
return {
"success": False,
"error": "Missing adbc_driver_flightsql module, please install: pip install adbc_driver_flightsql",
"error_type": "missing_flight_sql_driver"
}
return {
"success": True,
"adbc_manager_version": getattr(adbc_driver_manager, '__version__', 'unknown'),
"flight_sql_version": getattr(flight_sql, '__version__', 'unknown')
}
except Exception as e:
logger.error(f"ADBC module import failed: {str(e)}")
return {
"success": False,
"error": f"ADBC module import failed: {str(e)}",
"error_type": "import_error"
}
async def _create_adbc_connection(self) -> Dict[str, Any]:
"""Create ADBC connection"""
try:
db_config = self.connection_manager.config.database
fe_port = int(os.getenv("FE_ARROW_FLIGHT_SQL_PORT"))
# Build connection URI
uri = f"grpc://{db_config.host}:{fe_port}"
# Create database connection parameters
db_kwargs = {
self.adbc_manager_module.DatabaseOptions.USERNAME.value: db_config.user,
self.adbc_manager_module.DatabaseOptions.PASSWORD.value: db_config.password,
}
# Create connection
self.adbc_client = self.flight_sql_module.connect(
uri=uri,
db_kwargs=db_kwargs
)
return {
"success": True,
"uri": uri,
"connection_established": True
}
except Exception as e:
logger.error(f"Failed to create ADBC connection: {str(e)}")
return {
"success": False,
"error": f"Failed to create ADBC connection: {str(e)}",
"error_type": "connection_error"
}
async def _execute_query_with_adbc(
self,
sql: str,
max_rows: int,
timeout: int,
return_format: str
) -> Dict[str, Any]:
"""Execute query using ADBC"""
try:
if not self.adbc_client:
return {
"success": False,
"error": "ADBC connection not established",
"error_type": "no_connection"
}
cursor = self.adbc_client.cursor()
start_time = time.time()
# Execute query
cursor.execute(sql)
# Get results based on return format
if return_format == "arrow":
# Return Arrow format
arrow_data = cursor.fetchallarrow()
# Limit rows
if len(arrow_data) > max_rows:
arrow_data = arrow_data.slice(0, max_rows)
# Convert Arrow data to serializable format
preview_df = arrow_data.to_pandas().head(10) if len(arrow_data) > 0 else None
result_data = {
"format": "arrow",
"num_rows": len(arrow_data),
"num_columns": len(arrow_data.schema),
"column_names": arrow_data.schema.names,
"column_types": [str(field.type) for field in arrow_data.schema],
"data_preview": _convert_dataframe_to_json_serializable(preview_df) if preview_df is not None else [],
"total_bytes": arrow_data.nbytes if hasattr(arrow_data, 'nbytes') else 0
}
elif return_format == "pandas":
# Return Pandas DataFrame
df = cursor.fetch_df()
# Limit rows
if len(df) > max_rows:
df = df.head(max_rows)
result_data = {
"format": "pandas",
"num_rows": len(df),
"num_columns": len(df.columns),
"column_names": df.columns.tolist(),
"column_types": df.dtypes.astype(str).tolist(),
"data": _convert_dataframe_to_json_serializable(df),
"memory_usage": int(df.memory_usage(deep=True).sum())
}
else: # return_format == "dict"
# Return dictionary format
arrow_data = cursor.fetchallarrow()
df = arrow_data.to_pandas()
# Limit rows
if len(df) > max_rows:
df = df.head(max_rows)
result_data = {
"format": "dict",
"num_rows": len(df),
"num_columns": len(df.columns),
"column_names": df.columns.tolist(),
"column_types": df.dtypes.astype(str).tolist(),
"data": _convert_dataframe_to_json_serializable(df)
}
execution_time = time.time() - start_time
cursor.close()
return {
"success": True,
"result": result_data,
"execution_time": round(execution_time, 3),
"sql": sql,
"max_rows_applied": len(result_data.get("data", [])) >= max_rows
}
except Exception as e:
logger.error(f"ADBC query execution failed: {str(e)}")
return {
"success": False,
"error": f"ADBC query execution failed: {str(e)}",
"error_type": "query_execution_error",
"sql": sql
}
async def get_adbc_connection_info(self) -> Dict[str, Any]:
"""Get ADBC connection information and status"""
try:
# Check port status
port_status = await self._check_arrow_flight_ports()
# Check module status
module_status = await self._import_adbc_modules()
# Get configuration information
db_config = self.connection_manager.config.database
fe_port = os.getenv("FE_ARROW_FLIGHT_SQL_PORT")
be_port = os.getenv("BE_ARROW_FLIGHT_SQL_PORT")
connection_info = {
"adbc_available": module_status["success"],
"ports_available": port_status["success"],
"configuration": {
"fe_host": db_config.host,
"fe_arrow_flight_port": fe_port,
"be_arrow_flight_port": be_port,
"user": db_config.user
},
"port_status": port_status,
"module_status": module_status,
"timestamp": datetime.now().isoformat()
}
if port_status["success"] and module_status["success"]:
connection_info["status"] = "ready"
connection_info["message"] = "ADBC Arrow Flight SQL connection ready"
else:
connection_info["status"] = "not_ready"
errors = []
if not port_status["success"]:
errors.append(port_status["error"])
if not module_status["success"]:
errors.append(module_status["error"])
connection_info["message"] = "; ".join(errors)
return connection_info
except Exception as e:
logger.error(f"Failed to get ADBC connection information: {str(e)}")
return {
"status": "error",
"error": f"Failed to get ADBC connection information: {str(e)}",
"timestamp": datetime.now().isoformat()
}
def __del__(self):
"""Cleanup resources"""
try:
if self.adbc_client:
self.adbc_client.close()
except:
pass

View File

@@ -54,6 +54,10 @@ class DatabaseConfig:
be_hosts: list[str] = field(default_factory=list)
be_webserver_port: int = 8040
# Arrow Flight SQL Configuration (Required for ADBC tools)
fe_arrow_flight_sql_port: int | None = None
be_arrow_flight_sql_port: int | None = None
# Connection pool configuration
# Note: min_connections is fixed at 0 to avoid at_eof connection issues
# This prevents pre-creation of connections which can cause state problems
@@ -133,6 +137,22 @@ class PerformanceConfig:
max_response_content_size: int = 4096
@dataclass
class ADBCConfig:
"""ADBC (Arrow Flight SQL) configuration"""
# Default query parameters
default_max_rows: int = 100000
default_timeout: int = 60
default_return_format: str = "arrow" # "arrow", "pandas", "dict"
# Connection timeout for ADBC
connection_timeout: int = 30
# Whether to enable ADBC tools
enabled: bool = True
@dataclass
class LoggingConfig:
"""Logging configuration"""
@@ -190,6 +210,7 @@ class DorisConfig:
performance: PerformanceConfig = field(default_factory=PerformanceConfig)
logging: LoggingConfig = field(default_factory=LoggingConfig)
monitoring: MonitoringConfig = field(default_factory=MonitoringConfig)
adbc: ADBCConfig = field(default_factory=ADBCConfig)
# Custom configuration
custom_config: dict[str, Any] = field(default_factory=dict)
@@ -260,6 +281,15 @@ class DorisConfig:
if be_hosts_env:
config.database.be_hosts = [host.strip() for host in be_hosts_env.split(",") if host.strip()]
config.database.be_webserver_port = int(os.getenv("DORIS_BE_WEBSERVER_PORT", str(config.database.be_webserver_port)))
# Arrow Flight SQL Configuration
fe_arrow_port_env = os.getenv("FE_ARROW_FLIGHT_SQL_PORT")
if fe_arrow_port_env:
config.database.fe_arrow_flight_sql_port = int(fe_arrow_port_env)
be_arrow_port_env = os.getenv("BE_ARROW_FLIGHT_SQL_PORT")
if be_arrow_port_env:
config.database.be_arrow_flight_sql_port = int(be_arrow_port_env)
# Connection pool configuration
config.database.max_connections = int(
@@ -359,6 +389,21 @@ class DorisConfig:
)
config.monitoring.alert_webhook_url = os.getenv("ALERT_WEBHOOK_URL", config.monitoring.alert_webhook_url)
# ADBC configuration
config.adbc.default_max_rows = int(
os.getenv("ADBC_DEFAULT_MAX_ROWS", str(config.adbc.default_max_rows))
)
config.adbc.default_timeout = int(
os.getenv("ADBC_DEFAULT_TIMEOUT", str(config.adbc.default_timeout))
)
config.adbc.default_return_format = os.getenv("ADBC_DEFAULT_RETURN_FORMAT", config.adbc.default_return_format)
config.adbc.connection_timeout = int(
os.getenv("ADBC_CONNECTION_TIMEOUT", str(config.adbc.connection_timeout))
)
config.adbc.enabled = (
os.getenv("ADBC_ENABLED", str(config.adbc.enabled).lower()).lower() == "true"
)
# Server configuration
config.server_name = os.getenv("SERVER_NAME", config.server_name)
config.server_version = os.getenv("SERVER_VERSION", config.server_version)
@@ -412,6 +457,13 @@ class DorisConfig:
if hasattr(config.monitoring, key):
setattr(config.monitoring, key, value)
# Update ADBC configuration
if "adbc" in config_data:
adbc_config = config_data["adbc"]
for key, value in adbc_config.items():
if hasattr(config.adbc, key):
setattr(config.adbc, key, value)
# Custom configuration
config.custom_config = config_data.get("custom", {})
@@ -434,6 +486,8 @@ class DorisConfig:
"fe_http_port": self.database.fe_http_port,
"be_hosts": self.database.be_hosts,
"be_webserver_port": self.database.be_webserver_port,
"fe_arrow_flight_sql_port": self.database.fe_arrow_flight_sql_port,
"be_arrow_flight_sql_port": self.database.be_arrow_flight_sql_port,
"min_connections": self.database.min_connections, # Always 0, shown for reference
"max_connections": self.database.max_connections,
"connection_timeout": self.database.connection_timeout,
@@ -483,6 +537,13 @@ class DorisConfig:
"enable_alerts": self.monitoring.enable_alerts,
"alert_webhook_url": self.monitoring.alert_webhook_url,
},
"adbc": {
"default_max_rows": self.adbc.default_max_rows,
"default_timeout": self.adbc.default_timeout,
"default_return_format": self.adbc.default_return_format,
"connection_timeout": self.adbc.connection_timeout,
"enabled": self.adbc.enabled,
},
"custom": self.custom_config,
}
@@ -564,6 +625,19 @@ class DorisConfig:
if not (1 <= self.monitoring.health_check_port <= 65535):
errors.append("Health check port must be in the range 1-65535")
# Validate ADBC configuration
if self.adbc.default_max_rows <= 0:
errors.append("ADBC default max rows must be greater than 0")
if self.adbc.default_timeout <= 0:
errors.append("ADBC default timeout must be greater than 0")
if self.adbc.default_return_format not in ["arrow", "pandas", "dict"]:
errors.append("ADBC default return format must be one of arrow, pandas, or dict")
if self.adbc.connection_timeout <= 0:
errors.append("ADBC connection timeout must be greater than 0")
return errors
def get_connection_string(self) -> str:
@@ -603,6 +677,7 @@ class ConfigManager:
def setup_logging(self):
"""Setup logging configuration using enhanced logger"""
from .logger import setup_logging, get_logger
import sys
# Determine log directory
log_dir = "logs"
@@ -611,11 +686,19 @@ class ConfigManager:
from pathlib import Path
log_dir = str(Path(self.config.logging.file_path).parent)
# Detect if we're in stdio mode by checking if this is likely MCP stdio communication
# In stdio mode, we shouldn't output to console as it interferes with JSON protocol
is_stdio_mode = (
self.config.transport == "stdio" or
"--transport" in sys.argv and "stdio" in sys.argv or
not sys.stdout.isatty() # Not a terminal (likely piped/redirected)
)
# Setup enhanced logging with cleanup functionality
setup_logging(
level=self.config.logging.level,
log_dir=log_dir,
enable_console=True,
enable_console=not is_stdio_mode, # Disable console logging in stdio mode
enable_file=True,
enable_audit=self.config.logging.enable_audit,
audit_file=self.config.logging.audit_file_path,

View File

@@ -0,0 +1,733 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Data Exploration Tools Module
Provides table data distribution analysis and exploration capabilities
"""
import time
import math
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from .db import DorisConnectionManager
from .logger import get_logger
logger = get_logger(__name__)
class DataExplorationTools:
"""Data exploration tools for table distribution analysis"""
def __init__(self, connection_manager: DorisConnectionManager):
self.connection_manager = connection_manager
logger.info("DataExplorationTools initialized")
# ==================== Private Helper Methods ====================
def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str:
"""Build full table name with catalog and database using three-part naming convention"""
# Default catalog for internal tables
effective_catalog = catalog_name if catalog_name else "internal"
if db_name:
return f"{effective_catalog}.{db_name}.{table_name}"
else:
# If no db_name provided, need to determine the current database
return f"{effective_catalog}.{table_name}"
async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]:
"""Get basic table information including row count"""
try:
count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}"
result = await connection.execute(count_sql)
if result.data:
return {"row_count": result.data[0]["row_count"]}
return None
except Exception as e:
logger.warning(f"Failed to get basic info for table {table_name}: {str(e)}")
return {"row_count": 0}
async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]:
"""Get detailed column information"""
try:
where_conditions = [f"table_name = '{table_name}'"]
if db_name:
where_conditions.append(f"table_schema = '{db_name}'")
else:
where_conditions.append("table_schema = DATABASE()")
columns_sql = f"""
SELECT
column_name,
data_type,
is_nullable,
column_comment,
ordinal_position
FROM information_schema.columns
WHERE {' AND '.join(where_conditions)}
ORDER BY ordinal_position
"""
result = await connection.execute(columns_sql)
return result.data if result.data else []
except Exception as e:
logger.warning(f"Failed to get columns info for table {table_name}: {str(e)}")
return []
async def _determine_sampling_strategy(self, connection, table_name: str, total_rows: int, sample_size: int) -> Dict[str, Any]:
"""Determine optimal sampling strategy based on table size"""
if total_rows <= sample_size:
# Use all data if table is small enough
return {
"total_rows": total_rows,
"sample_size": total_rows,
"sampling_method": "full_scan",
"sampling_ratio": 1.0,
"use_sampling": False,
"sample_table_expression": table_name
}
else:
# Use random sampling for large tables
sampling_ratio = sample_size / total_rows
return {
"total_rows": total_rows,
"sample_size": sample_size,
"sampling_method": "random_sample",
"sampling_ratio": round(sampling_ratio, 4),
"use_sampling": True,
"sample_table_expression": f"(SELECT * FROM {table_name} ORDER BY RAND() LIMIT {sample_size}) as sample_table"
}
def _select_analysis_columns(self, columns_info: List[Dict], include_all: bool) -> List[Dict]:
"""Select columns for analysis based on strategy"""
if include_all:
return columns_info
# If not analyzing all columns, prioritize key columns
priority_keywords = ['id', 'key', 'code', 'status', 'type', 'amount', 'count', 'date', 'time']
priority_columns = []
other_columns = []
for col in columns_info:
col_name_lower = col["column_name"].lower()
if any(keyword in col_name_lower for keyword in priority_keywords):
priority_columns.append(col)
else:
other_columns.append(col)
# Return priority columns plus first 10 other columns
return priority_columns + other_columns[:10]
def _is_numeric_type(self, data_type: str) -> bool:
"""Check if column type is numeric"""
numeric_types = [
'tinyint', 'smallint', 'int', 'bigint', 'largeint',
'float', 'double', 'decimal', 'numeric'
]
return any(num_type in data_type.lower() for num_type in numeric_types)
def _is_categorical_type(self, data_type: str) -> bool:
"""Check if column type is categorical"""
categorical_types = ['varchar', 'char', 'string', 'text', 'enum']
return any(cat_type in data_type.lower() for cat_type in categorical_types)
def _is_temporal_type(self, data_type: str) -> bool:
"""Check if column type is temporal"""
temporal_types = ['date', 'datetime', 'timestamp', 'time']
return any(temp_type in data_type.lower() for temp_type in temporal_types)
async def _analyze_numeric_distributions(self, connection, table_name: str, numeric_columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
"""Analyze distribution patterns for numeric columns"""
numeric_analysis = {}
for column in numeric_columns:
col_name = column["column_name"]
try:
# Basic statistics
table_expr = sampling_info.get("sample_table_expression", table_name)
stats_sql = f"""
SELECT
COUNT({col_name}) as count,
MIN({col_name}) as min_value,
MAX({col_name}) as max_value,
AVG({col_name}) as mean_value,
STDDEV({col_name}) as std_dev
FROM {table_expr}
WHERE {col_name} IS NOT NULL
"""
stats_result = await connection.execute(stats_sql)
if stats_result.data and stats_result.data[0]["count"] > 0:
stats = stats_result.data[0]
# Percentiles calculation
percentiles = await self._calculate_percentiles(connection, table_name, col_name, sampling_info)
# Outlier detection
outliers = await self._detect_numeric_outliers(connection, table_name, col_name, percentiles, sampling_info)
# Distribution shape analysis
distribution_shape = await self._analyze_distribution_shape(
connection, table_name, col_name, stats, percentiles, sampling_info
)
numeric_analysis[col_name] = {
"data_type": column["data_type"],
"statistics": {
"count": stats["count"],
"mean": round(float(stats["mean_value"]), 4) if stats["mean_value"] else None,
"std": round(float(stats["std_dev"]), 4) if stats["std_dev"] else None,
"min": float(stats["min_value"]) if stats["min_value"] else None,
"max": float(stats["max_value"]) if stats["max_value"] else None,
**percentiles
},
"distribution_shape": distribution_shape,
"outliers": outliers
}
except Exception as e:
logger.warning(f"Failed to analyze numeric column {col_name}: {str(e)}")
numeric_analysis[col_name] = {"error": str(e)}
return numeric_analysis
async def _calculate_percentiles(self, connection, table_name: str, col_name: str, sampling_info: Dict) -> Dict[str, float]:
"""Calculate percentiles for numeric column"""
try:
table_expr = sampling_info.get("sample_table_expression", table_name)
percentile_sql = f"""
SELECT
PERCENTILE({col_name}, 0.25) as p25,
PERCENTILE({col_name}, 0.50) as p50,
PERCENTILE({col_name}, 0.75) as p75,
PERCENTILE({col_name}, 0.90) as p90,
PERCENTILE({col_name}, 0.95) as p95,
PERCENTILE({col_name}, 0.99) as p99
FROM {table_expr}
WHERE {col_name} IS NOT NULL
"""
result = await connection.execute(percentile_sql)
if result.data:
data = result.data[0]
return {
"25%": round(float(data["p25"]), 4) if data["p25"] else None,
"50%": round(float(data["p50"]), 4) if data["p50"] else None,
"75%": round(float(data["p75"]), 4) if data["p75"] else None,
"90%": round(float(data["p90"]), 4) if data["p90"] else None,
"95%": round(float(data["p95"]), 4) if data["p95"] else None,
"99%": round(float(data["p99"]), 4) if data["p99"] else None
}
except Exception as e:
logger.warning(f"Failed to calculate percentiles for {col_name}: {str(e)}")
return {}
async def _detect_numeric_outliers(self, connection, table_name: str, col_name: str, percentiles: Dict, sampling_info: Dict) -> Dict[str, Any]:
"""Detect outliers using IQR method"""
try:
if "25%" not in percentiles or "75%" not in percentiles:
return {"outlier_count": 0, "outlier_rate": 0.0}
q1 = percentiles["25%"]
q3 = percentiles["75%"]
iqr = q3 - q1
lower_bound = q1 - 1.5 * iqr
upper_bound = q3 + 1.5 * iqr
table_expr = sampling_info.get("sample_table_expression", table_name)
outlier_sql = f"""
SELECT
COUNT(*) as total_count,
SUM(CASE WHEN {col_name} < {lower_bound} OR {col_name} > {upper_bound} THEN 1 ELSE 0 END) as outlier_count
FROM {table_expr}
WHERE {col_name} IS NOT NULL
"""
result = await connection.execute(outlier_sql)
if result.data:
data = result.data[0]
total_count = data["total_count"]
outlier_count = data["outlier_count"]
outlier_rate = outlier_count / total_count if total_count > 0 else 0
return {
"outlier_count": outlier_count,
"outlier_rate": round(outlier_rate, 4),
"outlier_threshold_lower": round(lower_bound, 4),
"outlier_threshold_upper": round(upper_bound, 4),
"iqr": round(iqr, 4)
}
except Exception as e:
logger.warning(f"Failed to detect outliers for {col_name}: {str(e)}")
return {"outlier_count": 0, "outlier_rate": 0.0}
async def _analyze_distribution_shape(self, connection, table_name: str, col_name: str, stats: Dict, percentiles: Dict, sampling_info: Dict) -> Dict[str, Any]:
"""Analyze the shape of data distribution"""
try:
mean = stats.get("mean_value", 0)
median = percentiles.get("50%", 0)
if mean is None or median is None:
return {"distribution_type": "unknown"}
# Calculate skewness indicator
if abs(mean - median) < 0.01:
skew_indicator = "symmetric"
elif mean > median:
skew_indicator = "right_skewed"
else:
skew_indicator = "left_skewed"
# Estimate kurtosis based on percentile spread
if "25%" in percentiles and "75%" in percentiles:
iqr = percentiles["75%"] - percentiles["25%"]
range_90 = percentiles.get("90%", percentiles["75%"]) - percentiles.get("10%", percentiles["25%"])
if iqr > 0:
kurtosis_indicator = "normal" if 2.5 <= range_90/iqr <= 3.5 else ("heavy_tailed" if range_90/iqr > 3.5 else "light_tailed")
else:
kurtosis_indicator = "unknown"
else:
kurtosis_indicator = "unknown"
return {
"skewness_indicator": skew_indicator,
"kurtosis_indicator": kurtosis_indicator,
"distribution_type": self._classify_distribution_type(skew_indicator, kurtosis_indicator),
"mean_median_ratio": round(mean / median, 4) if median != 0 else None
}
except Exception as e:
logger.warning(f"Failed to analyze distribution shape for {col_name}: {str(e)}")
return {"distribution_type": "unknown"}
def _classify_distribution_type(self, skew: str, kurtosis: str) -> str:
"""Classify distribution type based on skewness and kurtosis"""
if skew == "symmetric" and kurtosis == "normal":
return "approximately_normal"
elif skew == "right_skewed":
return "right_skewed"
elif skew == "left_skewed":
return "left_skewed"
elif kurtosis == "heavy_tailed":
return "heavy_tailed"
else:
return "non_normal"
async def _analyze_categorical_distributions(self, connection, table_name: str, categorical_columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
"""Analyze distribution patterns for categorical columns"""
categorical_analysis = {}
for column in categorical_columns:
col_name = column["column_name"]
try:
# Basic cardinality and distribution
cardinality_sql = f"""
SELECT
COUNT(DISTINCT {col_name}) as cardinality,
COUNT({col_name}) as non_null_count
FROM {table_name}
WHERE {col_name} IS NOT NULL
{sampling_info.get('sample_query_suffix', '')}
"""
cardinality_result = await connection.execute(cardinality_sql)
if cardinality_result.data:
cardinality_data = cardinality_result.data[0]
cardinality = cardinality_data["cardinality"]
non_null_count = cardinality_data["non_null_count"]
# Value distribution (top values)
value_distribution = await self._get_categorical_value_distribution(
connection, table_name, col_name, sampling_info, non_null_count
)
# Calculate entropy and concentration
entropy = self._calculate_entropy(value_distribution)
concentration_ratio = value_distribution[0]["percentage"] if value_distribution else 0
categorical_analysis[col_name] = {
"data_type": column["data_type"],
"cardinality": cardinality,
"non_null_count": non_null_count,
"value_distribution": value_distribution,
"entropy": round(entropy, 3),
"concentration_ratio": round(concentration_ratio, 4),
"diversity_score": round(cardinality / non_null_count, 4) if non_null_count > 0 else 0
}
except Exception as e:
logger.warning(f"Failed to analyze categorical column {col_name}: {str(e)}")
categorical_analysis[col_name] = {"error": str(e)}
return categorical_analysis
async def _get_categorical_value_distribution(self, connection, table_name: str, col_name: str, sampling_info: Dict, total_count: int) -> List[Dict]:
"""Get value distribution for categorical column"""
try:
# Use sample table expression if sampling is enabled
table_expr = sampling_info.get("sample_table_expression", table_name)
distribution_sql = f"""
SELECT
{col_name} as value,
COUNT(*) as count
FROM {table_expr}
WHERE {col_name} IS NOT NULL
GROUP BY {col_name}
ORDER BY COUNT(*) DESC
LIMIT 20
"""
result = await connection.execute(distribution_sql)
if result.data:
distribution = []
for row in result.data:
count = row["count"]
percentage = count / total_count if total_count > 0 else 0
distribution.append({
"value": str(row["value"]),
"count": count,
"percentage": round(percentage, 4)
})
return distribution
except Exception as e:
logger.warning(f"Failed to get value distribution for {col_name}: {str(e)}")
return []
def _calculate_entropy(self, value_distribution: List[Dict]) -> float:
"""Calculate Shannon entropy for categorical distribution"""
if not value_distribution:
return 0.0
entropy = 0.0
for item in value_distribution:
p = item["percentage"]
if p > 0:
entropy -= p * math.log2(p)
return entropy
async def _analyze_temporal_distributions(self, connection, table_name: str, temporal_columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
"""Analyze distribution patterns for temporal columns"""
temporal_analysis = {}
for column in temporal_columns:
col_name = column["column_name"]
try:
# Date range analysis
table_expr = sampling_info.get("sample_table_expression", table_name)
range_sql = f"""
SELECT
MIN({col_name}) as earliest,
MAX({col_name}) as latest,
COUNT({col_name}) as non_null_count
FROM {table_expr}
WHERE {col_name} IS NOT NULL
"""
range_result = await connection.execute(range_sql)
if range_result.data and range_result.data[0]["non_null_count"] > 0:
range_data = range_result.data[0]
earliest = range_data["earliest"]
latest = range_data["latest"]
# Calculate span
date_span_info = self._calculate_date_span(earliest, latest)
# Temporal patterns analysis
temporal_patterns = await self._analyze_temporal_patterns(
connection, table_name, col_name, sampling_info
)
temporal_analysis[col_name] = {
"data_type": column["data_type"],
"non_null_count": range_data["non_null_count"],
"date_range": {
"earliest": str(earliest),
"latest": str(latest),
**date_span_info
},
"temporal_patterns": temporal_patterns
}
except Exception as e:
logger.warning(f"Failed to analyze temporal column {col_name}: {str(e)}")
temporal_analysis[col_name] = {"error": str(e)}
return temporal_analysis
def _calculate_date_span(self, earliest, latest) -> Dict[str, Any]:
"""Calculate date span information"""
try:
if isinstance(earliest, str):
earliest = datetime.fromisoformat(earliest.replace('Z', '+00:00'))
if isinstance(latest, str):
latest = datetime.fromisoformat(latest.replace('Z', '+00:00'))
span = latest - earliest
span_days = span.days
return {
"span_days": span_days,
"span_years": round(span_days / 365.25, 2),
"span_description": self._describe_time_span(span_days)
}
except Exception as e:
logger.warning(f"Failed to calculate date span: {str(e)}")
return {"span_days": 0}
def _describe_time_span(self, days: int) -> str:
"""Describe time span in human readable format"""
if days < 1:
return "less_than_day"
elif days < 7:
return "days"
elif days < 30:
return "weeks"
elif days < 365:
return "months"
else:
return "years"
async def _analyze_temporal_patterns(self, connection, table_name: str, col_name: str, sampling_info: Dict) -> Dict[str, Any]:
"""Analyze temporal patterns like seasonality and trends"""
try:
table_expr = sampling_info.get("sample_table_expression", table_name)
# Weekly pattern analysis
weekly_pattern_sql = f"""
SELECT
DAYOFWEEK({col_name}) as day_of_week,
COUNT(*) as count
FROM {table_expr}
WHERE {col_name} IS NOT NULL
GROUP BY DAYOFWEEK({col_name})
ORDER BY day_of_week
"""
weekly_result = await connection.execute(weekly_pattern_sql)
weekly_pattern = []
if weekly_result.data:
total_records = sum(row["count"] for row in weekly_result.data)
for row in weekly_result.data:
percentage = row["count"] / total_records if total_records > 0 else 0
weekly_pattern.append(round(percentage, 3))
# Monthly trend analysis (simplified)
monthly_trend_sql = f"""
SELECT
YEAR({col_name}) as year,
MONTH({col_name}) as month,
COUNT(*) as count
FROM {table_expr}
WHERE {col_name} IS NOT NULL
GROUP BY YEAR({col_name}), MONTH({col_name})
ORDER BY year, month
LIMIT 12
"""
monthly_result = await connection.execute(monthly_trend_sql)
monthly_trend = "stable" # Simplified trend analysis
if monthly_result.data and len(monthly_result.data) > 3:
counts = [row["count"] for row in monthly_result.data]
if len(counts) > 1:
trend_direction = "increasing" if counts[-1] > counts[0] else "decreasing"
monthly_trend = trend_direction
return {
"weekly_pattern": weekly_pattern,
"monthly_trend": monthly_trend,
"seasonal_component": self._estimate_seasonality(weekly_pattern)
}
except Exception as e:
logger.warning(f"Failed to analyze temporal patterns for {col_name}: {str(e)}")
return {"weekly_pattern": [], "monthly_trend": "unknown"}
def _estimate_seasonality(self, weekly_pattern: List[float]) -> float:
"""Estimate seasonality strength based on weekly pattern variance"""
if len(weekly_pattern) < 7:
return 0.0
mean_percentage = sum(weekly_pattern) / len(weekly_pattern)
variance = sum((x - mean_percentage) ** 2 for x in weekly_pattern) / len(weekly_pattern)
# Normalize variance to 0-1 scale as seasonality indicator
seasonality = min(variance * 10, 1.0) # Scaling factor
return round(seasonality, 3)
async def _generate_data_quality_insights(self, connection, table_name: str, columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
"""Generate overall data quality insights"""
try:
total_columns = len(columns)
# Calculate null rates across all columns
null_analysis = await self._analyze_overall_null_rates(connection, table_name, columns, sampling_info)
# Identify potential data quality issues
quality_issues = []
# High null rate columns
high_null_columns = [col for col, rate in null_analysis["column_null_rates"].items() if rate > 0.2]
if high_null_columns:
quality_issues.append({
"issue_type": "high_null_rates",
"severity": "medium",
"affected_columns": high_null_columns,
"description": f"{len(high_null_columns)} columns have null rates > 20%"
})
# Calculate overall data quality score
avg_null_rate = sum(null_analysis["column_null_rates"].values()) / len(null_analysis["column_null_rates"]) if null_analysis["column_null_rates"] else 0
data_quality_score = max(0, 1 - avg_null_rate)
return {
"total_columns_analyzed": total_columns,
"null_analysis": null_analysis,
"data_quality_score": round(data_quality_score, 3),
"quality_issues": quality_issues,
"recommendations": self._generate_quality_recommendations(quality_issues, null_analysis)
}
except Exception as e:
logger.warning(f"Failed to generate data quality insights: {str(e)}")
return {"data_quality_score": 0.0, "error": str(e)}
async def _analyze_overall_null_rates(self, connection, table_name: str, columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
"""Analyze null rates across all columns"""
column_null_rates = {}
total_null_count = 0
total_cell_count = 0
for column in columns:
col_name = column["column_name"]
try:
table_expr = sampling_info.get("sample_table_expression", table_name)
null_sql = f"""
SELECT
COUNT(*) as total_count,
COUNT({col_name}) as non_null_count
FROM {table_expr}
"""
result = await connection.execute(null_sql)
if result.data:
data = result.data[0]
total_count = data["total_count"]
non_null_count = data["non_null_count"]
null_count = total_count - non_null_count
null_rate = null_count / total_count if total_count > 0 else 0
column_null_rates[col_name] = round(null_rate, 4)
total_null_count += null_count
total_cell_count += total_count
except Exception as e:
logger.warning(f"Failed to analyze null rate for column {col_name}: {str(e)}")
column_null_rates[col_name] = 0.0
overall_null_rate = total_null_count / total_cell_count if total_cell_count > 0 else 0
return {
"column_null_rates": column_null_rates,
"overall_null_rate": round(overall_null_rate, 4),
"columns_with_nulls": len([rate for rate in column_null_rates.values() if rate > 0])
}
def _generate_quality_recommendations(self, quality_issues: List[Dict], null_analysis: Dict) -> List[Dict]:
"""Generate data quality improvement recommendations"""
recommendations = []
# Recommendations based on null analysis
overall_null_rate = null_analysis.get("overall_null_rate", 0)
if overall_null_rate > 0.1:
recommendations.append({
"type": "data_completeness",
"priority": "high" if overall_null_rate > 0.3 else "medium",
"description": f"Overall null rate is {overall_null_rate:.1%}",
"action": "Review data collection and validation processes"
})
# Recommendations based on quality issues
for issue in quality_issues:
if issue["issue_type"] == "high_null_rates":
recommendations.append({
"type": "column_completeness",
"priority": issue["severity"],
"description": issue["description"],
"action": f"Focus on improving data completeness for: {', '.join(issue['affected_columns'][:3])}"
})
return recommendations
def _generate_analysis_summary(self, distribution_analysis: Dict[str, Any]) -> Dict[str, Any]:
"""Generate high-level summary of distribution analysis"""
summary = {
"numeric_columns_count": len(distribution_analysis.get("numeric_columns", {})),
"categorical_columns_count": len(distribution_analysis.get("categorical_columns", {})),
"temporal_columns_count": len(distribution_analysis.get("temporal_columns", {}))
}
# Identify interesting patterns
patterns = []
# Check for highly skewed numeric columns
numeric_cols = distribution_analysis.get("numeric_columns", {})
skewed_cols = [
col for col, info in numeric_cols.items()
if isinstance(info, dict) and
info.get("distribution_shape", {}).get("skewness_indicator") in ["right_skewed", "left_skewed"]
]
if skewed_cols:
patterns.append(f"Found {len(skewed_cols)} skewed numeric columns")
# Check for high cardinality categorical columns
categorical_cols = distribution_analysis.get("categorical_columns", {})
high_cardinality_cols = [
col for col, info in categorical_cols.items()
if isinstance(info, dict) and info.get("cardinality", 0) > 1000
]
if high_cardinality_cols:
patterns.append(f"Found {len(high_cardinality_cols)} high cardinality categorical columns")
summary["notable_patterns"] = patterns
return summary

View File

@@ -0,0 +1,869 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Data Governance Tools Module
Provides data completeness analysis, field lineage tracking, and data freshness monitoring
"""
import re
import time
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from .db import DorisConnectionManager
from .logger import get_logger
logger = get_logger(__name__)
class DataGovernanceTools:
"""Data governance tools suite"""
def __init__(self, connection_manager: DorisConnectionManager):
self.connection_manager = connection_manager
logger.info("DataGovernanceTools initialized")
async def trace_column_lineage(
self,
table_name: str,
column_name: str,
depth: int = 3,
catalog_name: Optional[str] = None,
db_name: Optional[str] = None
) -> Dict[str, Any]:
"""
Column-level lineage tracing
Args:
table_name: Table name
column_name: Column name
depth: Trace depth
catalog_name: Catalog name
db_name: Database name
"""
try:
start_time = time.time()
connection = await self.connection_manager.get_connection("query")
full_table_name = self._build_full_table_name(table_name, catalog_name, db_name)
target_column = f"{full_table_name}.{column_name}"
# 1. Verify target column exists
if not await self._verify_column_exists(connection, full_table_name, column_name):
return {"error": f"Column {column_name} not found in table {full_table_name}"}
# 2. Analyze SQL logs to get lineage relationships
source_chain = await self._analyze_sql_logs_for_lineage(
connection, full_table_name, column_name, depth
)
# 3. Analyze downstream usage
downstream_usage = await self._analyze_downstream_column_usage(
connection, full_table_name, column_name
)
# 4. Analyze field transformation rules
transformation_rules = await self._extract_transformation_rules(
connection, full_table_name, column_name
)
execution_time = time.time() - start_time
return {
"target_column": target_column,
"analysis_timestamp": datetime.now().isoformat(),
"execution_time_seconds": round(execution_time, 3),
"lineage_depth": depth,
"source_chain": source_chain,
"downstream_usage": downstream_usage,
"transformation_rules": transformation_rules,
"lineage_confidence": self._calculate_lineage_confidence(source_chain),
"impact_analysis": {
"upstream_dependencies": len(source_chain),
"downstream_dependencies": len(downstream_usage),
"risk_level": self._assess_lineage_risk(source_chain, downstream_usage)
}
}
except Exception as e:
logger.error(f"Column lineage tracing failed for {table_name}.{column_name}: {str(e)}")
return {
"error": str(e),
"target_column": f"{table_name}.{column_name}",
"analysis_timestamp": datetime.now().isoformat()
}
async def monitor_data_freshness(
self,
tables: Optional[List[str]] = None,
time_threshold_hours: int = 24,
catalog_name: Optional[str] = None,
db_name: Optional[str] = None
) -> Dict[str, Any]:
"""
Data freshness monitoring
Args:
tables: List of tables to monitor, empty means monitor all tables
time_threshold_hours: Freshness threshold (hours)
catalog_name: Catalog name
db_name: Database name
"""
try:
start_time = time.time()
connection = await self.connection_manager.get_connection("query")
# 1. Get list of tables to monitor
if not tables:
tables = await self._get_all_tables(connection, catalog_name, db_name)
# 2. Analyze freshness of each table
table_freshness = {}
fresh_count = 0
stale_count = 0
for table in tables:
full_table_name = self._build_full_table_name(table, catalog_name, db_name)
freshness_info = await self._analyze_table_freshness(
connection, full_table_name, time_threshold_hours
)
table_freshness[table] = freshness_info
if freshness_info["status"] == "fresh":
fresh_count += 1
else:
stale_count += 1
# 3. Calculate overall freshness score
total_tables = len(tables)
overall_freshness_score = fresh_count / total_tables if total_tables > 0 else 0
# 4. Identify data flow issues
data_flow_issues = await self._identify_data_flow_issues(table_freshness)
execution_time = time.time() - start_time
return {
"monitoring_timestamp": datetime.now().isoformat(),
"execution_time_seconds": round(execution_time, 3),
"monitoring_scope": {
"catalog_name": catalog_name,
"db_name": db_name,
"time_threshold_hours": time_threshold_hours
},
"freshness_summary": {
"total_tables": total_tables,
"fresh_tables": fresh_count,
"stale_tables": stale_count,
"overall_freshness_score": round(overall_freshness_score, 3)
},
"table_freshness": table_freshness,
"data_flow_issues": data_flow_issues,
"alerts": self._generate_freshness_alerts(table_freshness, time_threshold_hours)
}
except Exception as e:
logger.error(f"Data freshness monitoring failed: {str(e)}")
return {
"error": str(e),
"monitoring_timestamp": datetime.now().isoformat()
}
# ==================== Private Helper Methods ====================
def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str:
"""Build full table name - use three-level naming convention"""
# Default catalog is internal for internal tables
effective_catalog = catalog_name if catalog_name else "internal"
if db_name:
return f"{effective_catalog}.{db_name}.{table_name}"
else:
# If db_name is not provided, need to determine current database
return f"{effective_catalog}.{table_name}"
async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]:
"""Get table basic information"""
try:
# Try to get table row count
count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}"
result = await connection.execute(count_sql)
if result.data:
return {"row_count": result.data[0]["row_count"]}
return None
except Exception as e:
logger.warning(f"Failed to get basic info for table {table_name}: {str(e)}")
return {"row_count": 0}
async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]:
"""Get table column information"""
try:
# Build query conditions
where_conditions = [f"table_name = '{table_name}'"]
if db_name:
where_conditions.append(f"table_schema = '{db_name}'")
else:
where_conditions.append("table_schema = DATABASE()")
columns_sql = f"""
SELECT
column_name,
data_type,
is_nullable,
column_comment,
ordinal_position
FROM information_schema.columns
WHERE {' AND '.join(where_conditions)}
ORDER BY ordinal_position
"""
result = await connection.execute(columns_sql)
return result.data if result.data else []
except Exception as e:
logger.warning(f"Failed to get columns info for table {table_name}: {str(e)}")
return []
async def _analyze_column_completeness(self, connection, table_name: str, columns_info: List[Dict]) -> Dict[str, Any]:
"""Analyze column completeness"""
column_completeness = {}
for column in columns_info:
column_name = column["column_name"]
try:
# Calculate null value statistics
null_sql = f"""
SELECT
COUNT(*) as total_count,
COUNT({column_name}) as non_null_count,
COUNT(*) - COUNT({column_name}) as null_count
FROM {table_name}
"""
result = await connection.execute(null_sql)
if result.data:
stats = result.data[0]
total_count = stats["total_count"]
null_count = stats["null_count"]
null_rate = null_count / total_count if total_count > 0 else 0
completeness_score = 1.0 - null_rate
column_completeness[column_name] = {
"data_type": column["data_type"],
"is_nullable": column["is_nullable"],
"total_count": total_count,
"null_count": null_count,
"non_null_count": stats["non_null_count"],
"null_rate": round(null_rate, 4),
"completeness_score": round(completeness_score, 4)
}
except Exception as e:
logger.warning(f"Failed to analyze completeness for column {column_name}: {str(e)}")
column_completeness[column_name] = {
"error": str(e),
"completeness_score": 0.0
}
return column_completeness
async def _check_business_rule_compliance(self, connection, table_name: str, business_rules: List[Dict], total_rows: int) -> Dict[str, Any]:
"""Check business rule compliance"""
compliance_results = {}
for rule in business_rules:
rule_name = rule.get("rule_name", "unknown")
sql_condition = rule.get("sql_condition", "")
if not sql_condition:
continue
try:
# Check number of records meeting conditions
compliance_sql = f"""
SELECT
COUNT(*) as total_count,
SUM(CASE WHEN {sql_condition} THEN 1 ELSE 0 END) as pass_count
FROM {table_name}
"""
result = await connection.execute(compliance_sql)
if result.data:
stats = result.data[0]
pass_count = stats["pass_count"] or 0
fail_count = total_rows - pass_count
pass_rate = pass_count / total_rows if total_rows > 0 else 0
compliance_results[rule_name] = {
"rule_condition": sql_condition,
"total_records": total_rows,
"pass_count": pass_count,
"fail_count": fail_count,
"pass_rate": round(pass_rate, 4),
"compliance_score": round(pass_rate, 4)
}
except Exception as e:
logger.warning(f"Failed to check business rule {rule_name}: {str(e)}")
compliance_results[rule_name] = {
"error": str(e),
"compliance_score": 0.0
}
return compliance_results
async def _detect_data_integrity_issues(self, connection, table_name: str, columns_info: List[Dict]) -> List[Dict]:
"""Detect data integrity issues"""
issues = []
try:
# Detect duplicate values in primary key fields
primary_key_columns = [col["column_name"] for col in columns_info if "primary" in col.get("column_comment", "").lower()]
for pk_col in primary_key_columns:
duplicate_sql = f"""
SELECT COUNT(*) as duplicate_count
FROM (
SELECT {pk_col}, COUNT(*) as cnt
FROM {table_name}
WHERE {pk_col} IS NOT NULL
GROUP BY {pk_col}
HAVING COUNT(*) > 1
) t
"""
result = await connection.execute(duplicate_sql)
if result.data and result.data[0]["duplicate_count"] > 0:
issues.append({
"type": "duplicate_primary_keys",
"column": pk_col,
"count": result.data[0]["duplicate_count"],
"severity": "high",
"description": f"Found duplicate values in primary key column {pk_col}"
})
except Exception as e:
logger.warning(f"Failed to detect integrity issues: {str(e)}")
issues.append({
"type": "detection_error",
"error": str(e),
"severity": "unknown"
})
return issues
def _calculate_completeness_score(self, column_completeness: Dict, business_rule_compliance: Dict) -> float:
"""Calculate overall completeness score"""
if not column_completeness:
return 0.0
# Calculate column completeness average score
column_scores = [
col_info.get("completeness_score", 0.0)
for col_info in column_completeness.values()
if isinstance(col_info, dict) and "completeness_score" in col_info
]
avg_column_score = sum(column_scores) / len(column_scores) if column_scores else 0.0
# Calculate business rule compliance average score
compliance_scores = [
rule_info.get("compliance_score", 0.0)
for rule_info in business_rule_compliance.values()
if isinstance(rule_info, dict) and "compliance_score" in rule_info
]
avg_compliance_score = sum(compliance_scores) / len(compliance_scores) if compliance_scores else 1.0
# Comprehensive score (column completeness weight 70%, business rules weight 30%)
overall_score = avg_column_score * 0.7 + avg_compliance_score * 0.3
return round(overall_score, 4)
def _generate_completeness_recommendations(self, column_completeness: Dict, integrity_issues: List[Dict]) -> List[Dict]:
"""Generate completeness improvement recommendations"""
recommendations = []
# Generate recommendations based on column completeness
for col_name, col_info in column_completeness.items():
if isinstance(col_info, dict):
null_rate = col_info.get("null_rate", 0)
if null_rate > 0.1: # Null rate exceeds 10%
recommendations.append({
"type": "high_null_rate",
"column": col_name,
"priority": "high" if null_rate > 0.5 else "medium",
"description": f"Column {col_name} has high null rate ({null_rate:.1%})",
"suggested_action": "Review data collection process or add data validation"
})
# Generate recommendations based on integrity issues
for issue in integrity_issues:
if issue["type"] == "duplicate_primary_keys":
recommendations.append({
"type": "data_deduplication",
"column": issue["column"],
"priority": "high",
"description": f"Duplicate primary key values found in {issue['column']}",
"suggested_action": "Implement unique constraint or data deduplication process"
})
return recommendations
async def _verify_column_exists(self, connection, table_name: str, column_name: str) -> bool:
"""Verify if column exists"""
try:
# Simple verification method: try to query the column
verify_sql = f"SELECT {column_name} FROM {table_name} LIMIT 1"
await connection.execute(verify_sql)
return True
except Exception:
return False
async def _analyze_sql_logs_for_lineage(self, connection, table_name: str, column_name: str, depth: int) -> List[Dict]:
"""Analyze SQL logs to get lineage relationships (simplified implementation)"""
# Note: This is a simplified implementation, actual environment needs to analyze audit logs
source_chain = []
try:
# Try to find related INSERT/CREATE TABLE AS SELECT statements from audit logs (one year range)
audit_sql = """
SELECT
stmt as sql_statement,
`time` as execution_time,
`user` as user_name
FROM internal.__internal_schema.audit_log
WHERE stmt LIKE '%{}%'
AND (stmt LIKE '%INSERT%' OR stmt LIKE '%CREATE%' OR stmt LIKE '%SELECT%')
AND `time` >= DATE_SUB(NOW(), INTERVAL 1 YEAR)
ORDER BY `time` DESC
LIMIT 50
""".format(table_name.split('.')[-1]) # Use the last part of table name
result = await connection.execute(audit_sql)
if result.data:
for i, log_entry in enumerate(result.data[:depth]):
# Simplified lineage analysis: extract possible source tables
sql_stmt = log_entry.get("sql_statement", "")
source_tables = self._extract_source_tables_from_sql(sql_stmt)
if source_tables:
# Handle datetime serialization issue
execution_time = log_entry.get("execution_time")
if execution_time and hasattr(execution_time, 'isoformat'):
execution_time = execution_time.isoformat()
elif execution_time:
execution_time = str(execution_time)
source_chain.append({
"level": i + 1,
"source_table": source_tables[0], # Take the first as main source table
"source_column": column_name, # Simplified: assume same name
"transformation": self._extract_transformation_from_sql(sql_stmt, column_name),
"confidence": 0.8 - (i * 0.1), # Decreasing confidence
"execution_time": execution_time,
"user": log_entry.get("user_name")
})
except Exception as e:
logger.warning(f"Failed to analyze SQL logs for lineage: {str(e)}")
# If unable to get from audit logs, return basic information
source_chain = [{
"level": 1,
"source_table": "unknown_source",
"source_column": column_name,
"transformation": "unknown",
"confidence": 0.3,
"note": "Limited lineage information available"
}]
return source_chain
def _extract_source_tables_from_sql(self, sql: str) -> List[str]:
"""Extract source table names from SQL statement (simplified implementation)"""
# Simplified regex to match table names in FROM clause
from_pattern = r'\bFROM\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)'
join_pattern = r'\bJOIN\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)'
tables = []
# Find tables in FROM clause
from_matches = re.findall(from_pattern, sql, re.IGNORECASE)
tables.extend(from_matches)
# Find tables in JOIN clause
join_matches = re.findall(join_pattern, sql, re.IGNORECASE)
tables.extend(join_matches)
return list(set(tables)) # Remove duplicates
def _extract_transformation_from_sql(self, sql: str, column_name: str) -> str:
"""Extract field transformation rules from SQL statement (simplified implementation)"""
# Simplified implementation: find expressions containing target field
lines = sql.split('\n')
for line in lines:
if column_name in line and ('SELECT' in line.upper() or '=' in line):
return line.strip()
return "direct_copy"
async def _analyze_downstream_column_usage(self, connection, table_name: str, column_name: str) -> List[Dict]:
"""Analyze downstream usage of field (simplified implementation)"""
downstream_usage = []
try:
# Find other tables that might use this field (through audit logs, one year range)
usage_sql = """
SELECT DISTINCT
stmt as sql_statement
FROM internal.__internal_schema.audit_log
WHERE stmt LIKE '%{}%'
AND stmt LIKE '%{}%'
AND stmt LIKE '%SELECT%'
AND `time` >= DATE_SUB(NOW(), INTERVAL 1 YEAR)
LIMIT 20
""".format(table_name.split('.')[-1], column_name)
result = await connection.execute(usage_sql)
if result.data:
for entry in result.data:
sql_stmt = entry.get("sql_statement", "")
target_tables = self._extract_target_tables_from_sql(sql_stmt)
for target_table in target_tables:
if target_table != table_name.split('.')[-1]: # Not the source table itself
downstream_usage.append({
"table": target_table,
"column": column_name, # Simplified: assume same name
"usage_type": "select_reference",
"confidence": 0.7
})
except Exception as e:
logger.warning(f"Failed to analyze downstream usage: {str(e)}")
return downstream_usage
def _extract_target_tables_from_sql(self, sql: str) -> List[str]:
"""Extract target table names from SQL statement"""
# Find target tables in INSERT INTO or CREATE TABLE statements
insert_pattern = r'\bINSERT\s+INTO\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)'
create_pattern = r'\bCREATE\s+TABLE\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)'
tables = []
insert_matches = re.findall(insert_pattern, sql, re.IGNORECASE)
tables.extend(insert_matches)
create_matches = re.findall(create_pattern, sql, re.IGNORECASE)
tables.extend(create_matches)
return list(set(tables))
async def _extract_transformation_rules(self, connection, table_name: str, column_name: str) -> List[Dict]:
"""Extract field transformation rules"""
# Simplified implementation: return basic transformation information
return [{
"transformation_type": "unknown",
"description": "Transformation rules analysis requires detailed ETL metadata",
"confidence": 0.5
}]
def _calculate_lineage_confidence(self, source_chain: List[Dict]) -> float:
"""Calculate overall confidence of lineage tracing"""
if not source_chain:
return 0.0
confidences = [item.get("confidence", 0.0) for item in source_chain]
return round(sum(confidences) / len(confidences), 3)
def _assess_lineage_risk(self, source_chain: List[Dict], downstream_usage: List[Dict]) -> str:
"""Assess lineage risk level"""
if len(downstream_usage) > 10:
return "high"
elif len(downstream_usage) > 5:
return "medium"
else:
return "low"
async def _get_all_tables(self, connection, catalog_name: Optional[str], db_name: Optional[str]) -> List[str]:
"""Get list of all tables"""
try:
where_conditions = []
if db_name:
where_conditions.append(f"table_schema = '{db_name}'")
else:
where_conditions.append("table_schema = DATABASE()")
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
tables_sql = f"""
SELECT table_name
FROM information_schema.tables
WHERE {where_clause}
AND table_type = 'BASE TABLE'
ORDER BY table_name
"""
result = await connection.execute(tables_sql)
return [row["table_name"] for row in result.data] if result.data else []
except Exception as e:
logger.warning(f"Failed to get table list: {str(e)}")
return []
async def _analyze_table_freshness(self, connection, table_name: str, threshold_hours: int) -> Dict[str, Any]:
"""Analyze freshness of single table"""
try:
# Try multiple methods to get table's last update time
freshness_methods = [
self._get_freshness_from_partition_info,
self._get_freshness_from_max_timestamp,
self._get_freshness_from_table_metadata
]
last_update = None
method_used = "unknown"
for method in freshness_methods:
try:
result = await method(connection, table_name)
if result:
last_update = result["last_update"]
method_used = result["method"]
break
except Exception as e:
continue
if not last_update:
return {
"last_update": None,
"staleness_hours": None,
"freshness_score": 0.0,
"status": "unknown",
"method_used": "none",
"error": "Unable to determine last update time"
}
# Calculate data staleness
now = datetime.now()
if isinstance(last_update, str):
last_update = datetime.fromisoformat(last_update.replace('Z', '+00:00'))
staleness_hours = (now - last_update).total_seconds() / 3600
# Calculate freshness score and status
if staleness_hours <= threshold_hours:
status = "fresh"
freshness_score = max(0.0, 1.0 - (staleness_hours / threshold_hours))
else:
status = "stale"
freshness_score = max(0.0, 1.0 - (staleness_hours / (threshold_hours * 2)))
return {
"last_update": last_update.isoformat() if hasattr(last_update, 'isoformat') else str(last_update),
"staleness_hours": round(staleness_hours, 2),
"freshness_score": round(freshness_score, 3),
"status": status,
"method_used": method_used,
"threshold_hours": threshold_hours
}
except Exception as e:
logger.warning(f"Failed to analyze freshness for table {table_name}: {str(e)}")
return {
"last_update": None,
"staleness_hours": None,
"freshness_score": 0.0,
"status": "error",
"error": str(e)
}
async def _get_freshness_from_partition_info(self, connection, table_name: str) -> Optional[Dict]:
"""Get freshness from partition information"""
try:
# Query partition information (if table has partitions)
partition_sql = f"""
SELECT MAX(CREATE_TIME) as last_update
FROM information_schema.partitions
WHERE table_name = '{table_name.split('.')[-1]}'
AND CREATE_TIME IS NOT NULL
"""
result = await connection.execute(partition_sql)
if result.data and result.data[0]["last_update"]:
return {
"last_update": result.data[0]["last_update"],
"method": "partition_info"
}
return None
except Exception:
return None
async def _get_freshness_from_max_timestamp(self, connection, table_name: str) -> Optional[Dict]:
"""Get freshness from timestamp fields"""
try:
# Find possible timestamp fields
timestamp_columns = await self._find_timestamp_columns(connection, table_name)
if timestamp_columns:
max_time_sql = f"""
SELECT MAX({timestamp_columns[0]}) as last_update
FROM {table_name}
"""
result = await connection.execute(max_time_sql)
if result.data and result.data[0]["last_update"]:
return {
"last_update": result.data[0]["last_update"],
"method": f"max_timestamp({timestamp_columns[0]})"
}
return None
except Exception:
return None
async def _get_freshness_from_table_metadata(self, connection, table_name: str) -> Optional[Dict]:
"""Get freshness from table metadata"""
try:
# Query table's update time
metadata_sql = f"""
SELECT UPDATE_TIME as last_update
FROM information_schema.tables
WHERE table_name = '{table_name.split('.')[-1]}'
AND UPDATE_TIME IS NOT NULL
"""
result = await connection.execute(metadata_sql)
if result.data and result.data[0]["last_update"]:
return {
"last_update": result.data[0]["last_update"],
"method": "table_metadata"
}
return None
except Exception:
return None
async def _find_timestamp_columns(self, connection, table_name: str) -> List[str]:
"""Find possible timestamp fields"""
try:
timestamp_sql = f"""
SELECT column_name
FROM information_schema.columns
WHERE table_name = '{table_name.split('.')[-1]}'
AND (
data_type IN ('datetime', 'timestamp', 'date')
OR column_name LIKE '%time%'
OR column_name LIKE '%date%'
OR column_name LIKE '%created%'
OR column_name LIKE '%updated%'
)
ORDER BY
CASE
WHEN column_name LIKE '%updated%' THEN 1
WHEN column_name LIKE '%created%' THEN 2
WHEN column_name LIKE '%time%' THEN 3
ELSE 4
END
"""
result = await connection.execute(timestamp_sql)
return [row["column_name"] for row in result.data] if result.data else []
except Exception:
return []
async def _identify_data_flow_issues(self, table_freshness: Dict[str, Any]) -> List[Dict]:
"""Identify data flow issues"""
issues = []
# Identify consecutively stale tables (may indicate ETL process issues)
stale_tables = [
table_name for table_name, info in table_freshness.items()
if info.get("status") == "stale"
]
if len(stale_tables) > len(table_freshness) * 0.3: # More than 30% of tables are stale
issues.append({
"issue_type": "widespread_staleness",
"severity": "high",
"affected_tables": len(stale_tables),
"total_tables": len(table_freshness),
"description": f"High percentage of stale tables ({len(stale_tables)}/{len(table_freshness)})",
"possible_causes": ["ETL pipeline failure", "Data source issues", "Processing delays"]
})
# Identify particularly stale tables
very_stale_tables = [
(table_name, info.get("staleness_hours", 0))
for table_name, info in table_freshness.items()
if info.get("staleness_hours", 0) > 72 # More than 3 days
]
if very_stale_tables:
issues.append({
"issue_type": "very_stale_data",
"severity": "medium",
"affected_tables": [table for table, _ in very_stale_tables],
"max_staleness_hours": max(hours for _, hours in very_stale_tables),
"description": "Some tables have very stale data (>72 hours)",
"recommendation": "Check data ingestion processes for affected tables"
})
return issues
def _generate_freshness_alerts(self, table_freshness: Dict[str, Any], threshold_hours: int) -> List[Dict]:
"""Generate freshness alerts"""
alerts = []
for table_name, info in table_freshness.items():
staleness_hours = info.get("staleness_hours")
status = info.get("status")
if status == "stale" and staleness_hours:
if staleness_hours > threshold_hours * 2: # Exceeds threshold by 2x
alert_level = "critical"
elif staleness_hours > threshold_hours * 1.5: # Exceeds threshold by 1.5x
alert_level = "warning"
else:
alert_level = "info"
alerts.append({
"alert_level": alert_level,
"table_name": table_name,
"staleness_hours": staleness_hours,
"threshold_hours": threshold_hours,
"message": f"Table {table_name} is stale ({staleness_hours:.1f} hours old, threshold: {threshold_hours}h)",
"timestamp": datetime.now().isoformat()
})
elif status == "error":
alerts.append({
"alert_level": "error",
"table_name": table_name,
"message": f"Unable to determine freshness for table {table_name}",
"error": info.get("error"),
"timestamp": datetime.now().isoformat()
})
return alerts

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,978 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Dependency Analysis Tools Module
Provides data flow dependency analysis and impact assessment capabilities
"""
import time
import re
from datetime import datetime
from typing import Any, Dict, List, Optional, Set, Tuple
from collections import defaultdict, deque
from .db import DorisConnectionManager
from .logger import get_logger
logger = get_logger(__name__)
class DependencyAnalysisTools:
"""Dependency analysis tools for data flow and impact assessment"""
def __init__(self, connection_manager: DorisConnectionManager):
self.connection_manager = connection_manager
logger.info("DependencyAnalysisTools initialized")
async def analyze_data_flow_dependencies(
self,
target_table: Optional[str] = None,
analysis_depth: int = 3,
include_views: bool = True,
catalog_name: Optional[str] = None,
db_name: Optional[str] = None
) -> Dict[str, Any]:
"""
Analyze data flow dependencies and impact relationships
Args:
target_table: Specific table to analyze (if None, analyzes all tables)
analysis_depth: Maximum depth for dependency traversal
include_views: Whether to include views in dependency analysis
catalog_name: Catalog name
db_name: Database name
Returns:
Comprehensive dependency analysis results
"""
try:
start_time = time.time()
connection = await self.connection_manager.get_connection("query")
# 1. Get table metadata and relationships
tables_metadata = await self._get_tables_metadata(connection, catalog_name, db_name, include_views)
if not tables_metadata:
return {
"error": "No tables found for dependency analysis",
"analysis_timestamp": datetime.now().isoformat()
}
# 2. Build dependency graph from SQL analysis
dependency_graph = await self._build_dependency_graph(connection, tables_metadata, analysis_depth)
# 3. Analyze specific table or all tables
if target_table:
# Analyze specific table
table_analysis = await self._analyze_single_table_dependencies(
target_table, dependency_graph, tables_metadata
)
impact_analysis = await self._calculate_impact_analysis(
target_table, dependency_graph, "both"
)
else:
# Analyze all tables
table_analysis = await self._analyze_all_tables_dependencies(
dependency_graph, tables_metadata
)
impact_analysis = await self._calculate_global_impact_analysis(dependency_graph)
# 4. Generate insights and recommendations
dependency_insights = await self._generate_dependency_insights(
dependency_graph, table_analysis, impact_analysis
)
execution_time = time.time() - start_time
return {
"analysis_target": target_table or "all_tables",
"analysis_timestamp": datetime.now().isoformat(),
"execution_time_seconds": round(execution_time, 3),
"tables_analyzed": len(tables_metadata),
"dependency_graph_stats": self._get_dependency_graph_stats(dependency_graph),
"table_dependencies": table_analysis,
"impact_analysis": impact_analysis,
"dependency_insights": dependency_insights,
"recommendations": self._generate_dependency_recommendations(dependency_insights)
}
except Exception as e:
logger.error(f"Data flow dependency analysis failed: {str(e)}")
return {
"error": str(e),
"analysis_timestamp": datetime.now().isoformat()
}
# ==================== Private Helper Methods ====================
async def _get_tables_metadata(self, connection, catalog_name: Optional[str], db_name: Optional[str], include_views: bool) -> List[Dict]:
"""Get metadata for all tables and views"""
try:
# Build conditions for query
where_conditions = []
if db_name:
where_conditions.append(f"table_schema = '{db_name}'")
else:
where_conditions.append("table_schema = DATABASE()")
table_types = ["'BASE TABLE'"]
if include_views:
table_types.append("'VIEW'")
where_conditions.append(f"table_type IN ({','.join(table_types)})")
metadata_sql = f"""
SELECT
table_schema as schema_name,
table_name,
table_type,
table_comment,
table_rows,
data_length
FROM information_schema.tables
WHERE {' AND '.join(where_conditions)}
ORDER BY table_schema, table_name
"""
result = await connection.execute(metadata_sql)
return result.data if result.data else []
except Exception as e:
logger.warning(f"Failed to get tables metadata: {str(e)}")
return []
async def _build_dependency_graph(self, connection, tables_metadata: List[Dict], analysis_depth: int) -> Dict[str, Dict]:
"""Build dependency graph by analyzing SQL statements and DDL"""
dependency_graph = defaultdict(lambda: {
"upstream_dependencies": set(),
"downstream_dependencies": set(),
"table_type": "unknown",
"dependency_strength": {},
"sql_patterns": []
})
# Initialize graph with table metadata
for table in tables_metadata:
table_name = table["table_name"]
schema_name = table.get("schema_name", "")
full_table_name = f"{schema_name}.{table_name}" if schema_name else table_name
dependency_graph[full_table_name]["table_type"] = table["table_type"]
# 1. Analyze view definitions for dependencies
await self._analyze_view_dependencies(connection, dependency_graph, tables_metadata)
# 2. Analyze audit logs for runtime dependencies
await self._analyze_runtime_dependencies(connection, dependency_graph, analysis_depth)
# 3. Analyze foreign key relationships
await self._analyze_foreign_key_dependencies(connection, dependency_graph, tables_metadata)
return dict(dependency_graph)
async def _analyze_view_dependencies(self, connection, dependency_graph: Dict, tables_metadata: List[Dict]) -> None:
"""Analyze view definitions to extract table dependencies"""
try:
for table in tables_metadata:
if table["table_type"] == "VIEW":
table_name = table["table_name"]
schema_name = table.get("schema_name", "")
# Get view definition
view_def_sql = f"SHOW CREATE VIEW {schema_name}.{table_name}" if schema_name else f"SHOW CREATE VIEW {table_name}"
try:
result = await connection.execute(view_def_sql)
if result.data and len(result.data) > 0:
# Extract view definition from result
view_definition = ""
for row in result.data:
for key, value in row.items():
if "create" in key.lower() and value:
view_definition = str(value)
break
if view_definition:
# Extract table dependencies from view definition
referenced_tables = self._extract_table_references(view_definition)
full_view_name = f"{schema_name}.{table_name}" if schema_name else table_name
for ref_table in referenced_tables:
# Add upstream dependency
dependency_graph[full_view_name]["upstream_dependencies"].add(ref_table)
dependency_graph[full_view_name]["dependency_strength"][ref_table] = "direct"
# Add downstream dependency for referenced table
dependency_graph[ref_table]["downstream_dependencies"].add(full_view_name)
dependency_graph[full_view_name]["sql_patterns"].append({
"pattern_type": "view_definition",
"referenced_table": ref_table,
"confidence": 1.0
})
except Exception as e:
logger.warning(f"Failed to analyze view {table_name}: {str(e)}")
continue
except Exception as e:
logger.warning(f"Failed to analyze view dependencies: {str(e)}")
async def _analyze_runtime_dependencies(self, connection, dependency_graph: Dict, analysis_depth: int) -> None:
"""Analyze audit logs to discover runtime table dependencies"""
try:
# Get recent SQL statements from audit logs
audit_sql = """
SELECT
`stmt` as sql_statement,
`user` as user_name,
COUNT(*) as frequency
FROM internal.__internal_schema.audit_log
WHERE `stmt` IS NOT NULL
AND `stmt` != ''
AND `time` >= DATE_SUB(NOW(), INTERVAL 1 YEAR)
GROUP BY `stmt`, `user`
HAVING frequency > 1
ORDER BY frequency DESC
LIMIT 1000
"""
result = await connection.execute(audit_sql)
if result.data:
for row in result.data:
sql_statement = row.get("sql_statement", "")
frequency = row.get("frequency", 1)
if sql_statement:
# Extract table references from SQL
referenced_tables = self._extract_table_references(sql_statement)
if len(referenced_tables) > 1:
# Infer dependencies from multi-table queries
self._infer_dependencies_from_sql(
dependency_graph, sql_statement, referenced_tables, frequency
)
except Exception as e:
logger.warning(f"Failed to analyze runtime dependencies: {str(e)}")
async def _analyze_foreign_key_dependencies(self, connection, dependency_graph: Dict, tables_metadata: List[Dict]) -> None:
"""Analyze foreign key constraints for explicit dependencies"""
try:
# Get foreign key information
fk_sql = """
SELECT
TABLE_SCHEMA as schema_name,
TABLE_NAME as table_name,
COLUMN_NAME as column_name,
REFERENCED_TABLE_SCHEMA as ref_schema,
REFERENCED_TABLE_NAME as ref_table_name,
REFERENCED_COLUMN_NAME as ref_column_name
FROM information_schema.KEY_COLUMN_USAGE
WHERE REFERENCED_TABLE_NAME IS NOT NULL
"""
result = await connection.execute(fk_sql)
if result.data:
for row in result.data:
schema_name = row.get("schema_name", "")
table_name = row["table_name"]
ref_schema = row.get("ref_schema", "")
ref_table_name = row["ref_table_name"]
# Build full table names
full_table_name = f"{schema_name}.{table_name}" if schema_name else table_name
full_ref_table = f"{ref_schema}.{ref_table_name}" if ref_schema else ref_table_name
# Add foreign key dependency
dependency_graph[full_table_name]["upstream_dependencies"].add(full_ref_table)
dependency_graph[full_table_name]["dependency_strength"][full_ref_table] = "foreign_key"
dependency_graph[full_ref_table]["downstream_dependencies"].add(full_table_name)
dependency_graph[full_table_name]["sql_patterns"].append({
"pattern_type": "foreign_key",
"referenced_table": full_ref_table,
"confidence": 1.0,
"column": row["column_name"],
"ref_column": row["ref_column_name"]
})
except Exception as e:
logger.warning(f"Failed to analyze foreign key dependencies: {str(e)}")
def _extract_table_references(self, sql: str) -> List[str]:
"""Extract table references from SQL statement"""
if not sql:
return []
# Normalize SQL
sql = re.sub(r'/\*.*?\*/', '', sql, flags=re.DOTALL) # Remove comments
sql = re.sub(r'--.*', '', sql) # Remove line comments
sql = sql.upper()
table_references = []
# Pattern to match table names in various contexts
patterns = [
r'\bFROM\s+([`"]?[a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*[`"]?)',
r'\bJOIN\s+([`"]?[a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*[`"]?)',
r'\bINTO\s+([`"]?[a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*[`"]?)',
r'\bUPDATE\s+([`"]?[a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*[`"]?)',
r'\bDELETE\s+FROM\s+([`"]?[a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*[`"]?)',
r'\bINSERT\s+INTO\s+([`"]?[a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*[`"]?)'
]
for pattern in patterns:
matches = re.findall(pattern, sql, re.IGNORECASE)
for match in matches:
# Clean up table name
table_name = match.strip('`"\'').split()[0] # Remove quotes and aliases
if table_name and not self._is_sql_keyword(table_name):
table_references.append(table_name.lower())
return list(set(table_references))
def _is_sql_keyword(self, word: str) -> bool:
"""Check if word is a SQL keyword"""
keywords = {
'SELECT', 'FROM', 'WHERE', 'JOIN', 'INNER', 'LEFT', 'RIGHT', 'OUTER',
'ON', 'AND', 'OR', 'NOT', 'IN', 'EXISTS', 'BETWEEN', 'LIKE',
'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'ALTER', 'DROP', 'INDEX',
'TABLE', 'VIEW', 'DATABASE', 'SCHEMA', 'PRIMARY', 'KEY', 'FOREIGN',
'REFERENCES', 'CONSTRAINT', 'NULL', 'DEFAULT', 'AUTO_INCREMENT'
}
return word.upper() in keywords
def _infer_dependencies_from_sql(self, dependency_graph: Dict, sql: str, referenced_tables: List[str], frequency: int) -> None:
"""Infer table dependencies from SQL patterns"""
# Analyze SQL pattern to determine dependency relationships
sql_upper = sql.upper()
# Look for INSERT ... SELECT patterns
if 'INSERT' in sql_upper and 'SELECT' in sql_upper:
# Find target table (after INSERT INTO)
insert_match = re.search(r'INSERT\s+INTO\s+([a-zA-Z_][a-zA-Z0-9_.]*)', sql_upper)
if insert_match:
target_table = insert_match.group(1).lower()
# All other tables are dependencies
for ref_table in referenced_tables:
if ref_table != target_table:
dependency_graph[target_table]["upstream_dependencies"].add(ref_table)
dependency_graph[ref_table]["downstream_dependencies"].add(target_table)
# Calculate confidence based on frequency
confidence = min(0.9, 0.3 + (frequency / 100))
dependency_graph[target_table]["sql_patterns"].append({
"pattern_type": "insert_select",
"referenced_table": ref_table,
"confidence": confidence,
"frequency": frequency
})
# Look for CREATE TABLE AS SELECT patterns
elif 'CREATE' in sql_upper and 'SELECT' in sql_upper:
create_match = re.search(r'CREATE\s+TABLE\s+([a-zA-Z_][a-zA-Z0-9_.]*)', sql_upper)
if create_match:
target_table = create_match.group(1).lower()
for ref_table in referenced_tables:
if ref_table != target_table:
dependency_graph[target_table]["upstream_dependencies"].add(ref_table)
dependency_graph[ref_table]["downstream_dependencies"].add(target_table)
dependency_graph[target_table]["sql_patterns"].append({
"pattern_type": "create_table_as_select",
"referenced_table": ref_table,
"confidence": 0.95,
"frequency": frequency
})
async def _analyze_single_table_dependencies(self, target_table: str, dependency_graph: Dict, tables_metadata: List[Dict]) -> Dict[str, Any]:
"""Analyze dependencies for a specific table"""
if target_table not in dependency_graph:
return {"error": f"Table {target_table} not found in dependency graph"}
table_info = dependency_graph[target_table]
# Get upstream dependencies (tables this table depends on)
upstream_deps = await self._get_dependency_chain(target_table, dependency_graph, "upstream", 3)
# Get downstream dependencies (tables that depend on this table)
downstream_deps = await self._get_dependency_chain(target_table, dependency_graph, "downstream", 3)
return {
"table_name": target_table,
"table_type": table_info["table_type"],
"direct_upstream_dependencies": list(table_info["upstream_dependencies"]),
"direct_downstream_dependencies": list(table_info["downstream_dependencies"]),
"upstream_dependency_chain": upstream_deps,
"downstream_dependency_chain": downstream_deps,
"dependency_patterns": table_info["sql_patterns"],
"dependency_metrics": {
"upstream_count": len(table_info["upstream_dependencies"]),
"downstream_count": len(table_info["downstream_dependencies"]),
"total_upstream_chain": len(upstream_deps.get("all_dependencies", [])),
"total_downstream_chain": len(downstream_deps.get("all_dependencies", [])),
"dependency_depth": max(upstream_deps.get("max_depth", 0), downstream_deps.get("max_depth", 0))
}
}
async def _get_dependency_chain(self, start_table: str, dependency_graph: Dict, direction: str, max_depth: int) -> Dict[str, Any]:
"""Get full dependency chain in specified direction"""
visited = set()
all_dependencies = []
levels = []
current_level = [start_table]
depth = 0
while current_level and depth < max_depth:
next_level = []
level_deps = []
for table in current_level:
if table in visited:
continue
visited.add(table)
if direction == "upstream":
dependencies = dependency_graph.get(table, {}).get("upstream_dependencies", set())
else:
dependencies = dependency_graph.get(table, {}).get("downstream_dependencies", set())
for dep in dependencies:
if dep not in visited:
next_level.append(dep)
level_deps.append(dep)
all_dependencies.append(dep)
if level_deps:
levels.append({
"level": depth + 1,
"tables": level_deps
})
current_level = next_level
depth += 1
return {
"direction": direction,
"max_depth": depth,
"all_dependencies": list(set(all_dependencies)),
"dependency_levels": levels,
"total_count": len(set(all_dependencies))
}
async def _analyze_all_tables_dependencies(self, dependency_graph: Dict, tables_metadata: List[Dict]) -> Dict[str, Any]:
"""Analyze dependencies for all tables"""
table_stats = {}
for table_name, table_info in dependency_graph.items():
upstream_count = len(table_info["upstream_dependencies"])
downstream_count = len(table_info["downstream_dependencies"])
table_stats[table_name] = {
"table_type": table_info["table_type"],
"upstream_count": upstream_count,
"downstream_count": downstream_count,
"total_connections": upstream_count + downstream_count,
"dependency_score": self._calculate_dependency_score(upstream_count, downstream_count),
"role_classification": self._classify_table_role(upstream_count, downstream_count)
}
# Find key tables
most_critical_tables = sorted(
table_stats.items(),
key=lambda x: x[1]["dependency_score"],
reverse=True
)[:10]
source_tables = [name for name, stats in table_stats.items() if stats["role_classification"] == "source"]
sink_tables = [name for name, stats in table_stats.items() if stats["role_classification"] == "sink"]
hub_tables = [name for name, stats in table_stats.items() if stats["role_classification"] == "hub"]
return {
"table_statistics": table_stats,
"summary": {
"total_tables": len(table_stats),
"source_tables": len(source_tables),
"sink_tables": len(sink_tables),
"hub_tables": len(hub_tables),
"isolated_tables": len([stats for stats in table_stats.values() if stats["total_connections"] == 0])
},
"critical_tables": [{"table": name, **stats} for name, stats in most_critical_tables],
"table_roles": {
"sources": source_tables[:10],
"sinks": sink_tables[:10],
"hubs": hub_tables[:10]
}
}
def _calculate_dependency_score(self, upstream_count: int, downstream_count: int) -> float:
"""Calculate dependency importance score for a table"""
# Score based on both incoming and outgoing dependencies
# Higher weight for downstream dependencies (impact)
return round(upstream_count * 0.3 + downstream_count * 0.7, 2)
def _classify_table_role(self, upstream_count: int, downstream_count: int) -> str:
"""Classify table role based on dependency pattern"""
if upstream_count == 0 and downstream_count > 0:
return "source" # Data source
elif upstream_count > 0 and downstream_count == 0:
return "sink" # Data destination
elif upstream_count > 2 and downstream_count > 2:
return "hub" # Data hub/transformation
elif upstream_count > 0 and downstream_count > 0:
return "intermediate" # Intermediate transformation
else:
return "isolated" # No dependencies
async def _calculate_impact_analysis(self, target_table: str, dependency_graph: Dict, direction: str) -> Dict[str, Any]:
"""Calculate impact analysis for a specific table"""
if direction == "upstream" or direction == "both":
upstream_impact = await self._calculate_upstream_impact(target_table, dependency_graph)
else:
upstream_impact = {}
if direction == "downstream" or direction == "both":
downstream_impact = await self._calculate_downstream_impact(target_table, dependency_graph)
else:
downstream_impact = {}
return {
"target_table": target_table,
"upstream_impact": upstream_impact,
"downstream_impact": downstream_impact,
"total_impact_score": self._calculate_total_impact_score(upstream_impact, downstream_impact)
}
async def _calculate_upstream_impact(self, target_table: str, dependency_graph: Dict) -> Dict[str, Any]:
"""Calculate what would be impacted if upstream dependencies fail"""
upstream_deps = dependency_graph.get(target_table, {}).get("upstream_dependencies", set())
impact_scenarios = []
for dep_table in upstream_deps:
# Simulate failure of this dependency
affected_tables = await self._simulate_table_failure_impact(dep_table, dependency_graph)
impact_scenarios.append({
"failed_dependency": dep_table,
"directly_affected_tables": len(affected_tables["direct"]),
"indirectly_affected_tables": len(affected_tables["indirect"]),
"total_affected": len(affected_tables["all"]),
"critical_affected": [table for table in affected_tables["all"]
if dependency_graph.get(table, {}).get("downstream_dependencies", set())],
"impact_severity": self._assess_impact_severity(len(affected_tables["all"]))
})
return {
"dependency_count": len(upstream_deps),
"impact_scenarios": impact_scenarios,
"max_potential_impact": max([scenario["total_affected"] for scenario in impact_scenarios], default=0),
"risk_assessment": self._assess_upstream_risk(impact_scenarios)
}
async def _calculate_downstream_impact(self, target_table: str, dependency_graph: Dict) -> Dict[str, Any]:
"""Calculate what would be impacted if target table fails"""
affected_tables = await self._simulate_table_failure_impact(target_table, dependency_graph)
return {
"direct_impact": len(affected_tables["direct"]),
"indirect_impact": len(affected_tables["indirect"]),
"total_impact": len(affected_tables["all"]),
"affected_table_details": [
{
"table_name": table,
"impact_type": "direct" if table in affected_tables["direct"] else "indirect",
"table_role": self._classify_table_role(
len(dependency_graph.get(table, {}).get("upstream_dependencies", set())),
len(dependency_graph.get(table, {}).get("downstream_dependencies", set()))
)
}
for table in affected_tables["all"]
],
"impact_severity": self._assess_impact_severity(len(affected_tables["all"]))
}
async def _simulate_table_failure_impact(self, failed_table: str, dependency_graph: Dict) -> Dict[str, List[str]]:
"""Simulate the impact of a table failure"""
direct_affected = list(dependency_graph.get(failed_table, {}).get("downstream_dependencies", set()))
# Find all indirectly affected tables using BFS
visited = {failed_table}
queue = deque(direct_affected)
indirect_affected = []
while queue:
current_table = queue.popleft()
if current_table in visited:
continue
visited.add(current_table)
indirect_affected.append(current_table)
# Add downstream dependencies to queue
downstream = dependency_graph.get(current_table, {}).get("downstream_dependencies", set())
for dep in downstream:
if dep not in visited:
queue.append(dep)
# Remove direct affected from indirect (they're already counted)
indirect_only = [table for table in indirect_affected if table not in direct_affected]
return {
"direct": direct_affected,
"indirect": indirect_only,
"all": direct_affected + indirect_only
}
def _assess_impact_severity(self, affected_count: int) -> str:
"""Assess impact severity based on affected table count"""
if affected_count == 0:
return "none"
elif affected_count <= 2:
return "low"
elif affected_count <= 5:
return "medium"
elif affected_count <= 10:
return "high"
else:
return "critical"
def _assess_upstream_risk(self, impact_scenarios: List[Dict]) -> str:
"""Assess upstream dependency risk"""
if not impact_scenarios:
return "low"
max_impact = max([scenario["total_affected"] for scenario in impact_scenarios])
high_impact_scenarios = len([s for s in impact_scenarios if s["impact_severity"] in ["high", "critical"]])
if high_impact_scenarios > 0 or max_impact > 10:
return "high"
elif max_impact > 5 or len(impact_scenarios) > 3:
return "medium"
else:
return "low"
def _calculate_total_impact_score(self, upstream_impact: Dict, downstream_impact: Dict) -> float:
"""Calculate total impact score combining upstream and downstream risks"""
upstream_score = 0
downstream_score = 0
if upstream_impact:
max_upstream_impact = upstream_impact.get("max_potential_impact", 0)
upstream_score = min(max_upstream_impact * 0.3, 10) # Cap at 10
if downstream_impact:
downstream_score = min(downstream_impact.get("total_impact", 0) * 0.7, 10) # Cap at 10
return round(upstream_score + downstream_score, 2)
async def _calculate_global_impact_analysis(self, dependency_graph: Dict) -> Dict[str, Any]:
"""Calculate global impact analysis for all tables"""
table_impacts = {}
for table_name in dependency_graph.keys():
impact = await self._calculate_impact_analysis(table_name, dependency_graph, "downstream")
table_impacts[table_name] = {
"downstream_impact": impact["downstream_impact"]["total_impact"],
"impact_severity": impact["downstream_impact"]["impact_severity"],
"impact_score": impact["total_impact_score"]
}
# Find most critical tables
critical_tables = sorted(
table_impacts.items(),
key=lambda x: x[1]["impact_score"],
reverse=True
)[:15]
# Risk distribution
risk_distribution = {
"critical": len([t for t in table_impacts.values() if t["impact_severity"] == "critical"]),
"high": len([t for t in table_impacts.values() if t["impact_severity"] == "high"]),
"medium": len([t for t in table_impacts.values() if t["impact_severity"] == "medium"]),
"low": len([t for t in table_impacts.values() if t["impact_severity"] == "low"]),
"none": len([t for t in table_impacts.values() if t["impact_severity"] == "none"])
}
return {
"global_impact_summary": {
"total_tables_analyzed": len(table_impacts),
"tables_with_impact": len([t for t in table_impacts.values() if t["downstream_impact"] > 0]),
"average_impact_score": round(sum(t["impact_score"] for t in table_impacts.values()) / len(table_impacts), 2) if table_impacts else 0,
"risk_distribution": risk_distribution
},
"most_critical_tables": [{"table": name, **stats} for name, stats in critical_tables],
"risk_matrix": self._generate_risk_matrix(table_impacts)
}
def _generate_risk_matrix(self, table_impacts: Dict[str, Dict]) -> Dict[str, List[str]]:
"""Generate risk matrix categorizing tables by impact level"""
risk_matrix = {
"critical_risk": [],
"high_risk": [],
"medium_risk": [],
"low_risk": [],
"minimal_risk": []
}
for table_name, impact_data in table_impacts.items():
severity = impact_data["impact_severity"]
if severity == "critical":
risk_matrix["critical_risk"].append(table_name)
elif severity == "high":
risk_matrix["high_risk"].append(table_name)
elif severity == "medium":
risk_matrix["medium_risk"].append(table_name)
elif severity == "low":
risk_matrix["low_risk"].append(table_name)
else:
risk_matrix["minimal_risk"].append(table_name)
return risk_matrix
def _get_dependency_graph_stats(self, dependency_graph: Dict) -> Dict[str, Any]:
"""Get statistics about the dependency graph"""
total_tables = len(dependency_graph)
total_dependencies = sum(
len(table_info.get("upstream_dependencies", set())) + len(table_info.get("downstream_dependencies", set()))
for table_info in dependency_graph.values()
) // 2 # Divide by 2 to avoid double counting
tables_with_upstream = len([
table for table, info in dependency_graph.items()
if info.get("upstream_dependencies")
])
tables_with_downstream = len([
table for table, info in dependency_graph.items()
if info.get("downstream_dependencies")
])
isolated_tables = len([
table for table, info in dependency_graph.items()
if not info.get("upstream_dependencies") and not info.get("downstream_dependencies")
])
return {
"total_tables": total_tables,
"total_dependencies": total_dependencies,
"tables_with_upstream_deps": tables_with_upstream,
"tables_with_downstream_deps": tables_with_downstream,
"isolated_tables": isolated_tables,
"connectivity_ratio": round((total_tables - isolated_tables) / total_tables, 3) if total_tables > 0 else 0,
"avg_dependencies_per_table": round(total_dependencies / total_tables, 2) if total_tables > 0 else 0
}
async def _generate_dependency_insights(self, dependency_graph: Dict, table_analysis: Dict, impact_analysis: Dict) -> Dict[str, Any]:
"""Generate insights from dependency analysis"""
insights = {
"architectural_patterns": {},
"risk_assessment": {},
"optimization_opportunities": {}
}
# Architectural patterns
graph_stats = self._get_dependency_graph_stats(dependency_graph)
insights["architectural_patterns"] = {
"connectivity_level": "high" if graph_stats["connectivity_ratio"] > 0.7 else "medium" if graph_stats["connectivity_ratio"] > 0.3 else "low",
"architecture_type": self._classify_architecture_type(graph_stats),
"complexity_score": round(graph_stats["avg_dependencies_per_table"] * graph_stats["connectivity_ratio"], 2),
"isolated_tables_concern": graph_stats["isolated_tables"] > graph_stats["total_tables"] * 0.3
}
# Risk assessment
if isinstance(impact_analysis, dict) and "global_impact_summary" in impact_analysis:
global_impact = impact_analysis["global_impact_summary"]
insights["risk_assessment"] = {
"overall_risk_level": self._assess_overall_risk_level(global_impact["risk_distribution"]),
"critical_tables_count": global_impact["risk_distribution"]["critical"],
"high_risk_tables_count": global_impact["risk_distribution"]["high"],
"impact_concentration": global_impact["average_impact_score"] > 5.0,
"resilience_score": self._calculate_resilience_score(global_impact)
}
# Optimization opportunities
insights["optimization_opportunities"] = self._identify_optimization_opportunities(dependency_graph, table_analysis)
return insights
def _classify_architecture_type(self, graph_stats: Dict) -> str:
"""Classify the overall architecture type"""
connectivity = graph_stats["connectivity_ratio"]
avg_deps = graph_stats["avg_dependencies_per_table"]
if connectivity > 0.8 and avg_deps > 3:
return "highly_interconnected"
elif connectivity > 0.5 and avg_deps > 2:
return "moderately_connected"
elif connectivity < 0.3:
return "loosely_coupled"
else:
return "mixed_architecture"
def _assess_overall_risk_level(self, risk_distribution: Dict[str, int]) -> str:
"""Assess overall risk level from risk distribution"""
total = sum(risk_distribution.values())
if total == 0:
return "minimal"
critical_ratio = risk_distribution["critical"] / total
high_ratio = risk_distribution["high"] / total
if critical_ratio > 0.1 or high_ratio > 0.2:
return "high"
elif critical_ratio > 0.05 or high_ratio > 0.1:
return "medium"
else:
return "low"
def _calculate_resilience_score(self, global_impact: Dict) -> float:
"""Calculate system resilience score (0-1, higher is better)"""
total_tables = global_impact["total_tables_analyzed"]
risk_dist = global_impact["risk_distribution"]
if total_tables == 0:
return 0.0
# Calculate weighted risk score
weighted_risk = (
risk_dist["critical"] * 5 +
risk_dist["high"] * 3 +
risk_dist["medium"] * 2 +
risk_dist["low"] * 1
) / total_tables
# Convert to resilience score (inverse of risk, normalized)
max_possible_risk = 5.0
resilience = max(0, (max_possible_risk - weighted_risk) / max_possible_risk)
return round(resilience, 3)
def _identify_optimization_opportunities(self, dependency_graph: Dict, table_analysis: Dict) -> List[Dict]:
"""Identify optimization opportunities"""
opportunities = []
# Find tables with excessive dependencies
for table_name, table_info in dependency_graph.items():
upstream_count = len(table_info.get("upstream_dependencies", set()))
downstream_count = len(table_info.get("downstream_dependencies", set()))
if upstream_count > 10:
opportunities.append({
"type": "excessive_upstream_dependencies",
"table": table_name,
"description": f"Table has {upstream_count} upstream dependencies",
"recommendation": "Consider breaking down complex transformations or using intermediate tables",
"priority": "high" if upstream_count > 15 else "medium"
})
if downstream_count > 10:
opportunities.append({
"type": "excessive_downstream_dependencies",
"table": table_name,
"description": f"Table has {downstream_count} downstream dependencies",
"recommendation": "Consider if this table is doing too much or if views could be used",
"priority": "high" if downstream_count > 15 else "medium"
})
# Find potential circular dependencies (simplified check)
# This is a basic check - full cycle detection would be more complex
for table_name, table_info in dependency_graph.items():
upstream_deps = table_info.get("upstream_dependencies", set())
for upstream_table in upstream_deps:
if table_name in dependency_graph.get(upstream_table, {}).get("upstream_dependencies", set()):
opportunities.append({
"type": "potential_circular_dependency",
"table": table_name,
"related_table": upstream_table,
"description": f"Potential circular dependency between {table_name} and {upstream_table}",
"recommendation": "Review and eliminate circular dependencies",
"priority": "high"
})
return opportunities
def _generate_dependency_recommendations(self, dependency_insights: Dict) -> List[Dict]:
"""Generate recommendations based on dependency analysis"""
recommendations = []
# Architecture recommendations
arch_patterns = dependency_insights.get("architectural_patterns", {})
if arch_patterns.get("isolated_tables_concern", False):
recommendations.append({
"type": "architecture",
"priority": "medium",
"title": "High number of isolated tables",
"description": "Many tables have no dependencies, which may indicate data silos",
"action": "Review isolated tables and consider if they should be integrated into data flows"
})
complexity_score = arch_patterns.get("complexity_score", 0)
if complexity_score > 5:
recommendations.append({
"type": "architecture",
"priority": "high",
"title": "High system complexity",
"description": f"System complexity score is {complexity_score} (high)",
"action": "Consider simplifying data architecture and reducing unnecessary dependencies"
})
# Risk recommendations
risk_assessment = dependency_insights.get("risk_assessment", {})
overall_risk = risk_assessment.get("overall_risk_level", "unknown")
if overall_risk == "high":
recommendations.append({
"type": "risk_mitigation",
"priority": "high",
"title": "High overall system risk",
"description": "System has high dependency risks that could cause widespread failures",
"action": "Implement monitoring and backup strategies for critical tables"
})
critical_tables = risk_assessment.get("critical_tables_count", 0)
if critical_tables > 0:
recommendations.append({
"type": "risk_mitigation",
"priority": "high",
"title": f"{critical_tables} critical impact tables identified",
"description": "Tables with critical impact require special attention",
"action": "Implement enhanced monitoring and backup procedures for critical tables"
})
# Optimization recommendations
optimization_ops = dependency_insights.get("optimization_opportunities", [])
if optimization_ops:
high_priority_ops = [op for op in optimization_ops if op.get("priority") == "high"]
if high_priority_ops:
recommendations.append({
"type": "optimization",
"priority": "high",
"title": f"{len(high_priority_ops)} high-priority optimization opportunities",
"description": "System has optimization opportunities that should be addressed",
"action": "Review and implement suggested optimizations for better maintainability"
})
return recommendations

View File

@@ -319,7 +319,20 @@ class DorisLoggerManager:
return
self.log_dir = Path(log_dir)
self.log_dir.mkdir(parents=True, exist_ok=True)
log_dir_writable = True # Initialize the variable
# Try to create log directory, fallback to console-only if fails
try:
self.log_dir.mkdir(parents=True, exist_ok=True)
except (OSError, PermissionError) as e:
# If we can't create log directory (e.g., read-only filesystem in stdio mode),
# fall back to console-only logging
log_dir_writable = False
enable_file = False
enable_audit = False
enable_cleanup = False
# Don't use print() in stdio mode as it interferes with MCP JSON protocol
# Log the warning through the logging system instead, which will be handled after setup
# Clear existing handlers
root_logger = logging.getLogger()
@@ -414,13 +427,19 @@ class DorisLoggerManager:
logger.info("=" * 80)
logger.info("Doris MCP Server Logging System Initialized")
logger.info(f"Log Level: {level}")
logger.info(f"Log Directory: {self.log_dir.absolute()}")
if log_dir_writable:
logger.info(f"Log Directory: {self.log_dir.absolute()}")
else:
logger.info("Log Directory: Not available (console-only mode)")
logger.info(f"Console Logging: {'Enabled' if enable_console else 'Disabled'}")
logger.info(f"File Logging: {'Enabled' if enable_file else 'Disabled'}")
logger.info(f"Audit Logging: {'Enabled' if enable_audit else 'Disabled'}")
logger.info(f"Log Cleanup: {'Enabled' if enable_cleanup and enable_file else 'Disabled'}")
logger.info(f"File Logging: {'Enabled' if enable_file else 'Disabled (fallback mode)'}")
logger.info(f"Audit Logging: {'Enabled' if enable_audit else 'Disabled (fallback mode)'}")
logger.info(f"Log Cleanup: {'Enabled' if enable_cleanup and enable_file else 'Disabled (fallback mode)'}")
if enable_cleanup and enable_file:
logger.info(f"Cleanup Settings: Max age {max_age_days} days, interval {cleanup_interval_hours}h")
if not log_dir_writable:
logger.warning("Running in console-only logging mode due to filesystem permissions")
logger.warning(f"Could not create log directory '{log_dir}' - stdio mode fallback enabled")
logger.info("=" * 80)
def _setup_package_loggers(self, level: str):
@@ -568,10 +587,10 @@ def setup_logging(level: str = "INFO",
def get_logger(name: str) -> logging.Logger:
"""
Get a logger instance.
Args:
name: Logger name (usually __name__)
Returns:
Configured logger instance
"""

View File

@@ -0,0 +1,758 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Performance Analytics Tools Module
Provides slow query analysis and resource growth monitoring capabilities
"""
import time
import statistics
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Tuple
from collections import defaultdict, Counter
from .db import DorisConnectionManager
from .logger import get_logger
logger = get_logger(__name__)
class PerformanceAnalyticsTools:
"""Performance analytics tools for query and resource monitoring"""
def __init__(self, connection_manager: DorisConnectionManager):
self.connection_manager = connection_manager
logger.info("PerformanceAnalyticsTools initialized")
async def analyze_slow_queries_topn(
self,
days: int = 7,
top_n: int = 20,
min_execution_time_ms: int = 1000,
include_patterns: bool = True
) -> Dict[str, Any]:
"""
Analyze top N slowest queries and performance patterns
Args:
days: Number of days to analyze
top_n: Number of top slow queries to return
min_execution_time_ms: Minimum execution time threshold
include_patterns: Whether to include query pattern analysis
Returns:
Slow query analysis results
"""
try:
start_time = time.time()
connection = await self.connection_manager.get_connection("query")
# Get slow query data
slow_queries = await self._get_slow_query_data(
connection, days, min_execution_time_ms
)
if not slow_queries:
return {
"message": "No slow queries found for the specified criteria",
"analysis_period": {"days": days, "threshold_ms": min_execution_time_ms},
"analysis_timestamp": datetime.now().isoformat()
}
# Analyze top N queries
top_queries = await self._analyze_top_slow_queries(slow_queries, top_n)
# Performance insights
performance_insights = await self._generate_performance_insights(slow_queries)
# Query patterns analysis
pattern_analysis = {}
if include_patterns:
pattern_analysis = await self._analyze_query_patterns(slow_queries)
execution_time = time.time() - start_time
return {
"analysis_period": {
"days": days,
"threshold_ms": min_execution_time_ms,
"start_date": (datetime.now() - timedelta(days=days)).isoformat(),
"end_date": datetime.now().isoformat()
},
"analysis_timestamp": datetime.now().isoformat(),
"execution_time_seconds": round(execution_time, 3),
"summary": {
"total_slow_queries": len(slow_queries),
"unique_queries": len(set(q.get("sql_hash", q.get("sql", ""))[:100] for q in slow_queries)),
"top_n_analyzed": min(top_n, len(slow_queries))
},
"top_slow_queries": top_queries,
"performance_insights": performance_insights,
"query_patterns": pattern_analysis,
"recommendations": self._generate_performance_recommendations(performance_insights, pattern_analysis)
}
except Exception as e:
logger.error(f"Slow query analysis failed: {str(e)}")
return {
"error": str(e),
"analysis_timestamp": datetime.now().isoformat()
}
async def analyze_resource_growth_curves(
self,
days: int = 30,
resource_types: List[str] = None,
include_predictions: bool = False,
detailed_response: bool = False
) -> Dict[str, Any]:
"""
Analyze resource growth patterns and trends
Args:
days: Number of days to analyze
resource_types: Types of resources to analyze
include_predictions: Whether to include growth predictions
detailed_response: Whether to return detailed data including daily breakdowns
Returns:
Resource growth analysis results
"""
try:
start_time = time.time()
connection = await self.connection_manager.get_connection("query")
if resource_types is None:
resource_types = ["storage", "query_volume", "user_activity"]
# Analyze each resource type
resource_analysis = {}
if "storage" in resource_types:
resource_analysis["storage"] = await self._analyze_storage_growth(connection, days, detailed_response)
if "query_volume" in resource_types:
resource_analysis["query_volume"] = await self._analyze_query_volume_growth(connection, days, detailed_response)
if "user_activity" in resource_types:
resource_analysis["user_activity"] = await self._analyze_user_activity_growth(connection, days, detailed_response)
# Generate growth insights
growth_insights = await self._generate_growth_insights(resource_analysis, days)
# Growth predictions
predictions = {}
if include_predictions:
predictions = await self._generate_growth_predictions(resource_analysis)
execution_time = time.time() - start_time
result = {
"analysis_period": {
"days": days,
"start_date": (datetime.now() - timedelta(days=days)).isoformat(),
"end_date": datetime.now().isoformat()
},
"analysis_timestamp": datetime.now().isoformat(),
"execution_time_seconds": round(execution_time, 3),
"resource_types_analyzed": resource_types,
"resource_analysis": resource_analysis,
"growth_insights": growth_insights,
"growth_predictions": predictions,
"recommendations": self._generate_growth_recommendations(growth_insights, predictions)
}
# Add execution info for debugging
result["_execution_info"] = {
"tool_name": "analyze_resource_growth_curves",
"execution_time": round(execution_time, 3),
"timestamp": datetime.now().isoformat(),
"detailed_response": detailed_response
}
return result
except Exception as e:
logger.error(f"Resource growth analysis failed: {str(e)}")
return {
"error": str(e),
"analysis_timestamp": datetime.now().isoformat()
}
# ==================== Private Helper Methods ====================
async def _get_slow_query_data(self, connection, days: int, min_execution_time_ms: int) -> List[Dict]:
"""Get slow query data from audit logs"""
try:
start_date = datetime.now() - timedelta(days=days)
slow_query_sql = f"""
SELECT
`user` as user_name,
`client_ip` as host,
`time` as query_time,
`stmt` as sql_statement,
`query_time` as execution_time_ms,
`scan_bytes` as scan_bytes,
`scan_rows` as scan_rows,
`return_rows` as return_rows
FROM internal.__internal_schema.audit_log
WHERE `time` >= '{start_date.strftime('%Y-%m-%d %H:%M:%S')}'
AND `query_time` >= {min_execution_time_ms}
AND `stmt` IS NOT NULL
AND `stmt` != ''
ORDER BY `query_time` DESC
LIMIT 5000
"""
result = await connection.execute(slow_query_sql)
return result.data if result.data else []
except Exception as e:
logger.warning(f"Failed to get slow query data: {str(e)}")
return []
async def _analyze_top_slow_queries(self, slow_queries: List[Dict], top_n: int) -> List[Dict]:
"""Analyze top N slowest queries"""
# Sort by execution time and take top N
sorted_queries = sorted(
slow_queries,
key=lambda x: x.get("execution_time_ms", 0),
reverse=True
)[:top_n]
analyzed_queries = []
for i, query in enumerate(sorted_queries):
sql = query.get("sql_statement", "")
execution_time = query.get("execution_time_ms", 0)
analyzed_query = {
"rank": i + 1,
"execution_time_ms": execution_time,
"execution_time_seconds": round(execution_time / 1000, 2),
"user_name": query.get("user_name", "unknown"),
"query_time": str(query.get("query_time", "")),
"sql_statement": sql[:500] + "..." if len(sql) > 500 else sql,
"sql_length": len(sql),
"query_type": self._classify_query_type(sql),
"scan_metrics": {
"scan_bytes": query.get("scan_bytes", 0),
"scan_rows": query.get("scan_rows", 0),
"return_rows": query.get("return_rows", 0)
},
"performance_issues": self._identify_performance_issues(query)
}
analyzed_queries.append(analyzed_query)
return analyzed_queries
def _classify_query_type(self, sql: str) -> str:
"""Classify SQL query type"""
if not sql:
return "unknown"
sql_upper = sql.upper().strip()
if sql_upper.startswith('SELECT'):
return "SELECT"
elif sql_upper.startswith('INSERT'):
return "INSERT"
elif sql_upper.startswith('UPDATE'):
return "UPDATE"
elif sql_upper.startswith('DELETE'):
return "DELETE"
else:
return "OTHER"
def _identify_performance_issues(self, query: Dict) -> List[str]:
"""Identify potential performance issues in query"""
issues = []
sql = query.get("sql_statement", "").upper()
execution_time = query.get("execution_time_ms", 0)
scan_bytes = query.get("scan_bytes", 0)
scan_rows = query.get("scan_rows", 0)
return_rows = query.get("return_rows", 0)
# High execution time
if execution_time > 60000: # > 1 minute
issues.append("very_long_execution")
elif execution_time > 10000: # > 10 seconds
issues.append("long_execution")
# Large data scan
if scan_bytes > 1024**3: # > 1GB
issues.append("large_data_scan")
# High row scan vs return ratio
if scan_rows > 0 and return_rows > 0:
scan_ratio = scan_rows / return_rows
if scan_ratio > 1000:
issues.append("inefficient_filtering")
# SQL pattern issues
if "SELECT *" in sql:
issues.append("select_all_columns")
if "ORDER BY" in sql and "LIMIT" not in sql:
issues.append("unlimited_sort")
return issues
async def _generate_performance_insights(self, slow_queries: List[Dict]) -> Dict[str, Any]:
"""Generate performance insights from slow queries"""
if not slow_queries:
return {}
execution_times = [q.get("execution_time_ms", 0) for q in slow_queries]
scan_bytes = [q.get("scan_bytes", 0) for q in slow_queries if q.get("scan_bytes", 0) > 0]
# User analysis
user_query_counts = Counter(q.get("user_name", "unknown") for q in slow_queries)
# Query type distribution
query_types = Counter(self._classify_query_type(q.get("sql_statement", "")) for q in slow_queries)
# Time pattern analysis
query_hours = []
for query in slow_queries:
try:
query_time = query.get("query_time")
if query_time:
if isinstance(query_time, str):
dt = datetime.fromisoformat(query_time.replace('Z', '+00:00'))
else:
dt = query_time
query_hours.append(dt.hour)
except:
continue
hour_distribution = Counter(query_hours)
return {
"execution_time_stats": {
"avg_ms": round(statistics.mean(execution_times), 2) if execution_times else 0,
"median_ms": round(statistics.median(execution_times), 2) if execution_times else 0,
"max_ms": max(execution_times) if execution_times else 0,
"min_ms": min(execution_times) if execution_times else 0
},
"data_scan_stats": {
"avg_bytes": round(statistics.mean(scan_bytes), 2) if scan_bytes else 0,
"max_bytes": max(scan_bytes) if scan_bytes else 0,
"total_bytes_scanned": sum(scan_bytes) if scan_bytes else 0
},
"user_analysis": {
"top_slow_query_users": dict(user_query_counts.most_common(10)),
"unique_users": len(user_query_counts)
},
"query_type_distribution": dict(query_types),
"temporal_patterns": {
"hourly_distribution": dict(hour_distribution),
"peak_hour": max(hour_distribution, key=hour_distribution.get) if hour_distribution else None
}
}
async def _analyze_query_patterns(self, slow_queries: List[Dict]) -> Dict[str, Any]:
"""Analyze query patterns in slow queries"""
patterns = {
"common_issues": Counter(),
"table_access_patterns": Counter(),
"query_complexity": []
}
for query in slow_queries:
sql = query.get("sql_statement", "")
# Identify common issues
issues = self._identify_performance_issues(query)
patterns["common_issues"].update(issues)
# Extract table names
tables = self._extract_table_names(sql)
patterns["table_access_patterns"].update(tables)
# Query complexity metrics
complexity = self._calculate_query_complexity(sql)
patterns["query_complexity"].append(complexity)
return {
"common_performance_issues": dict(patterns["common_issues"].most_common(10)),
"frequently_accessed_tables": dict(patterns["table_access_patterns"].most_common(15)),
"complexity_analysis": {
"avg_complexity": round(statistics.mean(patterns["query_complexity"]), 2) if patterns["query_complexity"] else 0,
"max_complexity": max(patterns["query_complexity"]) if patterns["query_complexity"] else 0,
"high_complexity_queries": len([c for c in patterns["query_complexity"] if c > 10])
}
}
def _extract_table_names(self, sql: str) -> List[str]:
"""Extract table names from SQL (simplified)"""
import re
if not sql:
return []
# Simple pattern matching for table names
patterns = [
r'\bFROM\s+([a-zA-Z_][a-zA-Z0-9_.]*)',
r'\bJOIN\s+([a-zA-Z_][a-zA-Z0-9_.]*)',
r'\bINTO\s+([a-zA-Z_][a-zA-Z0-9_.]*)',
r'\bUPDATE\s+([a-zA-Z_][a-zA-Z0-9_.]*)'
]
tables = []
for pattern in patterns:
matches = re.findall(pattern, sql, re.IGNORECASE)
tables.extend(matches)
return [table.lower() for table in tables if table]
def _calculate_query_complexity(self, sql: str) -> int:
"""Calculate query complexity score"""
if not sql:
return 0
sql_upper = sql.upper()
complexity = 0
# Basic complexity factors
complexity += sql_upper.count('JOIN') * 2
complexity += sql_upper.count('SUBQUERY') * 3
complexity += sql_upper.count('UNION') * 2
complexity += sql_upper.count('GROUP BY') * 1
complexity += sql_upper.count('ORDER BY') * 1
complexity += sql_upper.count('HAVING') * 2
complexity += sql_upper.count('CASE') * 1
# Length factor
complexity += len(sql) // 100
return complexity
async def _analyze_storage_growth(self, connection, days: int, detailed_response: bool = False) -> Dict[str, Any]:
"""Analyze storage growth patterns"""
try:
# Get table size data over time
# This is a simplified approach - in practice you'd need historical data
size_sql = """
SELECT
table_schema,
table_name,
ROUND(data_length / 1024 / 1024, 2) as size_mb,
table_rows
FROM information_schema.tables
WHERE table_type = 'BASE TABLE'
AND data_length > 0
ORDER BY data_length DESC
LIMIT 50
"""
result = await connection.execute(size_sql)
if result.data:
total_size = sum(row.get("size_mb", 0) for row in result.data)
total_rows = sum(row.get("table_rows", 0) for row in result.data)
# Estimate growth (simplified - would need historical data for real analysis)
growth_estimate = self._estimate_storage_growth(result.data)
storage_result = {
"current_storage_mb": round(total_size, 2),
"total_rows": total_rows,
"table_count": len(result.data),
"estimated_growth": growth_estimate,
"growth_trend": "stable" # Simplified
}
# Include detailed table information only if requested
if detailed_response:
storage_result["largest_tables"] = result.data[:10]
else:
# Only include top 3 for summary
storage_result["top_tables_summary"] = result.data[:3]
return storage_result
return {"current_storage_mb": 0, "message": "No storage data available"}
except Exception as e:
logger.warning(f"Failed to analyze storage growth: {str(e)}")
return {"error": str(e)}
def _estimate_storage_growth(self, table_data: List[Dict]) -> Dict[str, Any]:
"""Estimate storage growth based on current data"""
# This is a simplified estimation - real implementation would use historical data
total_size = sum(row.get("size_mb", 0) for row in table_data)
return {
"daily_growth_estimate_mb": round(total_size * 0.01, 2), # 1% daily growth estimate
"monthly_growth_estimate_mb": round(total_size * 0.3, 2), # 30% monthly growth estimate
"confidence": "low", # Low confidence without historical data
"method": "simplified_estimation"
}
async def _analyze_query_volume_growth(self, connection, days: int, detailed_response: bool = False) -> Dict[str, Any]:
"""Analyze query volume growth patterns"""
try:
start_date = datetime.now() - timedelta(days=days)
volume_sql = f"""
SELECT
DATE(`time`) as query_date,
COUNT(*) as query_count,
COUNT(DISTINCT `user`) as unique_users
FROM internal.__internal_schema.audit_log
WHERE `time` >= '{start_date.strftime('%Y-%m-%d')}'
AND `stmt` IS NOT NULL
GROUP BY DATE(`time`)
ORDER BY query_date DESC
LIMIT {days}
"""
result = await connection.execute(volume_sql)
if result.data:
daily_volumes = [row.get("query_count", 0) for row in result.data]
avg_daily_queries = statistics.mean(daily_volumes) if daily_volumes else 0
# Simple trend analysis
trend = "stable"
if len(daily_volumes) > 3:
recent_avg = statistics.mean(daily_volumes[:3])
older_avg = statistics.mean(daily_volumes[-3:])
if recent_avg > older_avg * 1.1:
trend = "increasing"
elif recent_avg < older_avg * 0.9:
trend = "decreasing"
volume_result = {
"avg_daily_queries": round(avg_daily_queries, 2),
"max_daily_queries": max(daily_volumes) if daily_volumes else 0,
"min_daily_queries": min(daily_volumes) if daily_volumes else 0,
"total_queries": sum(daily_volumes) if daily_volumes else 0,
"trend": trend
}
# Include detailed daily data only if requested
if detailed_response:
# Fix date serialization in daily_data
serializable_data = []
for row in result.data:
serializable_row = {}
for key, value in row.items():
if hasattr(value, 'isoformat'): # datetime/date object
serializable_row[key] = value.isoformat()
else:
serializable_row[key] = value
serializable_data.append(serializable_row)
volume_result["daily_data"] = serializable_data
else:
# Only include recent data summary
volume_result["recent_days_summary"] = {
"last_7_days_avg": round(statistics.mean(daily_volumes[:7]) if len(daily_volumes) >= 7 else avg_daily_queries, 2),
"data_points": min(len(daily_volumes), 7)
}
return volume_result
return {"avg_daily_queries": 0, "message": "No query volume data available"}
except Exception as e:
logger.warning(f"Failed to analyze query volume growth: {str(e)}")
return {"error": str(e)}
async def _analyze_user_activity_growth(self, connection, days: int, detailed_response: bool = False) -> Dict[str, Any]:
"""Analyze user activity growth patterns"""
try:
start_date = datetime.now() - timedelta(days=days)
activity_sql = f"""
SELECT
DATE(`time`) as activity_date,
COUNT(DISTINCT `user`) as active_users,
COUNT(*) as total_queries
FROM internal.__internal_schema.audit_log
WHERE `time` >= '{start_date.strftime('%Y-%m-%d')}'
AND `stmt` IS NOT NULL
GROUP BY DATE(`time`)
ORDER BY activity_date DESC
LIMIT {days}
"""
result = await connection.execute(activity_sql)
if result.data:
daily_users = [row.get("active_users", 0) for row in result.data]
avg_daily_users = statistics.mean(daily_users) if daily_users else 0
activity_result = {
"avg_daily_active_users": round(avg_daily_users, 2),
"max_daily_users": max(daily_users) if daily_users else 0,
"total_unique_users": len(set(row.get("active_users", 0) for row in result.data))
}
# Include detailed daily activity only if requested
if detailed_response:
# Fix date serialization in daily_activity
serializable_activity = []
for row in result.data:
serializable_row = {}
for key, value in row.items():
if hasattr(value, 'isoformat'): # datetime/date object
serializable_row[key] = value.isoformat()
else:
serializable_row[key] = value
serializable_activity.append(serializable_row)
activity_result["daily_activity"] = serializable_activity
else:
# Only include recent activity summary
recent_queries = [row.get("total_queries", 0) for row in result.data[:7]]
activity_result["recent_activity_summary"] = {
"last_7_days_avg_queries": round(statistics.mean(recent_queries) if recent_queries else 0, 2),
"activity_trend": "active" if avg_daily_users > 1 else "low"
}
return activity_result
return {"avg_daily_active_users": 0, "message": "No user activity data available"}
except Exception as e:
logger.warning(f"Failed to analyze user activity growth: {str(e)}")
return {"error": str(e)}
async def _generate_growth_insights(self, resource_analysis: Dict, days: int) -> Dict[str, Any]:
"""Generate insights from resource growth analysis"""
insights = {}
# Storage insights
if "storage" in resource_analysis:
storage = resource_analysis["storage"]
if "current_storage_mb" in storage:
insights["storage"] = {
"current_size_gb": round(storage["current_storage_mb"] / 1024, 2),
"growth_rate": storage.get("estimated_growth", {}).get("daily_growth_estimate_mb", 0),
"capacity_concern": storage["current_storage_mb"] > 10000 # > 10GB
}
# Query volume insights
if "query_volume" in resource_analysis:
volume = resource_analysis["query_volume"]
insights["query_volume"] = {
"daily_average": volume.get("avg_daily_queries", 0),
"load_level": "high" if volume.get("avg_daily_queries", 0) > 1000 else "normal",
"trend": volume.get("trend", "stable")
}
# User activity insights
if "user_activity" in resource_analysis:
activity = resource_analysis["user_activity"]
insights["user_activity"] = {
"active_user_base": activity.get("avg_daily_active_users", 0),
"user_engagement": "high" if activity.get("avg_daily_active_users", 0) > 10 else "normal"
}
return insights
async def _generate_growth_predictions(self, resource_analysis: Dict) -> Dict[str, Any]:
"""Generate growth predictions (simplified)"""
predictions = {}
# This is a simplified prediction model
# Real implementation would use time series analysis
if "storage" in resource_analysis:
storage = resource_analysis["storage"]
current_size = storage.get("current_storage_mb", 0)
daily_growth = storage.get("estimated_growth", {}).get("daily_growth_estimate_mb", 0)
predictions["storage"] = {
"30_day_projection_mb": round(current_size + (daily_growth * 30), 2),
"90_day_projection_mb": round(current_size + (daily_growth * 90), 2),
"confidence": "low"
}
return predictions
def _generate_performance_recommendations(self, performance_insights: Dict, pattern_analysis: Dict) -> List[Dict]:
"""Generate performance improvement recommendations"""
recommendations = []
# Execution time recommendations
exec_stats = performance_insights.get("execution_time_stats", {})
avg_time = exec_stats.get("avg_ms", 0)
if avg_time > 30000: # > 30 seconds
recommendations.append({
"type": "query_optimization",
"priority": "high",
"title": "High average query execution time",
"description": f"Average slow query time is {avg_time/1000:.1f} seconds",
"action": "Review and optimize slowest queries, consider indexing strategies"
})
# Pattern-based recommendations
if pattern_analysis:
common_issues = pattern_analysis.get("common_performance_issues", {})
if common_issues.get("select_all_columns", 0) > 5:
recommendations.append({
"type": "query_best_practices",
"priority": "medium",
"title": "Frequent SELECT * usage detected",
"description": "Many queries use SELECT * which can impact performance",
"action": "Replace SELECT * with specific column names in queries"
})
if common_issues.get("large_data_scan", 0) > 3:
recommendations.append({
"type": "data_access_optimization",
"priority": "high",
"title": "Large data scans detected",
"description": "Multiple queries are scanning large amounts of data",
"action": "Review partitioning strategies and add appropriate indexes"
})
return recommendations
def _generate_growth_recommendations(self, growth_insights: Dict, predictions: Dict) -> List[Dict]:
"""Generate resource growth recommendations"""
recommendations = []
# Storage recommendations
storage_insights = growth_insights.get("storage", {})
if storage_insights.get("capacity_concern", False):
recommendations.append({
"type": "capacity_planning",
"priority": "medium",
"title": "Storage capacity monitoring needed",
"description": "Current storage usage is significant",
"action": "Implement storage monitoring and consider data archival strategies"
})
# Query volume recommendations
query_insights = growth_insights.get("query_volume", {})
if query_insights.get("load_level") == "high":
recommendations.append({
"type": "performance_scaling",
"priority": "medium",
"title": "High query volume detected",
"description": "System is handling high query volumes",
"action": "Monitor system performance and consider scaling strategies"
})
return recommendations

View File

@@ -0,0 +1,744 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Security Analytics Tools Module
Provides data access analysis, user behavior monitoring, and security insights
"""
import time
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional
from collections import Counter, defaultdict
from .db import DorisConnectionManager
from .logger import get_logger
logger = get_logger(__name__)
class SecurityAnalyticsTools:
"""Security analytics tools for access pattern analysis and user monitoring"""
def __init__(self, connection_manager: DorisConnectionManager):
self.connection_manager = connection_manager
logger.info("SecurityAnalyticsTools initialized")
async def analyze_data_access_patterns(
self,
days: int = 7,
include_system_users: bool = False,
min_query_threshold: int = 5
) -> Dict[str, Any]:
"""
Analyze data access patterns for users and roles
Args:
days: Number of days to analyze
include_system_users: Whether to include system/service users
min_query_threshold: Minimum queries for a user to be included in analysis
Returns:
Comprehensive access pattern analysis
"""
try:
start_time = time.time()
connection = await self.connection_manager.get_connection("query")
# Define analysis period
end_date = datetime.now()
start_date = end_date - timedelta(days=days)
# 1. Get audit log data
audit_data = await self._get_audit_log_data(connection, start_date, end_date, include_system_users)
if not audit_data:
return {
"error": "No audit data available for the specified period",
"analysis_period": {
"start_date": start_date.isoformat(),
"end_date": end_date.isoformat(),
"days": days
}
}
# 2. Analyze user access patterns
user_access_analysis = await self._analyze_user_access_patterns(
audit_data, min_query_threshold
)
# 3. Analyze role-based access
role_access_analysis = await self._analyze_role_access_patterns(
connection, user_access_analysis
)
# 4. Detect security anomalies
security_alerts = await self._detect_security_anomalies(
audit_data, user_access_analysis
)
# 5. Generate access insights
access_insights = await self._generate_access_insights(
user_access_analysis, role_access_analysis
)
execution_time = time.time() - start_time
return {
"analysis_period": {
"start_date": start_date.isoformat(),
"end_date": end_date.isoformat(),
"days": days
},
"analysis_timestamp": datetime.now().isoformat(),
"execution_time_seconds": round(execution_time, 3),
"user_access_summary": self._generate_user_access_summary(user_access_analysis),
"user_access_details": user_access_analysis,
"role_analysis": role_access_analysis,
"security_alerts": security_alerts,
"access_insights": access_insights,
"recommendations": self._generate_security_recommendations(security_alerts, access_insights)
}
except Exception as e:
logger.error(f"Data access pattern analysis failed: {str(e)}")
return {
"error": str(e),
"analysis_timestamp": datetime.now().isoformat()
}
# ==================== Private Helper Methods ====================
async def _get_audit_log_data(self, connection, start_date: datetime, end_date: datetime, include_system_users: bool) -> List[Dict]:
"""Retrieve audit log data for the specified period"""
try:
# System users filter
system_user_filter = ""
if not include_system_users:
system_users = ['root', 'admin', 'system', 'doris', 'information_schema']
user_list = ','.join([f'"{user}"' for user in system_users])
system_user_filter = f"AND `user` NOT IN ({user_list})"
audit_sql = f"""
SELECT
`user` as user_name,
`client_ip` as host,
`time` as query_time,
`stmt` as sql_statement,
`state` as query_status,
`scan_bytes` as scan_bytes,
`scan_rows` as scan_rows,
`return_rows` as return_rows,
`query_time` as execution_time_ms
FROM internal.__internal_schema.audit_log
WHERE `time` >= '{start_date.strftime('%Y-%m-%d %H:%M:%S')}'
AND `time` <= '{end_date.strftime('%Y-%m-%d %H:%M:%S')}'
AND `stmt` IS NOT NULL
AND `stmt` != ''
{system_user_filter}
ORDER BY `time` DESC
LIMIT 10000
"""
result = await connection.execute(audit_sql)
return result.data if result.data else []
except Exception as e:
logger.warning(f"Failed to get audit log data: {str(e)}")
# Try alternative method without detailed metrics
try:
simple_audit_sql = f"""
SELECT
`user` as user_name,
`client_ip` as host,
`time` as query_time,
`stmt` as sql_statement,
`state` as query_status
FROM internal.__internal_schema.audit_log
WHERE `time` >= '{start_date.strftime('%Y-%m-%d %H:%M:%S')}'
AND `time` <= '{end_date.strftime('%Y-%m-%d %H:%M:%S')}'
AND `stmt` IS NOT NULL
{system_user_filter}
ORDER BY `time` DESC
LIMIT 10000
"""
result = await connection.execute(simple_audit_sql)
return result.data if result.data else []
except Exception as e2:
logger.error(f"Failed to get simplified audit log data: {str(e2)}")
return []
async def _analyze_user_access_patterns(self, audit_data: List[Dict], min_query_threshold: int) -> List[Dict]:
"""Analyze access patterns for individual users"""
user_stats = defaultdict(lambda: {
"total_queries": 0,
"unique_tables_accessed": set(),
"hosts": set(),
"query_types": Counter(),
"query_times": [],
"failed_queries": 0,
"data_volume_read_bytes": 0,
"data_volume_read_rows": 0,
"hourly_pattern": [0] * 24,
"daily_pattern": [0] * 7,
"query_statements": []
})
# Process audit data
for entry in audit_data:
user_name = entry.get("user_name", "unknown")
query_time = entry.get("query_time")
sql_statement = entry.get("sql_statement", "")
query_status = entry.get("query_status", "")
stats = user_stats[user_name]
stats["total_queries"] += 1
# Extract table names from SQL
tables = self._extract_table_names_from_sql(sql_statement)
stats["unique_tables_accessed"].update(tables)
# Host tracking
if entry.get("host"):
stats["hosts"].add(entry["host"])
# Query type analysis
query_type = self._classify_query_type(sql_statement)
stats["query_types"][query_type] += 1
# Query time patterns
if query_time:
try:
if isinstance(query_time, str):
query_dt = datetime.fromisoformat(query_time.replace('Z', '+00:00'))
else:
query_dt = query_time
stats["query_times"].append(query_dt)
stats["hourly_pattern"][query_dt.hour] += 1
stats["daily_pattern"][query_dt.weekday()] += 1
except Exception:
pass
# Error tracking
if query_status and "error" in query_status.lower():
stats["failed_queries"] += 1
# Data volume tracking
if entry.get("scan_bytes"):
try:
stats["data_volume_read_bytes"] += int(entry["scan_bytes"])
except (ValueError, TypeError):
pass
if entry.get("scan_rows"):
try:
stats["data_volume_read_rows"] += int(entry["scan_rows"])
except (ValueError, TypeError):
pass
# Store sample queries
if len(stats["query_statements"]) < 10:
stats["query_statements"].append({
"sql": sql_statement[:200] + "..." if len(sql_statement) > 200 else sql_statement,
"timestamp": str(query_time),
"type": query_type
})
# Convert to analysis results
user_analysis = []
for user_name, stats in user_stats.items():
if stats["total_queries"] >= min_query_threshold:
# Calculate patterns and insights
access_pattern = self._classify_access_pattern(stats["hourly_pattern"])
table_access_frequency = dict(Counter(
table for entry in audit_data
if entry.get("user_name") == user_name
for table in self._extract_table_names_from_sql(entry.get("sql_statement", ""))
).most_common(10))
user_analysis.append({
"user_name": user_name,
"access_stats": {
"total_queries": stats["total_queries"],
"unique_tables_accessed": len(stats["unique_tables_accessed"]),
"unique_hosts": len(stats["hosts"]),
"data_volume_read_gb": round(stats["data_volume_read_bytes"] / (1024**3), 3),
"data_volume_read_rows": stats["data_volume_read_rows"],
"failed_queries": stats["failed_queries"],
"success_rate": round((stats["total_queries"] - stats["failed_queries"]) / stats["total_queries"], 3) if stats["total_queries"] > 0 else 0,
"peak_access_hour": stats["hourly_pattern"].index(max(stats["hourly_pattern"])) if max(stats["hourly_pattern"]) > 0 else None,
"access_pattern": access_pattern
},
"query_type_distribution": dict(stats["query_types"]),
"table_access_frequency": table_access_frequency,
"hosts_used": list(stats["hosts"]),
"sample_queries": stats["query_statements"],
"temporal_patterns": {
"hourly_distribution": stats["hourly_pattern"],
"daily_distribution": stats["daily_pattern"]
}
})
return sorted(user_analysis, key=lambda x: x["access_stats"]["total_queries"], reverse=True)
def _extract_table_names_from_sql(self, sql: str) -> List[str]:
"""Extract table names from SQL statement (simplified implementation)"""
if not sql:
return []
import re
# Simple regex patterns to match table names
patterns = [
r'\bFROM\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)',
r'\bJOIN\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)',
r'\bINTO\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)',
r'\bUPDATE\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)',
r'\bDELETE\s+FROM\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)'
]
tables = []
for pattern in patterns:
matches = re.findall(pattern, sql, re.IGNORECASE)
tables.extend(matches)
# Clean up table names (remove quotes, aliases, etc.)
cleaned_tables = []
for table in tables:
# Remove backticks, quotes, and get just the table name
clean_table = table.strip('`"\'').split(' ')[0]
if clean_table and not clean_table.upper() in ['SELECT', 'WHERE', 'AND', 'OR']:
cleaned_tables.append(clean_table)
return list(set(cleaned_tables))
def _classify_query_type(self, sql: str) -> str:
"""Classify SQL query type"""
if not sql:
return "unknown"
sql_upper = sql.upper().strip()
if sql_upper.startswith('SELECT'):
return "SELECT"
elif sql_upper.startswith('INSERT'):
return "INSERT"
elif sql_upper.startswith('UPDATE'):
return "UPDATE"
elif sql_upper.startswith('DELETE'):
return "DELETE"
elif sql_upper.startswith('CREATE'):
return "CREATE"
elif sql_upper.startswith('ALTER'):
return "ALTER"
elif sql_upper.startswith('DROP'):
return "DROP"
elif sql_upper.startswith('SHOW'):
return "SHOW"
elif sql_upper.startswith('DESCRIBE') or sql_upper.startswith('DESC'):
return "DESCRIBE"
else:
return "OTHER"
def _classify_access_pattern(self, hourly_pattern: List[int]) -> str:
"""Classify user access pattern based on hourly distribution"""
if not hourly_pattern or max(hourly_pattern) == 0:
return "no_pattern"
# Find peak hours
max_queries = max(hourly_pattern)
peak_hours = [i for i, count in enumerate(hourly_pattern) if count == max_queries]
# Business hours: 9-17
business_hours = set(range(9, 18))
peak_in_business_hours = any(hour in business_hours for hour in peak_hours)
# Night hours: 22-6
night_hours = set(list(range(22, 24)) + list(range(0, 7)))
peak_in_night_hours = any(hour in night_hours for hour in peak_hours)
if peak_in_business_hours and not peak_in_night_hours:
return "regular_business_hours"
elif peak_in_night_hours:
return "night_shift_or_batch"
elif len(peak_hours) > 6: # Distributed throughout day
return "distributed_access"
else:
return "irregular_pattern"
async def _analyze_role_access_patterns(self, connection, user_access_analysis: List[Dict]) -> Dict[str, Any]:
"""Analyze access patterns by role"""
try:
# Get user roles information
user_roles = await self._get_user_roles(connection)
# Group users by roles
role_stats = defaultdict(lambda: {
"user_count": 0,
"total_queries": 0,
"unique_tables": set(),
"query_types": Counter(),
"avg_queries_per_user": 0,
"users": []
})
# Process user access data
for user_data in user_access_analysis:
user_name = user_data["user_name"]
user_stats = user_data["access_stats"]
query_types = user_data["query_type_distribution"]
# Get user roles (default to 'unknown' if not found)
roles = user_roles.get(user_name, ["unknown"])
for role in roles:
stats = role_stats[role]
stats["user_count"] += 1
stats["total_queries"] += user_stats["total_queries"]
stats["users"].append(user_name)
# Aggregate query types
for query_type, count in query_types.items():
stats["query_types"][query_type] += count
# Calculate role analysis
role_analysis = {}
for role, stats in role_stats.items():
if stats["user_count"] > 0:
avg_queries = stats["total_queries"] / stats["user_count"]
# Calculate privilege usage (simplified)
total_role_queries = sum(stats["query_types"].values())
privilege_usage = {}
if total_role_queries > 0:
privilege_usage = {
query_type: round(count / total_role_queries, 3)
for query_type, count in stats["query_types"].items()
}
role_analysis[role] = {
"user_count": stats["user_count"],
"users": stats["users"],
"total_queries": stats["total_queries"],
"avg_queries_per_user": round(avg_queries, 1),
"query_type_distribution": dict(stats["query_types"]),
"privilege_usage": privilege_usage,
"activity_level": self._classify_role_activity_level(avg_queries)
}
return role_analysis
except Exception as e:
logger.warning(f"Failed to analyze role access patterns: {str(e)}")
return {}
async def _get_user_roles(self, connection) -> Dict[str, List[str]]:
"""Get user roles mapping"""
try:
# Try to get user role information
roles_sql = """
SELECT
User as user_name,
COALESCE(Default_role, 'default') as role_name
FROM mysql.user
"""
result = await connection.execute(roles_sql)
user_roles = defaultdict(list)
if result.data:
for row in result.data:
user_name = row.get("user_name", "")
role_name = row.get("role_name", "default")
if user_name:
user_roles[user_name].append(role_name)
return dict(user_roles)
except Exception as e:
logger.warning(f"Failed to get user roles: {str(e)}")
return {}
def _classify_role_activity_level(self, avg_queries: float) -> str:
"""Classify role activity level based on average queries"""
if avg_queries > 100:
return "high"
elif avg_queries > 20:
return "medium"
elif avg_queries > 5:
return "low"
else:
return "minimal"
async def _detect_security_anomalies(self, audit_data: List[Dict], user_access_analysis: List[Dict]) -> List[Dict]:
"""Detect potential security anomalies"""
alerts = []
# 1. Detect unusual access times
for user_data in user_access_analysis:
user_name = user_data["user_name"]
hourly_pattern = user_data["temporal_patterns"]["hourly_distribution"]
# Check for significant night-time activity
night_queries = sum(hourly_pattern[22:24]) + sum(hourly_pattern[0:6])
total_queries = sum(hourly_pattern)
if total_queries > 0 and night_queries / total_queries > 0.3: # >30% night activity
alerts.append({
"alert_type": "unusual_access_time",
"severity": "medium",
"user": user_name,
"description": f"User {user_name} has {night_queries/total_queries:.1%} of queries during night hours",
"night_query_percentage": round(night_queries/total_queries, 3),
"timestamp": datetime.now().isoformat()
})
# 2. Detect users with high failure rates
for user_data in user_access_analysis:
user_name = user_data["user_name"]
success_rate = user_data["access_stats"]["success_rate"]
total_queries = user_data["access_stats"]["total_queries"]
if total_queries > 10 and success_rate < 0.8: # <80% success rate
alerts.append({
"alert_type": "high_failure_rate",
"severity": "medium",
"user": user_name,
"description": f"User {user_name} has low query success rate ({success_rate:.1%})",
"success_rate": success_rate,
"total_queries": total_queries,
"timestamp": datetime.now().isoformat()
})
# 3. Detect unusual data volume access
data_volumes = [user["access_stats"]["data_volume_read_gb"] for user in user_access_analysis]
if data_volumes:
avg_volume = sum(data_volumes) / len(data_volumes)
std_dev = (sum((x - avg_volume) ** 2 for x in data_volumes) / len(data_volumes)) ** 0.5
threshold = avg_volume + 2 * std_dev # 2 standard deviations above mean
for user_data in user_access_analysis:
user_name = user_data["user_name"]
volume = user_data["access_stats"]["data_volume_read_gb"]
if volume > threshold and volume > 1.0: # >1GB and above threshold
alerts.append({
"alert_type": "unusual_data_volume",
"severity": "high" if volume > threshold * 2 else "medium",
"user": user_name,
"description": f"User {user_name} read {volume:.2f}GB (threshold: {threshold:.2f}GB)",
"data_volume_gb": volume,
"threshold_gb": round(threshold, 2),
"timestamp": datetime.now().isoformat()
})
# 4. Detect users accessing many different tables
for user_data in user_access_analysis:
user_name = user_data["user_name"]
unique_tables = user_data["access_stats"]["unique_tables_accessed"]
total_queries = user_data["access_stats"]["total_queries"]
# High table diversity might indicate privilege escalation or data mining
if unique_tables > 20 and total_queries > 50:
alerts.append({
"alert_type": "broad_table_access",
"severity": "medium",
"user": user_name,
"description": f"User {user_name} accessed {unique_tables} different tables",
"unique_tables_count": unique_tables,
"total_queries": total_queries,
"timestamp": datetime.now().isoformat()
})
return sorted(alerts, key=lambda x: {"high": 3, "medium": 2, "low": 1}.get(x["severity"], 0), reverse=True)
async def _generate_access_insights(self, user_access_analysis: List[Dict], role_analysis: Dict[str, Any]) -> Dict[str, Any]:
"""Generate access insights and patterns"""
insights = {
"user_behavior_patterns": {},
"role_effectiveness": {},
"security_posture": {}
}
# User behavior patterns
if user_access_analysis:
total_users = len(user_access_analysis)
active_users = len([u for u in user_access_analysis if u["access_stats"]["total_queries"] > 10])
power_users = len([u for u in user_access_analysis if u["access_stats"]["total_queries"] > 100])
# Access pattern distribution
pattern_distribution = Counter(
user["access_stats"]["access_pattern"] for user in user_access_analysis
)
insights["user_behavior_patterns"] = {
"total_users_analyzed": total_users,
"active_users": active_users,
"power_users": power_users,
"access_pattern_distribution": dict(pattern_distribution),
"avg_queries_per_user": round(
sum(u["access_stats"]["total_queries"] for u in user_access_analysis) / total_users, 1
) if total_users > 0 else 0
}
# Role effectiveness
if role_analysis:
most_active_role = max(role_analysis.items(), key=lambda x: x[1]["total_queries"])
least_active_role = min(role_analysis.items(), key=lambda x: x[1]["total_queries"])
insights["role_effectiveness"] = {
"total_roles": len(role_analysis),
"most_active_role": {
"role": most_active_role[0],
"total_queries": most_active_role[1]["total_queries"],
"user_count": most_active_role[1]["user_count"]
},
"least_active_role": {
"role": least_active_role[0],
"total_queries": least_active_role[1]["total_queries"],
"user_count": least_active_role[1]["user_count"]
},
"avg_users_per_role": round(
sum(role_info["user_count"] for role_info in role_analysis.values()) / len(role_analysis), 1
)
}
# Security posture assessment
if user_access_analysis:
users_with_failures = len([u for u in user_access_analysis if u["access_stats"]["failed_queries"] > 0])
users_night_access = len([
u for u in user_access_analysis
if any(u["temporal_patterns"]["hourly_distribution"][hour] > 0 for hour in list(range(22, 24)) + list(range(0, 6)))
])
insights["security_posture"] = {
"users_with_query_failures": users_with_failures,
"users_with_night_access": users_night_access,
"security_score": self._calculate_security_score(user_access_analysis),
"risk_level": self._assess_overall_risk_level(user_access_analysis)
}
return insights
def _calculate_security_score(self, user_access_analysis: List[Dict]) -> float:
"""Calculate overall security score (0-1, higher is better)"""
if not user_access_analysis:
return 0.0
total_users = len(user_access_analysis)
# Factors that contribute to security score
users_with_high_success_rate = len([u for u in user_access_analysis if u["access_stats"]["success_rate"] > 0.9])
users_with_normal_patterns = len([u for u in user_access_analysis if u["access_stats"]["access_pattern"] == "regular_business_hours"])
success_rate_score = users_with_high_success_rate / total_users
pattern_score = users_with_normal_patterns / total_users
# Combined score
overall_score = (success_rate_score * 0.6 + pattern_score * 0.4)
return round(overall_score, 3)
def _assess_overall_risk_level(self, user_access_analysis: List[Dict]) -> str:
"""Assess overall security risk level"""
security_score = self._calculate_security_score(user_access_analysis)
if security_score > 0.8:
return "low"
elif security_score > 0.6:
return "medium"
else:
return "high"
def _generate_user_access_summary(self, user_access_analysis: List[Dict]) -> Dict[str, Any]:
"""Generate summary statistics for user access"""
if not user_access_analysis:
return {
"total_users": 0,
"active_users": 0,
"high_activity_users": 0,
"dormant_users": 0
}
total_users = len(user_access_analysis)
active_users = len([u for u in user_access_analysis if u["access_stats"]["total_queries"] > 10])
high_activity_users = len([u for u in user_access_analysis if u["access_stats"]["total_queries"] > 100])
dormant_users = total_users - active_users
return {
"total_users": total_users,
"active_users": active_users,
"high_activity_users": high_activity_users,
"dormant_users": dormant_users,
"activity_distribution": {
"high": high_activity_users,
"medium": active_users - high_activity_users,
"low": dormant_users
}
}
def _generate_security_recommendations(self, security_alerts: List[Dict], access_insights: Dict[str, Any]) -> List[Dict]:
"""Generate security recommendations based on analysis"""
recommendations = []
# Recommendations based on alerts
if security_alerts:
high_severity_alerts = [alert for alert in security_alerts if alert["severity"] == "high"]
if high_severity_alerts:
recommendations.append({
"type": "urgent_security_review",
"priority": "high",
"description": f"Found {len(high_severity_alerts)} high-severity security alerts",
"action": "Immediate review of flagged users and access patterns required",
"affected_users": list(set(alert["user"] for alert in high_severity_alerts if "user" in alert))
})
# Night access recommendations
night_access_alerts = [alert for alert in security_alerts if alert["alert_type"] == "unusual_access_time"]
if night_access_alerts:
recommendations.append({
"type": "access_time_policy",
"priority": "medium",
"description": f"{len(night_access_alerts)} users have significant night-time access",
"action": "Review access time policies and consider time-based restrictions",
"affected_users": [alert["user"] for alert in night_access_alerts]
})
# Recommendations based on insights
security_posture = access_insights.get("security_posture", {})
risk_level = security_posture.get("risk_level", "unknown")
if risk_level == "high":
recommendations.append({
"type": "overall_security_improvement",
"priority": "high",
"description": "Overall security posture indicates high risk",
"action": "Comprehensive security audit and policy review recommended"
})
# Role-based recommendations
role_effectiveness = access_insights.get("role_effectiveness", {})
if role_effectiveness and role_effectiveness.get("total_roles", 0) < 3:
recommendations.append({
"type": "role_management",
"priority": "medium",
"description": "Limited role diversity detected",
"action": "Consider implementing more granular role-based access control"
})
return recommendations