526
doris_mcp_server/utils/adbc_query_tools.py
Normal file
526
doris_mcp_server/utils/adbc_query_tools.py
Normal file
@@ -0,0 +1,526 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Apache Doris ADBC Query Tools
|
||||
High-performance data querying using Apache Arrow Flight SQL protocol
|
||||
"""
|
||||
|
||||
import os
|
||||
import socket
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.db import DorisConnectionManager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _convert_numpy_types(obj):
|
||||
"""Convert numpy types to native Python types for JSON serialization"""
|
||||
try:
|
||||
# Import numpy only when needed
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
if isinstance(obj, np.integer):
|
||||
return int(obj)
|
||||
elif isinstance(obj, np.floating):
|
||||
return float(obj)
|
||||
elif isinstance(obj, np.bool_):
|
||||
return bool(obj)
|
||||
elif isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
elif isinstance(obj, (pd.Timestamp, pd.NaT.__class__)):
|
||||
return str(obj)
|
||||
elif pd.isna(obj):
|
||||
return None
|
||||
else:
|
||||
return obj
|
||||
except ImportError:
|
||||
# If numpy/pandas not available, return as-is
|
||||
return obj
|
||||
|
||||
|
||||
def _convert_dataframe_to_json_serializable(df):
|
||||
"""Convert DataFrame to JSON serializable format"""
|
||||
try:
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
# Convert DataFrame to records
|
||||
records = df.to_dict('records')
|
||||
|
||||
# Convert each record's values
|
||||
converted_records = []
|
||||
for record in records:
|
||||
converted_record = {}
|
||||
for key, value in record.items():
|
||||
converted_record[key] = _convert_numpy_types(value)
|
||||
converted_records.append(converted_record)
|
||||
|
||||
return converted_records
|
||||
except ImportError:
|
||||
# Fallback to basic dict conversion
|
||||
return df.to_dict('records')
|
||||
|
||||
|
||||
class DorisADBCQueryTools:
|
||||
"""ADBC Query Tools for high-performance data transfer using Arrow Flight SQL"""
|
||||
|
||||
def __init__(self, connection_manager: DorisConnectionManager):
|
||||
self.connection_manager = connection_manager
|
||||
self.adbc_client = None
|
||||
self.flight_sql_module = None
|
||||
self.adbc_manager_module = None
|
||||
|
||||
async def exec_adbc_query(
|
||||
self,
|
||||
sql: str,
|
||||
max_rows: int | None = None,
|
||||
timeout: int | None = None,
|
||||
return_format: str | None = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Execute SQL query using ADBC (Arrow Flight SQL) protocol
|
||||
|
||||
Args:
|
||||
sql: SQL statement to execute
|
||||
max_rows: Maximum number of rows to return (uses config default if None)
|
||||
timeout: Query timeout in seconds (uses config default if None)
|
||||
return_format: Format for returned data ("arrow", "pandas", "dict", uses config default if None)
|
||||
|
||||
Returns:
|
||||
Query results in specified format with metadata
|
||||
"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# Use configuration defaults if parameters not specified
|
||||
adbc_config = self.connection_manager.config.adbc
|
||||
max_rows = max_rows if max_rows is not None else adbc_config.default_max_rows
|
||||
timeout = timeout if timeout is not None else adbc_config.default_timeout
|
||||
return_format = return_format if return_format is not None else adbc_config.default_return_format
|
||||
|
||||
# Step 1: Check environment variables and port availability
|
||||
port_check_result = await self._check_arrow_flight_ports()
|
||||
if not port_check_result["success"]:
|
||||
return port_check_result
|
||||
|
||||
# Step 2: Import required ADBC modules
|
||||
import_result = await self._import_adbc_modules()
|
||||
if not import_result["success"]:
|
||||
return import_result
|
||||
|
||||
# Step 3: Create ADBC connection
|
||||
connection_result = await self._create_adbc_connection()
|
||||
if not connection_result["success"]:
|
||||
return connection_result
|
||||
|
||||
# Step 4: Execute query using ADBC
|
||||
query_result = await self._execute_query_with_adbc(
|
||||
sql, max_rows, timeout, return_format
|
||||
)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
if query_result["success"]:
|
||||
query_result["execution_time"] = round(execution_time, 3)
|
||||
query_result["protocol"] = "ADBC_Arrow_Flight_SQL"
|
||||
query_result["timestamp"] = datetime.now().isoformat()
|
||||
|
||||
return query_result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ADBC query execution failed: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"ADBC query execution failed: {str(e)}",
|
||||
"error_type": "execution_error",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
async def _check_arrow_flight_ports(self) -> Dict[str, Any]:
|
||||
"""Check Arrow Flight SQL port configuration and availability"""
|
||||
try:
|
||||
# Check environment variables
|
||||
fe_port = os.getenv("FE_ARROW_FLIGHT_SQL_PORT")
|
||||
be_port = os.getenv("BE_ARROW_FLIGHT_SQL_PORT")
|
||||
|
||||
if not fe_port:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Missing environment variable FE_ARROW_FLIGHT_SQL_PORT, please configure Arrow Flight SQL FE port in .env file",
|
||||
"error_type": "missing_fe_port_config"
|
||||
}
|
||||
|
||||
if not be_port:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Missing environment variable BE_ARROW_FLIGHT_SQL_PORT, please configure Arrow Flight SQL BE port in .env file",
|
||||
"error_type": "missing_be_port_config"
|
||||
}
|
||||
|
||||
# Convert to integer and validate
|
||||
try:
|
||||
fe_port = int(fe_port)
|
||||
be_port = int(be_port)
|
||||
except ValueError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Invalid Arrow Flight SQL port configuration, please ensure FE_ARROW_FLIGHT_SQL_PORT and BE_ARROW_FLIGHT_SQL_PORT are valid numbers",
|
||||
"error_type": "invalid_port_format"
|
||||
}
|
||||
|
||||
# Get host address
|
||||
db_config = self.connection_manager.config.database
|
||||
fe_host = db_config.host
|
||||
|
||||
# Check FE Arrow Flight SQL port availability
|
||||
fe_available = self._check_port_connectivity(fe_host, fe_port)
|
||||
if not fe_available:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Cannot connect to FE Arrow Flight SQL port {fe_host}:{fe_port}, please check if service is running",
|
||||
"error_type": "fe_port_unavailable",
|
||||
"fe_host": fe_host,
|
||||
"fe_port": fe_port
|
||||
}
|
||||
|
||||
# Get BE host list
|
||||
be_hosts = await self._get_be_hosts()
|
||||
if not be_hosts:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Cannot get BE node information, please check cluster status",
|
||||
"error_type": "no_be_hosts"
|
||||
}
|
||||
|
||||
# Check at least one BE Arrow Flight SQL port availability
|
||||
be_available_count = 0
|
||||
be_check_results = []
|
||||
|
||||
for be_host in be_hosts[:3]: # Check first 3 BE nodes
|
||||
be_available = self._check_port_connectivity(be_host, be_port)
|
||||
be_check_results.append({
|
||||
"host": be_host,
|
||||
"port": be_port,
|
||||
"available": be_available
|
||||
})
|
||||
if be_available:
|
||||
be_available_count += 1
|
||||
|
||||
if be_available_count == 0:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Cannot connect to any BE Arrow Flight SQL port (port: {be_port}), please check if BE services are running",
|
||||
"error_type": "no_be_ports_available",
|
||||
"be_check_results": be_check_results
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"fe_host": fe_host,
|
||||
"fe_port": fe_port,
|
||||
"be_port": be_port,
|
||||
"be_hosts": be_hosts,
|
||||
"be_available_count": be_available_count,
|
||||
"be_check_results": be_check_results
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Arrow Flight port check failed: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Arrow Flight port check failed: {str(e)}",
|
||||
"error_type": "port_check_error"
|
||||
}
|
||||
|
||||
def _check_port_connectivity(self, host: str, port: int, timeout: int | None = None) -> bool:
|
||||
"""Check port connectivity"""
|
||||
try:
|
||||
# Use config timeout if not specified
|
||||
if timeout is None:
|
||||
timeout = self.connection_manager.config.adbc.connection_timeout
|
||||
|
||||
with socket.create_connection((host, port), timeout=timeout):
|
||||
return True
|
||||
except (socket.timeout, socket.error, OSError):
|
||||
return False
|
||||
|
||||
async def _get_be_hosts(self) -> List[str]:
|
||||
"""Get BE host list"""
|
||||
try:
|
||||
db_config = self.connection_manager.config.database
|
||||
|
||||
# Use configured BE hosts first
|
||||
if db_config.be_hosts:
|
||||
logger.info(f"Using configured BE hosts: {db_config.be_hosts}")
|
||||
return db_config.be_hosts
|
||||
|
||||
# Get BE nodes via SHOW BACKENDS
|
||||
logger.info("No BE hosts configured, getting BE node information via SHOW BACKENDS")
|
||||
connection = await self.connection_manager.get_connection("query")
|
||||
result = await connection.execute("SHOW BACKENDS")
|
||||
|
||||
be_hosts = []
|
||||
for row in result.data:
|
||||
host = row.get("Host")
|
||||
alive = row.get("Alive", "").lower()
|
||||
if host and alive == "true":
|
||||
be_hosts.append(host)
|
||||
|
||||
logger.info(f"Got {len(be_hosts)} active BE nodes from SHOW BACKENDS")
|
||||
return be_hosts
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get BE hosts: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _import_adbc_modules(self) -> Dict[str, Any]:
|
||||
"""Import ADBC related modules"""
|
||||
try:
|
||||
# Import ADBC Driver Manager
|
||||
try:
|
||||
import adbc_driver_manager
|
||||
self.adbc_manager_module = adbc_driver_manager
|
||||
except ImportError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Missing adbc_driver_manager module, please install: pip install adbc_driver_manager",
|
||||
"error_type": "missing_adbc_manager"
|
||||
}
|
||||
|
||||
# Import ADBC Flight SQL Driver
|
||||
try:
|
||||
import adbc_driver_flightsql.dbapi as flight_sql
|
||||
self.flight_sql_module = flight_sql
|
||||
except ImportError:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Missing adbc_driver_flightsql module, please install: pip install adbc_driver_flightsql",
|
||||
"error_type": "missing_flight_sql_driver"
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"adbc_manager_version": getattr(adbc_driver_manager, '__version__', 'unknown'),
|
||||
"flight_sql_version": getattr(flight_sql, '__version__', 'unknown')
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ADBC module import failed: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"ADBC module import failed: {str(e)}",
|
||||
"error_type": "import_error"
|
||||
}
|
||||
|
||||
async def _create_adbc_connection(self) -> Dict[str, Any]:
|
||||
"""Create ADBC connection"""
|
||||
try:
|
||||
db_config = self.connection_manager.config.database
|
||||
fe_port = int(os.getenv("FE_ARROW_FLIGHT_SQL_PORT"))
|
||||
|
||||
# Build connection URI
|
||||
uri = f"grpc://{db_config.host}:{fe_port}"
|
||||
|
||||
# Create database connection parameters
|
||||
db_kwargs = {
|
||||
self.adbc_manager_module.DatabaseOptions.USERNAME.value: db_config.user,
|
||||
self.adbc_manager_module.DatabaseOptions.PASSWORD.value: db_config.password,
|
||||
}
|
||||
|
||||
# Create connection
|
||||
self.adbc_client = self.flight_sql_module.connect(
|
||||
uri=uri,
|
||||
db_kwargs=db_kwargs
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"uri": uri,
|
||||
"connection_established": True
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create ADBC connection: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"Failed to create ADBC connection: {str(e)}",
|
||||
"error_type": "connection_error"
|
||||
}
|
||||
|
||||
async def _execute_query_with_adbc(
|
||||
self,
|
||||
sql: str,
|
||||
max_rows: int,
|
||||
timeout: int,
|
||||
return_format: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute query using ADBC"""
|
||||
try:
|
||||
if not self.adbc_client:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "ADBC connection not established",
|
||||
"error_type": "no_connection"
|
||||
}
|
||||
|
||||
cursor = self.adbc_client.cursor()
|
||||
start_time = time.time()
|
||||
|
||||
# Execute query
|
||||
cursor.execute(sql)
|
||||
|
||||
# Get results based on return format
|
||||
if return_format == "arrow":
|
||||
# Return Arrow format
|
||||
arrow_data = cursor.fetchallarrow()
|
||||
|
||||
# Limit rows
|
||||
if len(arrow_data) > max_rows:
|
||||
arrow_data = arrow_data.slice(0, max_rows)
|
||||
|
||||
# Convert Arrow data to serializable format
|
||||
preview_df = arrow_data.to_pandas().head(10) if len(arrow_data) > 0 else None
|
||||
result_data = {
|
||||
"format": "arrow",
|
||||
"num_rows": len(arrow_data),
|
||||
"num_columns": len(arrow_data.schema),
|
||||
"column_names": arrow_data.schema.names,
|
||||
"column_types": [str(field.type) for field in arrow_data.schema],
|
||||
"data_preview": _convert_dataframe_to_json_serializable(preview_df) if preview_df is not None else [],
|
||||
"total_bytes": arrow_data.nbytes if hasattr(arrow_data, 'nbytes') else 0
|
||||
}
|
||||
|
||||
elif return_format == "pandas":
|
||||
# Return Pandas DataFrame
|
||||
df = cursor.fetch_df()
|
||||
|
||||
# Limit rows
|
||||
if len(df) > max_rows:
|
||||
df = df.head(max_rows)
|
||||
|
||||
result_data = {
|
||||
"format": "pandas",
|
||||
"num_rows": len(df),
|
||||
"num_columns": len(df.columns),
|
||||
"column_names": df.columns.tolist(),
|
||||
"column_types": df.dtypes.astype(str).tolist(),
|
||||
"data": _convert_dataframe_to_json_serializable(df),
|
||||
"memory_usage": int(df.memory_usage(deep=True).sum())
|
||||
}
|
||||
|
||||
else: # return_format == "dict"
|
||||
# Return dictionary format
|
||||
arrow_data = cursor.fetchallarrow()
|
||||
df = arrow_data.to_pandas()
|
||||
|
||||
# Limit rows
|
||||
if len(df) > max_rows:
|
||||
df = df.head(max_rows)
|
||||
|
||||
result_data = {
|
||||
"format": "dict",
|
||||
"num_rows": len(df),
|
||||
"num_columns": len(df.columns),
|
||||
"column_names": df.columns.tolist(),
|
||||
"column_types": df.dtypes.astype(str).tolist(),
|
||||
"data": _convert_dataframe_to_json_serializable(df)
|
||||
}
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
cursor.close()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"result": result_data,
|
||||
"execution_time": round(execution_time, 3),
|
||||
"sql": sql,
|
||||
"max_rows_applied": len(result_data.get("data", [])) >= max_rows
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ADBC query execution failed: {str(e)}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"ADBC query execution failed: {str(e)}",
|
||||
"error_type": "query_execution_error",
|
||||
"sql": sql
|
||||
}
|
||||
|
||||
async def get_adbc_connection_info(self) -> Dict[str, Any]:
|
||||
"""Get ADBC connection information and status"""
|
||||
try:
|
||||
# Check port status
|
||||
port_status = await self._check_arrow_flight_ports()
|
||||
|
||||
# Check module status
|
||||
module_status = await self._import_adbc_modules()
|
||||
|
||||
# Get configuration information
|
||||
db_config = self.connection_manager.config.database
|
||||
fe_port = os.getenv("FE_ARROW_FLIGHT_SQL_PORT")
|
||||
be_port = os.getenv("BE_ARROW_FLIGHT_SQL_PORT")
|
||||
|
||||
connection_info = {
|
||||
"adbc_available": module_status["success"],
|
||||
"ports_available": port_status["success"],
|
||||
"configuration": {
|
||||
"fe_host": db_config.host,
|
||||
"fe_arrow_flight_port": fe_port,
|
||||
"be_arrow_flight_port": be_port,
|
||||
"user": db_config.user
|
||||
},
|
||||
"port_status": port_status,
|
||||
"module_status": module_status,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
if port_status["success"] and module_status["success"]:
|
||||
connection_info["status"] = "ready"
|
||||
connection_info["message"] = "ADBC Arrow Flight SQL connection ready"
|
||||
else:
|
||||
connection_info["status"] = "not_ready"
|
||||
errors = []
|
||||
if not port_status["success"]:
|
||||
errors.append(port_status["error"])
|
||||
if not module_status["success"]:
|
||||
errors.append(module_status["error"])
|
||||
connection_info["message"] = "; ".join(errors)
|
||||
|
||||
return connection_info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get ADBC connection information: {str(e)}")
|
||||
return {
|
||||
"status": "error",
|
||||
"error": f"Failed to get ADBC connection information: {str(e)}",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup resources"""
|
||||
try:
|
||||
if self.adbc_client:
|
||||
self.adbc_client.close()
|
||||
except:
|
||||
pass
|
||||
@@ -54,6 +54,10 @@ class DatabaseConfig:
|
||||
be_hosts: list[str] = field(default_factory=list)
|
||||
be_webserver_port: int = 8040
|
||||
|
||||
# Arrow Flight SQL Configuration (Required for ADBC tools)
|
||||
fe_arrow_flight_sql_port: int | None = None
|
||||
be_arrow_flight_sql_port: int | None = None
|
||||
|
||||
# Connection pool configuration
|
||||
# Note: min_connections is fixed at 0 to avoid at_eof connection issues
|
||||
# This prevents pre-creation of connections which can cause state problems
|
||||
@@ -133,6 +137,22 @@ class PerformanceConfig:
|
||||
max_response_content_size: int = 4096
|
||||
|
||||
|
||||
@dataclass
|
||||
class ADBCConfig:
|
||||
"""ADBC (Arrow Flight SQL) configuration"""
|
||||
|
||||
# Default query parameters
|
||||
default_max_rows: int = 100000
|
||||
default_timeout: int = 60
|
||||
default_return_format: str = "arrow" # "arrow", "pandas", "dict"
|
||||
|
||||
# Connection timeout for ADBC
|
||||
connection_timeout: int = 30
|
||||
|
||||
# Whether to enable ADBC tools
|
||||
enabled: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoggingConfig:
|
||||
"""Logging configuration"""
|
||||
@@ -190,6 +210,7 @@ class DorisConfig:
|
||||
performance: PerformanceConfig = field(default_factory=PerformanceConfig)
|
||||
logging: LoggingConfig = field(default_factory=LoggingConfig)
|
||||
monitoring: MonitoringConfig = field(default_factory=MonitoringConfig)
|
||||
adbc: ADBCConfig = field(default_factory=ADBCConfig)
|
||||
|
||||
# Custom configuration
|
||||
custom_config: dict[str, Any] = field(default_factory=dict)
|
||||
@@ -260,6 +281,15 @@ class DorisConfig:
|
||||
if be_hosts_env:
|
||||
config.database.be_hosts = [host.strip() for host in be_hosts_env.split(",") if host.strip()]
|
||||
config.database.be_webserver_port = int(os.getenv("DORIS_BE_WEBSERVER_PORT", str(config.database.be_webserver_port)))
|
||||
|
||||
# Arrow Flight SQL Configuration
|
||||
fe_arrow_port_env = os.getenv("FE_ARROW_FLIGHT_SQL_PORT")
|
||||
if fe_arrow_port_env:
|
||||
config.database.fe_arrow_flight_sql_port = int(fe_arrow_port_env)
|
||||
|
||||
be_arrow_port_env = os.getenv("BE_ARROW_FLIGHT_SQL_PORT")
|
||||
if be_arrow_port_env:
|
||||
config.database.be_arrow_flight_sql_port = int(be_arrow_port_env)
|
||||
|
||||
# Connection pool configuration
|
||||
config.database.max_connections = int(
|
||||
@@ -359,6 +389,21 @@ class DorisConfig:
|
||||
)
|
||||
config.monitoring.alert_webhook_url = os.getenv("ALERT_WEBHOOK_URL", config.monitoring.alert_webhook_url)
|
||||
|
||||
# ADBC configuration
|
||||
config.adbc.default_max_rows = int(
|
||||
os.getenv("ADBC_DEFAULT_MAX_ROWS", str(config.adbc.default_max_rows))
|
||||
)
|
||||
config.adbc.default_timeout = int(
|
||||
os.getenv("ADBC_DEFAULT_TIMEOUT", str(config.adbc.default_timeout))
|
||||
)
|
||||
config.adbc.default_return_format = os.getenv("ADBC_DEFAULT_RETURN_FORMAT", config.adbc.default_return_format)
|
||||
config.adbc.connection_timeout = int(
|
||||
os.getenv("ADBC_CONNECTION_TIMEOUT", str(config.adbc.connection_timeout))
|
||||
)
|
||||
config.adbc.enabled = (
|
||||
os.getenv("ADBC_ENABLED", str(config.adbc.enabled).lower()).lower() == "true"
|
||||
)
|
||||
|
||||
# Server configuration
|
||||
config.server_name = os.getenv("SERVER_NAME", config.server_name)
|
||||
config.server_version = os.getenv("SERVER_VERSION", config.server_version)
|
||||
@@ -412,6 +457,13 @@ class DorisConfig:
|
||||
if hasattr(config.monitoring, key):
|
||||
setattr(config.monitoring, key, value)
|
||||
|
||||
# Update ADBC configuration
|
||||
if "adbc" in config_data:
|
||||
adbc_config = config_data["adbc"]
|
||||
for key, value in adbc_config.items():
|
||||
if hasattr(config.adbc, key):
|
||||
setattr(config.adbc, key, value)
|
||||
|
||||
# Custom configuration
|
||||
config.custom_config = config_data.get("custom", {})
|
||||
|
||||
@@ -434,6 +486,8 @@ class DorisConfig:
|
||||
"fe_http_port": self.database.fe_http_port,
|
||||
"be_hosts": self.database.be_hosts,
|
||||
"be_webserver_port": self.database.be_webserver_port,
|
||||
"fe_arrow_flight_sql_port": self.database.fe_arrow_flight_sql_port,
|
||||
"be_arrow_flight_sql_port": self.database.be_arrow_flight_sql_port,
|
||||
"min_connections": self.database.min_connections, # Always 0, shown for reference
|
||||
"max_connections": self.database.max_connections,
|
||||
"connection_timeout": self.database.connection_timeout,
|
||||
@@ -483,6 +537,13 @@ class DorisConfig:
|
||||
"enable_alerts": self.monitoring.enable_alerts,
|
||||
"alert_webhook_url": self.monitoring.alert_webhook_url,
|
||||
},
|
||||
"adbc": {
|
||||
"default_max_rows": self.adbc.default_max_rows,
|
||||
"default_timeout": self.adbc.default_timeout,
|
||||
"default_return_format": self.adbc.default_return_format,
|
||||
"connection_timeout": self.adbc.connection_timeout,
|
||||
"enabled": self.adbc.enabled,
|
||||
},
|
||||
"custom": self.custom_config,
|
||||
}
|
||||
|
||||
@@ -564,6 +625,19 @@ class DorisConfig:
|
||||
if not (1 <= self.monitoring.health_check_port <= 65535):
|
||||
errors.append("Health check port must be in the range 1-65535")
|
||||
|
||||
# Validate ADBC configuration
|
||||
if self.adbc.default_max_rows <= 0:
|
||||
errors.append("ADBC default max rows must be greater than 0")
|
||||
|
||||
if self.adbc.default_timeout <= 0:
|
||||
errors.append("ADBC default timeout must be greater than 0")
|
||||
|
||||
if self.adbc.default_return_format not in ["arrow", "pandas", "dict"]:
|
||||
errors.append("ADBC default return format must be one of arrow, pandas, or dict")
|
||||
|
||||
if self.adbc.connection_timeout <= 0:
|
||||
errors.append("ADBC connection timeout must be greater than 0")
|
||||
|
||||
return errors
|
||||
|
||||
def get_connection_string(self) -> str:
|
||||
@@ -603,6 +677,7 @@ class ConfigManager:
|
||||
def setup_logging(self):
|
||||
"""Setup logging configuration using enhanced logger"""
|
||||
from .logger import setup_logging, get_logger
|
||||
import sys
|
||||
|
||||
# Determine log directory
|
||||
log_dir = "logs"
|
||||
@@ -611,11 +686,19 @@ class ConfigManager:
|
||||
from pathlib import Path
|
||||
log_dir = str(Path(self.config.logging.file_path).parent)
|
||||
|
||||
# Detect if we're in stdio mode by checking if this is likely MCP stdio communication
|
||||
# In stdio mode, we shouldn't output to console as it interferes with JSON protocol
|
||||
is_stdio_mode = (
|
||||
self.config.transport == "stdio" or
|
||||
"--transport" in sys.argv and "stdio" in sys.argv or
|
||||
not sys.stdout.isatty() # Not a terminal (likely piped/redirected)
|
||||
)
|
||||
|
||||
# Setup enhanced logging with cleanup functionality
|
||||
setup_logging(
|
||||
level=self.config.logging.level,
|
||||
log_dir=log_dir,
|
||||
enable_console=True,
|
||||
enable_console=not is_stdio_mode, # Disable console logging in stdio mode
|
||||
enable_file=True,
|
||||
enable_audit=self.config.logging.enable_audit,
|
||||
audit_file=self.config.logging.audit_file_path,
|
||||
|
||||
733
doris_mcp_server/utils/data_exploration_tools.py
Normal file
733
doris_mcp_server/utils/data_exploration_tools.py
Normal file
@@ -0,0 +1,733 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Data Exploration Tools Module
|
||||
Provides table data distribution analysis and exploration capabilities
|
||||
"""
|
||||
|
||||
import time
|
||||
import math
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DataExplorationTools:
|
||||
"""Data exploration tools for table distribution analysis"""
|
||||
|
||||
def __init__(self, connection_manager: DorisConnectionManager):
|
||||
self.connection_manager = connection_manager
|
||||
logger.info("DataExplorationTools initialized")
|
||||
|
||||
|
||||
|
||||
# ==================== Private Helper Methods ====================
|
||||
|
||||
def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str:
|
||||
"""Build full table name with catalog and database using three-part naming convention"""
|
||||
# Default catalog for internal tables
|
||||
effective_catalog = catalog_name if catalog_name else "internal"
|
||||
|
||||
if db_name:
|
||||
return f"{effective_catalog}.{db_name}.{table_name}"
|
||||
else:
|
||||
# If no db_name provided, need to determine the current database
|
||||
return f"{effective_catalog}.{table_name}"
|
||||
|
||||
async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]:
|
||||
"""Get basic table information including row count"""
|
||||
try:
|
||||
count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}"
|
||||
result = await connection.execute(count_sql)
|
||||
|
||||
if result.data:
|
||||
return {"row_count": result.data[0]["row_count"]}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get basic info for table {table_name}: {str(e)}")
|
||||
return {"row_count": 0}
|
||||
|
||||
async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]:
|
||||
"""Get detailed column information"""
|
||||
try:
|
||||
where_conditions = [f"table_name = '{table_name}'"]
|
||||
|
||||
if db_name:
|
||||
where_conditions.append(f"table_schema = '{db_name}'")
|
||||
else:
|
||||
where_conditions.append("table_schema = DATABASE()")
|
||||
|
||||
columns_sql = f"""
|
||||
SELECT
|
||||
column_name,
|
||||
data_type,
|
||||
is_nullable,
|
||||
column_comment,
|
||||
ordinal_position
|
||||
FROM information_schema.columns
|
||||
WHERE {' AND '.join(where_conditions)}
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
|
||||
result = await connection.execute(columns_sql)
|
||||
return result.data if result.data else []
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get columns info for table {table_name}: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _determine_sampling_strategy(self, connection, table_name: str, total_rows: int, sample_size: int) -> Dict[str, Any]:
|
||||
"""Determine optimal sampling strategy based on table size"""
|
||||
if total_rows <= sample_size:
|
||||
# Use all data if table is small enough
|
||||
return {
|
||||
"total_rows": total_rows,
|
||||
"sample_size": total_rows,
|
||||
"sampling_method": "full_scan",
|
||||
"sampling_ratio": 1.0,
|
||||
"use_sampling": False,
|
||||
"sample_table_expression": table_name
|
||||
}
|
||||
else:
|
||||
# Use random sampling for large tables
|
||||
sampling_ratio = sample_size / total_rows
|
||||
return {
|
||||
"total_rows": total_rows,
|
||||
"sample_size": sample_size,
|
||||
"sampling_method": "random_sample",
|
||||
"sampling_ratio": round(sampling_ratio, 4),
|
||||
"use_sampling": True,
|
||||
"sample_table_expression": f"(SELECT * FROM {table_name} ORDER BY RAND() LIMIT {sample_size}) as sample_table"
|
||||
}
|
||||
|
||||
def _select_analysis_columns(self, columns_info: List[Dict], include_all: bool) -> List[Dict]:
|
||||
"""Select columns for analysis based on strategy"""
|
||||
if include_all:
|
||||
return columns_info
|
||||
|
||||
# If not analyzing all columns, prioritize key columns
|
||||
priority_keywords = ['id', 'key', 'code', 'status', 'type', 'amount', 'count', 'date', 'time']
|
||||
|
||||
priority_columns = []
|
||||
other_columns = []
|
||||
|
||||
for col in columns_info:
|
||||
col_name_lower = col["column_name"].lower()
|
||||
if any(keyword in col_name_lower for keyword in priority_keywords):
|
||||
priority_columns.append(col)
|
||||
else:
|
||||
other_columns.append(col)
|
||||
|
||||
# Return priority columns plus first 10 other columns
|
||||
return priority_columns + other_columns[:10]
|
||||
|
||||
def _is_numeric_type(self, data_type: str) -> bool:
|
||||
"""Check if column type is numeric"""
|
||||
numeric_types = [
|
||||
'tinyint', 'smallint', 'int', 'bigint', 'largeint',
|
||||
'float', 'double', 'decimal', 'numeric'
|
||||
]
|
||||
return any(num_type in data_type.lower() for num_type in numeric_types)
|
||||
|
||||
def _is_categorical_type(self, data_type: str) -> bool:
|
||||
"""Check if column type is categorical"""
|
||||
categorical_types = ['varchar', 'char', 'string', 'text', 'enum']
|
||||
return any(cat_type in data_type.lower() for cat_type in categorical_types)
|
||||
|
||||
def _is_temporal_type(self, data_type: str) -> bool:
|
||||
"""Check if column type is temporal"""
|
||||
temporal_types = ['date', 'datetime', 'timestamp', 'time']
|
||||
return any(temp_type in data_type.lower() for temp_type in temporal_types)
|
||||
|
||||
async def _analyze_numeric_distributions(self, connection, table_name: str, numeric_columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
|
||||
"""Analyze distribution patterns for numeric columns"""
|
||||
numeric_analysis = {}
|
||||
|
||||
for column in numeric_columns:
|
||||
col_name = column["column_name"]
|
||||
try:
|
||||
# Basic statistics
|
||||
table_expr = sampling_info.get("sample_table_expression", table_name)
|
||||
stats_sql = f"""
|
||||
SELECT
|
||||
COUNT({col_name}) as count,
|
||||
MIN({col_name}) as min_value,
|
||||
MAX({col_name}) as max_value,
|
||||
AVG({col_name}) as mean_value,
|
||||
STDDEV({col_name}) as std_dev
|
||||
FROM {table_expr}
|
||||
WHERE {col_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
stats_result = await connection.execute(stats_sql)
|
||||
|
||||
if stats_result.data and stats_result.data[0]["count"] > 0:
|
||||
stats = stats_result.data[0]
|
||||
|
||||
# Percentiles calculation
|
||||
percentiles = await self._calculate_percentiles(connection, table_name, col_name, sampling_info)
|
||||
|
||||
# Outlier detection
|
||||
outliers = await self._detect_numeric_outliers(connection, table_name, col_name, percentiles, sampling_info)
|
||||
|
||||
# Distribution shape analysis
|
||||
distribution_shape = await self._analyze_distribution_shape(
|
||||
connection, table_name, col_name, stats, percentiles, sampling_info
|
||||
)
|
||||
|
||||
numeric_analysis[col_name] = {
|
||||
"data_type": column["data_type"],
|
||||
"statistics": {
|
||||
"count": stats["count"],
|
||||
"mean": round(float(stats["mean_value"]), 4) if stats["mean_value"] else None,
|
||||
"std": round(float(stats["std_dev"]), 4) if stats["std_dev"] else None,
|
||||
"min": float(stats["min_value"]) if stats["min_value"] else None,
|
||||
"max": float(stats["max_value"]) if stats["max_value"] else None,
|
||||
**percentiles
|
||||
},
|
||||
"distribution_shape": distribution_shape,
|
||||
"outliers": outliers
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze numeric column {col_name}: {str(e)}")
|
||||
numeric_analysis[col_name] = {"error": str(e)}
|
||||
|
||||
return numeric_analysis
|
||||
|
||||
async def _calculate_percentiles(self, connection, table_name: str, col_name: str, sampling_info: Dict) -> Dict[str, float]:
|
||||
"""Calculate percentiles for numeric column"""
|
||||
try:
|
||||
table_expr = sampling_info.get("sample_table_expression", table_name)
|
||||
percentile_sql = f"""
|
||||
SELECT
|
||||
PERCENTILE({col_name}, 0.25) as p25,
|
||||
PERCENTILE({col_name}, 0.50) as p50,
|
||||
PERCENTILE({col_name}, 0.75) as p75,
|
||||
PERCENTILE({col_name}, 0.90) as p90,
|
||||
PERCENTILE({col_name}, 0.95) as p95,
|
||||
PERCENTILE({col_name}, 0.99) as p99
|
||||
FROM {table_expr}
|
||||
WHERE {col_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
result = await connection.execute(percentile_sql)
|
||||
|
||||
if result.data:
|
||||
data = result.data[0]
|
||||
return {
|
||||
"25%": round(float(data["p25"]), 4) if data["p25"] else None,
|
||||
"50%": round(float(data["p50"]), 4) if data["p50"] else None,
|
||||
"75%": round(float(data["p75"]), 4) if data["p75"] else None,
|
||||
"90%": round(float(data["p90"]), 4) if data["p90"] else None,
|
||||
"95%": round(float(data["p95"]), 4) if data["p95"] else None,
|
||||
"99%": round(float(data["p99"]), 4) if data["p99"] else None
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to calculate percentiles for {col_name}: {str(e)}")
|
||||
|
||||
return {}
|
||||
|
||||
async def _detect_numeric_outliers(self, connection, table_name: str, col_name: str, percentiles: Dict, sampling_info: Dict) -> Dict[str, Any]:
|
||||
"""Detect outliers using IQR method"""
|
||||
try:
|
||||
if "25%" not in percentiles or "75%" not in percentiles:
|
||||
return {"outlier_count": 0, "outlier_rate": 0.0}
|
||||
|
||||
q1 = percentiles["25%"]
|
||||
q3 = percentiles["75%"]
|
||||
iqr = q3 - q1
|
||||
|
||||
lower_bound = q1 - 1.5 * iqr
|
||||
upper_bound = q3 + 1.5 * iqr
|
||||
|
||||
table_expr = sampling_info.get("sample_table_expression", table_name)
|
||||
outlier_sql = f"""
|
||||
SELECT
|
||||
COUNT(*) as total_count,
|
||||
SUM(CASE WHEN {col_name} < {lower_bound} OR {col_name} > {upper_bound} THEN 1 ELSE 0 END) as outlier_count
|
||||
FROM {table_expr}
|
||||
WHERE {col_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
result = await connection.execute(outlier_sql)
|
||||
|
||||
if result.data:
|
||||
data = result.data[0]
|
||||
total_count = data["total_count"]
|
||||
outlier_count = data["outlier_count"]
|
||||
outlier_rate = outlier_count / total_count if total_count > 0 else 0
|
||||
|
||||
return {
|
||||
"outlier_count": outlier_count,
|
||||
"outlier_rate": round(outlier_rate, 4),
|
||||
"outlier_threshold_lower": round(lower_bound, 4),
|
||||
"outlier_threshold_upper": round(upper_bound, 4),
|
||||
"iqr": round(iqr, 4)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to detect outliers for {col_name}: {str(e)}")
|
||||
|
||||
return {"outlier_count": 0, "outlier_rate": 0.0}
|
||||
|
||||
async def _analyze_distribution_shape(self, connection, table_name: str, col_name: str, stats: Dict, percentiles: Dict, sampling_info: Dict) -> Dict[str, Any]:
|
||||
"""Analyze the shape of data distribution"""
|
||||
try:
|
||||
mean = stats.get("mean_value", 0)
|
||||
median = percentiles.get("50%", 0)
|
||||
|
||||
if mean is None or median is None:
|
||||
return {"distribution_type": "unknown"}
|
||||
|
||||
# Calculate skewness indicator
|
||||
if abs(mean - median) < 0.01:
|
||||
skew_indicator = "symmetric"
|
||||
elif mean > median:
|
||||
skew_indicator = "right_skewed"
|
||||
else:
|
||||
skew_indicator = "left_skewed"
|
||||
|
||||
# Estimate kurtosis based on percentile spread
|
||||
if "25%" in percentiles and "75%" in percentiles:
|
||||
iqr = percentiles["75%"] - percentiles["25%"]
|
||||
range_90 = percentiles.get("90%", percentiles["75%"]) - percentiles.get("10%", percentiles["25%"])
|
||||
|
||||
if iqr > 0:
|
||||
kurtosis_indicator = "normal" if 2.5 <= range_90/iqr <= 3.5 else ("heavy_tailed" if range_90/iqr > 3.5 else "light_tailed")
|
||||
else:
|
||||
kurtosis_indicator = "unknown"
|
||||
else:
|
||||
kurtosis_indicator = "unknown"
|
||||
|
||||
return {
|
||||
"skewness_indicator": skew_indicator,
|
||||
"kurtosis_indicator": kurtosis_indicator,
|
||||
"distribution_type": self._classify_distribution_type(skew_indicator, kurtosis_indicator),
|
||||
"mean_median_ratio": round(mean / median, 4) if median != 0 else None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze distribution shape for {col_name}: {str(e)}")
|
||||
return {"distribution_type": "unknown"}
|
||||
|
||||
def _classify_distribution_type(self, skew: str, kurtosis: str) -> str:
|
||||
"""Classify distribution type based on skewness and kurtosis"""
|
||||
if skew == "symmetric" and kurtosis == "normal":
|
||||
return "approximately_normal"
|
||||
elif skew == "right_skewed":
|
||||
return "right_skewed"
|
||||
elif skew == "left_skewed":
|
||||
return "left_skewed"
|
||||
elif kurtosis == "heavy_tailed":
|
||||
return "heavy_tailed"
|
||||
else:
|
||||
return "non_normal"
|
||||
|
||||
async def _analyze_categorical_distributions(self, connection, table_name: str, categorical_columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
|
||||
"""Analyze distribution patterns for categorical columns"""
|
||||
categorical_analysis = {}
|
||||
|
||||
for column in categorical_columns:
|
||||
col_name = column["column_name"]
|
||||
try:
|
||||
# Basic cardinality and distribution
|
||||
cardinality_sql = f"""
|
||||
SELECT
|
||||
COUNT(DISTINCT {col_name}) as cardinality,
|
||||
COUNT({col_name}) as non_null_count
|
||||
FROM {table_name}
|
||||
WHERE {col_name} IS NOT NULL
|
||||
{sampling_info.get('sample_query_suffix', '')}
|
||||
"""
|
||||
|
||||
cardinality_result = await connection.execute(cardinality_sql)
|
||||
|
||||
if cardinality_result.data:
|
||||
cardinality_data = cardinality_result.data[0]
|
||||
cardinality = cardinality_data["cardinality"]
|
||||
non_null_count = cardinality_data["non_null_count"]
|
||||
|
||||
# Value distribution (top values)
|
||||
value_distribution = await self._get_categorical_value_distribution(
|
||||
connection, table_name, col_name, sampling_info, non_null_count
|
||||
)
|
||||
|
||||
# Calculate entropy and concentration
|
||||
entropy = self._calculate_entropy(value_distribution)
|
||||
concentration_ratio = value_distribution[0]["percentage"] if value_distribution else 0
|
||||
|
||||
categorical_analysis[col_name] = {
|
||||
"data_type": column["data_type"],
|
||||
"cardinality": cardinality,
|
||||
"non_null_count": non_null_count,
|
||||
"value_distribution": value_distribution,
|
||||
"entropy": round(entropy, 3),
|
||||
"concentration_ratio": round(concentration_ratio, 4),
|
||||
"diversity_score": round(cardinality / non_null_count, 4) if non_null_count > 0 else 0
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze categorical column {col_name}: {str(e)}")
|
||||
categorical_analysis[col_name] = {"error": str(e)}
|
||||
|
||||
return categorical_analysis
|
||||
|
||||
async def _get_categorical_value_distribution(self, connection, table_name: str, col_name: str, sampling_info: Dict, total_count: int) -> List[Dict]:
|
||||
"""Get value distribution for categorical column"""
|
||||
try:
|
||||
# Use sample table expression if sampling is enabled
|
||||
table_expr = sampling_info.get("sample_table_expression", table_name)
|
||||
|
||||
distribution_sql = f"""
|
||||
SELECT
|
||||
{col_name} as value,
|
||||
COUNT(*) as count
|
||||
FROM {table_expr}
|
||||
WHERE {col_name} IS NOT NULL
|
||||
GROUP BY {col_name}
|
||||
ORDER BY COUNT(*) DESC
|
||||
LIMIT 20
|
||||
"""
|
||||
|
||||
result = await connection.execute(distribution_sql)
|
||||
|
||||
if result.data:
|
||||
distribution = []
|
||||
for row in result.data:
|
||||
count = row["count"]
|
||||
percentage = count / total_count if total_count > 0 else 0
|
||||
distribution.append({
|
||||
"value": str(row["value"]),
|
||||
"count": count,
|
||||
"percentage": round(percentage, 4)
|
||||
})
|
||||
return distribution
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get value distribution for {col_name}: {str(e)}")
|
||||
|
||||
return []
|
||||
|
||||
def _calculate_entropy(self, value_distribution: List[Dict]) -> float:
|
||||
"""Calculate Shannon entropy for categorical distribution"""
|
||||
if not value_distribution:
|
||||
return 0.0
|
||||
|
||||
entropy = 0.0
|
||||
for item in value_distribution:
|
||||
p = item["percentage"]
|
||||
if p > 0:
|
||||
entropy -= p * math.log2(p)
|
||||
|
||||
return entropy
|
||||
|
||||
async def _analyze_temporal_distributions(self, connection, table_name: str, temporal_columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
|
||||
"""Analyze distribution patterns for temporal columns"""
|
||||
temporal_analysis = {}
|
||||
|
||||
for column in temporal_columns:
|
||||
col_name = column["column_name"]
|
||||
try:
|
||||
# Date range analysis
|
||||
table_expr = sampling_info.get("sample_table_expression", table_name)
|
||||
range_sql = f"""
|
||||
SELECT
|
||||
MIN({col_name}) as earliest,
|
||||
MAX({col_name}) as latest,
|
||||
COUNT({col_name}) as non_null_count
|
||||
FROM {table_expr}
|
||||
WHERE {col_name} IS NOT NULL
|
||||
"""
|
||||
|
||||
range_result = await connection.execute(range_sql)
|
||||
|
||||
if range_result.data and range_result.data[0]["non_null_count"] > 0:
|
||||
range_data = range_result.data[0]
|
||||
earliest = range_data["earliest"]
|
||||
latest = range_data["latest"]
|
||||
|
||||
# Calculate span
|
||||
date_span_info = self._calculate_date_span(earliest, latest)
|
||||
|
||||
# Temporal patterns analysis
|
||||
temporal_patterns = await self._analyze_temporal_patterns(
|
||||
connection, table_name, col_name, sampling_info
|
||||
)
|
||||
|
||||
temporal_analysis[col_name] = {
|
||||
"data_type": column["data_type"],
|
||||
"non_null_count": range_data["non_null_count"],
|
||||
"date_range": {
|
||||
"earliest": str(earliest),
|
||||
"latest": str(latest),
|
||||
**date_span_info
|
||||
},
|
||||
"temporal_patterns": temporal_patterns
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze temporal column {col_name}: {str(e)}")
|
||||
temporal_analysis[col_name] = {"error": str(e)}
|
||||
|
||||
return temporal_analysis
|
||||
|
||||
def _calculate_date_span(self, earliest, latest) -> Dict[str, Any]:
|
||||
"""Calculate date span information"""
|
||||
try:
|
||||
if isinstance(earliest, str):
|
||||
earliest = datetime.fromisoformat(earliest.replace('Z', '+00:00'))
|
||||
if isinstance(latest, str):
|
||||
latest = datetime.fromisoformat(latest.replace('Z', '+00:00'))
|
||||
|
||||
span = latest - earliest
|
||||
span_days = span.days
|
||||
|
||||
return {
|
||||
"span_days": span_days,
|
||||
"span_years": round(span_days / 365.25, 2),
|
||||
"span_description": self._describe_time_span(span_days)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to calculate date span: {str(e)}")
|
||||
return {"span_days": 0}
|
||||
|
||||
def _describe_time_span(self, days: int) -> str:
|
||||
"""Describe time span in human readable format"""
|
||||
if days < 1:
|
||||
return "less_than_day"
|
||||
elif days < 7:
|
||||
return "days"
|
||||
elif days < 30:
|
||||
return "weeks"
|
||||
elif days < 365:
|
||||
return "months"
|
||||
else:
|
||||
return "years"
|
||||
|
||||
async def _analyze_temporal_patterns(self, connection, table_name: str, col_name: str, sampling_info: Dict) -> Dict[str, Any]:
|
||||
"""Analyze temporal patterns like seasonality and trends"""
|
||||
try:
|
||||
table_expr = sampling_info.get("sample_table_expression", table_name)
|
||||
# Weekly pattern analysis
|
||||
weekly_pattern_sql = f"""
|
||||
SELECT
|
||||
DAYOFWEEK({col_name}) as day_of_week,
|
||||
COUNT(*) as count
|
||||
FROM {table_expr}
|
||||
WHERE {col_name} IS NOT NULL
|
||||
GROUP BY DAYOFWEEK({col_name})
|
||||
ORDER BY day_of_week
|
||||
"""
|
||||
|
||||
weekly_result = await connection.execute(weekly_pattern_sql)
|
||||
|
||||
weekly_pattern = []
|
||||
if weekly_result.data:
|
||||
total_records = sum(row["count"] for row in weekly_result.data)
|
||||
for row in weekly_result.data:
|
||||
percentage = row["count"] / total_records if total_records > 0 else 0
|
||||
weekly_pattern.append(round(percentage, 3))
|
||||
|
||||
# Monthly trend analysis (simplified)
|
||||
monthly_trend_sql = f"""
|
||||
SELECT
|
||||
YEAR({col_name}) as year,
|
||||
MONTH({col_name}) as month,
|
||||
COUNT(*) as count
|
||||
FROM {table_expr}
|
||||
WHERE {col_name} IS NOT NULL
|
||||
GROUP BY YEAR({col_name}), MONTH({col_name})
|
||||
ORDER BY year, month
|
||||
LIMIT 12
|
||||
"""
|
||||
|
||||
monthly_result = await connection.execute(monthly_trend_sql)
|
||||
monthly_trend = "stable" # Simplified trend analysis
|
||||
|
||||
if monthly_result.data and len(monthly_result.data) > 3:
|
||||
counts = [row["count"] for row in monthly_result.data]
|
||||
if len(counts) > 1:
|
||||
trend_direction = "increasing" if counts[-1] > counts[0] else "decreasing"
|
||||
monthly_trend = trend_direction
|
||||
|
||||
return {
|
||||
"weekly_pattern": weekly_pattern,
|
||||
"monthly_trend": monthly_trend,
|
||||
"seasonal_component": self._estimate_seasonality(weekly_pattern)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze temporal patterns for {col_name}: {str(e)}")
|
||||
return {"weekly_pattern": [], "monthly_trend": "unknown"}
|
||||
|
||||
def _estimate_seasonality(self, weekly_pattern: List[float]) -> float:
|
||||
"""Estimate seasonality strength based on weekly pattern variance"""
|
||||
if len(weekly_pattern) < 7:
|
||||
return 0.0
|
||||
|
||||
mean_percentage = sum(weekly_pattern) / len(weekly_pattern)
|
||||
variance = sum((x - mean_percentage) ** 2 for x in weekly_pattern) / len(weekly_pattern)
|
||||
|
||||
# Normalize variance to 0-1 scale as seasonality indicator
|
||||
seasonality = min(variance * 10, 1.0) # Scaling factor
|
||||
return round(seasonality, 3)
|
||||
|
||||
async def _generate_data_quality_insights(self, connection, table_name: str, columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
|
||||
"""Generate overall data quality insights"""
|
||||
try:
|
||||
total_columns = len(columns)
|
||||
|
||||
# Calculate null rates across all columns
|
||||
null_analysis = await self._analyze_overall_null_rates(connection, table_name, columns, sampling_info)
|
||||
|
||||
# Identify potential data quality issues
|
||||
quality_issues = []
|
||||
|
||||
# High null rate columns
|
||||
high_null_columns = [col for col, rate in null_analysis["column_null_rates"].items() if rate > 0.2]
|
||||
if high_null_columns:
|
||||
quality_issues.append({
|
||||
"issue_type": "high_null_rates",
|
||||
"severity": "medium",
|
||||
"affected_columns": high_null_columns,
|
||||
"description": f"{len(high_null_columns)} columns have null rates > 20%"
|
||||
})
|
||||
|
||||
# Calculate overall data quality score
|
||||
avg_null_rate = sum(null_analysis["column_null_rates"].values()) / len(null_analysis["column_null_rates"]) if null_analysis["column_null_rates"] else 0
|
||||
data_quality_score = max(0, 1 - avg_null_rate)
|
||||
|
||||
return {
|
||||
"total_columns_analyzed": total_columns,
|
||||
"null_analysis": null_analysis,
|
||||
"data_quality_score": round(data_quality_score, 3),
|
||||
"quality_issues": quality_issues,
|
||||
"recommendations": self._generate_quality_recommendations(quality_issues, null_analysis)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to generate data quality insights: {str(e)}")
|
||||
return {"data_quality_score": 0.0, "error": str(e)}
|
||||
|
||||
async def _analyze_overall_null_rates(self, connection, table_name: str, columns: List[Dict], sampling_info: Dict) -> Dict[str, Any]:
|
||||
"""Analyze null rates across all columns"""
|
||||
column_null_rates = {}
|
||||
total_null_count = 0
|
||||
total_cell_count = 0
|
||||
|
||||
for column in columns:
|
||||
col_name = column["column_name"]
|
||||
try:
|
||||
table_expr = sampling_info.get("sample_table_expression", table_name)
|
||||
null_sql = f"""
|
||||
SELECT
|
||||
COUNT(*) as total_count,
|
||||
COUNT({col_name}) as non_null_count
|
||||
FROM {table_expr}
|
||||
"""
|
||||
|
||||
result = await connection.execute(null_sql)
|
||||
if result.data:
|
||||
data = result.data[0]
|
||||
total_count = data["total_count"]
|
||||
non_null_count = data["non_null_count"]
|
||||
null_count = total_count - non_null_count
|
||||
null_rate = null_count / total_count if total_count > 0 else 0
|
||||
|
||||
column_null_rates[col_name] = round(null_rate, 4)
|
||||
total_null_count += null_count
|
||||
total_cell_count += total_count
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze null rate for column {col_name}: {str(e)}")
|
||||
column_null_rates[col_name] = 0.0
|
||||
|
||||
overall_null_rate = total_null_count / total_cell_count if total_cell_count > 0 else 0
|
||||
|
||||
return {
|
||||
"column_null_rates": column_null_rates,
|
||||
"overall_null_rate": round(overall_null_rate, 4),
|
||||
"columns_with_nulls": len([rate for rate in column_null_rates.values() if rate > 0])
|
||||
}
|
||||
|
||||
def _generate_quality_recommendations(self, quality_issues: List[Dict], null_analysis: Dict) -> List[Dict]:
|
||||
"""Generate data quality improvement recommendations"""
|
||||
recommendations = []
|
||||
|
||||
# Recommendations based on null analysis
|
||||
overall_null_rate = null_analysis.get("overall_null_rate", 0)
|
||||
if overall_null_rate > 0.1:
|
||||
recommendations.append({
|
||||
"type": "data_completeness",
|
||||
"priority": "high" if overall_null_rate > 0.3 else "medium",
|
||||
"description": f"Overall null rate is {overall_null_rate:.1%}",
|
||||
"action": "Review data collection and validation processes"
|
||||
})
|
||||
|
||||
# Recommendations based on quality issues
|
||||
for issue in quality_issues:
|
||||
if issue["issue_type"] == "high_null_rates":
|
||||
recommendations.append({
|
||||
"type": "column_completeness",
|
||||
"priority": issue["severity"],
|
||||
"description": issue["description"],
|
||||
"action": f"Focus on improving data completeness for: {', '.join(issue['affected_columns'][:3])}"
|
||||
})
|
||||
|
||||
return recommendations
|
||||
|
||||
def _generate_analysis_summary(self, distribution_analysis: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Generate high-level summary of distribution analysis"""
|
||||
summary = {
|
||||
"numeric_columns_count": len(distribution_analysis.get("numeric_columns", {})),
|
||||
"categorical_columns_count": len(distribution_analysis.get("categorical_columns", {})),
|
||||
"temporal_columns_count": len(distribution_analysis.get("temporal_columns", {}))
|
||||
}
|
||||
|
||||
# Identify interesting patterns
|
||||
patterns = []
|
||||
|
||||
# Check for highly skewed numeric columns
|
||||
numeric_cols = distribution_analysis.get("numeric_columns", {})
|
||||
skewed_cols = [
|
||||
col for col, info in numeric_cols.items()
|
||||
if isinstance(info, dict) and
|
||||
info.get("distribution_shape", {}).get("skewness_indicator") in ["right_skewed", "left_skewed"]
|
||||
]
|
||||
|
||||
if skewed_cols:
|
||||
patterns.append(f"Found {len(skewed_cols)} skewed numeric columns")
|
||||
|
||||
# Check for high cardinality categorical columns
|
||||
categorical_cols = distribution_analysis.get("categorical_columns", {})
|
||||
high_cardinality_cols = [
|
||||
col for col, info in categorical_cols.items()
|
||||
if isinstance(info, dict) and info.get("cardinality", 0) > 1000
|
||||
]
|
||||
|
||||
if high_cardinality_cols:
|
||||
patterns.append(f"Found {len(high_cardinality_cols)} high cardinality categorical columns")
|
||||
|
||||
summary["notable_patterns"] = patterns
|
||||
|
||||
return summary
|
||||
869
doris_mcp_server/utils/data_governance_tools.py
Normal file
869
doris_mcp_server/utils/data_governance_tools.py
Normal file
@@ -0,0 +1,869 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Data Governance Tools Module
|
||||
Provides data completeness analysis, field lineage tracking, and data freshness monitoring
|
||||
"""
|
||||
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DataGovernanceTools:
|
||||
"""Data governance tools suite"""
|
||||
|
||||
def __init__(self, connection_manager: DorisConnectionManager):
|
||||
self.connection_manager = connection_manager
|
||||
logger.info("DataGovernanceTools initialized")
|
||||
|
||||
|
||||
|
||||
async def trace_column_lineage(
|
||||
self,
|
||||
table_name: str,
|
||||
column_name: str,
|
||||
depth: int = 3,
|
||||
catalog_name: Optional[str] = None,
|
||||
db_name: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Column-level lineage tracing
|
||||
|
||||
Args:
|
||||
table_name: Table name
|
||||
column_name: Column name
|
||||
depth: Trace depth
|
||||
catalog_name: Catalog name
|
||||
db_name: Database name
|
||||
"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
connection = await self.connection_manager.get_connection("query")
|
||||
|
||||
full_table_name = self._build_full_table_name(table_name, catalog_name, db_name)
|
||||
target_column = f"{full_table_name}.{column_name}"
|
||||
|
||||
# 1. Verify target column exists
|
||||
if not await self._verify_column_exists(connection, full_table_name, column_name):
|
||||
return {"error": f"Column {column_name} not found in table {full_table_name}"}
|
||||
|
||||
# 2. Analyze SQL logs to get lineage relationships
|
||||
source_chain = await self._analyze_sql_logs_for_lineage(
|
||||
connection, full_table_name, column_name, depth
|
||||
)
|
||||
|
||||
# 3. Analyze downstream usage
|
||||
downstream_usage = await self._analyze_downstream_column_usage(
|
||||
connection, full_table_name, column_name
|
||||
)
|
||||
|
||||
# 4. Analyze field transformation rules
|
||||
transformation_rules = await self._extract_transformation_rules(
|
||||
connection, full_table_name, column_name
|
||||
)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
return {
|
||||
"target_column": target_column,
|
||||
"analysis_timestamp": datetime.now().isoformat(),
|
||||
"execution_time_seconds": round(execution_time, 3),
|
||||
"lineage_depth": depth,
|
||||
"source_chain": source_chain,
|
||||
"downstream_usage": downstream_usage,
|
||||
"transformation_rules": transformation_rules,
|
||||
"lineage_confidence": self._calculate_lineage_confidence(source_chain),
|
||||
"impact_analysis": {
|
||||
"upstream_dependencies": len(source_chain),
|
||||
"downstream_dependencies": len(downstream_usage),
|
||||
"risk_level": self._assess_lineage_risk(source_chain, downstream_usage)
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Column lineage tracing failed for {table_name}.{column_name}: {str(e)}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"target_column": f"{table_name}.{column_name}",
|
||||
"analysis_timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
async def monitor_data_freshness(
|
||||
self,
|
||||
tables: Optional[List[str]] = None,
|
||||
time_threshold_hours: int = 24,
|
||||
catalog_name: Optional[str] = None,
|
||||
db_name: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Data freshness monitoring
|
||||
|
||||
Args:
|
||||
tables: List of tables to monitor, empty means monitor all tables
|
||||
time_threshold_hours: Freshness threshold (hours)
|
||||
catalog_name: Catalog name
|
||||
db_name: Database name
|
||||
"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
connection = await self.connection_manager.get_connection("query")
|
||||
|
||||
# 1. Get list of tables to monitor
|
||||
if not tables:
|
||||
tables = await self._get_all_tables(connection, catalog_name, db_name)
|
||||
|
||||
# 2. Analyze freshness of each table
|
||||
table_freshness = {}
|
||||
fresh_count = 0
|
||||
stale_count = 0
|
||||
|
||||
for table in tables:
|
||||
full_table_name = self._build_full_table_name(table, catalog_name, db_name)
|
||||
freshness_info = await self._analyze_table_freshness(
|
||||
connection, full_table_name, time_threshold_hours
|
||||
)
|
||||
table_freshness[table] = freshness_info
|
||||
|
||||
if freshness_info["status"] == "fresh":
|
||||
fresh_count += 1
|
||||
else:
|
||||
stale_count += 1
|
||||
|
||||
# 3. Calculate overall freshness score
|
||||
total_tables = len(tables)
|
||||
overall_freshness_score = fresh_count / total_tables if total_tables > 0 else 0
|
||||
|
||||
# 4. Identify data flow issues
|
||||
data_flow_issues = await self._identify_data_flow_issues(table_freshness)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
return {
|
||||
"monitoring_timestamp": datetime.now().isoformat(),
|
||||
"execution_time_seconds": round(execution_time, 3),
|
||||
"monitoring_scope": {
|
||||
"catalog_name": catalog_name,
|
||||
"db_name": db_name,
|
||||
"time_threshold_hours": time_threshold_hours
|
||||
},
|
||||
"freshness_summary": {
|
||||
"total_tables": total_tables,
|
||||
"fresh_tables": fresh_count,
|
||||
"stale_tables": stale_count,
|
||||
"overall_freshness_score": round(overall_freshness_score, 3)
|
||||
},
|
||||
"table_freshness": table_freshness,
|
||||
"data_flow_issues": data_flow_issues,
|
||||
"alerts": self._generate_freshness_alerts(table_freshness, time_threshold_hours)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Data freshness monitoring failed: {str(e)}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"monitoring_timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# ==================== Private Helper Methods ====================
|
||||
|
||||
def _build_full_table_name(self, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> str:
|
||||
"""Build full table name - use three-level naming convention"""
|
||||
# Default catalog is internal for internal tables
|
||||
effective_catalog = catalog_name if catalog_name else "internal"
|
||||
|
||||
if db_name:
|
||||
return f"{effective_catalog}.{db_name}.{table_name}"
|
||||
else:
|
||||
# If db_name is not provided, need to determine current database
|
||||
return f"{effective_catalog}.{table_name}"
|
||||
|
||||
async def _get_table_basic_info(self, connection, table_name: str) -> Optional[Dict]:
|
||||
"""Get table basic information"""
|
||||
try:
|
||||
# Try to get table row count
|
||||
count_sql = f"SELECT COUNT(*) as row_count FROM {table_name}"
|
||||
result = await connection.execute(count_sql)
|
||||
|
||||
if result.data:
|
||||
return {"row_count": result.data[0]["row_count"]}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get basic info for table {table_name}: {str(e)}")
|
||||
return {"row_count": 0}
|
||||
|
||||
async def _get_table_columns_info(self, connection, table_name: str, catalog_name: Optional[str], db_name: Optional[str]) -> List[Dict]:
|
||||
"""Get table column information"""
|
||||
try:
|
||||
# Build query conditions
|
||||
where_conditions = [f"table_name = '{table_name}'"]
|
||||
|
||||
if db_name:
|
||||
where_conditions.append(f"table_schema = '{db_name}'")
|
||||
else:
|
||||
where_conditions.append("table_schema = DATABASE()")
|
||||
|
||||
columns_sql = f"""
|
||||
SELECT
|
||||
column_name,
|
||||
data_type,
|
||||
is_nullable,
|
||||
column_comment,
|
||||
ordinal_position
|
||||
FROM information_schema.columns
|
||||
WHERE {' AND '.join(where_conditions)}
|
||||
ORDER BY ordinal_position
|
||||
"""
|
||||
|
||||
result = await connection.execute(columns_sql)
|
||||
return result.data if result.data else []
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get columns info for table {table_name}: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _analyze_column_completeness(self, connection, table_name: str, columns_info: List[Dict]) -> Dict[str, Any]:
|
||||
"""Analyze column completeness"""
|
||||
column_completeness = {}
|
||||
|
||||
for column in columns_info:
|
||||
column_name = column["column_name"]
|
||||
try:
|
||||
# Calculate null value statistics
|
||||
null_sql = f"""
|
||||
SELECT
|
||||
COUNT(*) as total_count,
|
||||
COUNT({column_name}) as non_null_count,
|
||||
COUNT(*) - COUNT({column_name}) as null_count
|
||||
FROM {table_name}
|
||||
"""
|
||||
|
||||
result = await connection.execute(null_sql)
|
||||
if result.data:
|
||||
stats = result.data[0]
|
||||
total_count = stats["total_count"]
|
||||
null_count = stats["null_count"]
|
||||
null_rate = null_count / total_count if total_count > 0 else 0
|
||||
completeness_score = 1.0 - null_rate
|
||||
|
||||
column_completeness[column_name] = {
|
||||
"data_type": column["data_type"],
|
||||
"is_nullable": column["is_nullable"],
|
||||
"total_count": total_count,
|
||||
"null_count": null_count,
|
||||
"non_null_count": stats["non_null_count"],
|
||||
"null_rate": round(null_rate, 4),
|
||||
"completeness_score": round(completeness_score, 4)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze completeness for column {column_name}: {str(e)}")
|
||||
column_completeness[column_name] = {
|
||||
"error": str(e),
|
||||
"completeness_score": 0.0
|
||||
}
|
||||
|
||||
return column_completeness
|
||||
|
||||
async def _check_business_rule_compliance(self, connection, table_name: str, business_rules: List[Dict], total_rows: int) -> Dict[str, Any]:
|
||||
"""Check business rule compliance"""
|
||||
compliance_results = {}
|
||||
|
||||
for rule in business_rules:
|
||||
rule_name = rule.get("rule_name", "unknown")
|
||||
sql_condition = rule.get("sql_condition", "")
|
||||
|
||||
if not sql_condition:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Check number of records meeting conditions
|
||||
compliance_sql = f"""
|
||||
SELECT
|
||||
COUNT(*) as total_count,
|
||||
SUM(CASE WHEN {sql_condition} THEN 1 ELSE 0 END) as pass_count
|
||||
FROM {table_name}
|
||||
"""
|
||||
|
||||
result = await connection.execute(compliance_sql)
|
||||
if result.data:
|
||||
stats = result.data[0]
|
||||
pass_count = stats["pass_count"] or 0
|
||||
fail_count = total_rows - pass_count
|
||||
pass_rate = pass_count / total_rows if total_rows > 0 else 0
|
||||
|
||||
compliance_results[rule_name] = {
|
||||
"rule_condition": sql_condition,
|
||||
"total_records": total_rows,
|
||||
"pass_count": pass_count,
|
||||
"fail_count": fail_count,
|
||||
"pass_rate": round(pass_rate, 4),
|
||||
"compliance_score": round(pass_rate, 4)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to check business rule {rule_name}: {str(e)}")
|
||||
compliance_results[rule_name] = {
|
||||
"error": str(e),
|
||||
"compliance_score": 0.0
|
||||
}
|
||||
|
||||
return compliance_results
|
||||
|
||||
async def _detect_data_integrity_issues(self, connection, table_name: str, columns_info: List[Dict]) -> List[Dict]:
|
||||
"""Detect data integrity issues"""
|
||||
issues = []
|
||||
|
||||
try:
|
||||
# Detect duplicate values in primary key fields
|
||||
primary_key_columns = [col["column_name"] for col in columns_info if "primary" in col.get("column_comment", "").lower()]
|
||||
|
||||
for pk_col in primary_key_columns:
|
||||
duplicate_sql = f"""
|
||||
SELECT COUNT(*) as duplicate_count
|
||||
FROM (
|
||||
SELECT {pk_col}, COUNT(*) as cnt
|
||||
FROM {table_name}
|
||||
WHERE {pk_col} IS NOT NULL
|
||||
GROUP BY {pk_col}
|
||||
HAVING COUNT(*) > 1
|
||||
) t
|
||||
"""
|
||||
|
||||
result = await connection.execute(duplicate_sql)
|
||||
if result.data and result.data[0]["duplicate_count"] > 0:
|
||||
issues.append({
|
||||
"type": "duplicate_primary_keys",
|
||||
"column": pk_col,
|
||||
"count": result.data[0]["duplicate_count"],
|
||||
"severity": "high",
|
||||
"description": f"Found duplicate values in primary key column {pk_col}"
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to detect integrity issues: {str(e)}")
|
||||
issues.append({
|
||||
"type": "detection_error",
|
||||
"error": str(e),
|
||||
"severity": "unknown"
|
||||
})
|
||||
|
||||
return issues
|
||||
|
||||
def _calculate_completeness_score(self, column_completeness: Dict, business_rule_compliance: Dict) -> float:
|
||||
"""Calculate overall completeness score"""
|
||||
if not column_completeness:
|
||||
return 0.0
|
||||
|
||||
# Calculate column completeness average score
|
||||
column_scores = [
|
||||
col_info.get("completeness_score", 0.0)
|
||||
for col_info in column_completeness.values()
|
||||
if isinstance(col_info, dict) and "completeness_score" in col_info
|
||||
]
|
||||
avg_column_score = sum(column_scores) / len(column_scores) if column_scores else 0.0
|
||||
|
||||
# Calculate business rule compliance average score
|
||||
compliance_scores = [
|
||||
rule_info.get("compliance_score", 0.0)
|
||||
for rule_info in business_rule_compliance.values()
|
||||
if isinstance(rule_info, dict) and "compliance_score" in rule_info
|
||||
]
|
||||
avg_compliance_score = sum(compliance_scores) / len(compliance_scores) if compliance_scores else 1.0
|
||||
|
||||
# Comprehensive score (column completeness weight 70%, business rules weight 30%)
|
||||
overall_score = avg_column_score * 0.7 + avg_compliance_score * 0.3
|
||||
return round(overall_score, 4)
|
||||
|
||||
def _generate_completeness_recommendations(self, column_completeness: Dict, integrity_issues: List[Dict]) -> List[Dict]:
|
||||
"""Generate completeness improvement recommendations"""
|
||||
recommendations = []
|
||||
|
||||
# Generate recommendations based on column completeness
|
||||
for col_name, col_info in column_completeness.items():
|
||||
if isinstance(col_info, dict):
|
||||
null_rate = col_info.get("null_rate", 0)
|
||||
if null_rate > 0.1: # Null rate exceeds 10%
|
||||
recommendations.append({
|
||||
"type": "high_null_rate",
|
||||
"column": col_name,
|
||||
"priority": "high" if null_rate > 0.5 else "medium",
|
||||
"description": f"Column {col_name} has high null rate ({null_rate:.1%})",
|
||||
"suggested_action": "Review data collection process or add data validation"
|
||||
})
|
||||
|
||||
# Generate recommendations based on integrity issues
|
||||
for issue in integrity_issues:
|
||||
if issue["type"] == "duplicate_primary_keys":
|
||||
recommendations.append({
|
||||
"type": "data_deduplication",
|
||||
"column": issue["column"],
|
||||
"priority": "high",
|
||||
"description": f"Duplicate primary key values found in {issue['column']}",
|
||||
"suggested_action": "Implement unique constraint or data deduplication process"
|
||||
})
|
||||
|
||||
return recommendations
|
||||
|
||||
async def _verify_column_exists(self, connection, table_name: str, column_name: str) -> bool:
|
||||
"""Verify if column exists"""
|
||||
try:
|
||||
# Simple verification method: try to query the column
|
||||
verify_sql = f"SELECT {column_name} FROM {table_name} LIMIT 1"
|
||||
await connection.execute(verify_sql)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def _analyze_sql_logs_for_lineage(self, connection, table_name: str, column_name: str, depth: int) -> List[Dict]:
|
||||
"""Analyze SQL logs to get lineage relationships (simplified implementation)"""
|
||||
# Note: This is a simplified implementation, actual environment needs to analyze audit logs
|
||||
source_chain = []
|
||||
|
||||
try:
|
||||
# Try to find related INSERT/CREATE TABLE AS SELECT statements from audit logs (one year range)
|
||||
audit_sql = """
|
||||
SELECT
|
||||
stmt as sql_statement,
|
||||
`time` as execution_time,
|
||||
`user` as user_name
|
||||
FROM internal.__internal_schema.audit_log
|
||||
WHERE stmt LIKE '%{}%'
|
||||
AND (stmt LIKE '%INSERT%' OR stmt LIKE '%CREATE%' OR stmt LIKE '%SELECT%')
|
||||
AND `time` >= DATE_SUB(NOW(), INTERVAL 1 YEAR)
|
||||
ORDER BY `time` DESC
|
||||
LIMIT 50
|
||||
""".format(table_name.split('.')[-1]) # Use the last part of table name
|
||||
|
||||
result = await connection.execute(audit_sql)
|
||||
|
||||
if result.data:
|
||||
for i, log_entry in enumerate(result.data[:depth]):
|
||||
# Simplified lineage analysis: extract possible source tables
|
||||
sql_stmt = log_entry.get("sql_statement", "")
|
||||
source_tables = self._extract_source_tables_from_sql(sql_stmt)
|
||||
|
||||
if source_tables:
|
||||
# Handle datetime serialization issue
|
||||
execution_time = log_entry.get("execution_time")
|
||||
if execution_time and hasattr(execution_time, 'isoformat'):
|
||||
execution_time = execution_time.isoformat()
|
||||
elif execution_time:
|
||||
execution_time = str(execution_time)
|
||||
|
||||
source_chain.append({
|
||||
"level": i + 1,
|
||||
"source_table": source_tables[0], # Take the first as main source table
|
||||
"source_column": column_name, # Simplified: assume same name
|
||||
"transformation": self._extract_transformation_from_sql(sql_stmt, column_name),
|
||||
"confidence": 0.8 - (i * 0.1), # Decreasing confidence
|
||||
"execution_time": execution_time,
|
||||
"user": log_entry.get("user_name")
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze SQL logs for lineage: {str(e)}")
|
||||
# If unable to get from audit logs, return basic information
|
||||
source_chain = [{
|
||||
"level": 1,
|
||||
"source_table": "unknown_source",
|
||||
"source_column": column_name,
|
||||
"transformation": "unknown",
|
||||
"confidence": 0.3,
|
||||
"note": "Limited lineage information available"
|
||||
}]
|
||||
|
||||
return source_chain
|
||||
|
||||
def _extract_source_tables_from_sql(self, sql: str) -> List[str]:
|
||||
"""Extract source table names from SQL statement (simplified implementation)"""
|
||||
# Simplified regex to match table names in FROM clause
|
||||
from_pattern = r'\bFROM\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)'
|
||||
join_pattern = r'\bJOIN\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)'
|
||||
|
||||
tables = []
|
||||
|
||||
# Find tables in FROM clause
|
||||
from_matches = re.findall(from_pattern, sql, re.IGNORECASE)
|
||||
tables.extend(from_matches)
|
||||
|
||||
# Find tables in JOIN clause
|
||||
join_matches = re.findall(join_pattern, sql, re.IGNORECASE)
|
||||
tables.extend(join_matches)
|
||||
|
||||
return list(set(tables)) # Remove duplicates
|
||||
|
||||
def _extract_transformation_from_sql(self, sql: str, column_name: str) -> str:
|
||||
"""Extract field transformation rules from SQL statement (simplified implementation)"""
|
||||
# Simplified implementation: find expressions containing target field
|
||||
lines = sql.split('\n')
|
||||
for line in lines:
|
||||
if column_name in line and ('SELECT' in line.upper() or '=' in line):
|
||||
return line.strip()
|
||||
|
||||
return "direct_copy"
|
||||
|
||||
async def _analyze_downstream_column_usage(self, connection, table_name: str, column_name: str) -> List[Dict]:
|
||||
"""Analyze downstream usage of field (simplified implementation)"""
|
||||
downstream_usage = []
|
||||
|
||||
try:
|
||||
# Find other tables that might use this field (through audit logs, one year range)
|
||||
usage_sql = """
|
||||
SELECT DISTINCT
|
||||
stmt as sql_statement
|
||||
FROM internal.__internal_schema.audit_log
|
||||
WHERE stmt LIKE '%{}%'
|
||||
AND stmt LIKE '%{}%'
|
||||
AND stmt LIKE '%SELECT%'
|
||||
AND `time` >= DATE_SUB(NOW(), INTERVAL 1 YEAR)
|
||||
LIMIT 20
|
||||
""".format(table_name.split('.')[-1], column_name)
|
||||
|
||||
result = await connection.execute(usage_sql)
|
||||
|
||||
if result.data:
|
||||
for entry in result.data:
|
||||
sql_stmt = entry.get("sql_statement", "")
|
||||
target_tables = self._extract_target_tables_from_sql(sql_stmt)
|
||||
|
||||
for target_table in target_tables:
|
||||
if target_table != table_name.split('.')[-1]: # Not the source table itself
|
||||
downstream_usage.append({
|
||||
"table": target_table,
|
||||
"column": column_name, # Simplified: assume same name
|
||||
"usage_type": "select_reference",
|
||||
"confidence": 0.7
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze downstream usage: {str(e)}")
|
||||
|
||||
return downstream_usage
|
||||
|
||||
def _extract_target_tables_from_sql(self, sql: str) -> List[str]:
|
||||
"""Extract target table names from SQL statement"""
|
||||
# Find target tables in INSERT INTO or CREATE TABLE statements
|
||||
insert_pattern = r'\bINSERT\s+INTO\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)'
|
||||
create_pattern = r'\bCREATE\s+TABLE\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)'
|
||||
|
||||
tables = []
|
||||
|
||||
insert_matches = re.findall(insert_pattern, sql, re.IGNORECASE)
|
||||
tables.extend(insert_matches)
|
||||
|
||||
create_matches = re.findall(create_pattern, sql, re.IGNORECASE)
|
||||
tables.extend(create_matches)
|
||||
|
||||
return list(set(tables))
|
||||
|
||||
async def _extract_transformation_rules(self, connection, table_name: str, column_name: str) -> List[Dict]:
|
||||
"""Extract field transformation rules"""
|
||||
# Simplified implementation: return basic transformation information
|
||||
return [{
|
||||
"transformation_type": "unknown",
|
||||
"description": "Transformation rules analysis requires detailed ETL metadata",
|
||||
"confidence": 0.5
|
||||
}]
|
||||
|
||||
def _calculate_lineage_confidence(self, source_chain: List[Dict]) -> float:
|
||||
"""Calculate overall confidence of lineage tracing"""
|
||||
if not source_chain:
|
||||
return 0.0
|
||||
|
||||
confidences = [item.get("confidence", 0.0) for item in source_chain]
|
||||
return round(sum(confidences) / len(confidences), 3)
|
||||
|
||||
def _assess_lineage_risk(self, source_chain: List[Dict], downstream_usage: List[Dict]) -> str:
|
||||
"""Assess lineage risk level"""
|
||||
if len(downstream_usage) > 10:
|
||||
return "high"
|
||||
elif len(downstream_usage) > 5:
|
||||
return "medium"
|
||||
else:
|
||||
return "low"
|
||||
|
||||
async def _get_all_tables(self, connection, catalog_name: Optional[str], db_name: Optional[str]) -> List[str]:
|
||||
"""Get list of all tables"""
|
||||
try:
|
||||
where_conditions = []
|
||||
|
||||
if db_name:
|
||||
where_conditions.append(f"table_schema = '{db_name}'")
|
||||
else:
|
||||
where_conditions.append("table_schema = DATABASE()")
|
||||
|
||||
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
|
||||
|
||||
tables_sql = f"""
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
WHERE {where_clause}
|
||||
AND table_type = 'BASE TABLE'
|
||||
ORDER BY table_name
|
||||
"""
|
||||
|
||||
result = await connection.execute(tables_sql)
|
||||
return [row["table_name"] for row in result.data] if result.data else []
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get table list: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _analyze_table_freshness(self, connection, table_name: str, threshold_hours: int) -> Dict[str, Any]:
|
||||
"""Analyze freshness of single table"""
|
||||
try:
|
||||
# Try multiple methods to get table's last update time
|
||||
freshness_methods = [
|
||||
self._get_freshness_from_partition_info,
|
||||
self._get_freshness_from_max_timestamp,
|
||||
self._get_freshness_from_table_metadata
|
||||
]
|
||||
|
||||
last_update = None
|
||||
method_used = "unknown"
|
||||
|
||||
for method in freshness_methods:
|
||||
try:
|
||||
result = await method(connection, table_name)
|
||||
if result:
|
||||
last_update = result["last_update"]
|
||||
method_used = result["method"]
|
||||
break
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
if not last_update:
|
||||
return {
|
||||
"last_update": None,
|
||||
"staleness_hours": None,
|
||||
"freshness_score": 0.0,
|
||||
"status": "unknown",
|
||||
"method_used": "none",
|
||||
"error": "Unable to determine last update time"
|
||||
}
|
||||
|
||||
# Calculate data staleness
|
||||
now = datetime.now()
|
||||
if isinstance(last_update, str):
|
||||
last_update = datetime.fromisoformat(last_update.replace('Z', '+00:00'))
|
||||
|
||||
staleness_hours = (now - last_update).total_seconds() / 3600
|
||||
|
||||
# Calculate freshness score and status
|
||||
if staleness_hours <= threshold_hours:
|
||||
status = "fresh"
|
||||
freshness_score = max(0.0, 1.0 - (staleness_hours / threshold_hours))
|
||||
else:
|
||||
status = "stale"
|
||||
freshness_score = max(0.0, 1.0 - (staleness_hours / (threshold_hours * 2)))
|
||||
|
||||
return {
|
||||
"last_update": last_update.isoformat() if hasattr(last_update, 'isoformat') else str(last_update),
|
||||
"staleness_hours": round(staleness_hours, 2),
|
||||
"freshness_score": round(freshness_score, 3),
|
||||
"status": status,
|
||||
"method_used": method_used,
|
||||
"threshold_hours": threshold_hours
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze freshness for table {table_name}: {str(e)}")
|
||||
return {
|
||||
"last_update": None,
|
||||
"staleness_hours": None,
|
||||
"freshness_score": 0.0,
|
||||
"status": "error",
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
async def _get_freshness_from_partition_info(self, connection, table_name: str) -> Optional[Dict]:
|
||||
"""Get freshness from partition information"""
|
||||
try:
|
||||
# Query partition information (if table has partitions)
|
||||
partition_sql = f"""
|
||||
SELECT MAX(CREATE_TIME) as last_update
|
||||
FROM information_schema.partitions
|
||||
WHERE table_name = '{table_name.split('.')[-1]}'
|
||||
AND CREATE_TIME IS NOT NULL
|
||||
"""
|
||||
|
||||
result = await connection.execute(partition_sql)
|
||||
if result.data and result.data[0]["last_update"]:
|
||||
return {
|
||||
"last_update": result.data[0]["last_update"],
|
||||
"method": "partition_info"
|
||||
}
|
||||
return None
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def _get_freshness_from_max_timestamp(self, connection, table_name: str) -> Optional[Dict]:
|
||||
"""Get freshness from timestamp fields"""
|
||||
try:
|
||||
# Find possible timestamp fields
|
||||
timestamp_columns = await self._find_timestamp_columns(connection, table_name)
|
||||
|
||||
if timestamp_columns:
|
||||
max_time_sql = f"""
|
||||
SELECT MAX({timestamp_columns[0]}) as last_update
|
||||
FROM {table_name}
|
||||
"""
|
||||
|
||||
result = await connection.execute(max_time_sql)
|
||||
if result.data and result.data[0]["last_update"]:
|
||||
return {
|
||||
"last_update": result.data[0]["last_update"],
|
||||
"method": f"max_timestamp({timestamp_columns[0]})"
|
||||
}
|
||||
return None
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def _get_freshness_from_table_metadata(self, connection, table_name: str) -> Optional[Dict]:
|
||||
"""Get freshness from table metadata"""
|
||||
try:
|
||||
# Query table's update time
|
||||
metadata_sql = f"""
|
||||
SELECT UPDATE_TIME as last_update
|
||||
FROM information_schema.tables
|
||||
WHERE table_name = '{table_name.split('.')[-1]}'
|
||||
AND UPDATE_TIME IS NOT NULL
|
||||
"""
|
||||
|
||||
result = await connection.execute(metadata_sql)
|
||||
if result.data and result.data[0]["last_update"]:
|
||||
return {
|
||||
"last_update": result.data[0]["last_update"],
|
||||
"method": "table_metadata"
|
||||
}
|
||||
return None
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def _find_timestamp_columns(self, connection, table_name: str) -> List[str]:
|
||||
"""Find possible timestamp fields"""
|
||||
try:
|
||||
timestamp_sql = f"""
|
||||
SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_name = '{table_name.split('.')[-1]}'
|
||||
AND (
|
||||
data_type IN ('datetime', 'timestamp', 'date')
|
||||
OR column_name LIKE '%time%'
|
||||
OR column_name LIKE '%date%'
|
||||
OR column_name LIKE '%created%'
|
||||
OR column_name LIKE '%updated%'
|
||||
)
|
||||
ORDER BY
|
||||
CASE
|
||||
WHEN column_name LIKE '%updated%' THEN 1
|
||||
WHEN column_name LIKE '%created%' THEN 2
|
||||
WHEN column_name LIKE '%time%' THEN 3
|
||||
ELSE 4
|
||||
END
|
||||
"""
|
||||
|
||||
result = await connection.execute(timestamp_sql)
|
||||
return [row["column_name"] for row in result.data] if result.data else []
|
||||
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
async def _identify_data_flow_issues(self, table_freshness: Dict[str, Any]) -> List[Dict]:
|
||||
"""Identify data flow issues"""
|
||||
issues = []
|
||||
|
||||
# Identify consecutively stale tables (may indicate ETL process issues)
|
||||
stale_tables = [
|
||||
table_name for table_name, info in table_freshness.items()
|
||||
if info.get("status") == "stale"
|
||||
]
|
||||
|
||||
if len(stale_tables) > len(table_freshness) * 0.3: # More than 30% of tables are stale
|
||||
issues.append({
|
||||
"issue_type": "widespread_staleness",
|
||||
"severity": "high",
|
||||
"affected_tables": len(stale_tables),
|
||||
"total_tables": len(table_freshness),
|
||||
"description": f"High percentage of stale tables ({len(stale_tables)}/{len(table_freshness)})",
|
||||
"possible_causes": ["ETL pipeline failure", "Data source issues", "Processing delays"]
|
||||
})
|
||||
|
||||
# Identify particularly stale tables
|
||||
very_stale_tables = [
|
||||
(table_name, info.get("staleness_hours", 0))
|
||||
for table_name, info in table_freshness.items()
|
||||
if info.get("staleness_hours", 0) > 72 # More than 3 days
|
||||
]
|
||||
|
||||
if very_stale_tables:
|
||||
issues.append({
|
||||
"issue_type": "very_stale_data",
|
||||
"severity": "medium",
|
||||
"affected_tables": [table for table, _ in very_stale_tables],
|
||||
"max_staleness_hours": max(hours for _, hours in very_stale_tables),
|
||||
"description": "Some tables have very stale data (>72 hours)",
|
||||
"recommendation": "Check data ingestion processes for affected tables"
|
||||
})
|
||||
|
||||
return issues
|
||||
|
||||
def _generate_freshness_alerts(self, table_freshness: Dict[str, Any], threshold_hours: int) -> List[Dict]:
|
||||
"""Generate freshness alerts"""
|
||||
alerts = []
|
||||
|
||||
for table_name, info in table_freshness.items():
|
||||
staleness_hours = info.get("staleness_hours")
|
||||
status = info.get("status")
|
||||
|
||||
if status == "stale" and staleness_hours:
|
||||
if staleness_hours > threshold_hours * 2: # Exceeds threshold by 2x
|
||||
alert_level = "critical"
|
||||
elif staleness_hours > threshold_hours * 1.5: # Exceeds threshold by 1.5x
|
||||
alert_level = "warning"
|
||||
else:
|
||||
alert_level = "info"
|
||||
|
||||
alerts.append({
|
||||
"alert_level": alert_level,
|
||||
"table_name": table_name,
|
||||
"staleness_hours": staleness_hours,
|
||||
"threshold_hours": threshold_hours,
|
||||
"message": f"Table {table_name} is stale ({staleness_hours:.1f} hours old, threshold: {threshold_hours}h)",
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
elif status == "error":
|
||||
alerts.append({
|
||||
"alert_level": "error",
|
||||
"table_name": table_name,
|
||||
"message": f"Unable to determine freshness for table {table_name}",
|
||||
"error": info.get("error"),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
return alerts
|
||||
1342
doris_mcp_server/utils/data_quality_tools.py
Normal file
1342
doris_mcp_server/utils/data_quality_tools.py
Normal file
File diff suppressed because it is too large
Load Diff
978
doris_mcp_server/utils/dependency_analysis_tools.py
Normal file
978
doris_mcp_server/utils/dependency_analysis_tools.py
Normal file
@@ -0,0 +1,978 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Dependency Analysis Tools Module
|
||||
Provides data flow dependency analysis and impact assessment capabilities
|
||||
"""
|
||||
|
||||
import time
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
from collections import defaultdict, deque
|
||||
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DependencyAnalysisTools:
|
||||
"""Dependency analysis tools for data flow and impact assessment"""
|
||||
|
||||
def __init__(self, connection_manager: DorisConnectionManager):
|
||||
self.connection_manager = connection_manager
|
||||
logger.info("DependencyAnalysisTools initialized")
|
||||
|
||||
async def analyze_data_flow_dependencies(
|
||||
self,
|
||||
target_table: Optional[str] = None,
|
||||
analysis_depth: int = 3,
|
||||
include_views: bool = True,
|
||||
catalog_name: Optional[str] = None,
|
||||
db_name: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Analyze data flow dependencies and impact relationships
|
||||
|
||||
Args:
|
||||
target_table: Specific table to analyze (if None, analyzes all tables)
|
||||
analysis_depth: Maximum depth for dependency traversal
|
||||
include_views: Whether to include views in dependency analysis
|
||||
catalog_name: Catalog name
|
||||
db_name: Database name
|
||||
|
||||
Returns:
|
||||
Comprehensive dependency analysis results
|
||||
"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
connection = await self.connection_manager.get_connection("query")
|
||||
|
||||
# 1. Get table metadata and relationships
|
||||
tables_metadata = await self._get_tables_metadata(connection, catalog_name, db_name, include_views)
|
||||
|
||||
if not tables_metadata:
|
||||
return {
|
||||
"error": "No tables found for dependency analysis",
|
||||
"analysis_timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# 2. Build dependency graph from SQL analysis
|
||||
dependency_graph = await self._build_dependency_graph(connection, tables_metadata, analysis_depth)
|
||||
|
||||
# 3. Analyze specific table or all tables
|
||||
if target_table:
|
||||
# Analyze specific table
|
||||
table_analysis = await self._analyze_single_table_dependencies(
|
||||
target_table, dependency_graph, tables_metadata
|
||||
)
|
||||
impact_analysis = await self._calculate_impact_analysis(
|
||||
target_table, dependency_graph, "both"
|
||||
)
|
||||
else:
|
||||
# Analyze all tables
|
||||
table_analysis = await self._analyze_all_tables_dependencies(
|
||||
dependency_graph, tables_metadata
|
||||
)
|
||||
impact_analysis = await self._calculate_global_impact_analysis(dependency_graph)
|
||||
|
||||
# 4. Generate insights and recommendations
|
||||
dependency_insights = await self._generate_dependency_insights(
|
||||
dependency_graph, table_analysis, impact_analysis
|
||||
)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
return {
|
||||
"analysis_target": target_table or "all_tables",
|
||||
"analysis_timestamp": datetime.now().isoformat(),
|
||||
"execution_time_seconds": round(execution_time, 3),
|
||||
"tables_analyzed": len(tables_metadata),
|
||||
"dependency_graph_stats": self._get_dependency_graph_stats(dependency_graph),
|
||||
"table_dependencies": table_analysis,
|
||||
"impact_analysis": impact_analysis,
|
||||
"dependency_insights": dependency_insights,
|
||||
"recommendations": self._generate_dependency_recommendations(dependency_insights)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Data flow dependency analysis failed: {str(e)}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"analysis_timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# ==================== Private Helper Methods ====================
|
||||
|
||||
async def _get_tables_metadata(self, connection, catalog_name: Optional[str], db_name: Optional[str], include_views: bool) -> List[Dict]:
|
||||
"""Get metadata for all tables and views"""
|
||||
try:
|
||||
# Build conditions for query
|
||||
where_conditions = []
|
||||
if db_name:
|
||||
where_conditions.append(f"table_schema = '{db_name}'")
|
||||
else:
|
||||
where_conditions.append("table_schema = DATABASE()")
|
||||
|
||||
table_types = ["'BASE TABLE'"]
|
||||
if include_views:
|
||||
table_types.append("'VIEW'")
|
||||
|
||||
where_conditions.append(f"table_type IN ({','.join(table_types)})")
|
||||
|
||||
metadata_sql = f"""
|
||||
SELECT
|
||||
table_schema as schema_name,
|
||||
table_name,
|
||||
table_type,
|
||||
table_comment,
|
||||
table_rows,
|
||||
data_length
|
||||
FROM information_schema.tables
|
||||
WHERE {' AND '.join(where_conditions)}
|
||||
ORDER BY table_schema, table_name
|
||||
"""
|
||||
|
||||
result = await connection.execute(metadata_sql)
|
||||
return result.data if result.data else []
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get tables metadata: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _build_dependency_graph(self, connection, tables_metadata: List[Dict], analysis_depth: int) -> Dict[str, Dict]:
|
||||
"""Build dependency graph by analyzing SQL statements and DDL"""
|
||||
dependency_graph = defaultdict(lambda: {
|
||||
"upstream_dependencies": set(),
|
||||
"downstream_dependencies": set(),
|
||||
"table_type": "unknown",
|
||||
"dependency_strength": {},
|
||||
"sql_patterns": []
|
||||
})
|
||||
|
||||
# Initialize graph with table metadata
|
||||
for table in tables_metadata:
|
||||
table_name = table["table_name"]
|
||||
schema_name = table.get("schema_name", "")
|
||||
full_table_name = f"{schema_name}.{table_name}" if schema_name else table_name
|
||||
|
||||
dependency_graph[full_table_name]["table_type"] = table["table_type"]
|
||||
|
||||
# 1. Analyze view definitions for dependencies
|
||||
await self._analyze_view_dependencies(connection, dependency_graph, tables_metadata)
|
||||
|
||||
# 2. Analyze audit logs for runtime dependencies
|
||||
await self._analyze_runtime_dependencies(connection, dependency_graph, analysis_depth)
|
||||
|
||||
# 3. Analyze foreign key relationships
|
||||
await self._analyze_foreign_key_dependencies(connection, dependency_graph, tables_metadata)
|
||||
|
||||
return dict(dependency_graph)
|
||||
|
||||
async def _analyze_view_dependencies(self, connection, dependency_graph: Dict, tables_metadata: List[Dict]) -> None:
|
||||
"""Analyze view definitions to extract table dependencies"""
|
||||
try:
|
||||
for table in tables_metadata:
|
||||
if table["table_type"] == "VIEW":
|
||||
table_name = table["table_name"]
|
||||
schema_name = table.get("schema_name", "")
|
||||
|
||||
# Get view definition
|
||||
view_def_sql = f"SHOW CREATE VIEW {schema_name}.{table_name}" if schema_name else f"SHOW CREATE VIEW {table_name}"
|
||||
|
||||
try:
|
||||
result = await connection.execute(view_def_sql)
|
||||
if result.data and len(result.data) > 0:
|
||||
# Extract view definition from result
|
||||
view_definition = ""
|
||||
for row in result.data:
|
||||
for key, value in row.items():
|
||||
if "create" in key.lower() and value:
|
||||
view_definition = str(value)
|
||||
break
|
||||
|
||||
if view_definition:
|
||||
# Extract table dependencies from view definition
|
||||
referenced_tables = self._extract_table_references(view_definition)
|
||||
|
||||
full_view_name = f"{schema_name}.{table_name}" if schema_name else table_name
|
||||
|
||||
for ref_table in referenced_tables:
|
||||
# Add upstream dependency
|
||||
dependency_graph[full_view_name]["upstream_dependencies"].add(ref_table)
|
||||
dependency_graph[full_view_name]["dependency_strength"][ref_table] = "direct"
|
||||
|
||||
# Add downstream dependency for referenced table
|
||||
dependency_graph[ref_table]["downstream_dependencies"].add(full_view_name)
|
||||
|
||||
dependency_graph[full_view_name]["sql_patterns"].append({
|
||||
"pattern_type": "view_definition",
|
||||
"referenced_table": ref_table,
|
||||
"confidence": 1.0
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze view {table_name}: {str(e)}")
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze view dependencies: {str(e)}")
|
||||
|
||||
async def _analyze_runtime_dependencies(self, connection, dependency_graph: Dict, analysis_depth: int) -> None:
|
||||
"""Analyze audit logs to discover runtime table dependencies"""
|
||||
try:
|
||||
# Get recent SQL statements from audit logs
|
||||
audit_sql = """
|
||||
SELECT
|
||||
`stmt` as sql_statement,
|
||||
`user` as user_name,
|
||||
COUNT(*) as frequency
|
||||
FROM internal.__internal_schema.audit_log
|
||||
WHERE `stmt` IS NOT NULL
|
||||
AND `stmt` != ''
|
||||
AND `time` >= DATE_SUB(NOW(), INTERVAL 1 YEAR)
|
||||
GROUP BY `stmt`, `user`
|
||||
HAVING frequency > 1
|
||||
ORDER BY frequency DESC
|
||||
LIMIT 1000
|
||||
"""
|
||||
|
||||
result = await connection.execute(audit_sql)
|
||||
|
||||
if result.data:
|
||||
for row in result.data:
|
||||
sql_statement = row.get("sql_statement", "")
|
||||
frequency = row.get("frequency", 1)
|
||||
|
||||
if sql_statement:
|
||||
# Extract table references from SQL
|
||||
referenced_tables = self._extract_table_references(sql_statement)
|
||||
|
||||
if len(referenced_tables) > 1:
|
||||
# Infer dependencies from multi-table queries
|
||||
self._infer_dependencies_from_sql(
|
||||
dependency_graph, sql_statement, referenced_tables, frequency
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze runtime dependencies: {str(e)}")
|
||||
|
||||
async def _analyze_foreign_key_dependencies(self, connection, dependency_graph: Dict, tables_metadata: List[Dict]) -> None:
|
||||
"""Analyze foreign key constraints for explicit dependencies"""
|
||||
try:
|
||||
# Get foreign key information
|
||||
fk_sql = """
|
||||
SELECT
|
||||
TABLE_SCHEMA as schema_name,
|
||||
TABLE_NAME as table_name,
|
||||
COLUMN_NAME as column_name,
|
||||
REFERENCED_TABLE_SCHEMA as ref_schema,
|
||||
REFERENCED_TABLE_NAME as ref_table_name,
|
||||
REFERENCED_COLUMN_NAME as ref_column_name
|
||||
FROM information_schema.KEY_COLUMN_USAGE
|
||||
WHERE REFERENCED_TABLE_NAME IS NOT NULL
|
||||
"""
|
||||
|
||||
result = await connection.execute(fk_sql)
|
||||
|
||||
if result.data:
|
||||
for row in result.data:
|
||||
schema_name = row.get("schema_name", "")
|
||||
table_name = row["table_name"]
|
||||
ref_schema = row.get("ref_schema", "")
|
||||
ref_table_name = row["ref_table_name"]
|
||||
|
||||
# Build full table names
|
||||
full_table_name = f"{schema_name}.{table_name}" if schema_name else table_name
|
||||
full_ref_table = f"{ref_schema}.{ref_table_name}" if ref_schema else ref_table_name
|
||||
|
||||
# Add foreign key dependency
|
||||
dependency_graph[full_table_name]["upstream_dependencies"].add(full_ref_table)
|
||||
dependency_graph[full_table_name]["dependency_strength"][full_ref_table] = "foreign_key"
|
||||
dependency_graph[full_ref_table]["downstream_dependencies"].add(full_table_name)
|
||||
|
||||
dependency_graph[full_table_name]["sql_patterns"].append({
|
||||
"pattern_type": "foreign_key",
|
||||
"referenced_table": full_ref_table,
|
||||
"confidence": 1.0,
|
||||
"column": row["column_name"],
|
||||
"ref_column": row["ref_column_name"]
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze foreign key dependencies: {str(e)}")
|
||||
|
||||
def _extract_table_references(self, sql: str) -> List[str]:
|
||||
"""Extract table references from SQL statement"""
|
||||
if not sql:
|
||||
return []
|
||||
|
||||
# Normalize SQL
|
||||
sql = re.sub(r'/\*.*?\*/', '', sql, flags=re.DOTALL) # Remove comments
|
||||
sql = re.sub(r'--.*', '', sql) # Remove line comments
|
||||
sql = sql.upper()
|
||||
|
||||
table_references = []
|
||||
|
||||
# Pattern to match table names in various contexts
|
||||
patterns = [
|
||||
r'\bFROM\s+([`"]?[a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*[`"]?)',
|
||||
r'\bJOIN\s+([`"]?[a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*[`"]?)',
|
||||
r'\bINTO\s+([`"]?[a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*[`"]?)',
|
||||
r'\bUPDATE\s+([`"]?[a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*[`"]?)',
|
||||
r'\bDELETE\s+FROM\s+([`"]?[a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*[`"]?)',
|
||||
r'\bINSERT\s+INTO\s+([`"]?[a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*[`"]?)'
|
||||
]
|
||||
|
||||
for pattern in patterns:
|
||||
matches = re.findall(pattern, sql, re.IGNORECASE)
|
||||
for match in matches:
|
||||
# Clean up table name
|
||||
table_name = match.strip('`"\'').split()[0] # Remove quotes and aliases
|
||||
if table_name and not self._is_sql_keyword(table_name):
|
||||
table_references.append(table_name.lower())
|
||||
|
||||
return list(set(table_references))
|
||||
|
||||
def _is_sql_keyword(self, word: str) -> bool:
|
||||
"""Check if word is a SQL keyword"""
|
||||
keywords = {
|
||||
'SELECT', 'FROM', 'WHERE', 'JOIN', 'INNER', 'LEFT', 'RIGHT', 'OUTER',
|
||||
'ON', 'AND', 'OR', 'NOT', 'IN', 'EXISTS', 'BETWEEN', 'LIKE',
|
||||
'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'ALTER', 'DROP', 'INDEX',
|
||||
'TABLE', 'VIEW', 'DATABASE', 'SCHEMA', 'PRIMARY', 'KEY', 'FOREIGN',
|
||||
'REFERENCES', 'CONSTRAINT', 'NULL', 'DEFAULT', 'AUTO_INCREMENT'
|
||||
}
|
||||
return word.upper() in keywords
|
||||
|
||||
def _infer_dependencies_from_sql(self, dependency_graph: Dict, sql: str, referenced_tables: List[str], frequency: int) -> None:
|
||||
"""Infer table dependencies from SQL patterns"""
|
||||
# Analyze SQL pattern to determine dependency relationships
|
||||
sql_upper = sql.upper()
|
||||
|
||||
# Look for INSERT ... SELECT patterns
|
||||
if 'INSERT' in sql_upper and 'SELECT' in sql_upper:
|
||||
# Find target table (after INSERT INTO)
|
||||
insert_match = re.search(r'INSERT\s+INTO\s+([a-zA-Z_][a-zA-Z0-9_.]*)', sql_upper)
|
||||
if insert_match:
|
||||
target_table = insert_match.group(1).lower()
|
||||
|
||||
# All other tables are dependencies
|
||||
for ref_table in referenced_tables:
|
||||
if ref_table != target_table:
|
||||
dependency_graph[target_table]["upstream_dependencies"].add(ref_table)
|
||||
dependency_graph[ref_table]["downstream_dependencies"].add(target_table)
|
||||
|
||||
# Calculate confidence based on frequency
|
||||
confidence = min(0.9, 0.3 + (frequency / 100))
|
||||
dependency_graph[target_table]["sql_patterns"].append({
|
||||
"pattern_type": "insert_select",
|
||||
"referenced_table": ref_table,
|
||||
"confidence": confidence,
|
||||
"frequency": frequency
|
||||
})
|
||||
|
||||
# Look for CREATE TABLE AS SELECT patterns
|
||||
elif 'CREATE' in sql_upper and 'SELECT' in sql_upper:
|
||||
create_match = re.search(r'CREATE\s+TABLE\s+([a-zA-Z_][a-zA-Z0-9_.]*)', sql_upper)
|
||||
if create_match:
|
||||
target_table = create_match.group(1).lower()
|
||||
|
||||
for ref_table in referenced_tables:
|
||||
if ref_table != target_table:
|
||||
dependency_graph[target_table]["upstream_dependencies"].add(ref_table)
|
||||
dependency_graph[ref_table]["downstream_dependencies"].add(target_table)
|
||||
|
||||
dependency_graph[target_table]["sql_patterns"].append({
|
||||
"pattern_type": "create_table_as_select",
|
||||
"referenced_table": ref_table,
|
||||
"confidence": 0.95,
|
||||
"frequency": frequency
|
||||
})
|
||||
|
||||
async def _analyze_single_table_dependencies(self, target_table: str, dependency_graph: Dict, tables_metadata: List[Dict]) -> Dict[str, Any]:
|
||||
"""Analyze dependencies for a specific table"""
|
||||
if target_table not in dependency_graph:
|
||||
return {"error": f"Table {target_table} not found in dependency graph"}
|
||||
|
||||
table_info = dependency_graph[target_table]
|
||||
|
||||
# Get upstream dependencies (tables this table depends on)
|
||||
upstream_deps = await self._get_dependency_chain(target_table, dependency_graph, "upstream", 3)
|
||||
|
||||
# Get downstream dependencies (tables that depend on this table)
|
||||
downstream_deps = await self._get_dependency_chain(target_table, dependency_graph, "downstream", 3)
|
||||
|
||||
return {
|
||||
"table_name": target_table,
|
||||
"table_type": table_info["table_type"],
|
||||
"direct_upstream_dependencies": list(table_info["upstream_dependencies"]),
|
||||
"direct_downstream_dependencies": list(table_info["downstream_dependencies"]),
|
||||
"upstream_dependency_chain": upstream_deps,
|
||||
"downstream_dependency_chain": downstream_deps,
|
||||
"dependency_patterns": table_info["sql_patterns"],
|
||||
"dependency_metrics": {
|
||||
"upstream_count": len(table_info["upstream_dependencies"]),
|
||||
"downstream_count": len(table_info["downstream_dependencies"]),
|
||||
"total_upstream_chain": len(upstream_deps.get("all_dependencies", [])),
|
||||
"total_downstream_chain": len(downstream_deps.get("all_dependencies", [])),
|
||||
"dependency_depth": max(upstream_deps.get("max_depth", 0), downstream_deps.get("max_depth", 0))
|
||||
}
|
||||
}
|
||||
|
||||
async def _get_dependency_chain(self, start_table: str, dependency_graph: Dict, direction: str, max_depth: int) -> Dict[str, Any]:
|
||||
"""Get full dependency chain in specified direction"""
|
||||
visited = set()
|
||||
all_dependencies = []
|
||||
levels = []
|
||||
current_level = [start_table]
|
||||
depth = 0
|
||||
|
||||
while current_level and depth < max_depth:
|
||||
next_level = []
|
||||
level_deps = []
|
||||
|
||||
for table in current_level:
|
||||
if table in visited:
|
||||
continue
|
||||
|
||||
visited.add(table)
|
||||
|
||||
if direction == "upstream":
|
||||
dependencies = dependency_graph.get(table, {}).get("upstream_dependencies", set())
|
||||
else:
|
||||
dependencies = dependency_graph.get(table, {}).get("downstream_dependencies", set())
|
||||
|
||||
for dep in dependencies:
|
||||
if dep not in visited:
|
||||
next_level.append(dep)
|
||||
level_deps.append(dep)
|
||||
all_dependencies.append(dep)
|
||||
|
||||
if level_deps:
|
||||
levels.append({
|
||||
"level": depth + 1,
|
||||
"tables": level_deps
|
||||
})
|
||||
|
||||
current_level = next_level
|
||||
depth += 1
|
||||
|
||||
return {
|
||||
"direction": direction,
|
||||
"max_depth": depth,
|
||||
"all_dependencies": list(set(all_dependencies)),
|
||||
"dependency_levels": levels,
|
||||
"total_count": len(set(all_dependencies))
|
||||
}
|
||||
|
||||
async def _analyze_all_tables_dependencies(self, dependency_graph: Dict, tables_metadata: List[Dict]) -> Dict[str, Any]:
|
||||
"""Analyze dependencies for all tables"""
|
||||
table_stats = {}
|
||||
|
||||
for table_name, table_info in dependency_graph.items():
|
||||
upstream_count = len(table_info["upstream_dependencies"])
|
||||
downstream_count = len(table_info["downstream_dependencies"])
|
||||
|
||||
table_stats[table_name] = {
|
||||
"table_type": table_info["table_type"],
|
||||
"upstream_count": upstream_count,
|
||||
"downstream_count": downstream_count,
|
||||
"total_connections": upstream_count + downstream_count,
|
||||
"dependency_score": self._calculate_dependency_score(upstream_count, downstream_count),
|
||||
"role_classification": self._classify_table_role(upstream_count, downstream_count)
|
||||
}
|
||||
|
||||
# Find key tables
|
||||
most_critical_tables = sorted(
|
||||
table_stats.items(),
|
||||
key=lambda x: x[1]["dependency_score"],
|
||||
reverse=True
|
||||
)[:10]
|
||||
|
||||
source_tables = [name for name, stats in table_stats.items() if stats["role_classification"] == "source"]
|
||||
sink_tables = [name for name, stats in table_stats.items() if stats["role_classification"] == "sink"]
|
||||
hub_tables = [name for name, stats in table_stats.items() if stats["role_classification"] == "hub"]
|
||||
|
||||
return {
|
||||
"table_statistics": table_stats,
|
||||
"summary": {
|
||||
"total_tables": len(table_stats),
|
||||
"source_tables": len(source_tables),
|
||||
"sink_tables": len(sink_tables),
|
||||
"hub_tables": len(hub_tables),
|
||||
"isolated_tables": len([stats for stats in table_stats.values() if stats["total_connections"] == 0])
|
||||
},
|
||||
"critical_tables": [{"table": name, **stats} for name, stats in most_critical_tables],
|
||||
"table_roles": {
|
||||
"sources": source_tables[:10],
|
||||
"sinks": sink_tables[:10],
|
||||
"hubs": hub_tables[:10]
|
||||
}
|
||||
}
|
||||
|
||||
def _calculate_dependency_score(self, upstream_count: int, downstream_count: int) -> float:
|
||||
"""Calculate dependency importance score for a table"""
|
||||
# Score based on both incoming and outgoing dependencies
|
||||
# Higher weight for downstream dependencies (impact)
|
||||
return round(upstream_count * 0.3 + downstream_count * 0.7, 2)
|
||||
|
||||
def _classify_table_role(self, upstream_count: int, downstream_count: int) -> str:
|
||||
"""Classify table role based on dependency pattern"""
|
||||
if upstream_count == 0 and downstream_count > 0:
|
||||
return "source" # Data source
|
||||
elif upstream_count > 0 and downstream_count == 0:
|
||||
return "sink" # Data destination
|
||||
elif upstream_count > 2 and downstream_count > 2:
|
||||
return "hub" # Data hub/transformation
|
||||
elif upstream_count > 0 and downstream_count > 0:
|
||||
return "intermediate" # Intermediate transformation
|
||||
else:
|
||||
return "isolated" # No dependencies
|
||||
|
||||
async def _calculate_impact_analysis(self, target_table: str, dependency_graph: Dict, direction: str) -> Dict[str, Any]:
|
||||
"""Calculate impact analysis for a specific table"""
|
||||
if direction == "upstream" or direction == "both":
|
||||
upstream_impact = await self._calculate_upstream_impact(target_table, dependency_graph)
|
||||
else:
|
||||
upstream_impact = {}
|
||||
|
||||
if direction == "downstream" or direction == "both":
|
||||
downstream_impact = await self._calculate_downstream_impact(target_table, dependency_graph)
|
||||
else:
|
||||
downstream_impact = {}
|
||||
|
||||
return {
|
||||
"target_table": target_table,
|
||||
"upstream_impact": upstream_impact,
|
||||
"downstream_impact": downstream_impact,
|
||||
"total_impact_score": self._calculate_total_impact_score(upstream_impact, downstream_impact)
|
||||
}
|
||||
|
||||
async def _calculate_upstream_impact(self, target_table: str, dependency_graph: Dict) -> Dict[str, Any]:
|
||||
"""Calculate what would be impacted if upstream dependencies fail"""
|
||||
upstream_deps = dependency_graph.get(target_table, {}).get("upstream_dependencies", set())
|
||||
|
||||
impact_scenarios = []
|
||||
for dep_table in upstream_deps:
|
||||
# Simulate failure of this dependency
|
||||
affected_tables = await self._simulate_table_failure_impact(dep_table, dependency_graph)
|
||||
|
||||
impact_scenarios.append({
|
||||
"failed_dependency": dep_table,
|
||||
"directly_affected_tables": len(affected_tables["direct"]),
|
||||
"indirectly_affected_tables": len(affected_tables["indirect"]),
|
||||
"total_affected": len(affected_tables["all"]),
|
||||
"critical_affected": [table for table in affected_tables["all"]
|
||||
if dependency_graph.get(table, {}).get("downstream_dependencies", set())],
|
||||
"impact_severity": self._assess_impact_severity(len(affected_tables["all"]))
|
||||
})
|
||||
|
||||
return {
|
||||
"dependency_count": len(upstream_deps),
|
||||
"impact_scenarios": impact_scenarios,
|
||||
"max_potential_impact": max([scenario["total_affected"] for scenario in impact_scenarios], default=0),
|
||||
"risk_assessment": self._assess_upstream_risk(impact_scenarios)
|
||||
}
|
||||
|
||||
async def _calculate_downstream_impact(self, target_table: str, dependency_graph: Dict) -> Dict[str, Any]:
|
||||
"""Calculate what would be impacted if target table fails"""
|
||||
affected_tables = await self._simulate_table_failure_impact(target_table, dependency_graph)
|
||||
|
||||
return {
|
||||
"direct_impact": len(affected_tables["direct"]),
|
||||
"indirect_impact": len(affected_tables["indirect"]),
|
||||
"total_impact": len(affected_tables["all"]),
|
||||
"affected_table_details": [
|
||||
{
|
||||
"table_name": table,
|
||||
"impact_type": "direct" if table in affected_tables["direct"] else "indirect",
|
||||
"table_role": self._classify_table_role(
|
||||
len(dependency_graph.get(table, {}).get("upstream_dependencies", set())),
|
||||
len(dependency_graph.get(table, {}).get("downstream_dependencies", set()))
|
||||
)
|
||||
}
|
||||
for table in affected_tables["all"]
|
||||
],
|
||||
"impact_severity": self._assess_impact_severity(len(affected_tables["all"]))
|
||||
}
|
||||
|
||||
async def _simulate_table_failure_impact(self, failed_table: str, dependency_graph: Dict) -> Dict[str, List[str]]:
|
||||
"""Simulate the impact of a table failure"""
|
||||
direct_affected = list(dependency_graph.get(failed_table, {}).get("downstream_dependencies", set()))
|
||||
|
||||
# Find all indirectly affected tables using BFS
|
||||
visited = {failed_table}
|
||||
queue = deque(direct_affected)
|
||||
indirect_affected = []
|
||||
|
||||
while queue:
|
||||
current_table = queue.popleft()
|
||||
if current_table in visited:
|
||||
continue
|
||||
|
||||
visited.add(current_table)
|
||||
indirect_affected.append(current_table)
|
||||
|
||||
# Add downstream dependencies to queue
|
||||
downstream = dependency_graph.get(current_table, {}).get("downstream_dependencies", set())
|
||||
for dep in downstream:
|
||||
if dep not in visited:
|
||||
queue.append(dep)
|
||||
|
||||
# Remove direct affected from indirect (they're already counted)
|
||||
indirect_only = [table for table in indirect_affected if table not in direct_affected]
|
||||
|
||||
return {
|
||||
"direct": direct_affected,
|
||||
"indirect": indirect_only,
|
||||
"all": direct_affected + indirect_only
|
||||
}
|
||||
|
||||
def _assess_impact_severity(self, affected_count: int) -> str:
|
||||
"""Assess impact severity based on affected table count"""
|
||||
if affected_count == 0:
|
||||
return "none"
|
||||
elif affected_count <= 2:
|
||||
return "low"
|
||||
elif affected_count <= 5:
|
||||
return "medium"
|
||||
elif affected_count <= 10:
|
||||
return "high"
|
||||
else:
|
||||
return "critical"
|
||||
|
||||
def _assess_upstream_risk(self, impact_scenarios: List[Dict]) -> str:
|
||||
"""Assess upstream dependency risk"""
|
||||
if not impact_scenarios:
|
||||
return "low"
|
||||
|
||||
max_impact = max([scenario["total_affected"] for scenario in impact_scenarios])
|
||||
high_impact_scenarios = len([s for s in impact_scenarios if s["impact_severity"] in ["high", "critical"]])
|
||||
|
||||
if high_impact_scenarios > 0 or max_impact > 10:
|
||||
return "high"
|
||||
elif max_impact > 5 or len(impact_scenarios) > 3:
|
||||
return "medium"
|
||||
else:
|
||||
return "low"
|
||||
|
||||
def _calculate_total_impact_score(self, upstream_impact: Dict, downstream_impact: Dict) -> float:
|
||||
"""Calculate total impact score combining upstream and downstream risks"""
|
||||
upstream_score = 0
|
||||
downstream_score = 0
|
||||
|
||||
if upstream_impact:
|
||||
max_upstream_impact = upstream_impact.get("max_potential_impact", 0)
|
||||
upstream_score = min(max_upstream_impact * 0.3, 10) # Cap at 10
|
||||
|
||||
if downstream_impact:
|
||||
downstream_score = min(downstream_impact.get("total_impact", 0) * 0.7, 10) # Cap at 10
|
||||
|
||||
return round(upstream_score + downstream_score, 2)
|
||||
|
||||
async def _calculate_global_impact_analysis(self, dependency_graph: Dict) -> Dict[str, Any]:
|
||||
"""Calculate global impact analysis for all tables"""
|
||||
table_impacts = {}
|
||||
|
||||
for table_name in dependency_graph.keys():
|
||||
impact = await self._calculate_impact_analysis(table_name, dependency_graph, "downstream")
|
||||
table_impacts[table_name] = {
|
||||
"downstream_impact": impact["downstream_impact"]["total_impact"],
|
||||
"impact_severity": impact["downstream_impact"]["impact_severity"],
|
||||
"impact_score": impact["total_impact_score"]
|
||||
}
|
||||
|
||||
# Find most critical tables
|
||||
critical_tables = sorted(
|
||||
table_impacts.items(),
|
||||
key=lambda x: x[1]["impact_score"],
|
||||
reverse=True
|
||||
)[:15]
|
||||
|
||||
# Risk distribution
|
||||
risk_distribution = {
|
||||
"critical": len([t for t in table_impacts.values() if t["impact_severity"] == "critical"]),
|
||||
"high": len([t for t in table_impacts.values() if t["impact_severity"] == "high"]),
|
||||
"medium": len([t for t in table_impacts.values() if t["impact_severity"] == "medium"]),
|
||||
"low": len([t for t in table_impacts.values() if t["impact_severity"] == "low"]),
|
||||
"none": len([t for t in table_impacts.values() if t["impact_severity"] == "none"])
|
||||
}
|
||||
|
||||
return {
|
||||
"global_impact_summary": {
|
||||
"total_tables_analyzed": len(table_impacts),
|
||||
"tables_with_impact": len([t for t in table_impacts.values() if t["downstream_impact"] > 0]),
|
||||
"average_impact_score": round(sum(t["impact_score"] for t in table_impacts.values()) / len(table_impacts), 2) if table_impacts else 0,
|
||||
"risk_distribution": risk_distribution
|
||||
},
|
||||
"most_critical_tables": [{"table": name, **stats} for name, stats in critical_tables],
|
||||
"risk_matrix": self._generate_risk_matrix(table_impacts)
|
||||
}
|
||||
|
||||
def _generate_risk_matrix(self, table_impacts: Dict[str, Dict]) -> Dict[str, List[str]]:
|
||||
"""Generate risk matrix categorizing tables by impact level"""
|
||||
risk_matrix = {
|
||||
"critical_risk": [],
|
||||
"high_risk": [],
|
||||
"medium_risk": [],
|
||||
"low_risk": [],
|
||||
"minimal_risk": []
|
||||
}
|
||||
|
||||
for table_name, impact_data in table_impacts.items():
|
||||
severity = impact_data["impact_severity"]
|
||||
if severity == "critical":
|
||||
risk_matrix["critical_risk"].append(table_name)
|
||||
elif severity == "high":
|
||||
risk_matrix["high_risk"].append(table_name)
|
||||
elif severity == "medium":
|
||||
risk_matrix["medium_risk"].append(table_name)
|
||||
elif severity == "low":
|
||||
risk_matrix["low_risk"].append(table_name)
|
||||
else:
|
||||
risk_matrix["minimal_risk"].append(table_name)
|
||||
|
||||
return risk_matrix
|
||||
|
||||
def _get_dependency_graph_stats(self, dependency_graph: Dict) -> Dict[str, Any]:
|
||||
"""Get statistics about the dependency graph"""
|
||||
total_tables = len(dependency_graph)
|
||||
total_dependencies = sum(
|
||||
len(table_info.get("upstream_dependencies", set())) + len(table_info.get("downstream_dependencies", set()))
|
||||
for table_info in dependency_graph.values()
|
||||
) // 2 # Divide by 2 to avoid double counting
|
||||
|
||||
tables_with_upstream = len([
|
||||
table for table, info in dependency_graph.items()
|
||||
if info.get("upstream_dependencies")
|
||||
])
|
||||
|
||||
tables_with_downstream = len([
|
||||
table for table, info in dependency_graph.items()
|
||||
if info.get("downstream_dependencies")
|
||||
])
|
||||
|
||||
isolated_tables = len([
|
||||
table for table, info in dependency_graph.items()
|
||||
if not info.get("upstream_dependencies") and not info.get("downstream_dependencies")
|
||||
])
|
||||
|
||||
return {
|
||||
"total_tables": total_tables,
|
||||
"total_dependencies": total_dependencies,
|
||||
"tables_with_upstream_deps": tables_with_upstream,
|
||||
"tables_with_downstream_deps": tables_with_downstream,
|
||||
"isolated_tables": isolated_tables,
|
||||
"connectivity_ratio": round((total_tables - isolated_tables) / total_tables, 3) if total_tables > 0 else 0,
|
||||
"avg_dependencies_per_table": round(total_dependencies / total_tables, 2) if total_tables > 0 else 0
|
||||
}
|
||||
|
||||
async def _generate_dependency_insights(self, dependency_graph: Dict, table_analysis: Dict, impact_analysis: Dict) -> Dict[str, Any]:
|
||||
"""Generate insights from dependency analysis"""
|
||||
insights = {
|
||||
"architectural_patterns": {},
|
||||
"risk_assessment": {},
|
||||
"optimization_opportunities": {}
|
||||
}
|
||||
|
||||
# Architectural patterns
|
||||
graph_stats = self._get_dependency_graph_stats(dependency_graph)
|
||||
|
||||
insights["architectural_patterns"] = {
|
||||
"connectivity_level": "high" if graph_stats["connectivity_ratio"] > 0.7 else "medium" if graph_stats["connectivity_ratio"] > 0.3 else "low",
|
||||
"architecture_type": self._classify_architecture_type(graph_stats),
|
||||
"complexity_score": round(graph_stats["avg_dependencies_per_table"] * graph_stats["connectivity_ratio"], 2),
|
||||
"isolated_tables_concern": graph_stats["isolated_tables"] > graph_stats["total_tables"] * 0.3
|
||||
}
|
||||
|
||||
# Risk assessment
|
||||
if isinstance(impact_analysis, dict) and "global_impact_summary" in impact_analysis:
|
||||
global_impact = impact_analysis["global_impact_summary"]
|
||||
|
||||
insights["risk_assessment"] = {
|
||||
"overall_risk_level": self._assess_overall_risk_level(global_impact["risk_distribution"]),
|
||||
"critical_tables_count": global_impact["risk_distribution"]["critical"],
|
||||
"high_risk_tables_count": global_impact["risk_distribution"]["high"],
|
||||
"impact_concentration": global_impact["average_impact_score"] > 5.0,
|
||||
"resilience_score": self._calculate_resilience_score(global_impact)
|
||||
}
|
||||
|
||||
# Optimization opportunities
|
||||
insights["optimization_opportunities"] = self._identify_optimization_opportunities(dependency_graph, table_analysis)
|
||||
|
||||
return insights
|
||||
|
||||
def _classify_architecture_type(self, graph_stats: Dict) -> str:
|
||||
"""Classify the overall architecture type"""
|
||||
connectivity = graph_stats["connectivity_ratio"]
|
||||
avg_deps = graph_stats["avg_dependencies_per_table"]
|
||||
|
||||
if connectivity > 0.8 and avg_deps > 3:
|
||||
return "highly_interconnected"
|
||||
elif connectivity > 0.5 and avg_deps > 2:
|
||||
return "moderately_connected"
|
||||
elif connectivity < 0.3:
|
||||
return "loosely_coupled"
|
||||
else:
|
||||
return "mixed_architecture"
|
||||
|
||||
def _assess_overall_risk_level(self, risk_distribution: Dict[str, int]) -> str:
|
||||
"""Assess overall risk level from risk distribution"""
|
||||
total = sum(risk_distribution.values())
|
||||
if total == 0:
|
||||
return "minimal"
|
||||
|
||||
critical_ratio = risk_distribution["critical"] / total
|
||||
high_ratio = risk_distribution["high"] / total
|
||||
|
||||
if critical_ratio > 0.1 or high_ratio > 0.2:
|
||||
return "high"
|
||||
elif critical_ratio > 0.05 or high_ratio > 0.1:
|
||||
return "medium"
|
||||
else:
|
||||
return "low"
|
||||
|
||||
def _calculate_resilience_score(self, global_impact: Dict) -> float:
|
||||
"""Calculate system resilience score (0-1, higher is better)"""
|
||||
total_tables = global_impact["total_tables_analyzed"]
|
||||
risk_dist = global_impact["risk_distribution"]
|
||||
|
||||
if total_tables == 0:
|
||||
return 0.0
|
||||
|
||||
# Calculate weighted risk score
|
||||
weighted_risk = (
|
||||
risk_dist["critical"] * 5 +
|
||||
risk_dist["high"] * 3 +
|
||||
risk_dist["medium"] * 2 +
|
||||
risk_dist["low"] * 1
|
||||
) / total_tables
|
||||
|
||||
# Convert to resilience score (inverse of risk, normalized)
|
||||
max_possible_risk = 5.0
|
||||
resilience = max(0, (max_possible_risk - weighted_risk) / max_possible_risk)
|
||||
|
||||
return round(resilience, 3)
|
||||
|
||||
def _identify_optimization_opportunities(self, dependency_graph: Dict, table_analysis: Dict) -> List[Dict]:
|
||||
"""Identify optimization opportunities"""
|
||||
opportunities = []
|
||||
|
||||
# Find tables with excessive dependencies
|
||||
for table_name, table_info in dependency_graph.items():
|
||||
upstream_count = len(table_info.get("upstream_dependencies", set()))
|
||||
downstream_count = len(table_info.get("downstream_dependencies", set()))
|
||||
|
||||
if upstream_count > 10:
|
||||
opportunities.append({
|
||||
"type": "excessive_upstream_dependencies",
|
||||
"table": table_name,
|
||||
"description": f"Table has {upstream_count} upstream dependencies",
|
||||
"recommendation": "Consider breaking down complex transformations or using intermediate tables",
|
||||
"priority": "high" if upstream_count > 15 else "medium"
|
||||
})
|
||||
|
||||
if downstream_count > 10:
|
||||
opportunities.append({
|
||||
"type": "excessive_downstream_dependencies",
|
||||
"table": table_name,
|
||||
"description": f"Table has {downstream_count} downstream dependencies",
|
||||
"recommendation": "Consider if this table is doing too much or if views could be used",
|
||||
"priority": "high" if downstream_count > 15 else "medium"
|
||||
})
|
||||
|
||||
# Find potential circular dependencies (simplified check)
|
||||
# This is a basic check - full cycle detection would be more complex
|
||||
for table_name, table_info in dependency_graph.items():
|
||||
upstream_deps = table_info.get("upstream_dependencies", set())
|
||||
for upstream_table in upstream_deps:
|
||||
if table_name in dependency_graph.get(upstream_table, {}).get("upstream_dependencies", set()):
|
||||
opportunities.append({
|
||||
"type": "potential_circular_dependency",
|
||||
"table": table_name,
|
||||
"related_table": upstream_table,
|
||||
"description": f"Potential circular dependency between {table_name} and {upstream_table}",
|
||||
"recommendation": "Review and eliminate circular dependencies",
|
||||
"priority": "high"
|
||||
})
|
||||
|
||||
return opportunities
|
||||
|
||||
def _generate_dependency_recommendations(self, dependency_insights: Dict) -> List[Dict]:
|
||||
"""Generate recommendations based on dependency analysis"""
|
||||
recommendations = []
|
||||
|
||||
# Architecture recommendations
|
||||
arch_patterns = dependency_insights.get("architectural_patterns", {})
|
||||
if arch_patterns.get("isolated_tables_concern", False):
|
||||
recommendations.append({
|
||||
"type": "architecture",
|
||||
"priority": "medium",
|
||||
"title": "High number of isolated tables",
|
||||
"description": "Many tables have no dependencies, which may indicate data silos",
|
||||
"action": "Review isolated tables and consider if they should be integrated into data flows"
|
||||
})
|
||||
|
||||
complexity_score = arch_patterns.get("complexity_score", 0)
|
||||
if complexity_score > 5:
|
||||
recommendations.append({
|
||||
"type": "architecture",
|
||||
"priority": "high",
|
||||
"title": "High system complexity",
|
||||
"description": f"System complexity score is {complexity_score} (high)",
|
||||
"action": "Consider simplifying data architecture and reducing unnecessary dependencies"
|
||||
})
|
||||
|
||||
# Risk recommendations
|
||||
risk_assessment = dependency_insights.get("risk_assessment", {})
|
||||
overall_risk = risk_assessment.get("overall_risk_level", "unknown")
|
||||
|
||||
if overall_risk == "high":
|
||||
recommendations.append({
|
||||
"type": "risk_mitigation",
|
||||
"priority": "high",
|
||||
"title": "High overall system risk",
|
||||
"description": "System has high dependency risks that could cause widespread failures",
|
||||
"action": "Implement monitoring and backup strategies for critical tables"
|
||||
})
|
||||
|
||||
critical_tables = risk_assessment.get("critical_tables_count", 0)
|
||||
if critical_tables > 0:
|
||||
recommendations.append({
|
||||
"type": "risk_mitigation",
|
||||
"priority": "high",
|
||||
"title": f"{critical_tables} critical impact tables identified",
|
||||
"description": "Tables with critical impact require special attention",
|
||||
"action": "Implement enhanced monitoring and backup procedures for critical tables"
|
||||
})
|
||||
|
||||
# Optimization recommendations
|
||||
optimization_ops = dependency_insights.get("optimization_opportunities", [])
|
||||
if optimization_ops:
|
||||
high_priority_ops = [op for op in optimization_ops if op.get("priority") == "high"]
|
||||
if high_priority_ops:
|
||||
recommendations.append({
|
||||
"type": "optimization",
|
||||
"priority": "high",
|
||||
"title": f"{len(high_priority_ops)} high-priority optimization opportunities",
|
||||
"description": "System has optimization opportunities that should be addressed",
|
||||
"action": "Review and implement suggested optimizations for better maintainability"
|
||||
})
|
||||
|
||||
return recommendations
|
||||
@@ -319,7 +319,20 @@ class DorisLoggerManager:
|
||||
return
|
||||
|
||||
self.log_dir = Path(log_dir)
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
log_dir_writable = True # Initialize the variable
|
||||
|
||||
# Try to create log directory, fallback to console-only if fails
|
||||
try:
|
||||
self.log_dir.mkdir(parents=True, exist_ok=True)
|
||||
except (OSError, PermissionError) as e:
|
||||
# If we can't create log directory (e.g., read-only filesystem in stdio mode),
|
||||
# fall back to console-only logging
|
||||
log_dir_writable = False
|
||||
enable_file = False
|
||||
enable_audit = False
|
||||
enable_cleanup = False
|
||||
# Don't use print() in stdio mode as it interferes with MCP JSON protocol
|
||||
# Log the warning through the logging system instead, which will be handled after setup
|
||||
|
||||
# Clear existing handlers
|
||||
root_logger = logging.getLogger()
|
||||
@@ -414,13 +427,19 @@ class DorisLoggerManager:
|
||||
logger.info("=" * 80)
|
||||
logger.info("Doris MCP Server Logging System Initialized")
|
||||
logger.info(f"Log Level: {level}")
|
||||
logger.info(f"Log Directory: {self.log_dir.absolute()}")
|
||||
if log_dir_writable:
|
||||
logger.info(f"Log Directory: {self.log_dir.absolute()}")
|
||||
else:
|
||||
logger.info("Log Directory: Not available (console-only mode)")
|
||||
logger.info(f"Console Logging: {'Enabled' if enable_console else 'Disabled'}")
|
||||
logger.info(f"File Logging: {'Enabled' if enable_file else 'Disabled'}")
|
||||
logger.info(f"Audit Logging: {'Enabled' if enable_audit else 'Disabled'}")
|
||||
logger.info(f"Log Cleanup: {'Enabled' if enable_cleanup and enable_file else 'Disabled'}")
|
||||
logger.info(f"File Logging: {'Enabled' if enable_file else 'Disabled (fallback mode)'}")
|
||||
logger.info(f"Audit Logging: {'Enabled' if enable_audit else 'Disabled (fallback mode)'}")
|
||||
logger.info(f"Log Cleanup: {'Enabled' if enable_cleanup and enable_file else 'Disabled (fallback mode)'}")
|
||||
if enable_cleanup and enable_file:
|
||||
logger.info(f"Cleanup Settings: Max age {max_age_days} days, interval {cleanup_interval_hours}h")
|
||||
if not log_dir_writable:
|
||||
logger.warning("Running in console-only logging mode due to filesystem permissions")
|
||||
logger.warning(f"Could not create log directory '{log_dir}' - stdio mode fallback enabled")
|
||||
logger.info("=" * 80)
|
||||
|
||||
def _setup_package_loggers(self, level: str):
|
||||
@@ -568,10 +587,10 @@ def setup_logging(level: str = "INFO",
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
"""
|
||||
Get a logger instance.
|
||||
|
||||
|
||||
Args:
|
||||
name: Logger name (usually __name__)
|
||||
|
||||
|
||||
Returns:
|
||||
Configured logger instance
|
||||
"""
|
||||
|
||||
758
doris_mcp_server/utils/performance_analytics_tools.py
Normal file
758
doris_mcp_server/utils/performance_analytics_tools.py
Normal file
@@ -0,0 +1,758 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Performance Analytics Tools Module
|
||||
Provides slow query analysis and resource growth monitoring capabilities
|
||||
"""
|
||||
|
||||
import time
|
||||
import statistics
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from collections import defaultdict, Counter
|
||||
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PerformanceAnalyticsTools:
|
||||
"""Performance analytics tools for query and resource monitoring"""
|
||||
|
||||
def __init__(self, connection_manager: DorisConnectionManager):
|
||||
self.connection_manager = connection_manager
|
||||
logger.info("PerformanceAnalyticsTools initialized")
|
||||
|
||||
async def analyze_slow_queries_topn(
|
||||
self,
|
||||
days: int = 7,
|
||||
top_n: int = 20,
|
||||
min_execution_time_ms: int = 1000,
|
||||
include_patterns: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Analyze top N slowest queries and performance patterns
|
||||
|
||||
Args:
|
||||
days: Number of days to analyze
|
||||
top_n: Number of top slow queries to return
|
||||
min_execution_time_ms: Minimum execution time threshold
|
||||
include_patterns: Whether to include query pattern analysis
|
||||
|
||||
Returns:
|
||||
Slow query analysis results
|
||||
"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
connection = await self.connection_manager.get_connection("query")
|
||||
|
||||
# Get slow query data
|
||||
slow_queries = await self._get_slow_query_data(
|
||||
connection, days, min_execution_time_ms
|
||||
)
|
||||
|
||||
if not slow_queries:
|
||||
return {
|
||||
"message": "No slow queries found for the specified criteria",
|
||||
"analysis_period": {"days": days, "threshold_ms": min_execution_time_ms},
|
||||
"analysis_timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Analyze top N queries
|
||||
top_queries = await self._analyze_top_slow_queries(slow_queries, top_n)
|
||||
|
||||
# Performance insights
|
||||
performance_insights = await self._generate_performance_insights(slow_queries)
|
||||
|
||||
# Query patterns analysis
|
||||
pattern_analysis = {}
|
||||
if include_patterns:
|
||||
pattern_analysis = await self._analyze_query_patterns(slow_queries)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
return {
|
||||
"analysis_period": {
|
||||
"days": days,
|
||||
"threshold_ms": min_execution_time_ms,
|
||||
"start_date": (datetime.now() - timedelta(days=days)).isoformat(),
|
||||
"end_date": datetime.now().isoformat()
|
||||
},
|
||||
"analysis_timestamp": datetime.now().isoformat(),
|
||||
"execution_time_seconds": round(execution_time, 3),
|
||||
"summary": {
|
||||
"total_slow_queries": len(slow_queries),
|
||||
"unique_queries": len(set(q.get("sql_hash", q.get("sql", ""))[:100] for q in slow_queries)),
|
||||
"top_n_analyzed": min(top_n, len(slow_queries))
|
||||
},
|
||||
"top_slow_queries": top_queries,
|
||||
"performance_insights": performance_insights,
|
||||
"query_patterns": pattern_analysis,
|
||||
"recommendations": self._generate_performance_recommendations(performance_insights, pattern_analysis)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Slow query analysis failed: {str(e)}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"analysis_timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
async def analyze_resource_growth_curves(
|
||||
self,
|
||||
days: int = 30,
|
||||
resource_types: List[str] = None,
|
||||
include_predictions: bool = False,
|
||||
detailed_response: bool = False
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Analyze resource growth patterns and trends
|
||||
|
||||
Args:
|
||||
days: Number of days to analyze
|
||||
resource_types: Types of resources to analyze
|
||||
include_predictions: Whether to include growth predictions
|
||||
detailed_response: Whether to return detailed data including daily breakdowns
|
||||
|
||||
Returns:
|
||||
Resource growth analysis results
|
||||
"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
connection = await self.connection_manager.get_connection("query")
|
||||
|
||||
if resource_types is None:
|
||||
resource_types = ["storage", "query_volume", "user_activity"]
|
||||
|
||||
# Analyze each resource type
|
||||
resource_analysis = {}
|
||||
|
||||
if "storage" in resource_types:
|
||||
resource_analysis["storage"] = await self._analyze_storage_growth(connection, days, detailed_response)
|
||||
|
||||
if "query_volume" in resource_types:
|
||||
resource_analysis["query_volume"] = await self._analyze_query_volume_growth(connection, days, detailed_response)
|
||||
|
||||
if "user_activity" in resource_types:
|
||||
resource_analysis["user_activity"] = await self._analyze_user_activity_growth(connection, days, detailed_response)
|
||||
|
||||
# Generate growth insights
|
||||
growth_insights = await self._generate_growth_insights(resource_analysis, days)
|
||||
|
||||
# Growth predictions
|
||||
predictions = {}
|
||||
if include_predictions:
|
||||
predictions = await self._generate_growth_predictions(resource_analysis)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
result = {
|
||||
"analysis_period": {
|
||||
"days": days,
|
||||
"start_date": (datetime.now() - timedelta(days=days)).isoformat(),
|
||||
"end_date": datetime.now().isoformat()
|
||||
},
|
||||
"analysis_timestamp": datetime.now().isoformat(),
|
||||
"execution_time_seconds": round(execution_time, 3),
|
||||
"resource_types_analyzed": resource_types,
|
||||
"resource_analysis": resource_analysis,
|
||||
"growth_insights": growth_insights,
|
||||
"growth_predictions": predictions,
|
||||
"recommendations": self._generate_growth_recommendations(growth_insights, predictions)
|
||||
}
|
||||
|
||||
# Add execution info for debugging
|
||||
result["_execution_info"] = {
|
||||
"tool_name": "analyze_resource_growth_curves",
|
||||
"execution_time": round(execution_time, 3),
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"detailed_response": detailed_response
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Resource growth analysis failed: {str(e)}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"analysis_timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# ==================== Private Helper Methods ====================
|
||||
|
||||
async def _get_slow_query_data(self, connection, days: int, min_execution_time_ms: int) -> List[Dict]:
|
||||
"""Get slow query data from audit logs"""
|
||||
try:
|
||||
start_date = datetime.now() - timedelta(days=days)
|
||||
|
||||
slow_query_sql = f"""
|
||||
SELECT
|
||||
`user` as user_name,
|
||||
`client_ip` as host,
|
||||
`time` as query_time,
|
||||
`stmt` as sql_statement,
|
||||
`query_time` as execution_time_ms,
|
||||
`scan_bytes` as scan_bytes,
|
||||
`scan_rows` as scan_rows,
|
||||
`return_rows` as return_rows
|
||||
FROM internal.__internal_schema.audit_log
|
||||
WHERE `time` >= '{start_date.strftime('%Y-%m-%d %H:%M:%S')}'
|
||||
AND `query_time` >= {min_execution_time_ms}
|
||||
AND `stmt` IS NOT NULL
|
||||
AND `stmt` != ''
|
||||
ORDER BY `query_time` DESC
|
||||
LIMIT 5000
|
||||
"""
|
||||
|
||||
result = await connection.execute(slow_query_sql)
|
||||
return result.data if result.data else []
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get slow query data: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _analyze_top_slow_queries(self, slow_queries: List[Dict], top_n: int) -> List[Dict]:
|
||||
"""Analyze top N slowest queries"""
|
||||
# Sort by execution time and take top N
|
||||
sorted_queries = sorted(
|
||||
slow_queries,
|
||||
key=lambda x: x.get("execution_time_ms", 0),
|
||||
reverse=True
|
||||
)[:top_n]
|
||||
|
||||
analyzed_queries = []
|
||||
for i, query in enumerate(sorted_queries):
|
||||
sql = query.get("sql_statement", "")
|
||||
execution_time = query.get("execution_time_ms", 0)
|
||||
|
||||
analyzed_query = {
|
||||
"rank": i + 1,
|
||||
"execution_time_ms": execution_time,
|
||||
"execution_time_seconds": round(execution_time / 1000, 2),
|
||||
"user_name": query.get("user_name", "unknown"),
|
||||
"query_time": str(query.get("query_time", "")),
|
||||
"sql_statement": sql[:500] + "..." if len(sql) > 500 else sql,
|
||||
"sql_length": len(sql),
|
||||
"query_type": self._classify_query_type(sql),
|
||||
"scan_metrics": {
|
||||
"scan_bytes": query.get("scan_bytes", 0),
|
||||
"scan_rows": query.get("scan_rows", 0),
|
||||
"return_rows": query.get("return_rows", 0)
|
||||
},
|
||||
"performance_issues": self._identify_performance_issues(query)
|
||||
}
|
||||
|
||||
analyzed_queries.append(analyzed_query)
|
||||
|
||||
return analyzed_queries
|
||||
|
||||
def _classify_query_type(self, sql: str) -> str:
|
||||
"""Classify SQL query type"""
|
||||
if not sql:
|
||||
return "unknown"
|
||||
|
||||
sql_upper = sql.upper().strip()
|
||||
|
||||
if sql_upper.startswith('SELECT'):
|
||||
return "SELECT"
|
||||
elif sql_upper.startswith('INSERT'):
|
||||
return "INSERT"
|
||||
elif sql_upper.startswith('UPDATE'):
|
||||
return "UPDATE"
|
||||
elif sql_upper.startswith('DELETE'):
|
||||
return "DELETE"
|
||||
else:
|
||||
return "OTHER"
|
||||
|
||||
def _identify_performance_issues(self, query: Dict) -> List[str]:
|
||||
"""Identify potential performance issues in query"""
|
||||
issues = []
|
||||
|
||||
sql = query.get("sql_statement", "").upper()
|
||||
execution_time = query.get("execution_time_ms", 0)
|
||||
scan_bytes = query.get("scan_bytes", 0)
|
||||
scan_rows = query.get("scan_rows", 0)
|
||||
return_rows = query.get("return_rows", 0)
|
||||
|
||||
# High execution time
|
||||
if execution_time > 60000: # > 1 minute
|
||||
issues.append("very_long_execution")
|
||||
elif execution_time > 10000: # > 10 seconds
|
||||
issues.append("long_execution")
|
||||
|
||||
# Large data scan
|
||||
if scan_bytes > 1024**3: # > 1GB
|
||||
issues.append("large_data_scan")
|
||||
|
||||
# High row scan vs return ratio
|
||||
if scan_rows > 0 and return_rows > 0:
|
||||
scan_ratio = scan_rows / return_rows
|
||||
if scan_ratio > 1000:
|
||||
issues.append("inefficient_filtering")
|
||||
|
||||
# SQL pattern issues
|
||||
if "SELECT *" in sql:
|
||||
issues.append("select_all_columns")
|
||||
|
||||
if "ORDER BY" in sql and "LIMIT" not in sql:
|
||||
issues.append("unlimited_sort")
|
||||
|
||||
return issues
|
||||
|
||||
async def _generate_performance_insights(self, slow_queries: List[Dict]) -> Dict[str, Any]:
|
||||
"""Generate performance insights from slow queries"""
|
||||
if not slow_queries:
|
||||
return {}
|
||||
|
||||
execution_times = [q.get("execution_time_ms", 0) for q in slow_queries]
|
||||
scan_bytes = [q.get("scan_bytes", 0) for q in slow_queries if q.get("scan_bytes", 0) > 0]
|
||||
|
||||
# User analysis
|
||||
user_query_counts = Counter(q.get("user_name", "unknown") for q in slow_queries)
|
||||
|
||||
# Query type distribution
|
||||
query_types = Counter(self._classify_query_type(q.get("sql_statement", "")) for q in slow_queries)
|
||||
|
||||
# Time pattern analysis
|
||||
query_hours = []
|
||||
for query in slow_queries:
|
||||
try:
|
||||
query_time = query.get("query_time")
|
||||
if query_time:
|
||||
if isinstance(query_time, str):
|
||||
dt = datetime.fromisoformat(query_time.replace('Z', '+00:00'))
|
||||
else:
|
||||
dt = query_time
|
||||
query_hours.append(dt.hour)
|
||||
except:
|
||||
continue
|
||||
|
||||
hour_distribution = Counter(query_hours)
|
||||
|
||||
return {
|
||||
"execution_time_stats": {
|
||||
"avg_ms": round(statistics.mean(execution_times), 2) if execution_times else 0,
|
||||
"median_ms": round(statistics.median(execution_times), 2) if execution_times else 0,
|
||||
"max_ms": max(execution_times) if execution_times else 0,
|
||||
"min_ms": min(execution_times) if execution_times else 0
|
||||
},
|
||||
"data_scan_stats": {
|
||||
"avg_bytes": round(statistics.mean(scan_bytes), 2) if scan_bytes else 0,
|
||||
"max_bytes": max(scan_bytes) if scan_bytes else 0,
|
||||
"total_bytes_scanned": sum(scan_bytes) if scan_bytes else 0
|
||||
},
|
||||
"user_analysis": {
|
||||
"top_slow_query_users": dict(user_query_counts.most_common(10)),
|
||||
"unique_users": len(user_query_counts)
|
||||
},
|
||||
"query_type_distribution": dict(query_types),
|
||||
"temporal_patterns": {
|
||||
"hourly_distribution": dict(hour_distribution),
|
||||
"peak_hour": max(hour_distribution, key=hour_distribution.get) if hour_distribution else None
|
||||
}
|
||||
}
|
||||
|
||||
async def _analyze_query_patterns(self, slow_queries: List[Dict]) -> Dict[str, Any]:
|
||||
"""Analyze query patterns in slow queries"""
|
||||
patterns = {
|
||||
"common_issues": Counter(),
|
||||
"table_access_patterns": Counter(),
|
||||
"query_complexity": []
|
||||
}
|
||||
|
||||
for query in slow_queries:
|
||||
sql = query.get("sql_statement", "")
|
||||
|
||||
# Identify common issues
|
||||
issues = self._identify_performance_issues(query)
|
||||
patterns["common_issues"].update(issues)
|
||||
|
||||
# Extract table names
|
||||
tables = self._extract_table_names(sql)
|
||||
patterns["table_access_patterns"].update(tables)
|
||||
|
||||
# Query complexity metrics
|
||||
complexity = self._calculate_query_complexity(sql)
|
||||
patterns["query_complexity"].append(complexity)
|
||||
|
||||
return {
|
||||
"common_performance_issues": dict(patterns["common_issues"].most_common(10)),
|
||||
"frequently_accessed_tables": dict(patterns["table_access_patterns"].most_common(15)),
|
||||
"complexity_analysis": {
|
||||
"avg_complexity": round(statistics.mean(patterns["query_complexity"]), 2) if patterns["query_complexity"] else 0,
|
||||
"max_complexity": max(patterns["query_complexity"]) if patterns["query_complexity"] else 0,
|
||||
"high_complexity_queries": len([c for c in patterns["query_complexity"] if c > 10])
|
||||
}
|
||||
}
|
||||
|
||||
def _extract_table_names(self, sql: str) -> List[str]:
|
||||
"""Extract table names from SQL (simplified)"""
|
||||
import re
|
||||
|
||||
if not sql:
|
||||
return []
|
||||
|
||||
# Simple pattern matching for table names
|
||||
patterns = [
|
||||
r'\bFROM\s+([a-zA-Z_][a-zA-Z0-9_.]*)',
|
||||
r'\bJOIN\s+([a-zA-Z_][a-zA-Z0-9_.]*)',
|
||||
r'\bINTO\s+([a-zA-Z_][a-zA-Z0-9_.]*)',
|
||||
r'\bUPDATE\s+([a-zA-Z_][a-zA-Z0-9_.]*)'
|
||||
]
|
||||
|
||||
tables = []
|
||||
for pattern in patterns:
|
||||
matches = re.findall(pattern, sql, re.IGNORECASE)
|
||||
tables.extend(matches)
|
||||
|
||||
return [table.lower() for table in tables if table]
|
||||
|
||||
def _calculate_query_complexity(self, sql: str) -> int:
|
||||
"""Calculate query complexity score"""
|
||||
if not sql:
|
||||
return 0
|
||||
|
||||
sql_upper = sql.upper()
|
||||
complexity = 0
|
||||
|
||||
# Basic complexity factors
|
||||
complexity += sql_upper.count('JOIN') * 2
|
||||
complexity += sql_upper.count('SUBQUERY') * 3
|
||||
complexity += sql_upper.count('UNION') * 2
|
||||
complexity += sql_upper.count('GROUP BY') * 1
|
||||
complexity += sql_upper.count('ORDER BY') * 1
|
||||
complexity += sql_upper.count('HAVING') * 2
|
||||
complexity += sql_upper.count('CASE') * 1
|
||||
|
||||
# Length factor
|
||||
complexity += len(sql) // 100
|
||||
|
||||
return complexity
|
||||
|
||||
async def _analyze_storage_growth(self, connection, days: int, detailed_response: bool = False) -> Dict[str, Any]:
|
||||
"""Analyze storage growth patterns"""
|
||||
try:
|
||||
# Get table size data over time
|
||||
# This is a simplified approach - in practice you'd need historical data
|
||||
size_sql = """
|
||||
SELECT
|
||||
table_schema,
|
||||
table_name,
|
||||
ROUND(data_length / 1024 / 1024, 2) as size_mb,
|
||||
table_rows
|
||||
FROM information_schema.tables
|
||||
WHERE table_type = 'BASE TABLE'
|
||||
AND data_length > 0
|
||||
ORDER BY data_length DESC
|
||||
LIMIT 50
|
||||
"""
|
||||
|
||||
result = await connection.execute(size_sql)
|
||||
|
||||
if result.data:
|
||||
total_size = sum(row.get("size_mb", 0) for row in result.data)
|
||||
total_rows = sum(row.get("table_rows", 0) for row in result.data)
|
||||
|
||||
# Estimate growth (simplified - would need historical data for real analysis)
|
||||
growth_estimate = self._estimate_storage_growth(result.data)
|
||||
|
||||
storage_result = {
|
||||
"current_storage_mb": round(total_size, 2),
|
||||
"total_rows": total_rows,
|
||||
"table_count": len(result.data),
|
||||
"estimated_growth": growth_estimate,
|
||||
"growth_trend": "stable" # Simplified
|
||||
}
|
||||
|
||||
# Include detailed table information only if requested
|
||||
if detailed_response:
|
||||
storage_result["largest_tables"] = result.data[:10]
|
||||
else:
|
||||
# Only include top 3 for summary
|
||||
storage_result["top_tables_summary"] = result.data[:3]
|
||||
|
||||
return storage_result
|
||||
|
||||
return {"current_storage_mb": 0, "message": "No storage data available"}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze storage growth: {str(e)}")
|
||||
return {"error": str(e)}
|
||||
|
||||
def _estimate_storage_growth(self, table_data: List[Dict]) -> Dict[str, Any]:
|
||||
"""Estimate storage growth based on current data"""
|
||||
# This is a simplified estimation - real implementation would use historical data
|
||||
total_size = sum(row.get("size_mb", 0) for row in table_data)
|
||||
|
||||
return {
|
||||
"daily_growth_estimate_mb": round(total_size * 0.01, 2), # 1% daily growth estimate
|
||||
"monthly_growth_estimate_mb": round(total_size * 0.3, 2), # 30% monthly growth estimate
|
||||
"confidence": "low", # Low confidence without historical data
|
||||
"method": "simplified_estimation"
|
||||
}
|
||||
|
||||
async def _analyze_query_volume_growth(self, connection, days: int, detailed_response: bool = False) -> Dict[str, Any]:
|
||||
"""Analyze query volume growth patterns"""
|
||||
try:
|
||||
start_date = datetime.now() - timedelta(days=days)
|
||||
|
||||
volume_sql = f"""
|
||||
SELECT
|
||||
DATE(`time`) as query_date,
|
||||
COUNT(*) as query_count,
|
||||
COUNT(DISTINCT `user`) as unique_users
|
||||
FROM internal.__internal_schema.audit_log
|
||||
WHERE `time` >= '{start_date.strftime('%Y-%m-%d')}'
|
||||
AND `stmt` IS NOT NULL
|
||||
GROUP BY DATE(`time`)
|
||||
ORDER BY query_date DESC
|
||||
LIMIT {days}
|
||||
"""
|
||||
|
||||
result = await connection.execute(volume_sql)
|
||||
|
||||
if result.data:
|
||||
daily_volumes = [row.get("query_count", 0) for row in result.data]
|
||||
avg_daily_queries = statistics.mean(daily_volumes) if daily_volumes else 0
|
||||
|
||||
# Simple trend analysis
|
||||
trend = "stable"
|
||||
if len(daily_volumes) > 3:
|
||||
recent_avg = statistics.mean(daily_volumes[:3])
|
||||
older_avg = statistics.mean(daily_volumes[-3:])
|
||||
if recent_avg > older_avg * 1.1:
|
||||
trend = "increasing"
|
||||
elif recent_avg < older_avg * 0.9:
|
||||
trend = "decreasing"
|
||||
|
||||
volume_result = {
|
||||
"avg_daily_queries": round(avg_daily_queries, 2),
|
||||
"max_daily_queries": max(daily_volumes) if daily_volumes else 0,
|
||||
"min_daily_queries": min(daily_volumes) if daily_volumes else 0,
|
||||
"total_queries": sum(daily_volumes) if daily_volumes else 0,
|
||||
"trend": trend
|
||||
}
|
||||
|
||||
# Include detailed daily data only if requested
|
||||
if detailed_response:
|
||||
# Fix date serialization in daily_data
|
||||
serializable_data = []
|
||||
for row in result.data:
|
||||
serializable_row = {}
|
||||
for key, value in row.items():
|
||||
if hasattr(value, 'isoformat'): # datetime/date object
|
||||
serializable_row[key] = value.isoformat()
|
||||
else:
|
||||
serializable_row[key] = value
|
||||
serializable_data.append(serializable_row)
|
||||
volume_result["daily_data"] = serializable_data
|
||||
else:
|
||||
# Only include recent data summary
|
||||
volume_result["recent_days_summary"] = {
|
||||
"last_7_days_avg": round(statistics.mean(daily_volumes[:7]) if len(daily_volumes) >= 7 else avg_daily_queries, 2),
|
||||
"data_points": min(len(daily_volumes), 7)
|
||||
}
|
||||
|
||||
return volume_result
|
||||
|
||||
return {"avg_daily_queries": 0, "message": "No query volume data available"}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze query volume growth: {str(e)}")
|
||||
return {"error": str(e)}
|
||||
|
||||
async def _analyze_user_activity_growth(self, connection, days: int, detailed_response: bool = False) -> Dict[str, Any]:
|
||||
"""Analyze user activity growth patterns"""
|
||||
try:
|
||||
start_date = datetime.now() - timedelta(days=days)
|
||||
|
||||
activity_sql = f"""
|
||||
SELECT
|
||||
DATE(`time`) as activity_date,
|
||||
COUNT(DISTINCT `user`) as active_users,
|
||||
COUNT(*) as total_queries
|
||||
FROM internal.__internal_schema.audit_log
|
||||
WHERE `time` >= '{start_date.strftime('%Y-%m-%d')}'
|
||||
AND `stmt` IS NOT NULL
|
||||
GROUP BY DATE(`time`)
|
||||
ORDER BY activity_date DESC
|
||||
LIMIT {days}
|
||||
"""
|
||||
|
||||
result = await connection.execute(activity_sql)
|
||||
|
||||
if result.data:
|
||||
daily_users = [row.get("active_users", 0) for row in result.data]
|
||||
avg_daily_users = statistics.mean(daily_users) if daily_users else 0
|
||||
|
||||
activity_result = {
|
||||
"avg_daily_active_users": round(avg_daily_users, 2),
|
||||
"max_daily_users": max(daily_users) if daily_users else 0,
|
||||
"total_unique_users": len(set(row.get("active_users", 0) for row in result.data))
|
||||
}
|
||||
|
||||
# Include detailed daily activity only if requested
|
||||
if detailed_response:
|
||||
# Fix date serialization in daily_activity
|
||||
serializable_activity = []
|
||||
for row in result.data:
|
||||
serializable_row = {}
|
||||
for key, value in row.items():
|
||||
if hasattr(value, 'isoformat'): # datetime/date object
|
||||
serializable_row[key] = value.isoformat()
|
||||
else:
|
||||
serializable_row[key] = value
|
||||
serializable_activity.append(serializable_row)
|
||||
activity_result["daily_activity"] = serializable_activity
|
||||
else:
|
||||
# Only include recent activity summary
|
||||
recent_queries = [row.get("total_queries", 0) for row in result.data[:7]]
|
||||
activity_result["recent_activity_summary"] = {
|
||||
"last_7_days_avg_queries": round(statistics.mean(recent_queries) if recent_queries else 0, 2),
|
||||
"activity_trend": "active" if avg_daily_users > 1 else "low"
|
||||
}
|
||||
|
||||
return activity_result
|
||||
|
||||
return {"avg_daily_active_users": 0, "message": "No user activity data available"}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze user activity growth: {str(e)}")
|
||||
return {"error": str(e)}
|
||||
|
||||
async def _generate_growth_insights(self, resource_analysis: Dict, days: int) -> Dict[str, Any]:
|
||||
"""Generate insights from resource growth analysis"""
|
||||
insights = {}
|
||||
|
||||
# Storage insights
|
||||
if "storage" in resource_analysis:
|
||||
storage = resource_analysis["storage"]
|
||||
if "current_storage_mb" in storage:
|
||||
insights["storage"] = {
|
||||
"current_size_gb": round(storage["current_storage_mb"] / 1024, 2),
|
||||
"growth_rate": storage.get("estimated_growth", {}).get("daily_growth_estimate_mb", 0),
|
||||
"capacity_concern": storage["current_storage_mb"] > 10000 # > 10GB
|
||||
}
|
||||
|
||||
# Query volume insights
|
||||
if "query_volume" in resource_analysis:
|
||||
volume = resource_analysis["query_volume"]
|
||||
insights["query_volume"] = {
|
||||
"daily_average": volume.get("avg_daily_queries", 0),
|
||||
"load_level": "high" if volume.get("avg_daily_queries", 0) > 1000 else "normal",
|
||||
"trend": volume.get("trend", "stable")
|
||||
}
|
||||
|
||||
# User activity insights
|
||||
if "user_activity" in resource_analysis:
|
||||
activity = resource_analysis["user_activity"]
|
||||
insights["user_activity"] = {
|
||||
"active_user_base": activity.get("avg_daily_active_users", 0),
|
||||
"user_engagement": "high" if activity.get("avg_daily_active_users", 0) > 10 else "normal"
|
||||
}
|
||||
|
||||
return insights
|
||||
|
||||
async def _generate_growth_predictions(self, resource_analysis: Dict) -> Dict[str, Any]:
|
||||
"""Generate growth predictions (simplified)"""
|
||||
predictions = {}
|
||||
|
||||
# This is a simplified prediction model
|
||||
# Real implementation would use time series analysis
|
||||
|
||||
if "storage" in resource_analysis:
|
||||
storage = resource_analysis["storage"]
|
||||
current_size = storage.get("current_storage_mb", 0)
|
||||
daily_growth = storage.get("estimated_growth", {}).get("daily_growth_estimate_mb", 0)
|
||||
|
||||
predictions["storage"] = {
|
||||
"30_day_projection_mb": round(current_size + (daily_growth * 30), 2),
|
||||
"90_day_projection_mb": round(current_size + (daily_growth * 90), 2),
|
||||
"confidence": "low"
|
||||
}
|
||||
|
||||
return predictions
|
||||
|
||||
def _generate_performance_recommendations(self, performance_insights: Dict, pattern_analysis: Dict) -> List[Dict]:
|
||||
"""Generate performance improvement recommendations"""
|
||||
recommendations = []
|
||||
|
||||
# Execution time recommendations
|
||||
exec_stats = performance_insights.get("execution_time_stats", {})
|
||||
avg_time = exec_stats.get("avg_ms", 0)
|
||||
|
||||
if avg_time > 30000: # > 30 seconds
|
||||
recommendations.append({
|
||||
"type": "query_optimization",
|
||||
"priority": "high",
|
||||
"title": "High average query execution time",
|
||||
"description": f"Average slow query time is {avg_time/1000:.1f} seconds",
|
||||
"action": "Review and optimize slowest queries, consider indexing strategies"
|
||||
})
|
||||
|
||||
# Pattern-based recommendations
|
||||
if pattern_analysis:
|
||||
common_issues = pattern_analysis.get("common_performance_issues", {})
|
||||
|
||||
if common_issues.get("select_all_columns", 0) > 5:
|
||||
recommendations.append({
|
||||
"type": "query_best_practices",
|
||||
"priority": "medium",
|
||||
"title": "Frequent SELECT * usage detected",
|
||||
"description": "Many queries use SELECT * which can impact performance",
|
||||
"action": "Replace SELECT * with specific column names in queries"
|
||||
})
|
||||
|
||||
if common_issues.get("large_data_scan", 0) > 3:
|
||||
recommendations.append({
|
||||
"type": "data_access_optimization",
|
||||
"priority": "high",
|
||||
"title": "Large data scans detected",
|
||||
"description": "Multiple queries are scanning large amounts of data",
|
||||
"action": "Review partitioning strategies and add appropriate indexes"
|
||||
})
|
||||
|
||||
return recommendations
|
||||
|
||||
def _generate_growth_recommendations(self, growth_insights: Dict, predictions: Dict) -> List[Dict]:
|
||||
"""Generate resource growth recommendations"""
|
||||
recommendations = []
|
||||
|
||||
# Storage recommendations
|
||||
storage_insights = growth_insights.get("storage", {})
|
||||
if storage_insights.get("capacity_concern", False):
|
||||
recommendations.append({
|
||||
"type": "capacity_planning",
|
||||
"priority": "medium",
|
||||
"title": "Storage capacity monitoring needed",
|
||||
"description": "Current storage usage is significant",
|
||||
"action": "Implement storage monitoring and consider data archival strategies"
|
||||
})
|
||||
|
||||
# Query volume recommendations
|
||||
query_insights = growth_insights.get("query_volume", {})
|
||||
if query_insights.get("load_level") == "high":
|
||||
recommendations.append({
|
||||
"type": "performance_scaling",
|
||||
"priority": "medium",
|
||||
"title": "High query volume detected",
|
||||
"description": "System is handling high query volumes",
|
||||
"action": "Monitor system performance and consider scaling strategies"
|
||||
})
|
||||
|
||||
return recommendations
|
||||
744
doris_mcp_server/utils/security_analytics_tools.py
Normal file
744
doris_mcp_server/utils/security_analytics_tools.py
Normal file
@@ -0,0 +1,744 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Security Analytics Tools Module
|
||||
Provides data access analysis, user behavior monitoring, and security insights
|
||||
"""
|
||||
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional
|
||||
from collections import Counter, defaultdict
|
||||
|
||||
from .db import DorisConnectionManager
|
||||
from .logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SecurityAnalyticsTools:
|
||||
"""Security analytics tools for access pattern analysis and user monitoring"""
|
||||
|
||||
def __init__(self, connection_manager: DorisConnectionManager):
|
||||
self.connection_manager = connection_manager
|
||||
logger.info("SecurityAnalyticsTools initialized")
|
||||
|
||||
async def analyze_data_access_patterns(
|
||||
self,
|
||||
days: int = 7,
|
||||
include_system_users: bool = False,
|
||||
min_query_threshold: int = 5
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Analyze data access patterns for users and roles
|
||||
|
||||
Args:
|
||||
days: Number of days to analyze
|
||||
include_system_users: Whether to include system/service users
|
||||
min_query_threshold: Minimum queries for a user to be included in analysis
|
||||
|
||||
Returns:
|
||||
Comprehensive access pattern analysis
|
||||
"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
connection = await self.connection_manager.get_connection("query")
|
||||
|
||||
# Define analysis period
|
||||
end_date = datetime.now()
|
||||
start_date = end_date - timedelta(days=days)
|
||||
|
||||
# 1. Get audit log data
|
||||
audit_data = await self._get_audit_log_data(connection, start_date, end_date, include_system_users)
|
||||
|
||||
if not audit_data:
|
||||
return {
|
||||
"error": "No audit data available for the specified period",
|
||||
"analysis_period": {
|
||||
"start_date": start_date.isoformat(),
|
||||
"end_date": end_date.isoformat(),
|
||||
"days": days
|
||||
}
|
||||
}
|
||||
|
||||
# 2. Analyze user access patterns
|
||||
user_access_analysis = await self._analyze_user_access_patterns(
|
||||
audit_data, min_query_threshold
|
||||
)
|
||||
|
||||
# 3. Analyze role-based access
|
||||
role_access_analysis = await self._analyze_role_access_patterns(
|
||||
connection, user_access_analysis
|
||||
)
|
||||
|
||||
# 4. Detect security anomalies
|
||||
security_alerts = await self._detect_security_anomalies(
|
||||
audit_data, user_access_analysis
|
||||
)
|
||||
|
||||
# 5. Generate access insights
|
||||
access_insights = await self._generate_access_insights(
|
||||
user_access_analysis, role_access_analysis
|
||||
)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
return {
|
||||
"analysis_period": {
|
||||
"start_date": start_date.isoformat(),
|
||||
"end_date": end_date.isoformat(),
|
||||
"days": days
|
||||
},
|
||||
"analysis_timestamp": datetime.now().isoformat(),
|
||||
"execution_time_seconds": round(execution_time, 3),
|
||||
"user_access_summary": self._generate_user_access_summary(user_access_analysis),
|
||||
"user_access_details": user_access_analysis,
|
||||
"role_analysis": role_access_analysis,
|
||||
"security_alerts": security_alerts,
|
||||
"access_insights": access_insights,
|
||||
"recommendations": self._generate_security_recommendations(security_alerts, access_insights)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Data access pattern analysis failed: {str(e)}")
|
||||
return {
|
||||
"error": str(e),
|
||||
"analysis_timestamp": datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# ==================== Private Helper Methods ====================
|
||||
|
||||
async def _get_audit_log_data(self, connection, start_date: datetime, end_date: datetime, include_system_users: bool) -> List[Dict]:
|
||||
"""Retrieve audit log data for the specified period"""
|
||||
try:
|
||||
# System users filter
|
||||
system_user_filter = ""
|
||||
if not include_system_users:
|
||||
system_users = ['root', 'admin', 'system', 'doris', 'information_schema']
|
||||
user_list = ','.join([f'"{user}"' for user in system_users])
|
||||
system_user_filter = f"AND `user` NOT IN ({user_list})"
|
||||
|
||||
audit_sql = f"""
|
||||
SELECT
|
||||
`user` as user_name,
|
||||
`client_ip` as host,
|
||||
`time` as query_time,
|
||||
`stmt` as sql_statement,
|
||||
`state` as query_status,
|
||||
`scan_bytes` as scan_bytes,
|
||||
`scan_rows` as scan_rows,
|
||||
`return_rows` as return_rows,
|
||||
`query_time` as execution_time_ms
|
||||
FROM internal.__internal_schema.audit_log
|
||||
WHERE `time` >= '{start_date.strftime('%Y-%m-%d %H:%M:%S')}'
|
||||
AND `time` <= '{end_date.strftime('%Y-%m-%d %H:%M:%S')}'
|
||||
AND `stmt` IS NOT NULL
|
||||
AND `stmt` != ''
|
||||
{system_user_filter}
|
||||
ORDER BY `time` DESC
|
||||
LIMIT 10000
|
||||
"""
|
||||
|
||||
result = await connection.execute(audit_sql)
|
||||
return result.data if result.data else []
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get audit log data: {str(e)}")
|
||||
# Try alternative method without detailed metrics
|
||||
try:
|
||||
simple_audit_sql = f"""
|
||||
SELECT
|
||||
`user` as user_name,
|
||||
`client_ip` as host,
|
||||
`time` as query_time,
|
||||
`stmt` as sql_statement,
|
||||
`state` as query_status
|
||||
FROM internal.__internal_schema.audit_log
|
||||
WHERE `time` >= '{start_date.strftime('%Y-%m-%d %H:%M:%S')}'
|
||||
AND `time` <= '{end_date.strftime('%Y-%m-%d %H:%M:%S')}'
|
||||
AND `stmt` IS NOT NULL
|
||||
{system_user_filter}
|
||||
ORDER BY `time` DESC
|
||||
LIMIT 10000
|
||||
"""
|
||||
|
||||
result = await connection.execute(simple_audit_sql)
|
||||
return result.data if result.data else []
|
||||
|
||||
except Exception as e2:
|
||||
logger.error(f"Failed to get simplified audit log data: {str(e2)}")
|
||||
return []
|
||||
|
||||
async def _analyze_user_access_patterns(self, audit_data: List[Dict], min_query_threshold: int) -> List[Dict]:
|
||||
"""Analyze access patterns for individual users"""
|
||||
user_stats = defaultdict(lambda: {
|
||||
"total_queries": 0,
|
||||
"unique_tables_accessed": set(),
|
||||
"hosts": set(),
|
||||
"query_types": Counter(),
|
||||
"query_times": [],
|
||||
"failed_queries": 0,
|
||||
"data_volume_read_bytes": 0,
|
||||
"data_volume_read_rows": 0,
|
||||
"hourly_pattern": [0] * 24,
|
||||
"daily_pattern": [0] * 7,
|
||||
"query_statements": []
|
||||
})
|
||||
|
||||
# Process audit data
|
||||
for entry in audit_data:
|
||||
user_name = entry.get("user_name", "unknown")
|
||||
query_time = entry.get("query_time")
|
||||
sql_statement = entry.get("sql_statement", "")
|
||||
query_status = entry.get("query_status", "")
|
||||
|
||||
stats = user_stats[user_name]
|
||||
stats["total_queries"] += 1
|
||||
|
||||
# Extract table names from SQL
|
||||
tables = self._extract_table_names_from_sql(sql_statement)
|
||||
stats["unique_tables_accessed"].update(tables)
|
||||
|
||||
# Host tracking
|
||||
if entry.get("host"):
|
||||
stats["hosts"].add(entry["host"])
|
||||
|
||||
# Query type analysis
|
||||
query_type = self._classify_query_type(sql_statement)
|
||||
stats["query_types"][query_type] += 1
|
||||
|
||||
# Query time patterns
|
||||
if query_time:
|
||||
try:
|
||||
if isinstance(query_time, str):
|
||||
query_dt = datetime.fromisoformat(query_time.replace('Z', '+00:00'))
|
||||
else:
|
||||
query_dt = query_time
|
||||
|
||||
stats["query_times"].append(query_dt)
|
||||
stats["hourly_pattern"][query_dt.hour] += 1
|
||||
stats["daily_pattern"][query_dt.weekday()] += 1
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Error tracking
|
||||
if query_status and "error" in query_status.lower():
|
||||
stats["failed_queries"] += 1
|
||||
|
||||
# Data volume tracking
|
||||
if entry.get("scan_bytes"):
|
||||
try:
|
||||
stats["data_volume_read_bytes"] += int(entry["scan_bytes"])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
if entry.get("scan_rows"):
|
||||
try:
|
||||
stats["data_volume_read_rows"] += int(entry["scan_rows"])
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
# Store sample queries
|
||||
if len(stats["query_statements"]) < 10:
|
||||
stats["query_statements"].append({
|
||||
"sql": sql_statement[:200] + "..." if len(sql_statement) > 200 else sql_statement,
|
||||
"timestamp": str(query_time),
|
||||
"type": query_type
|
||||
})
|
||||
|
||||
# Convert to analysis results
|
||||
user_analysis = []
|
||||
for user_name, stats in user_stats.items():
|
||||
if stats["total_queries"] >= min_query_threshold:
|
||||
# Calculate patterns and insights
|
||||
access_pattern = self._classify_access_pattern(stats["hourly_pattern"])
|
||||
table_access_frequency = dict(Counter(
|
||||
table for entry in audit_data
|
||||
if entry.get("user_name") == user_name
|
||||
for table in self._extract_table_names_from_sql(entry.get("sql_statement", ""))
|
||||
).most_common(10))
|
||||
|
||||
user_analysis.append({
|
||||
"user_name": user_name,
|
||||
"access_stats": {
|
||||
"total_queries": stats["total_queries"],
|
||||
"unique_tables_accessed": len(stats["unique_tables_accessed"]),
|
||||
"unique_hosts": len(stats["hosts"]),
|
||||
"data_volume_read_gb": round(stats["data_volume_read_bytes"] / (1024**3), 3),
|
||||
"data_volume_read_rows": stats["data_volume_read_rows"],
|
||||
"failed_queries": stats["failed_queries"],
|
||||
"success_rate": round((stats["total_queries"] - stats["failed_queries"]) / stats["total_queries"], 3) if stats["total_queries"] > 0 else 0,
|
||||
"peak_access_hour": stats["hourly_pattern"].index(max(stats["hourly_pattern"])) if max(stats["hourly_pattern"]) > 0 else None,
|
||||
"access_pattern": access_pattern
|
||||
},
|
||||
"query_type_distribution": dict(stats["query_types"]),
|
||||
"table_access_frequency": table_access_frequency,
|
||||
"hosts_used": list(stats["hosts"]),
|
||||
"sample_queries": stats["query_statements"],
|
||||
"temporal_patterns": {
|
||||
"hourly_distribution": stats["hourly_pattern"],
|
||||
"daily_distribution": stats["daily_pattern"]
|
||||
}
|
||||
})
|
||||
|
||||
return sorted(user_analysis, key=lambda x: x["access_stats"]["total_queries"], reverse=True)
|
||||
|
||||
def _extract_table_names_from_sql(self, sql: str) -> List[str]:
|
||||
"""Extract table names from SQL statement (simplified implementation)"""
|
||||
if not sql:
|
||||
return []
|
||||
|
||||
import re
|
||||
|
||||
# Simple regex patterns to match table names
|
||||
patterns = [
|
||||
r'\bFROM\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)',
|
||||
r'\bJOIN\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)',
|
||||
r'\bINTO\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)',
|
||||
r'\bUPDATE\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)',
|
||||
r'\bDELETE\s+FROM\s+([a-zA-Z_][a-zA-Z0-9_]*(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)'
|
||||
]
|
||||
|
||||
tables = []
|
||||
for pattern in patterns:
|
||||
matches = re.findall(pattern, sql, re.IGNORECASE)
|
||||
tables.extend(matches)
|
||||
|
||||
# Clean up table names (remove quotes, aliases, etc.)
|
||||
cleaned_tables = []
|
||||
for table in tables:
|
||||
# Remove backticks, quotes, and get just the table name
|
||||
clean_table = table.strip('`"\'').split(' ')[0]
|
||||
if clean_table and not clean_table.upper() in ['SELECT', 'WHERE', 'AND', 'OR']:
|
||||
cleaned_tables.append(clean_table)
|
||||
|
||||
return list(set(cleaned_tables))
|
||||
|
||||
def _classify_query_type(self, sql: str) -> str:
|
||||
"""Classify SQL query type"""
|
||||
if not sql:
|
||||
return "unknown"
|
||||
|
||||
sql_upper = sql.upper().strip()
|
||||
|
||||
if sql_upper.startswith('SELECT'):
|
||||
return "SELECT"
|
||||
elif sql_upper.startswith('INSERT'):
|
||||
return "INSERT"
|
||||
elif sql_upper.startswith('UPDATE'):
|
||||
return "UPDATE"
|
||||
elif sql_upper.startswith('DELETE'):
|
||||
return "DELETE"
|
||||
elif sql_upper.startswith('CREATE'):
|
||||
return "CREATE"
|
||||
elif sql_upper.startswith('ALTER'):
|
||||
return "ALTER"
|
||||
elif sql_upper.startswith('DROP'):
|
||||
return "DROP"
|
||||
elif sql_upper.startswith('SHOW'):
|
||||
return "SHOW"
|
||||
elif sql_upper.startswith('DESCRIBE') or sql_upper.startswith('DESC'):
|
||||
return "DESCRIBE"
|
||||
else:
|
||||
return "OTHER"
|
||||
|
||||
def _classify_access_pattern(self, hourly_pattern: List[int]) -> str:
|
||||
"""Classify user access pattern based on hourly distribution"""
|
||||
if not hourly_pattern or max(hourly_pattern) == 0:
|
||||
return "no_pattern"
|
||||
|
||||
# Find peak hours
|
||||
max_queries = max(hourly_pattern)
|
||||
peak_hours = [i for i, count in enumerate(hourly_pattern) if count == max_queries]
|
||||
|
||||
# Business hours: 9-17
|
||||
business_hours = set(range(9, 18))
|
||||
peak_in_business_hours = any(hour in business_hours for hour in peak_hours)
|
||||
|
||||
# Night hours: 22-6
|
||||
night_hours = set(list(range(22, 24)) + list(range(0, 7)))
|
||||
peak_in_night_hours = any(hour in night_hours for hour in peak_hours)
|
||||
|
||||
if peak_in_business_hours and not peak_in_night_hours:
|
||||
return "regular_business_hours"
|
||||
elif peak_in_night_hours:
|
||||
return "night_shift_or_batch"
|
||||
elif len(peak_hours) > 6: # Distributed throughout day
|
||||
return "distributed_access"
|
||||
else:
|
||||
return "irregular_pattern"
|
||||
|
||||
async def _analyze_role_access_patterns(self, connection, user_access_analysis: List[Dict]) -> Dict[str, Any]:
|
||||
"""Analyze access patterns by role"""
|
||||
try:
|
||||
# Get user roles information
|
||||
user_roles = await self._get_user_roles(connection)
|
||||
|
||||
# Group users by roles
|
||||
role_stats = defaultdict(lambda: {
|
||||
"user_count": 0,
|
||||
"total_queries": 0,
|
||||
"unique_tables": set(),
|
||||
"query_types": Counter(),
|
||||
"avg_queries_per_user": 0,
|
||||
"users": []
|
||||
})
|
||||
|
||||
# Process user access data
|
||||
for user_data in user_access_analysis:
|
||||
user_name = user_data["user_name"]
|
||||
user_stats = user_data["access_stats"]
|
||||
query_types = user_data["query_type_distribution"]
|
||||
|
||||
# Get user roles (default to 'unknown' if not found)
|
||||
roles = user_roles.get(user_name, ["unknown"])
|
||||
|
||||
for role in roles:
|
||||
stats = role_stats[role]
|
||||
stats["user_count"] += 1
|
||||
stats["total_queries"] += user_stats["total_queries"]
|
||||
stats["users"].append(user_name)
|
||||
|
||||
# Aggregate query types
|
||||
for query_type, count in query_types.items():
|
||||
stats["query_types"][query_type] += count
|
||||
|
||||
# Calculate role analysis
|
||||
role_analysis = {}
|
||||
for role, stats in role_stats.items():
|
||||
if stats["user_count"] > 0:
|
||||
avg_queries = stats["total_queries"] / stats["user_count"]
|
||||
|
||||
# Calculate privilege usage (simplified)
|
||||
total_role_queries = sum(stats["query_types"].values())
|
||||
privilege_usage = {}
|
||||
if total_role_queries > 0:
|
||||
privilege_usage = {
|
||||
query_type: round(count / total_role_queries, 3)
|
||||
for query_type, count in stats["query_types"].items()
|
||||
}
|
||||
|
||||
role_analysis[role] = {
|
||||
"user_count": stats["user_count"],
|
||||
"users": stats["users"],
|
||||
"total_queries": stats["total_queries"],
|
||||
"avg_queries_per_user": round(avg_queries, 1),
|
||||
"query_type_distribution": dict(stats["query_types"]),
|
||||
"privilege_usage": privilege_usage,
|
||||
"activity_level": self._classify_role_activity_level(avg_queries)
|
||||
}
|
||||
|
||||
return role_analysis
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to analyze role access patterns: {str(e)}")
|
||||
return {}
|
||||
|
||||
async def _get_user_roles(self, connection) -> Dict[str, List[str]]:
|
||||
"""Get user roles mapping"""
|
||||
try:
|
||||
# Try to get user role information
|
||||
roles_sql = """
|
||||
SELECT
|
||||
User as user_name,
|
||||
COALESCE(Default_role, 'default') as role_name
|
||||
FROM mysql.user
|
||||
"""
|
||||
|
||||
result = await connection.execute(roles_sql)
|
||||
|
||||
user_roles = defaultdict(list)
|
||||
if result.data:
|
||||
for row in result.data:
|
||||
user_name = row.get("user_name", "")
|
||||
role_name = row.get("role_name", "default")
|
||||
if user_name:
|
||||
user_roles[user_name].append(role_name)
|
||||
|
||||
return dict(user_roles)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get user roles: {str(e)}")
|
||||
return {}
|
||||
|
||||
def _classify_role_activity_level(self, avg_queries: float) -> str:
|
||||
"""Classify role activity level based on average queries"""
|
||||
if avg_queries > 100:
|
||||
return "high"
|
||||
elif avg_queries > 20:
|
||||
return "medium"
|
||||
elif avg_queries > 5:
|
||||
return "low"
|
||||
else:
|
||||
return "minimal"
|
||||
|
||||
async def _detect_security_anomalies(self, audit_data: List[Dict], user_access_analysis: List[Dict]) -> List[Dict]:
|
||||
"""Detect potential security anomalies"""
|
||||
alerts = []
|
||||
|
||||
# 1. Detect unusual access times
|
||||
for user_data in user_access_analysis:
|
||||
user_name = user_data["user_name"]
|
||||
hourly_pattern = user_data["temporal_patterns"]["hourly_distribution"]
|
||||
|
||||
# Check for significant night-time activity
|
||||
night_queries = sum(hourly_pattern[22:24]) + sum(hourly_pattern[0:6])
|
||||
total_queries = sum(hourly_pattern)
|
||||
|
||||
if total_queries > 0 and night_queries / total_queries > 0.3: # >30% night activity
|
||||
alerts.append({
|
||||
"alert_type": "unusual_access_time",
|
||||
"severity": "medium",
|
||||
"user": user_name,
|
||||
"description": f"User {user_name} has {night_queries/total_queries:.1%} of queries during night hours",
|
||||
"night_query_percentage": round(night_queries/total_queries, 3),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
# 2. Detect users with high failure rates
|
||||
for user_data in user_access_analysis:
|
||||
user_name = user_data["user_name"]
|
||||
success_rate = user_data["access_stats"]["success_rate"]
|
||||
total_queries = user_data["access_stats"]["total_queries"]
|
||||
|
||||
if total_queries > 10 and success_rate < 0.8: # <80% success rate
|
||||
alerts.append({
|
||||
"alert_type": "high_failure_rate",
|
||||
"severity": "medium",
|
||||
"user": user_name,
|
||||
"description": f"User {user_name} has low query success rate ({success_rate:.1%})",
|
||||
"success_rate": success_rate,
|
||||
"total_queries": total_queries,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
# 3. Detect unusual data volume access
|
||||
data_volumes = [user["access_stats"]["data_volume_read_gb"] for user in user_access_analysis]
|
||||
if data_volumes:
|
||||
avg_volume = sum(data_volumes) / len(data_volumes)
|
||||
std_dev = (sum((x - avg_volume) ** 2 for x in data_volumes) / len(data_volumes)) ** 0.5
|
||||
threshold = avg_volume + 2 * std_dev # 2 standard deviations above mean
|
||||
|
||||
for user_data in user_access_analysis:
|
||||
user_name = user_data["user_name"]
|
||||
volume = user_data["access_stats"]["data_volume_read_gb"]
|
||||
|
||||
if volume > threshold and volume > 1.0: # >1GB and above threshold
|
||||
alerts.append({
|
||||
"alert_type": "unusual_data_volume",
|
||||
"severity": "high" if volume > threshold * 2 else "medium",
|
||||
"user": user_name,
|
||||
"description": f"User {user_name} read {volume:.2f}GB (threshold: {threshold:.2f}GB)",
|
||||
"data_volume_gb": volume,
|
||||
"threshold_gb": round(threshold, 2),
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
# 4. Detect users accessing many different tables
|
||||
for user_data in user_access_analysis:
|
||||
user_name = user_data["user_name"]
|
||||
unique_tables = user_data["access_stats"]["unique_tables_accessed"]
|
||||
total_queries = user_data["access_stats"]["total_queries"]
|
||||
|
||||
# High table diversity might indicate privilege escalation or data mining
|
||||
if unique_tables > 20 and total_queries > 50:
|
||||
alerts.append({
|
||||
"alert_type": "broad_table_access",
|
||||
"severity": "medium",
|
||||
"user": user_name,
|
||||
"description": f"User {user_name} accessed {unique_tables} different tables",
|
||||
"unique_tables_count": unique_tables,
|
||||
"total_queries": total_queries,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
})
|
||||
|
||||
return sorted(alerts, key=lambda x: {"high": 3, "medium": 2, "low": 1}.get(x["severity"], 0), reverse=True)
|
||||
|
||||
async def _generate_access_insights(self, user_access_analysis: List[Dict], role_analysis: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Generate access insights and patterns"""
|
||||
insights = {
|
||||
"user_behavior_patterns": {},
|
||||
"role_effectiveness": {},
|
||||
"security_posture": {}
|
||||
}
|
||||
|
||||
# User behavior patterns
|
||||
if user_access_analysis:
|
||||
total_users = len(user_access_analysis)
|
||||
active_users = len([u for u in user_access_analysis if u["access_stats"]["total_queries"] > 10])
|
||||
power_users = len([u for u in user_access_analysis if u["access_stats"]["total_queries"] > 100])
|
||||
|
||||
# Access pattern distribution
|
||||
pattern_distribution = Counter(
|
||||
user["access_stats"]["access_pattern"] for user in user_access_analysis
|
||||
)
|
||||
|
||||
insights["user_behavior_patterns"] = {
|
||||
"total_users_analyzed": total_users,
|
||||
"active_users": active_users,
|
||||
"power_users": power_users,
|
||||
"access_pattern_distribution": dict(pattern_distribution),
|
||||
"avg_queries_per_user": round(
|
||||
sum(u["access_stats"]["total_queries"] for u in user_access_analysis) / total_users, 1
|
||||
) if total_users > 0 else 0
|
||||
}
|
||||
|
||||
# Role effectiveness
|
||||
if role_analysis:
|
||||
most_active_role = max(role_analysis.items(), key=lambda x: x[1]["total_queries"])
|
||||
least_active_role = min(role_analysis.items(), key=lambda x: x[1]["total_queries"])
|
||||
|
||||
insights["role_effectiveness"] = {
|
||||
"total_roles": len(role_analysis),
|
||||
"most_active_role": {
|
||||
"role": most_active_role[0],
|
||||
"total_queries": most_active_role[1]["total_queries"],
|
||||
"user_count": most_active_role[1]["user_count"]
|
||||
},
|
||||
"least_active_role": {
|
||||
"role": least_active_role[0],
|
||||
"total_queries": least_active_role[1]["total_queries"],
|
||||
"user_count": least_active_role[1]["user_count"]
|
||||
},
|
||||
"avg_users_per_role": round(
|
||||
sum(role_info["user_count"] for role_info in role_analysis.values()) / len(role_analysis), 1
|
||||
)
|
||||
}
|
||||
|
||||
# Security posture assessment
|
||||
if user_access_analysis:
|
||||
users_with_failures = len([u for u in user_access_analysis if u["access_stats"]["failed_queries"] > 0])
|
||||
users_night_access = len([
|
||||
u for u in user_access_analysis
|
||||
if any(u["temporal_patterns"]["hourly_distribution"][hour] > 0 for hour in list(range(22, 24)) + list(range(0, 6)))
|
||||
])
|
||||
|
||||
insights["security_posture"] = {
|
||||
"users_with_query_failures": users_with_failures,
|
||||
"users_with_night_access": users_night_access,
|
||||
"security_score": self._calculate_security_score(user_access_analysis),
|
||||
"risk_level": self._assess_overall_risk_level(user_access_analysis)
|
||||
}
|
||||
|
||||
return insights
|
||||
|
||||
def _calculate_security_score(self, user_access_analysis: List[Dict]) -> float:
|
||||
"""Calculate overall security score (0-1, higher is better)"""
|
||||
if not user_access_analysis:
|
||||
return 0.0
|
||||
|
||||
total_users = len(user_access_analysis)
|
||||
|
||||
# Factors that contribute to security score
|
||||
users_with_high_success_rate = len([u for u in user_access_analysis if u["access_stats"]["success_rate"] > 0.9])
|
||||
users_with_normal_patterns = len([u for u in user_access_analysis if u["access_stats"]["access_pattern"] == "regular_business_hours"])
|
||||
|
||||
success_rate_score = users_with_high_success_rate / total_users
|
||||
pattern_score = users_with_normal_patterns / total_users
|
||||
|
||||
# Combined score
|
||||
overall_score = (success_rate_score * 0.6 + pattern_score * 0.4)
|
||||
return round(overall_score, 3)
|
||||
|
||||
def _assess_overall_risk_level(self, user_access_analysis: List[Dict]) -> str:
|
||||
"""Assess overall security risk level"""
|
||||
security_score = self._calculate_security_score(user_access_analysis)
|
||||
|
||||
if security_score > 0.8:
|
||||
return "low"
|
||||
elif security_score > 0.6:
|
||||
return "medium"
|
||||
else:
|
||||
return "high"
|
||||
|
||||
def _generate_user_access_summary(self, user_access_analysis: List[Dict]) -> Dict[str, Any]:
|
||||
"""Generate summary statistics for user access"""
|
||||
if not user_access_analysis:
|
||||
return {
|
||||
"total_users": 0,
|
||||
"active_users": 0,
|
||||
"high_activity_users": 0,
|
||||
"dormant_users": 0
|
||||
}
|
||||
|
||||
total_users = len(user_access_analysis)
|
||||
active_users = len([u for u in user_access_analysis if u["access_stats"]["total_queries"] > 10])
|
||||
high_activity_users = len([u for u in user_access_analysis if u["access_stats"]["total_queries"] > 100])
|
||||
dormant_users = total_users - active_users
|
||||
|
||||
return {
|
||||
"total_users": total_users,
|
||||
"active_users": active_users,
|
||||
"high_activity_users": high_activity_users,
|
||||
"dormant_users": dormant_users,
|
||||
"activity_distribution": {
|
||||
"high": high_activity_users,
|
||||
"medium": active_users - high_activity_users,
|
||||
"low": dormant_users
|
||||
}
|
||||
}
|
||||
|
||||
def _generate_security_recommendations(self, security_alerts: List[Dict], access_insights: Dict[str, Any]) -> List[Dict]:
|
||||
"""Generate security recommendations based on analysis"""
|
||||
recommendations = []
|
||||
|
||||
# Recommendations based on alerts
|
||||
if security_alerts:
|
||||
high_severity_alerts = [alert for alert in security_alerts if alert["severity"] == "high"]
|
||||
if high_severity_alerts:
|
||||
recommendations.append({
|
||||
"type": "urgent_security_review",
|
||||
"priority": "high",
|
||||
"description": f"Found {len(high_severity_alerts)} high-severity security alerts",
|
||||
"action": "Immediate review of flagged users and access patterns required",
|
||||
"affected_users": list(set(alert["user"] for alert in high_severity_alerts if "user" in alert))
|
||||
})
|
||||
|
||||
# Night access recommendations
|
||||
night_access_alerts = [alert for alert in security_alerts if alert["alert_type"] == "unusual_access_time"]
|
||||
if night_access_alerts:
|
||||
recommendations.append({
|
||||
"type": "access_time_policy",
|
||||
"priority": "medium",
|
||||
"description": f"{len(night_access_alerts)} users have significant night-time access",
|
||||
"action": "Review access time policies and consider time-based restrictions",
|
||||
"affected_users": [alert["user"] for alert in night_access_alerts]
|
||||
})
|
||||
|
||||
# Recommendations based on insights
|
||||
security_posture = access_insights.get("security_posture", {})
|
||||
risk_level = security_posture.get("risk_level", "unknown")
|
||||
|
||||
if risk_level == "high":
|
||||
recommendations.append({
|
||||
"type": "overall_security_improvement",
|
||||
"priority": "high",
|
||||
"description": "Overall security posture indicates high risk",
|
||||
"action": "Comprehensive security audit and policy review recommended"
|
||||
})
|
||||
|
||||
# Role-based recommendations
|
||||
role_effectiveness = access_insights.get("role_effectiveness", {})
|
||||
if role_effectiveness and role_effectiveness.get("total_roles", 0) < 3:
|
||||
recommendations.append({
|
||||
"type": "role_management",
|
||||
"priority": "medium",
|
||||
"description": "Limited role diversity detected",
|
||||
"action": "Consider implementing more granular role-based access control"
|
||||
})
|
||||
|
||||
return recommendations
|
||||
Reference in New Issue
Block a user