Files
doris-mcp-server/doris_mcp_server/utils/security.py

888 lines
33 KiB
Python
Raw Normal View History

2025-06-08 18:44:40 +08:00
#!/usr/bin/env python3
2025-06-08 19:22:13 +08:00
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
2025-06-08 18:44:40 +08:00
"""
Doris Security Management Module
Implements enterprise-level authentication, authorization, SQL security validation and data masking functionality
"""
import logging
import re
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Any
import sqlparse
from sqlparse.sql import Statement
from sqlparse.tokens import Keyword, Name
class SecurityLevel(Enum):
"""Security level enumeration"""
PUBLIC = "public"
INTERNAL = "internal"
CONFIDENTIAL = "confidential"
SECRET = "secret"
@dataclass
class AuthContext:
"""Authentication context"""
user_id: str
roles: list[str]
permissions: list[str]
session_id: str
login_time: datetime | None = None
last_activity: datetime | None = None
security_level: SecurityLevel = SecurityLevel.INTERNAL
@dataclass
class ValidationResult:
"""Validation result"""
is_valid: bool
error_message: str | None = None
risk_level: str = "low"
blocked_operations: list[str] = None
def __post_init__(self):
if self.blocked_operations is None:
self.blocked_operations = []
@dataclass
class MaskingRule:
"""Data masking rule"""
column_pattern: str
algorithm: str
parameters: dict[str, Any]
security_level: SecurityLevel
class DorisSecurityManager:
"""Doris security manager
Provides complete security control functionality, including authentication, authorization, SQL security validation and data masking
"""
def __init__(self, config):
self.config = config
self.logger = logging.getLogger(__name__)
# Initialize security components
self.auth_provider = AuthenticationProvider(config)
self.authz_provider = AuthorizationProvider(config)
self.sql_validator = SQLSecurityValidator(config)
self.masking_processor = DataMaskingProcessor(config)
# Security rule configuration
self.blocked_keywords = self._load_blocked_keywords()
self.sensitive_tables = self._load_sensitive_tables()
self.masking_rules = self._load_masking_rules()
def _load_blocked_keywords(self) -> set[str]:
2025-06-26 18:55:30 +08:00
"""Load blocked SQL keywords from configuration"""
# Load keywords from configuration, unified source of truth
2025-06-08 18:44:40 +08:00
if hasattr(self.config, 'get'):
2025-06-26 18:55:30 +08:00
# Dictionary-style configuration
blocked_keywords = self.config.get("blocked_keywords", [])
elif hasattr(self.config, 'security') and hasattr(self.config.security, 'blocked_keywords'):
# DorisConfig object, get through security.blocked_keywords
blocked_keywords = self.config.security.blocked_keywords
2025-06-08 18:44:40 +08:00
else:
2025-06-26 18:55:30 +08:00
# Fallback to default if no configuration available
blocked_keywords = [
"DROP", "CREATE", "ALTER", "TRUNCATE",
"DELETE", "INSERT", "UPDATE",
"GRANT", "REVOKE",
"EXEC", "EXECUTE", "SHUTDOWN", "KILL"
]
2025-06-08 18:44:40 +08:00
2025-06-26 18:55:30 +08:00
return set(blocked_keywords)
2025-06-08 18:44:40 +08:00
def _load_sensitive_tables(self) -> dict[str, SecurityLevel]:
"""Load sensitive table configuration"""
default_tables = {
"user_info": SecurityLevel.CONFIDENTIAL,
"payment_records": SecurityLevel.SECRET,
"employee_data": SecurityLevel.CONFIDENTIAL,
"public_reports": SecurityLevel.PUBLIC,
}
if hasattr(self.config, 'get'):
config_tables = self.config.get("sensitive_tables", {})
# Convert string values to SecurityLevel enum
for table_name, level in config_tables.items():
if isinstance(level, str):
try:
default_tables[table_name] = SecurityLevel(level.lower())
except ValueError:
default_tables[table_name] = SecurityLevel.INTERNAL
else:
default_tables[table_name] = level
return default_tables
else:
return default_tables
def _load_masking_rules(self) -> list[MaskingRule]:
"""Load data masking rules"""
default_rules = [
MaskingRule(
column_pattern=r".*phone.*|.*mobile.*",
algorithm="phone_mask",
parameters={"mask_char": "*", "keep_prefix": 3, "keep_suffix": 4},
security_level=SecurityLevel.INTERNAL,
),
MaskingRule(
column_pattern=r".*email.*",
algorithm="email_mask",
parameters={"mask_char": "*"},
security_level=SecurityLevel.INTERNAL,
),
MaskingRule(
column_pattern=r".*id_card.*|.*identity.*",
algorithm="id_mask",
parameters={"mask_char": "*", "keep_prefix": 6, "keep_suffix": 4},
security_level=SecurityLevel.CONFIDENTIAL,
),
]
# Load custom rules from configuration
custom_rules = []
if hasattr(self.config, 'get'):
custom_rules = self.config.get("masking_rules", [])
elif hasattr(self.config, 'security') and hasattr(self.config.security, 'masking_rules'):
custom_rules = self.config.security.masking_rules
for rule_config in custom_rules:
if isinstance(rule_config, dict):
default_rules.append(MaskingRule(**rule_config))
elif isinstance(rule_config, MaskingRule):
default_rules.append(rule_config)
return default_rules
async def authenticate_request(self, auth_info: dict[str, Any]) -> AuthContext:
"""Validate request authentication information"""
return await self.auth_provider.authenticate(auth_info)
async def authorize_resource_access(
self, auth_context: AuthContext, resource_uri: str
) -> bool:
"""Validate resource access permissions"""
return await self.authz_provider.check_permission(
auth_context, resource_uri, "read"
)
async def validate_sql_security(
self, sql: str, auth_context: AuthContext
) -> ValidationResult:
"""Validate SQL query security"""
return await self.sql_validator.validate(sql, auth_context)
async def apply_data_masking(
self, data: list[dict[str, Any]], auth_context: AuthContext
) -> list[dict[str, Any]]:
"""Apply data masking processing"""
return await self.masking_processor.process(data, auth_context)
class AuthenticationProvider:
"""Authentication provider"""
def __init__(self, config):
self.config = config
self.logger = logging.getLogger(__name__)
self.session_cache = {}
async def authenticate(self, auth_info: dict[str, Any]) -> AuthContext:
"""Perform identity authentication"""
auth_type = auth_info.get("type", "token")
if auth_type == "token":
return await self._authenticate_token(auth_info)
elif auth_type == "basic":
return await self._authenticate_basic(auth_info)
else:
raise ValueError(f"Unsupported authentication type: {auth_type}")
async def _authenticate_token(self, auth_info: dict[str, Any]) -> AuthContext:
"""Token authentication"""
token = auth_info.get("token")
if not token:
raise ValueError("Missing authentication token")
# Validate token (simplified implementation, should validate JWT or query authentication service in practice)
user_info = await self._validate_token(token)
return AuthContext(
user_id=user_info["user_id"],
roles=user_info["roles"],
permissions=user_info["permissions"],
session_id=auth_info.get("session_id", "default"),
login_time=datetime.utcnow(),
security_level=SecurityLevel(user_info.get("security_level", "internal")),
)
async def _authenticate_basic(self, auth_info: dict[str, Any]) -> AuthContext:
"""Basic authentication (username password)"""
username = auth_info.get("username")
password = auth_info.get("password")
if not username or not password:
raise ValueError("Missing username or password")
# Validate username password (simplified implementation)
user_info = await self._validate_credentials(username, password)
return AuthContext(
user_id=user_info["user_id"],
roles=user_info["roles"],
permissions=user_info["permissions"],
session_id=auth_info.get("session_id", "default"),
login_time=datetime.utcnow(),
security_level=SecurityLevel(user_info.get("security_level", "internal")),
)
async def _validate_token(self, token: str) -> dict[str, Any]:
"""Validate token validity"""
# Simplified implementation for testing, should parse JWT or query authentication service in practice
valid_tokens = {
"valid_token_123": {
"user_id": "test_user",
"roles": ["data_analyst"],
"permissions": ["read_data"],
"security_level": SecurityLevel.INTERNAL,
},
"admin_token_456": {
"user_id": "admin_user",
"roles": ["data_admin"],
"permissions": ["admin"],
"security_level": SecurityLevel.SECRET,
}
}
if token in valid_tokens:
return valid_tokens[token]
else:
raise ValueError("Invalid token")
async def _validate_credentials(
self, username: str, password: str
) -> dict[str, Any]:
"""Validate user credentials"""
# Simplified implementation for testing, should query user database in practice
valid_users = {
"admin": {
"password": "admin123",
"user_id": "admin_user",
"roles": ["data_admin"],
"permissions": ["admin", "read_data", "write_data"],
"security_level": SecurityLevel.SECRET,
},
"analyst": {
"password": "analyst123",
"user_id": "analyst_user",
"roles": ["data_analyst"],
"permissions": ["read_data"],
"security_level": SecurityLevel.INTERNAL,
}
}
if username in valid_users and valid_users[username]["password"] == password:
user_info = valid_users[username].copy()
del user_info["password"] # Remove password from returned info
return user_info
else:
raise ValueError("Incorrect username or password")
class AuthorizationProvider:
"""Authorization provider"""
def __init__(self, config):
self.config = config
self.logger = logging.getLogger(__name__)
self.permission_cache = {}
# Load sensitive tables configuration
self.sensitive_tables = self._load_sensitive_tables()
def _load_sensitive_tables(self) -> dict[str, SecurityLevel]:
"""Load sensitive table configuration"""
default_tables = {
"user_info": SecurityLevel.CONFIDENTIAL,
"payment_records": SecurityLevel.SECRET,
"employee_data": SecurityLevel.CONFIDENTIAL,
"public_reports": SecurityLevel.PUBLIC,
}
if hasattr(self.config, 'get'):
config_tables = self.config.get("sensitive_tables", {})
# Convert string values to SecurityLevel enum
for table_name, level in config_tables.items():
if isinstance(level, str):
try:
default_tables[table_name] = SecurityLevel(level.lower())
except ValueError:
default_tables[table_name] = SecurityLevel.INTERNAL
else:
default_tables[table_name] = level
return default_tables
else:
return default_tables
async def check_permission(
self, auth_context: AuthContext, resource_uri: str, action: str
) -> bool:
"""Check permissions"""
# Parse resource information
resource_info = self._parse_resource_uri(resource_uri)
# First check security level - this is mandatory
if not await self._check_security_level_permission(auth_context, resource_info):
return False
# Then check role-based permissions
if await self._check_role_permission(auth_context, resource_info, action):
return True
# Finally check user-based permissions
if await self._check_user_permission(auth_context, resource_info, action):
return True
return False
def _parse_resource_uri(self, uri: str) -> dict[str, str]:
"""Parse resource URI"""
parts = uri.split("/")
if len(parts) >= 3:
return {
"type": parts[2], # table, view, etc.
"name": parts[3] if len(parts) > 3 else "",
"schema": parts[4] if len(parts) > 4 else "default",
}
return {"type": "unknown", "name": "", "schema": "default"}
async def _check_role_permission(
self, auth_context: AuthContext, resource_info: dict[str, str], action: str
) -> bool:
"""Check role-based permissions"""
# Role permission mapping
role_permissions = {
"data_analyst": {"table": ["read"], "view": ["read"]},
"data_admin": {
"table": ["read", "write", "admin"],
"view": ["read", "write", "admin"],
},
}
for role in auth_context.roles:
role_perms = role_permissions.get(role, {})
resource_perms = role_perms.get(resource_info["type"], [])
if action in resource_perms:
return True
return False
async def _check_user_permission(
self, auth_context: AuthContext, resource_info: dict[str, str], action: str
) -> bool:
"""Check user-based permissions"""
# User-specific permission check
if "admin" in auth_context.permissions:
return True
if action == "read" and "read_data" in auth_context.permissions:
return True
return False
async def _check_security_level_permission(
self, auth_context: AuthContext, resource_info: dict[str, str]
) -> bool:
"""Check security level permissions"""
# Get resource security level
resource_security_level = self._get_resource_security_level(resource_info)
# Check if user security level is sufficient
security_hierarchy = {
SecurityLevel.PUBLIC: 0,
SecurityLevel.INTERNAL: 1,
SecurityLevel.CONFIDENTIAL: 2,
SecurityLevel.SECRET: 3,
}
user_level = security_hierarchy.get(auth_context.security_level, 0)
resource_level = security_hierarchy.get(resource_security_level, 0)
# User must have higher or equal security level to access resource
return user_level >= resource_level
def _get_resource_security_level(
self, resource_info: dict[str, str]
) -> SecurityLevel:
"""Get resource security level"""
# Get table security level from configuration
table_name = resource_info.get("name", "")
# Use the loaded sensitive tables
sensitive_tables = self.sensitive_tables
# Convert string values to SecurityLevel enum if needed
security_level = sensitive_tables.get(table_name, SecurityLevel.INTERNAL)
if isinstance(security_level, str):
try:
security_level = SecurityLevel(security_level.lower())
except ValueError:
security_level = SecurityLevel.INTERNAL
return security_level
class SQLSecurityValidator:
"""SQL security validator"""
def __init__(self, config):
self.config = config
self.logger = logging.getLogger(__name__)
# Handle DorisConfig object or dictionary configuration
if hasattr(config, 'get'):
# Dictionary configuration
self.blocked_keywords = set(config.get("blocked_keywords", []))
self.max_query_complexity = config.get("max_query_complexity", 100)
2025-06-26 18:55:30 +08:00
self.enable_security_check = config.get("enable_security_check", True)
elif hasattr(config, 'security'):
# DorisConfig object with security attribute - unified source from config
self.blocked_keywords = set(config.security.blocked_keywords)
self.max_query_complexity = config.security.max_query_complexity
self.enable_security_check = getattr(config.security, 'enable_security_check', True)
2025-06-08 18:44:40 +08:00
else:
2025-06-26 18:55:30 +08:00
# Fallback to default if no configuration available
self.blocked_keywords = set([
"DROP", "CREATE", "ALTER", "TRUNCATE",
"DELETE", "INSERT", "UPDATE",
"GRANT", "REVOKE",
"EXEC", "EXECUTE", "SHUTDOWN", "KILL"
])
2025-06-08 18:44:40 +08:00
self.max_query_complexity = 100
2025-06-26 18:55:30 +08:00
self.enable_security_check = True
2025-06-08 18:44:40 +08:00
async def validate(self, sql: str, auth_context: AuthContext) -> ValidationResult:
"""Validate SQL query security"""
2025-06-26 18:55:30 +08:00
# If security check is disabled, always return valid
if not self.enable_security_check:
self.logger.debug("SQL security check is disabled, allowing all queries")
return ValidationResult(is_valid=True)
2025-06-08 18:44:40 +08:00
try:
# Parse SQL statement
parsed = sqlparse.parse(sql)[0]
# Check blocked operations first (more specific)
keyword_result = await self._check_blocked_keywords(parsed)
if not keyword_result.is_valid:
return keyword_result
# Check SQL injection risks
injection_result = await self._check_sql_injection(sql, parsed)
if not injection_result.is_valid:
return injection_result
# Check query complexity
complexity_result = await self._check_query_complexity(parsed)
if not complexity_result.is_valid:
return complexity_result
# Check table access permissions
table_result = await self._check_table_access(parsed, auth_context)
if not table_result.is_valid:
return table_result
return ValidationResult(is_valid=True)
except Exception as e:
self.logger.error(f"SQL security validation failed: {e}")
return ValidationResult(
is_valid=False,
error_message=f"SQL parsing error: {str(e)}",
risk_level="high",
)
async def _check_sql_injection(
self, sql: str, parsed: Statement
) -> ValidationResult:
"""Check SQL injection risks"""
# Check common SQL injection patterns
injection_patterns = [
r"(\s|^)(union|select|insert|update|delete|drop|create|alter)\s+.*\s+(union|select|insert|update|delete|drop|create|alter)",
r"(\s|^)(or|and)\s+\d+\s*=\s*\d+",
r"(\s|^)(or|and)\s+['\"].*['\"]",
r";\s*(drop|delete|truncate|alter|create)",
r"(exec|execute|sp_|xp_)",
r"(script|javascript|vbscript)",
r"(char|ascii|substring|concat)\s*\(",
]
sql_lower = sql.lower()
for pattern in injection_patterns:
if re.search(pattern, sql_lower, re.IGNORECASE):
return ValidationResult(
is_valid=False,
error_message="Potential SQL injection risk detected",
risk_level="high",
)
# Check suspicious quotes and comments
if self._has_suspicious_quotes_or_comments(sql):
return ValidationResult(
is_valid=False,
error_message="Suspicious quote or comment pattern detected",
risk_level="medium",
)
return ValidationResult(is_valid=True)
def _has_suspicious_quotes_or_comments(self, sql: str) -> bool:
"""Check suspicious quote and comment patterns"""
# Check unmatched quotes
single_quotes = sql.count("'")
double_quotes = sql.count('"')
if single_quotes % 2 != 0 or double_quotes % 2 != 0:
return True
# Check SQL comments
if "--" in sql or "/*" in sql:
return True
return False
async def _check_blocked_keywords(self, parsed: Statement) -> ValidationResult:
"""Check blocked keywords"""
blocked_operations = []
# Check all tokens in the parsed statement
for token in parsed.flatten():
# Check if token is a keyword (including DML/DDL) or name that matches blocked operations
if (token.ttype is Keyword or
token.ttype is Name or
(token.ttype and str(token.ttype).startswith('Token.Keyword'))):
token_value = token.value.upper().strip()
if token_value in self.blocked_keywords:
blocked_operations.append(token_value)
# Also check for DDL/DML keywords in token values
elif hasattr(token, 'value') and token.value:
token_value = token.value.upper().strip()
for blocked_keyword in self.blocked_keywords:
if blocked_keyword in token_value:
blocked_operations.append(blocked_keyword)
if blocked_operations:
return ValidationResult(
is_valid=False,
error_message=f"Contains blocked operations: {', '.join(set(blocked_operations))}",
risk_level="high",
blocked_operations=list(set(blocked_operations)),
)
return ValidationResult(is_valid=True)
async def _check_query_complexity(self, parsed: Statement) -> ValidationResult:
"""Check query complexity"""
complexity_score = 0
# Calculate complexity score
for token in parsed.flatten():
if token.ttype is Keyword:
keyword = token.value.upper()
if keyword in ["JOIN", "INNER", "LEFT", "RIGHT", "FULL"]:
complexity_score += 10
elif keyword in ["UNION", "INTERSECT", "EXCEPT"]:
complexity_score += 15
elif keyword in ["GROUP BY", "ORDER BY", "HAVING"]:
complexity_score += 5
elif keyword in ["SUBQUERY", "EXISTS", "IN"]:
complexity_score += 8
if complexity_score > self.max_query_complexity:
return ValidationResult(
is_valid=False,
error_message=f"Query complexity too high (score: {complexity_score}, limit: {self.max_query_complexity})",
risk_level="medium",
)
return ValidationResult(is_valid=True)
async def _check_table_access(
self, parsed: Statement, auth_context: AuthContext
) -> ValidationResult:
"""Check table access permissions"""
# Extract table names from query
tables = self._extract_table_names(parsed)
# Check access permissions for each table
unauthorized_tables = []
for table in tables:
# Should call authorization provider to check permissions
# Simplified implementation, assume some tables require special permissions
if (
table.lower() in ["sensitive_data", "admin_logs"]
and "admin" not in auth_context.roles
):
unauthorized_tables.append(table)
if unauthorized_tables:
return ValidationResult(
is_valid=False,
error_message=f"No access to tables: {', '.join(unauthorized_tables)}",
risk_level="high",
)
return ValidationResult(is_valid=True)
def _extract_table_names(self, parsed: Statement) -> list[str]:
"""Extract table names from SQL statement"""
tables = []
# Simplified table name extraction logic
tokens = list(parsed.flatten())
for i, token in enumerate(tokens):
if token.ttype is Keyword and token.value.upper() == "FROM":
# Find table name after FROM
for j in range(i + 1, len(tokens)):
next_token = tokens[j]
if next_token.ttype is Name:
tables.append(next_token.value)
break
elif next_token.ttype is Keyword:
break
return tables
class DataMaskingProcessor:
"""Data masking processor"""
def __init__(self, config):
self.config = config
self.logger = logging.getLogger(__name__)
self.masking_algorithms = self._init_masking_algorithms()
self.masking_rules = self._load_masking_rules()
def _load_masking_rules(self) -> list[MaskingRule]:
"""Load data masking rules"""
default_rules = [
MaskingRule(
column_pattern=r".*phone.*|.*mobile.*",
algorithm="phone_mask",
parameters={"mask_char": "*", "keep_prefix": 3, "keep_suffix": 4},
security_level=SecurityLevel.INTERNAL,
),
MaskingRule(
column_pattern=r".*email.*",
algorithm="email_mask",
parameters={"mask_char": "*"},
security_level=SecurityLevel.INTERNAL,
),
MaskingRule(
column_pattern=r".*id_card.*|.*identity.*",
algorithm="id_mask",
parameters={"mask_char": "*", "keep_prefix": 6, "keep_suffix": 4},
security_level=SecurityLevel.CONFIDENTIAL,
),
]
# Load custom rules from configuration
if hasattr(self.config, 'get'):
custom_rules = self.config.get("masking_rules", [])
for rule_config in custom_rules:
if isinstance(rule_config, dict):
# Convert string security level to enum
if 'security_level' in rule_config and isinstance(rule_config['security_level'], str):
try:
rule_config['security_level'] = SecurityLevel(rule_config['security_level'].lower())
except ValueError:
rule_config['security_level'] = SecurityLevel.INTERNAL
default_rules.append(MaskingRule(**rule_config))
elif isinstance(rule_config, MaskingRule):
default_rules.append(rule_config)
return default_rules
def _init_masking_algorithms(self) -> dict[str, callable]:
"""Initialize masking algorithms"""
return {
"phone_mask": self._mask_phone,
"email_mask": self._mask_email,
"id_mask": self._mask_id_card,
"name_mask": self._mask_name,
"partial_mask": self._mask_partial,
}
async def process(
self, data: list[dict[str, Any]], auth_context: AuthContext
) -> list[dict[str, Any]]:
"""Process data masking"""
if not data:
return data
# Get applicable masking rules
applicable_rules = self._get_applicable_rules(auth_context)
masked_data = []
for row in data:
masked_row = {}
for column, value in row.items():
masked_value = await self._apply_masking_rules(
column, value, applicable_rules
)
masked_row[column] = masked_value
masked_data.append(masked_row)
return masked_data
def _get_applicable_rules(self, auth_context: AuthContext) -> list[MaskingRule]:
"""Get applicable masking rules"""
applicable_rules = []
for rule in self.masking_rules:
# Decide whether to apply masking rules based on user security level
if self._should_apply_rule(rule, auth_context):
applicable_rules.append(rule)
return applicable_rules
def _should_apply_rule(self, rule: MaskingRule, auth_context: AuthContext) -> bool:
"""Determine whether masking rule should be applied"""
# Admin users can see original data
if "admin" in auth_context.roles:
return False
# Decide based on security level
security_hierarchy = {
SecurityLevel.PUBLIC: 0,
SecurityLevel.INTERNAL: 1,
SecurityLevel.CONFIDENTIAL: 2,
SecurityLevel.SECRET: 3,
}
user_level = security_hierarchy.get(auth_context.security_level, 0)
rule_level = security_hierarchy.get(rule.security_level, 0)
# Apply masking if user level is less than or equal to rule level
return user_level <= rule_level
async def _apply_masking_rules(
self, column: str, value: Any, rules: list[MaskingRule]
) -> Any:
"""Apply masking rules"""
if value is None:
return value
for rule in rules:
if re.match(rule.column_pattern, column, re.IGNORECASE):
algorithm = self.masking_algorithms.get(rule.algorithm)
if algorithm:
return algorithm(str(value), rule.parameters)
return value
def _mask_phone(self, value: str, params: dict[str, Any]) -> str:
"""Phone number masking"""
if len(value) < 7:
return value
mask_char = params.get("mask_char", "*")
keep_prefix = params.get("keep_prefix", 3)
keep_suffix = params.get("keep_suffix", 4)
if len(value) <= keep_prefix + keep_suffix:
return mask_char * len(value)
prefix = value[:keep_prefix]
suffix = value[-keep_suffix:]
middle_length = len(value) - keep_prefix - keep_suffix
return prefix + mask_char * middle_length + suffix
def _mask_email(self, value: str, params: dict[str, Any]) -> str:
"""Email masking"""
if "@" not in value:
return value
mask_char = params.get("mask_char", "*")
local, domain = value.split("@", 1)
if len(local) <= 2:
masked_local = mask_char * len(local)
else:
masked_local = local[0] + mask_char * (len(local) - 2) + local[-1]
return f"{masked_local}@{domain}"
def _mask_id_card(self, value: str, params: dict[str, Any]) -> str:
"""ID card number masking"""
if len(value) < 10:
return value
mask_char = params.get("mask_char", "*")
keep_prefix = params.get("keep_prefix", 6)
keep_suffix = params.get("keep_suffix", 4)
if len(value) <= keep_prefix + keep_suffix:
return mask_char * len(value)
prefix = value[:keep_prefix]
suffix = value[-keep_suffix:]
middle_length = len(value) - keep_prefix - keep_suffix
return prefix + mask_char * middle_length + suffix
def _mask_name(self, value: str, params: dict[str, Any]) -> str:
"""Name masking"""
if len(value) <= 1:
return value
mask_char = params.get("mask_char", "*")
if len(value) == 2:
return value[0] + mask_char
else:
return value[0] + mask_char * (len(value) - 2) + value[-1]
def _mask_partial(self, value: str, params: dict[str, Any]) -> str:
"""Partial masking"""
mask_char = params.get("mask_char", "*")
mask_ratio = params.get("mask_ratio", 0.5)
mask_length = int(len(value) * mask_ratio)
start_pos = (len(value) - mask_length) // 2
result = list(value)
for i in range(start_pos, start_pos + mask_length):
if i < len(result):
result[i] = mask_char
return "".join(result)