0.3.0 Release Version
This commit is contained in:
145
test/security/test_sql_validation.py
Normal file
145
test/security/test_sql_validation.py
Normal file
@@ -0,0 +1,145 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
SQL security validation tests
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from doris_mcp_server.utils.security import (
|
||||
SQLSecurityValidator,
|
||||
AuthContext,
|
||||
SecurityLevel,
|
||||
ValidationResult
|
||||
)
|
||||
|
||||
|
||||
class TestSQLSecurityValidator:
|
||||
"""SQL security validator tests"""
|
||||
|
||||
@pytest.fixture
|
||||
def sql_validator(self, test_config):
|
||||
"""Create SQL validator instance"""
|
||||
return SQLSecurityValidator(test_config)
|
||||
|
||||
@pytest.fixture
|
||||
def analyst_context(self):
|
||||
"""Create analyst auth context"""
|
||||
return AuthContext(
|
||||
user_id="analyst1",
|
||||
roles=["data_analyst"],
|
||||
permissions=["read_data"],
|
||||
session_id="session_123",
|
||||
security_level=SecurityLevel.INTERNAL
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_safe_select_query(self, sql_validator, analyst_context, test_sql_queries):
|
||||
"""Test safe SELECT query validation"""
|
||||
sql = test_sql_queries["safe_select"]
|
||||
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
|
||||
assert result.is_valid is True
|
||||
assert result.error_message is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocked_drop_operation(self, sql_validator, analyst_context, test_sql_queries):
|
||||
"""Test blocked DROP operation"""
|
||||
sql = test_sql_queries["dangerous_drop"]
|
||||
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert "blocked operations" in result.error_message.lower()
|
||||
assert "DROP" in result.blocked_operations
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sql_injection_detection(self, sql_validator, analyst_context, test_sql_queries):
|
||||
"""Test SQL injection detection"""
|
||||
sql = test_sql_queries["sql_injection"]
|
||||
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert "injection" in result.error_message.lower()
|
||||
assert result.risk_level == "high"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_union_injection_detection(self, sql_validator, analyst_context, test_sql_queries):
|
||||
"""Test UNION injection detection"""
|
||||
sql = test_sql_queries["union_injection"]
|
||||
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert "injection" in result.error_message.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_comment_injection_detection(self, sql_validator, analyst_context, test_sql_queries):
|
||||
"""Test comment injection detection"""
|
||||
sql = test_sql_queries["comment_injection"]
|
||||
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert "comment" in result.error_message.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complex_query_validation(self, sql_validator, analyst_context, test_sql_queries):
|
||||
"""Test complex query validation"""
|
||||
sql = test_sql_queries["complex_query"]
|
||||
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
|
||||
# Complex query should pass if within limits
|
||||
assert result.is_valid is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocked_keywords_detection(self, sql_validator, analyst_context):
|
||||
"""Test blocked keywords detection"""
|
||||
blocked_sqls = [
|
||||
"DELETE FROM users WHERE id = 1",
|
||||
"TRUNCATE TABLE logs",
|
||||
"ALTER TABLE users ADD COLUMN new_col VARCHAR(50)",
|
||||
"CREATE TABLE test (id INT)",
|
||||
"INSERT INTO users VALUES (1, 'test')",
|
||||
"UPDATE users SET name = 'test' WHERE id = 1"
|
||||
]
|
||||
|
||||
for sql in blocked_sqls:
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
assert result.is_valid is False
|
||||
assert result.blocked_operations is not None
|
||||
assert len(result.blocked_operations) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_access_validation(self, sql_validator, analyst_context):
|
||||
"""Test table access validation"""
|
||||
# Test access to sensitive table
|
||||
sql = "SELECT * FROM sensitive_data"
|
||||
|
||||
result = await sql_validator.validate(sql, analyst_context)
|
||||
|
||||
# Should fail for non-admin users
|
||||
assert result.is_valid is False
|
||||
assert "access" in result.error_message.lower()
|
||||
|
||||
def test_extract_table_names(self, sql_validator):
|
||||
"""Test table name extraction"""
|
||||
sql = "SELECT u.name FROM users u JOIN departments d ON u.dept_id = d.id"
|
||||
|
||||
parsed = __import__('sqlparse').parse(sql)[0]
|
||||
tables = sql_validator._extract_table_names(parsed)
|
||||
|
||||
# Should extract at least one table name
|
||||
assert len(tables) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_malformed_sql_handling(self, sql_validator, analyst_context):
|
||||
"""Test malformed SQL handling"""
|
||||
malformed_sql = "SELECT * FROM users WHERE"
|
||||
|
||||
result = await sql_validator.validate(malformed_sql, analyst_context)
|
||||
|
||||
# Should handle gracefully
|
||||
assert isinstance(result, ValidationResult)
|
||||
Reference in New Issue
Block a user