145 lines
5.1 KiB
Python
145 lines
5.1 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""
|
||
|
|
SQL security validation tests
|
||
|
|
"""
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
from doris_mcp_server.utils.security import (
|
||
|
|
SQLSecurityValidator,
|
||
|
|
AuthContext,
|
||
|
|
SecurityLevel,
|
||
|
|
ValidationResult
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class TestSQLSecurityValidator:
|
||
|
|
"""SQL security validator tests"""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def sql_validator(self, test_config):
|
||
|
|
"""Create SQL validator instance"""
|
||
|
|
return SQLSecurityValidator(test_config)
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def analyst_context(self):
|
||
|
|
"""Create analyst auth context"""
|
||
|
|
return AuthContext(
|
||
|
|
user_id="analyst1",
|
||
|
|
roles=["data_analyst"],
|
||
|
|
permissions=["read_data"],
|
||
|
|
session_id="session_123",
|
||
|
|
security_level=SecurityLevel.INTERNAL
|
||
|
|
)
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_safe_select_query(self, sql_validator, analyst_context, test_sql_queries):
|
||
|
|
"""Test safe SELECT query validation"""
|
||
|
|
sql = test_sql_queries["safe_select"]
|
||
|
|
|
||
|
|
result = await sql_validator.validate(sql, analyst_context)
|
||
|
|
|
||
|
|
assert result.is_valid is True
|
||
|
|
assert result.error_message is None
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_blocked_drop_operation(self, sql_validator, analyst_context, test_sql_queries):
|
||
|
|
"""Test blocked DROP operation"""
|
||
|
|
sql = test_sql_queries["dangerous_drop"]
|
||
|
|
|
||
|
|
result = await sql_validator.validate(sql, analyst_context)
|
||
|
|
|
||
|
|
assert result.is_valid is False
|
||
|
|
assert "blocked operations" in result.error_message.lower()
|
||
|
|
assert "DROP" in result.blocked_operations
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_sql_injection_detection(self, sql_validator, analyst_context, test_sql_queries):
|
||
|
|
"""Test SQL injection detection"""
|
||
|
|
sql = test_sql_queries["sql_injection"]
|
||
|
|
|
||
|
|
result = await sql_validator.validate(sql, analyst_context)
|
||
|
|
|
||
|
|
assert result.is_valid is False
|
||
|
|
assert "injection" in result.error_message.lower()
|
||
|
|
assert result.risk_level == "high"
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_union_injection_detection(self, sql_validator, analyst_context, test_sql_queries):
|
||
|
|
"""Test UNION injection detection"""
|
||
|
|
sql = test_sql_queries["union_injection"]
|
||
|
|
|
||
|
|
result = await sql_validator.validate(sql, analyst_context)
|
||
|
|
|
||
|
|
assert result.is_valid is False
|
||
|
|
assert "injection" in result.error_message.lower()
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_comment_injection_detection(self, sql_validator, analyst_context, test_sql_queries):
|
||
|
|
"""Test comment injection detection"""
|
||
|
|
sql = test_sql_queries["comment_injection"]
|
||
|
|
|
||
|
|
result = await sql_validator.validate(sql, analyst_context)
|
||
|
|
|
||
|
|
assert result.is_valid is False
|
||
|
|
assert "comment" in result.error_message.lower()
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_complex_query_validation(self, sql_validator, analyst_context, test_sql_queries):
|
||
|
|
"""Test complex query validation"""
|
||
|
|
sql = test_sql_queries["complex_query"]
|
||
|
|
|
||
|
|
result = await sql_validator.validate(sql, analyst_context)
|
||
|
|
|
||
|
|
# Complex query should pass if within limits
|
||
|
|
assert result.is_valid is True
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_blocked_keywords_detection(self, sql_validator, analyst_context):
|
||
|
|
"""Test blocked keywords detection"""
|
||
|
|
blocked_sqls = [
|
||
|
|
"DELETE FROM users WHERE id = 1",
|
||
|
|
"TRUNCATE TABLE logs",
|
||
|
|
"ALTER TABLE users ADD COLUMN new_col VARCHAR(50)",
|
||
|
|
"CREATE TABLE test (id INT)",
|
||
|
|
"INSERT INTO users VALUES (1, 'test')",
|
||
|
|
"UPDATE users SET name = 'test' WHERE id = 1"
|
||
|
|
]
|
||
|
|
|
||
|
|
for sql in blocked_sqls:
|
||
|
|
result = await sql_validator.validate(sql, analyst_context)
|
||
|
|
assert result.is_valid is False
|
||
|
|
assert result.blocked_operations is not None
|
||
|
|
assert len(result.blocked_operations) > 0
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_table_access_validation(self, sql_validator, analyst_context):
|
||
|
|
"""Test table access validation"""
|
||
|
|
# Test access to sensitive table
|
||
|
|
sql = "SELECT * FROM sensitive_data"
|
||
|
|
|
||
|
|
result = await sql_validator.validate(sql, analyst_context)
|
||
|
|
|
||
|
|
# Should fail for non-admin users
|
||
|
|
assert result.is_valid is False
|
||
|
|
assert "access" in result.error_message.lower()
|
||
|
|
|
||
|
|
def test_extract_table_names(self, sql_validator):
|
||
|
|
"""Test table name extraction"""
|
||
|
|
sql = "SELECT u.name FROM users u JOIN departments d ON u.dept_id = d.id"
|
||
|
|
|
||
|
|
parsed = __import__('sqlparse').parse(sql)[0]
|
||
|
|
tables = sql_validator._extract_table_names(parsed)
|
||
|
|
|
||
|
|
# Should extract at least one table name
|
||
|
|
assert len(tables) > 0
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_malformed_sql_handling(self, sql_validator, analyst_context):
|
||
|
|
"""Test malformed SQL handling"""
|
||
|
|
malformed_sql = "SELECT * FROM users WHERE"
|
||
|
|
|
||
|
|
result = await sql_validator.validate(malformed_sql, analyst_context)
|
||
|
|
|
||
|
|
# Should handle gracefully
|
||
|
|
assert isinstance(result, ValidationResult)
|