0.3.0 Release Version
This commit is contained in:
87
test/security/test_authentication.py
Normal file
87
test/security/test_authentication.py
Normal file
@@ -0,0 +1,87 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Authentication module tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from datetime import datetime
|
||||
|
||||
from doris_mcp_server.utils.security import (
|
||||
AuthenticationProvider,
|
||||
AuthContext,
|
||||
SecurityLevel
|
||||
)
|
||||
|
||||
|
||||
class TestAuthenticationProvider:
|
||||
"""Authentication provider tests"""
|
||||
|
||||
@pytest.fixture
|
||||
def auth_provider(self, test_config):
|
||||
"""Create authentication provider instance"""
|
||||
return AuthenticationProvider(test_config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_authentication_success(self, auth_provider):
|
||||
"""Test successful token authentication"""
|
||||
auth_info = {
|
||||
"type": "token",
|
||||
"token": "valid_token_123"
|
||||
}
|
||||
|
||||
result = await auth_provider.authenticate(auth_info)
|
||||
|
||||
assert isinstance(result, AuthContext)
|
||||
assert result.user_id == "test_user"
|
||||
assert "data_analyst" in result.roles
|
||||
assert result.security_level == SecurityLevel.INTERNAL
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_token_authentication_failure(self, auth_provider):
|
||||
"""Test failed token authentication"""
|
||||
auth_info = {
|
||||
"type": "token",
|
||||
"token": "invalid_token"
|
||||
}
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await auth_provider.authenticate(auth_info)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_authentication_success(self, auth_provider):
|
||||
"""Test successful basic authentication"""
|
||||
auth_info = {
|
||||
"type": "basic",
|
||||
"username": "admin",
|
||||
"password": "admin123"
|
||||
}
|
||||
|
||||
result = await auth_provider.authenticate(auth_info)
|
||||
|
||||
assert isinstance(result, AuthContext)
|
||||
assert result.user_id == "admin_user"
|
||||
assert "data_admin" in result.roles
|
||||
assert result.security_level == SecurityLevel.SECRET
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_authentication_failure(self, auth_provider):
|
||||
"""Test failed basic authentication"""
|
||||
auth_info = {
|
||||
"type": "basic",
|
||||
"username": "admin",
|
||||
"password": "wrong_password"
|
||||
}
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await auth_provider.authenticate(auth_info)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsupported_auth_type(self, auth_provider):
|
||||
"""Test unsupported authentication type"""
|
||||
auth_info = {
|
||||
"type": "oauth",
|
||||
"token": "oauth_token"
|
||||
}
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await auth_provider.authenticate(auth_info)
|
||||
131
test/security/test_authorization.py
Normal file
131
test/security/test_authorization.py
Normal file
@@ -0,0 +1,131 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Authorization module tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from doris_mcp_server.utils.security import (
|
||||
AuthorizationProvider,
|
||||
AuthContext,
|
||||
SecurityLevel
|
||||
)
|
||||
|
||||
|
||||
class TestAuthorizationProvider:
|
||||
"""Authorization provider tests"""
|
||||
|
||||
@pytest.fixture
|
||||
def authz_provider(self, test_config):
|
||||
"""Create authorization provider instance"""
|
||||
return AuthorizationProvider(test_config)
|
||||
|
||||
@pytest.fixture
|
||||
def analyst_context(self):
|
||||
"""Create analyst auth context"""
|
||||
return AuthContext(
|
||||
user_id="analyst1",
|
||||
roles=["data_analyst"],
|
||||
permissions=["read_data"],
|
||||
session_id="session_123",
|
||||
security_level=SecurityLevel.INTERNAL
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def admin_context(self):
|
||||
"""Create admin auth context"""
|
||||
return AuthContext(
|
||||
user_id="admin1",
|
||||
roles=["data_admin"],
|
||||
permissions=["admin"],
|
||||
session_id="session_456",
|
||||
security_level=SecurityLevel.SECRET
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyst_access_public_resource(self, authz_provider, analyst_context):
|
||||
"""Test analyst accessing public resource"""
|
||||
resource_uri = "/api/table/public_reports"
|
||||
|
||||
result = await authz_provider.check_permission(analyst_context, resource_uri, "read")
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_analyst_denied_confidential_resource(self, authz_provider):
|
||||
"""Test analyst denied access to confidential resource"""
|
||||
# Create analyst with lower security level
|
||||
analyst_context = AuthContext(
|
||||
user_id="analyst1",
|
||||
roles=["data_analyst"],
|
||||
permissions=["read_data"],
|
||||
session_id="session_123",
|
||||
security_level=SecurityLevel.PUBLIC # Lower than CONFIDENTIAL
|
||||
)
|
||||
|
||||
resource_uri = "/api/table/user_info"
|
||||
|
||||
result = await authz_provider.check_permission(analyst_context, resource_uri, "read")
|
||||
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_access_secret_resource(self, authz_provider, admin_context):
|
||||
"""Test admin accessing secret resource"""
|
||||
resource_uri = "/api/table/payment_records"
|
||||
|
||||
result = await authz_provider.check_permission(admin_context, resource_uri, "read")
|
||||
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_role_based_permission(self, authz_provider):
|
||||
"""Test role-based permission check"""
|
||||
# Create analyst context
|
||||
analyst_context = AuthContext(
|
||||
user_id="analyst1",
|
||||
roles=["data_analyst"],
|
||||
permissions=["read_data"],
|
||||
session_id="session_123",
|
||||
security_level=SecurityLevel.INTERNAL
|
||||
)
|
||||
|
||||
resource_uri = "/api/table/some_table"
|
||||
|
||||
# Analyst should have read permission
|
||||
result = await authz_provider.check_permission(analyst_context, resource_uri, "read")
|
||||
assert result is True
|
||||
|
||||
# Analyst should not have write permission
|
||||
result = await authz_provider.check_permission(analyst_context, resource_uri, "write")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_override(self, authz_provider, admin_context):
|
||||
"""Test admin permission override"""
|
||||
resource_uri = "/api/table/any_table"
|
||||
|
||||
# Admin should have all permissions
|
||||
result = await authz_provider.check_permission(admin_context, resource_uri, "read")
|
||||
assert result is True
|
||||
|
||||
result = await authz_provider.check_permission(admin_context, resource_uri, "write")
|
||||
assert result is True
|
||||
|
||||
def test_parse_resource_uri(self, authz_provider):
|
||||
"""Test resource URI parsing"""
|
||||
uri = "/api/table/user_info/default"
|
||||
|
||||
result = authz_provider._parse_resource_uri(uri)
|
||||
|
||||
assert result["type"] == "table"
|
||||
assert result["name"] == "user_info"
|
||||
assert result["schema"] == "default"
|
||||
|
||||
def test_get_resource_security_level(self, authz_provider):
|
||||
"""Test getting resource security level"""
|
||||
resource_info = {"name": "user_info", "type": "table"}
|
||||
|
||||
level = authz_provider._get_resource_security_level(resource_info)
|
||||
|
||||
assert level == SecurityLevel.CONFIDENTIAL
|
||||
181
test/security/test_data_masking.py
Normal file
181
test/security/test_data_masking.py
Normal file
@@ -0,0 +1,181 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Data masking tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from doris_mcp_server.utils.security import (
|
||||
DataMaskingProcessor,
|
||||
AuthContext,
|
||||
SecurityLevel,
|
||||
MaskingRule
|
||||
)
|
||||
|
||||
|
||||
class TestDataMaskingProcessor:
|
||||
"""Data masking processor tests"""
|
||||
|
||||
@pytest.fixture
|
||||
def masking_processor(self, test_config):
|
||||
"""Create data masking processor instance"""
|
||||
return DataMaskingProcessor(test_config)
|
||||
|
||||
@pytest.fixture
|
||||
def internal_user_context(self):
|
||||
"""Create internal user auth context"""
|
||||
return AuthContext(
|
||||
user_id="internal_user",
|
||||
roles=["data_analyst"],
|
||||
permissions=["read_data"],
|
||||
session_id="session_123",
|
||||
security_level=SecurityLevel.INTERNAL
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def admin_context(self):
|
||||
"""Create admin auth context"""
|
||||
return AuthContext(
|
||||
user_id="admin",
|
||||
roles=["data_admin"],
|
||||
permissions=["admin"],
|
||||
session_id="session_456",
|
||||
security_level=SecurityLevel.SECRET
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_phone_masking_for_internal_user(self, masking_processor, internal_user_context, sample_data):
|
||||
"""Test phone number masking for internal user"""
|
||||
result = await masking_processor.process(sample_data, internal_user_context)
|
||||
|
||||
# Phone numbers should be masked
|
||||
assert result[0]["phone"] == "138****5678"
|
||||
assert result[1]["phone"] == "139****4321"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_email_masking_for_internal_user(self, masking_processor, internal_user_context, sample_data):
|
||||
"""Test email masking for internal user"""
|
||||
result = await masking_processor.process(sample_data, internal_user_context)
|
||||
|
||||
# Emails should be masked
|
||||
assert result[0]["email"] == "z******n@example.com"
|
||||
assert result[1]["email"] == "l**i@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_masking_for_admin(self, masking_processor, admin_context, sample_data):
|
||||
"""Test no masking for admin user"""
|
||||
result = await masking_processor.process(sample_data, admin_context)
|
||||
|
||||
# Admin should see original data
|
||||
assert result[0]["phone"] == "13812345678"
|
||||
assert result[0]["email"] == "zhangsan@example.com"
|
||||
assert result[1]["phone"] == "13987654321"
|
||||
assert result[1]["email"] == "lisi@example.com"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_id_card_masking_for_confidential_data(self, masking_processor, internal_user_context, sample_data):
|
||||
"""Test ID card masking for confidential data"""
|
||||
# Internal user should not see ID card details (confidential level)
|
||||
result = await masking_processor.process(sample_data, internal_user_context)
|
||||
|
||||
# ID cards should be masked for internal users
|
||||
assert result[0]["id_card"] == "110101********1234"
|
||||
assert result[1]["id_card"] == "110101********2345"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_data_handling(self, masking_processor, internal_user_context):
|
||||
"""Test empty data handling"""
|
||||
empty_data = []
|
||||
|
||||
result = await masking_processor.process(empty_data, internal_user_context)
|
||||
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_null_value_handling(self, masking_processor, internal_user_context):
|
||||
"""Test null value handling"""
|
||||
data_with_nulls = [
|
||||
{
|
||||
"id": 1,
|
||||
"name": "张三",
|
||||
"phone": None,
|
||||
"email": None,
|
||||
"id_card": None
|
||||
}
|
||||
]
|
||||
|
||||
result = await masking_processor.process(data_with_nulls, internal_user_context)
|
||||
|
||||
# Null values should remain null
|
||||
assert result[0]["phone"] is None
|
||||
assert result[0]["email"] is None
|
||||
assert result[0]["id_card"] is None
|
||||
|
||||
def test_phone_masking_algorithm(self, masking_processor):
|
||||
"""Test phone masking algorithm"""
|
||||
params = {"mask_char": "*", "keep_prefix": 3, "keep_suffix": 4}
|
||||
|
||||
result = masking_processor._mask_phone("13812345678", params)
|
||||
|
||||
assert result == "138****5678"
|
||||
|
||||
def test_email_masking_algorithm(self, masking_processor):
|
||||
"""Test email masking algorithm"""
|
||||
params = {"mask_char": "*"}
|
||||
|
||||
result = masking_processor._mask_email("zhangsan@example.com", params)
|
||||
|
||||
assert result == "z******n@example.com"
|
||||
|
||||
def test_id_card_masking_algorithm(self, masking_processor):
|
||||
"""Test ID card masking algorithm"""
|
||||
params = {"mask_char": "*", "keep_prefix": 6, "keep_suffix": 4}
|
||||
|
||||
result = masking_processor._mask_id_card("110101199001011234", params)
|
||||
|
||||
assert result == "110101********1234"
|
||||
|
||||
def test_name_masking_algorithm(self, masking_processor):
|
||||
"""Test name masking algorithm"""
|
||||
params = {"mask_char": "*"}
|
||||
|
||||
# Test 2-character name
|
||||
result = masking_processor._mask_name("张三", params)
|
||||
assert result == "张*"
|
||||
|
||||
# Test 3-character name
|
||||
result = masking_processor._mask_name("李小明", params)
|
||||
assert result == "李*明"
|
||||
|
||||
def test_partial_masking_algorithm(self, masking_processor):
|
||||
"""Test partial masking algorithm"""
|
||||
params = {"mask_char": "*", "mask_ratio": 0.5}
|
||||
|
||||
result = masking_processor._mask_partial("1234567890", params)
|
||||
|
||||
# Should mask middle 50% of the string
|
||||
assert "*" in result
|
||||
assert len(result) == 10
|
||||
|
||||
def test_should_apply_rule_logic(self, masking_processor, internal_user_context, admin_context):
|
||||
"""Test masking rule application logic"""
|
||||
rule = MaskingRule(
|
||||
column_pattern=r".*phone.*",
|
||||
algorithm="phone_mask",
|
||||
parameters={"mask_char": "*", "keep_prefix": 3, "keep_suffix": 4},
|
||||
security_level=SecurityLevel.INTERNAL
|
||||
)
|
||||
|
||||
# Internal user should have rule applied
|
||||
assert masking_processor._should_apply_rule(rule, internal_user_context) is True
|
||||
|
||||
# Admin should not have rule applied
|
||||
assert masking_processor._should_apply_rule(rule, admin_context) is False
|
||||
|
||||
def test_get_applicable_rules(self, masking_processor, internal_user_context):
|
||||
"""Test getting applicable rules"""
|
||||
rules = masking_processor._get_applicable_rules(internal_user_context)
|
||||
|
||||
# Should return some rules for internal user
|
||||
assert len(rules) > 0
|
||||
assert all(isinstance(rule, MaskingRule) for rule in rules)
|
||||
156
test/security/test_security_manager.py
Normal file
156
test/security/test_security_manager.py
Normal file
@@ -0,0 +1,156 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Security manager integration tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from doris_mcp_server.utils.security import (
|
||||
DorisSecurityManager,
|
||||
AuthContext,
|
||||
SecurityLevel,
|
||||
ValidationResult
|
||||
)
|
||||
|
||||
|
||||
class TestDorisSecurityManager:
|
||||
"""Doris security manager integration tests"""
|
||||
|
||||
@pytest.fixture
|
||||
def security_manager(self, test_config):
|
||||
"""Create security manager instance"""
|
||||
return DorisSecurityManager(test_config)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_security_workflow(self, security_manager, sample_data):
|
||||
"""Test complete security workflow"""
|
||||
# 1. Authentication
|
||||
auth_info = {
|
||||
"type": "token",
|
||||
"token": "valid_token_123"
|
||||
}
|
||||
|
||||
auth_context = await security_manager.authenticate_request(auth_info)
|
||||
assert isinstance(auth_context, AuthContext)
|
||||
assert auth_context.security_level == SecurityLevel.INTERNAL
|
||||
|
||||
# 2. Authorization
|
||||
resource_uri = "/api/table/public_reports"
|
||||
has_access = await security_manager.authorize_resource_access(auth_context, resource_uri)
|
||||
assert has_access is True
|
||||
|
||||
# 3. SQL Validation
|
||||
safe_sql = "SELECT name, email FROM users WHERE department = 'sales'"
|
||||
validation_result = await security_manager.validate_sql_security(safe_sql, auth_context)
|
||||
assert validation_result.is_valid is True
|
||||
|
||||
# 4. Data Masking
|
||||
masked_data = await security_manager.apply_data_masking(sample_data, auth_context)
|
||||
assert masked_data[0]["phone"] == "138****5678" # Should be masked
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_workflow(self, security_manager, sample_data):
|
||||
"""Test admin user workflow"""
|
||||
# Admin authentication
|
||||
auth_info = {
|
||||
"type": "basic",
|
||||
"username": "admin",
|
||||
"password": "admin123"
|
||||
}
|
||||
|
||||
auth_context = await security_manager.authenticate_request(auth_info)
|
||||
assert auth_context.security_level == SecurityLevel.SECRET
|
||||
|
||||
# Admin should access secret resources
|
||||
resource_uri = "/api/table/payment_records"
|
||||
has_access = await security_manager.authorize_resource_access(auth_context, resource_uri)
|
||||
assert has_access is True
|
||||
|
||||
# Admin should see original data (no masking)
|
||||
masked_data = await security_manager.apply_data_masking(sample_data, auth_context)
|
||||
assert masked_data[0]["phone"] == "13812345678" # Original data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_security_violation_detection(self, security_manager):
|
||||
"""Test security violation detection"""
|
||||
# Authenticate as regular user
|
||||
auth_info = {
|
||||
"type": "token",
|
||||
"token": "valid_token_123"
|
||||
}
|
||||
|
||||
auth_context = await security_manager.authenticate_request(auth_info)
|
||||
|
||||
# Try to access confidential resource (user_info is CONFIDENTIAL, user is INTERNAL)
|
||||
# INTERNAL(1) should not access CONFIDENTIAL(2) resource
|
||||
resource_uri = "/api/table/user_info"
|
||||
has_access = await security_manager.authorize_resource_access(auth_context, resource_uri)
|
||||
assert has_access is False
|
||||
|
||||
# Try dangerous SQL
|
||||
dangerous_sql = "DROP TABLE users"
|
||||
validation_result = await security_manager.validate_sql_security(dangerous_sql, auth_context)
|
||||
assert validation_result.is_valid is False
|
||||
assert "DROP" in validation_result.blocked_operations
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sql_injection_prevention(self, security_manager):
|
||||
"""Test SQL injection prevention"""
|
||||
auth_info = {
|
||||
"type": "token",
|
||||
"token": "valid_token_123"
|
||||
}
|
||||
|
||||
auth_context = await security_manager.authenticate_request(auth_info)
|
||||
|
||||
# Test various injection attempts
|
||||
injection_attempts = [
|
||||
"SELECT * FROM users WHERE id = 1; DROP TABLE users;",
|
||||
"SELECT * FROM users UNION SELECT password FROM admin_users",
|
||||
"SELECT * FROM users WHERE id = 1 OR 1=1",
|
||||
"SELECT * FROM users WHERE name = 'test' -- AND password = 'secret'"
|
||||
]
|
||||
|
||||
for sql in injection_attempts:
|
||||
result = await security_manager.validate_sql_security(sql, auth_context)
|
||||
assert result.is_valid is False
|
||||
assert result.risk_level in ["medium", "high"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authentication_failure_handling(self, security_manager):
|
||||
"""Test authentication failure handling"""
|
||||
invalid_auth_info = {
|
||||
"type": "token",
|
||||
"token": "invalid_token"
|
||||
}
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await security_manager.authenticate_request(invalid_auth_info)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_configuration_loading(self, security_manager):
|
||||
"""Test security configuration loading"""
|
||||
# Test blocked keywords loading
|
||||
assert "DROP" in security_manager.blocked_keywords
|
||||
assert "DELETE" in security_manager.blocked_keywords
|
||||
|
||||
# Test sensitive tables loading
|
||||
assert SecurityLevel.CONFIDENTIAL in security_manager.sensitive_tables.values()
|
||||
assert SecurityLevel.SECRET in security_manager.sensitive_tables.values()
|
||||
|
||||
# Test masking rules loading
|
||||
assert len(security_manager.masking_rules) > 0
|
||||
phone_rules = [rule for rule in security_manager.masking_rules
|
||||
if "phone" in rule.column_pattern]
|
||||
assert len(phone_rules) > 0
|
||||
|
||||
def test_security_level_hierarchy(self, security_manager):
|
||||
"""Test security level hierarchy"""
|
||||
# Test that hierarchy is correctly defined
|
||||
levels = [SecurityLevel.PUBLIC, SecurityLevel.INTERNAL,
|
||||
SecurityLevel.CONFIDENTIAL, SecurityLevel.SECRET]
|
||||
|
||||
# Each level should be properly defined
|
||||
for level in levels:
|
||||
assert isinstance(level, SecurityLevel)
|
||||
assert level.value in ["public", "internal", "confidential", "secret"]
|
||||
145
test/security/test_sql_validation.py
Normal file
145
test/security/test_sql_validation.py
Normal file
@@ -0,0 +1,145 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
SQL security validation tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from doris_mcp_server.utils.security import (
|
||||
SQLSecurityValidator,
|
||||
AuthContext,
|
||||
SecurityLevel,
|
||||
ValidationResult
|
||||
)
|
||||
|
||||
|
||||
class TestSQLSecurityValidator:
|
||||
"""SQL security validator tests"""
|
||||
|
||||
@pytest.fixture
|
||||
def sql_validator(self, test_config):
|
||||
"""Create SQL validator instance"""
|
||||
return SQLSecurityValidator(test_config)
|
||||
|
||||
@pytest.fixture
|
||||
def analyst_context(self):
|
||||
"""Create analyst auth context"""
|
||||
return AuthContext(
|
||||
user_id="analyst1",
|
||||
roles=["data_analyst"],
|
||||
permissions=["read_data"],
|
||||
session_id="session_123",
|
||||
security_level=SecurityLevel.INTERNAL
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_safe_select_query(self, sql_validator, analyst_context, test_sql_queries):
|
||||
"""Test safe SELECT query validation"""
|
||||
sql = test_sql_queries["safe_select"]
|
||||
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
|
||||
assert result.is_valid is True
|
||||
assert result.error_message is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocked_drop_operation(self, sql_validator, analyst_context, test_sql_queries):
|
||||
"""Test blocked DROP operation"""
|
||||
sql = test_sql_queries["dangerous_drop"]
|
||||
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert "blocked operations" in result.error_message.lower()
|
||||
assert "DROP" in result.blocked_operations
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sql_injection_detection(self, sql_validator, analyst_context, test_sql_queries):
|
||||
"""Test SQL injection detection"""
|
||||
sql = test_sql_queries["sql_injection"]
|
||||
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert "injection" in result.error_message.lower()
|
||||
assert result.risk_level == "high"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_union_injection_detection(self, sql_validator, analyst_context, test_sql_queries):
|
||||
"""Test UNION injection detection"""
|
||||
sql = test_sql_queries["union_injection"]
|
||||
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert "injection" in result.error_message.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_comment_injection_detection(self, sql_validator, analyst_context, test_sql_queries):
|
||||
"""Test comment injection detection"""
|
||||
sql = test_sql_queries["comment_injection"]
|
||||
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert "comment" in result.error_message.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complex_query_validation(self, sql_validator, analyst_context, test_sql_queries):
|
||||
"""Test complex query validation"""
|
||||
sql = test_sql_queries["complex_query"]
|
||||
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
|
||||
# Complex query should pass if within limits
|
||||
assert result.is_valid is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocked_keywords_detection(self, sql_validator, analyst_context):
|
||||
"""Test blocked keywords detection"""
|
||||
blocked_sqls = [
|
||||
"DELETE FROM users WHERE id = 1",
|
||||
"TRUNCATE TABLE logs",
|
||||
"ALTER TABLE users ADD COLUMN new_col VARCHAR(50)",
|
||||
"CREATE TABLE test (id INT)",
|
||||
"INSERT INTO users VALUES (1, 'test')",
|
||||
"UPDATE users SET name = 'test' WHERE id = 1"
|
||||
]
|
||||
|
||||
for sql in blocked_sqls:
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
assert result.is_valid is False
|
||||
assert result.blocked_operations is not None
|
||||
assert len(result.blocked_operations) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_access_validation(self, sql_validator, analyst_context):
|
||||
"""Test table access validation"""
|
||||
# Test access to sensitive table
|
||||
sql = "SELECT * FROM sensitive_data"
|
||||
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
|
||||
# Should fail for non-admin users
|
||||
assert result.is_valid is False
|
||||
assert "access" in result.error_message.lower()
|
||||
|
||||
def test_extract_table_names(self, sql_validator):
|
||||
"""Test table name extraction"""
|
||||
sql = "SELECT u.name FROM users u JOIN departments d ON u.dept_id = d.id"
|
||||
|
||||
parsed = __import__('sqlparse').parse(sql)[0]
|
||||
tables = sql_validator._extract_table_names(parsed)
|
||||
|
||||
# Should extract at least one table name
|
||||
assert len(tables) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_malformed_sql_handling(self, sql_validator, analyst_context):
|
||||
"""Test malformed SQL handling"""
|
||||
malformed_sql = "SELECT * FROM users WHERE"
|
||||
|
||||
result = await sql_validator.validate(malformed_sql, analyst_context)
|
||||
|
||||
# Should handle gracefully
|
||||
assert isinstance(result, ValidationResult)
|
||||
Reference in New Issue
Block a user