[Test]Update tests (#29)

This commit is contained in:
ivin
2025-08-07 23:27:36 +08:00
committed by GitHub
parent ecb5db8137
commit affa4a0319
9 changed files with 52 additions and 59 deletions

6
.gitignore vendored
View File

@@ -19,7 +19,5 @@ ENV/
env.bak/ env.bak/
venv.bak/ venv.bak/
.idea/ .idea/
.coverage
coverage.xml

View File

@@ -59,7 +59,6 @@ def test_config():
config.database.password = "test_password" config.database.password = "test_password"
config.database.database = "test_db" config.database.database = "test_db"
config.database.health_check_interval = 60 config.database.health_check_interval = 60
config.database.min_connections = 5
config.database.max_connections = 20 config.database.max_connections = 20
config.database.connection_timeout = 30 config.database.connection_timeout = 30
config.database.max_connection_age = 3600 config.database.max_connection_age = 3600

View File

@@ -34,7 +34,7 @@ class TestEndToEndIntegration:
@pytest.fixture @pytest.fixture
def mock_config(self): def mock_config(self):
"""Create mock configuration""" """Create mock configuration"""
from doris_mcp_server.utils.config import DatabaseConfig, SecurityConfig from doris_mcp_server.utils.config import ADBCConfig, DatabaseConfig, SecurityConfig
config = Mock(spec=DorisConfig) config = Mock(spec=DorisConfig)
@@ -46,7 +46,6 @@ class TestEndToEndIntegration:
config.database.password = "test_password" config.database.password = "test_password"
config.database.database = "test_db" config.database.database = "test_db"
config.database.health_check_interval = 60 config.database.health_check_interval = 60
config.database.min_connections = 5
config.database.max_connections = 20 config.database.max_connections = 20
config.database.connection_timeout = 30 config.database.connection_timeout = 30
config.database.max_connection_age = 3600 config.database.max_connection_age = 3600
@@ -57,7 +56,12 @@ class TestEndToEndIntegration:
config.security.auth_type = "token" config.security.auth_type = "token"
config.security.token_secret = "test_secret" config.security.token_secret = "test_secret"
config.security.token_expiry = 3600 config.security.token_expiry = 3600
config.security.blocked_keywords = ["DROP"]
# Add adbc config
config.adbc = Mock(spec=ADBCConfig)
config.adbc.enabled = True
return config return config
@pytest.fixture @pytest.fixture
@@ -231,7 +235,7 @@ class TestEndToEndIntegration:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_tool_execution_with_security(self, doris_server): async def test_tool_execution_with_security(self, doris_server):
"""Test tool execution with security checks""" """Test tool execution with security checks"""
with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute: with patch.object(doris_server.tools_manager.connection_manager, 'execute_query') as mock_execute:
mock_execute.return_value = [{"Database": "test_db"}] mock_execute.return_value = [{"Database": "test_db"}]
# Test tool execution through tools manager # Test tool execution through tools manager
@@ -258,7 +262,7 @@ class TestEndToEndIntegration:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_performance_monitoring_integration(self, doris_server): async def test_performance_monitoring_integration(self, doris_server):
"""Test performance monitoring integration""" """Test performance monitoring integration"""
with patch.object(doris_server.tools_manager.query_executor, 'execute_query') as mock_execute: with patch.object(doris_server.tools_manager.connection_manager, 'execute_query') as mock_execute:
mock_execute.return_value = [ mock_execute.return_value = [
{ {
"query_count": 1500, "query_count": 1500,
@@ -285,4 +289,4 @@ class TestEndToEndIntegration:
# Verify tools are available - use list_tools instead # Verify tools are available - use list_tools instead
import asyncio import asyncio
tools = asyncio.run(doris_server.tools_manager.list_tools()) tools = asyncio.run(doris_server.tools_manager.list_tools())
assert len(tools) > 0 assert len(tools) > 0

View File

@@ -44,22 +44,31 @@
} }
}, },
"expected_tools": [ "expected_tools": [
"analyze_columns",
"analyze_data_access_patterns",
"analyze_data_flow_dependencies",
"analyze_resource_growth_curves",
"analyze_slow_queries_topn",
"analyze_table_storage",
"exec_adbc_query",
"exec_query", "exec_query",
"get_db_list", "get_adbc_connection_info",
"get_db_table_list",
"get_table_schema",
"get_table_comment",
"get_table_column_comments",
"get_table_indexes",
"get_recent_audit_logs",
"get_catalog_list", "get_catalog_list",
"get_db_list",
"get_db_table_list",
"get_memory_stats",
"get_monitoring_metrics",
"get_recent_audit_logs",
"get_sql_explain", "get_sql_explain",
"get_sql_profile", "get_sql_profile",
"get_table_basic_info",
"get_table_column_comments",
"get_table_comment",
"get_table_data_size", "get_table_data_size",
"get_monitoring_metrics_info", "get_table_indexes",
"get_monitoring_metrics_data", "get_table_schema",
"get_realtime_memory_stats", "monitor_data_freshness",
"get_historical_memory_stats" "trace_column_lineage"
], ],
"expected_resources": [ "expected_resources": [
"database", "database",
@@ -71,4 +80,4 @@
"data_analysis_helper", "data_analysis_helper",
"schema_explorer" "schema_explorer"
] ]
} }

View File

@@ -185,8 +185,9 @@ async def test_server_connectivity(transport: Optional[str] = None) -> bool:
logger.error(f"Connectivity test failed: {e}") logger.error(f"Connectivity test failed: {e}")
return False return False
result = await client.connect_and_run(test_connection) await client.connect_and_run(test_connection)
return result return True
except Exception as e: except Exception as e:
logger.error(f"Failed to test server connectivity: {e}") logger.error(f"Failed to test server connectivity: {e}")
return False return False
@@ -211,4 +212,4 @@ if __name__ == "__main__":
stdio_ok = await test_server_connectivity("stdio") stdio_ok = await test_server_connectivity("stdio")
print(f" Stdio connectivity: {'' if stdio_ok else ''}") print(f" Stdio connectivity: {'' if stdio_ok else ''}")
asyncio.run(main()) asyncio.run(main())

View File

@@ -72,8 +72,7 @@ class TestToolsClientServer:
return tools return tools
result = await client.connect_and_run(test_callback) await client.connect_and_run(test_callback)
assert len(result) > 0
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_call_tool_exec_query_via_client(self, client, test_config): async def test_call_tool_exec_query_via_client(self, client, test_config):
@@ -91,14 +90,13 @@ class TestToolsClientServer:
assert "success" in result, "Result should contain 'success' field" assert "success" in result, "Result should contain 'success' field"
if result["success"]: if result["success"]:
assert "result" in result, "Successful result should contain 'result' field" assert "data" in result, "Successful result should contain 'data' field"
else: else:
assert "error" in result, "Failed result should contain 'error' field" assert "error" in result, "Failed result should contain 'error' field"
return result return result
result = await client.connect_and_run(test_callback) await client.connect_and_run(test_callback)
# Don't assert success=True as it depends on actual server state
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_call_tool_get_db_list_via_client(self, client, test_config): async def test_call_tool_get_db_list_via_client(self, client, test_config):
@@ -115,8 +113,7 @@ class TestToolsClientServer:
return result return result
result = await client.connect_and_run(test_callback) await client.connect_and_run(test_callback)
assert "success" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_call_tool_get_table_schema_via_client(self, client, test_config): async def test_call_tool_get_table_schema_via_client(self, client, test_config):
@@ -133,10 +130,7 @@ class TestToolsClientServer:
assert "success" in result, "Result should contain 'success' field" assert "success" in result, "Result should contain 'success' field"
return result return result
result = await client.connect_and_run(test_callback) await client.connect_and_run(test_callback)
assert "success" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_tool_error_handling_via_client(self, client, test_config): async def test_tool_error_handling_via_client(self, client, test_config):
@@ -151,8 +145,7 @@ class TestToolsClientServer:
assert "success" in result, "Result should contain 'success' field" assert "success" in result, "Result should contain 'success' field"
return result return result
result = await client.connect_and_run(test_callback) await client.connect_and_run(test_callback)
assert "success" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_tool_with_auth_token_via_client(self, client, test_config): async def test_tool_with_auth_token_via_client(self, client, test_config):
@@ -171,5 +164,4 @@ class TestToolsClientServer:
assert "success" in result, "Result should contain 'success' field" assert "success" in result, "Result should contain 'success' field"
return result return result
result = await client.connect_and_run(test_callback) await client.connect_and_run(test_callback)
assert "success" in result

View File

@@ -45,7 +45,6 @@ class TestDorisToolsManager:
config.database.password = "test_password" config.database.password = "test_password"
config.database.database = "test_db" config.database.database = "test_db"
config.database.health_check_interval = 60 config.database.health_check_interval = 60
config.database.min_connections = 5
config.database.max_connections = 20 config.database.max_connections = 20
config.database.connection_timeout = 30 config.database.connection_timeout = 30
config.database.max_connection_age = 3600 config.database.max_connection_age = 3600
@@ -268,4 +267,4 @@ class TestDorisToolsManager:
# Required fields should be defined # Required fields should be defined
if 'required' in tool.inputSchema: if 'required' in tool.inputSchema:
assert isinstance(tool.inputSchema['required'], list) assert isinstance(tool.inputSchema['required'], list)

View File

@@ -44,7 +44,6 @@ class TestDorisQueryExecutor:
config.database.password = "test_password" config.database.password = "test_password"
config.database.database = "test_db" config.database.database = "test_db"
config.database.health_check_interval = 60 config.database.health_check_interval = 60
config.database.min_connections = 5
config.database.max_connections = 20 config.database.max_connections = 20
config.database.connection_timeout = 30 config.database.connection_timeout = 30
config.database.max_connection_age = 3600 config.database.max_connection_age = 3600
@@ -201,4 +200,4 @@ class TestDorisQueryExecutor:
assert "success" in result assert "success" in result
if result["success"]: if result["success"]:
assert "data" in result assert "data" in result
assert "row_count" in result assert "row_count" in result

View File

@@ -21,8 +21,6 @@ Tests the query execution functionality through actual MCP client-server communi
Assumes the server is already running and configured properly Assumes the server is already running and configured properly
""" """
import asyncio
import json
import pytest import pytest
import os import os
import sys import sys
@@ -66,14 +64,13 @@ class TestQueryExecutorClientServer:
assert "success" in result, "Result should contain 'success' field" assert "success" in result, "Result should contain 'success' field"
if result["success"]: if result["success"]:
assert "result" in result, "Successful result should contain 'result' field" assert "data" in result, "Successful result should contain 'data' field"
else: else:
assert "error" in result, "Failed result should contain 'error' field" assert "error" in result, "Failed result should contain 'error' field"
return result return result
result = await client.connect_and_run(test_callback) await client.connect_and_run(test_callback)
assert "success" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_show_databases_query_via_client(self, client, test_config): async def test_show_databases_query_via_client(self, client, test_config):
@@ -87,8 +84,7 @@ class TestQueryExecutorClientServer:
assert "success" in result, "Result should contain 'success' field" assert "success" in result, "Result should contain 'success' field"
return result return result
result = await client.connect_and_run(test_callback) await client.connect_and_run(test_callback)
assert "success" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_information_schema_query_via_client(self, client, test_config): async def test_information_schema_query_via_client(self, client, test_config):
@@ -102,8 +98,7 @@ class TestQueryExecutorClientServer:
assert "success" in result, "Result should contain 'success' field" assert "success" in result, "Result should contain 'success' field"
return result return result
result = await client.connect_and_run(test_callback) await client.connect_and_run(test_callback)
assert "success" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_with_max_rows_parameter_via_client(self, client, test_config): async def test_query_with_max_rows_parameter_via_client(self, client, test_config):
@@ -118,8 +113,7 @@ class TestQueryExecutorClientServer:
assert "success" in result, "Result should contain 'success' field" assert "success" in result, "Result should contain 'success' field"
return result return result
result = await client.connect_and_run(test_callback) await client.connect_and_run(test_callback)
assert "success" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_error_handling_via_client(self, client, test_config): async def test_query_error_handling_via_client(self, client, test_config):
@@ -131,8 +125,7 @@ class TestQueryExecutorClientServer:
assert "success" in result, "Result should contain 'success' field" assert "success" in result, "Result should contain 'success' field"
return result return result
result = await client.connect_and_run(test_callback) await client.connect_and_run(test_callback)
assert "success" in result
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_with_auth_token_via_client(self, client, test_config): async def test_query_with_auth_token_via_client(self, client, test_config):
@@ -152,5 +145,4 @@ class TestQueryExecutorClientServer:
assert "success" in result, "Result should contain 'success' field" assert "success" in result, "Result should contain 'success' field"
return result return result
result = await client.connect_and_run(test_callback) await client.connect_and_run(test_callback)
assert "success" in result