feat: add batch SQL execution support for MCP (#70)
* feat: add batch SQL execution support for MCP - Add sql field to QueryResult to track executed query - Implement execute_batch_sqls_for_mcp for executing multiple SQL - Use sqlparse to split and execute multiple SQL in single request - Improve error handling in execute_batch_queries - Return multiple results format when batch queries are detected * test: add multi-SQL statements test for query executor
This commit is contained in:
@@ -59,6 +59,7 @@ class QueryResult:
|
|||||||
metadata: dict[str, Any]
|
metadata: dict[str, Any]
|
||||||
execution_time: float
|
execution_time: float
|
||||||
row_count: int
|
row_count: int
|
||||||
|
sql: str
|
||||||
|
|
||||||
|
|
||||||
class DorisConnection:
|
class DorisConnection:
|
||||||
@@ -132,6 +133,7 @@ class DorisConnection:
|
|||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
execution_time=execution_time,
|
execution_time=execution_time,
|
||||||
row_count=row_count,
|
row_count=row_count,
|
||||||
|
sql=sql
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -33,6 +33,8 @@ from datetime import datetime, timedelta, date
|
|||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
|
||||||
|
import sqlparse
|
||||||
|
|
||||||
from .db import DorisConnectionManager, QueryResult
|
from .db import DorisConnectionManager, QueryResult
|
||||||
from .logger import get_logger
|
from .logger import get_logger
|
||||||
from .sql_security_utils import get_auth_context
|
from .sql_security_utils import get_auth_context
|
||||||
@@ -468,6 +470,51 @@ class DorisQueryExecutor:
|
|||||||
f"Slow query detected: {execution_time:.2f}s (threshold: {self.slow_query_threshold}s)"
|
f"Slow query detected: {execution_time:.2f}s (threshold: {self.slow_query_threshold}s)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def execute_batch_sqls_for_mcp(
|
||||||
|
self, sqls: list[str],
|
||||||
|
timeout: int = 30,
|
||||||
|
session_id: str = "mcp_session",
|
||||||
|
user_id: str = "mcp_user",
|
||||||
|
auth_context=None
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Execute multiple sqls in batch"""
|
||||||
|
if not sqls:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": "SQL query is required",
|
||||||
|
"data": None
|
||||||
|
}
|
||||||
|
query_requests = [
|
||||||
|
QueryRequest(
|
||||||
|
sql=sql,
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=user_id,
|
||||||
|
timeout=timeout,
|
||||||
|
cache_enabled=False
|
||||||
|
)
|
||||||
|
for sql in sqls
|
||||||
|
]
|
||||||
|
query_results = await self.execute_batch_queries(query_requests, auth_context)
|
||||||
|
# Serialize data for JSON response
|
||||||
|
results = [
|
||||||
|
{
|
||||||
|
"data": [self._serialize_row_data(data) for data in result.data],
|
||||||
|
"row_count": result.row_count,
|
||||||
|
"execution_time": result.execution_time,
|
||||||
|
"metadata": {
|
||||||
|
"columns": result.metadata.get("columns", []),
|
||||||
|
"query": result.sql
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for result in query_results
|
||||||
|
]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"multiple_results": True,
|
||||||
|
"results": results
|
||||||
|
}
|
||||||
|
|
||||||
async def execute_batch_queries(
|
async def execute_batch_queries(
|
||||||
self, query_requests: list[QueryRequest], auth_context=None
|
self, query_requests: list[QueryRequest], auth_context=None
|
||||||
) -> list[QueryResult]:
|
) -> list[QueryResult]:
|
||||||
@@ -485,13 +532,16 @@ class DorisQueryExecutor:
|
|||||||
self.execute_query(request, auth_context) for request in query_requests
|
self.execute_query(request, auth_context) for request in query_requests
|
||||||
]
|
]
|
||||||
|
|
||||||
try:
|
query_results = []
|
||||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
except Exception as e:
|
for result in results:
|
||||||
self.logger.error(f"Batch query execution failed: {e}")
|
if isinstance(result, Exception):
|
||||||
raise
|
self.logger.error(f"Batch query execution failed: {result}")
|
||||||
|
raise result
|
||||||
|
else:
|
||||||
|
query_results.append(result)
|
||||||
|
|
||||||
return results
|
return query_results
|
||||||
|
|
||||||
async def explain_query(self, sql: str, session_id: str) -> dict[str, Any]:
|
async def explain_query(self, sql: str, session_id: str) -> dict[str, Any]:
|
||||||
"""Get query execution plan"""
|
"""Get query execution plan"""
|
||||||
@@ -635,6 +685,15 @@ class DorisQueryExecutor:
|
|||||||
sql = sql[:-1]
|
sql = sql[:-1]
|
||||||
sql = f"{sql} LIMIT {limit}"
|
sql = f"{sql} LIMIT {limit}"
|
||||||
|
|
||||||
|
all_statements = [
|
||||||
|
s.strip()
|
||||||
|
for s in sqlparse.split(sql)
|
||||||
|
if s.strip()
|
||||||
|
]
|
||||||
|
if len(all_statements) > 1:
|
||||||
|
return await self.execute_batch_sqls_for_mcp(sqls=all_statements, timeout=timeout,
|
||||||
|
session_id=session_id, user_id=user_id,
|
||||||
|
auth_context=auth_context)
|
||||||
# Create query request
|
# Create query request
|
||||||
query_request = QueryRequest(
|
query_request = QueryRequest(
|
||||||
sql=sql,
|
sql=sql,
|
||||||
|
|||||||
@@ -201,3 +201,73 @@ class TestDorisQueryExecutor:
|
|||||||
if result["success"]:
|
if result["success"]:
|
||||||
assert "data" in result
|
assert "data" in result
|
||||||
assert "row_count" in result
|
assert "row_count" in result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_multi_sql_statements(self, query_executor):
|
||||||
|
"""Test execution of multiple SQL statements"""
|
||||||
|
from doris_mcp_server.utils.query_executor import QueryResult
|
||||||
|
|
||||||
|
# Disable security check for this test
|
||||||
|
query_executor.connection_manager.config.security.enable_security_check = False
|
||||||
|
|
||||||
|
with patch.object(query_executor, 'execute_query') as mock_execute:
|
||||||
|
# Mock results for three SQL statements
|
||||||
|
mock_execute.side_effect = [
|
||||||
|
QueryResult(
|
||||||
|
data=[{"id": 1, "name": "张三"}],
|
||||||
|
row_count=1,
|
||||||
|
execution_time=0.1,
|
||||||
|
sql="SELECT id, name FROM users WHERE id = 1",
|
||||||
|
metadata={"columns": ["id", "name"]}
|
||||||
|
),
|
||||||
|
QueryResult(
|
||||||
|
data=[{"id": 2, "name": "李四"}],
|
||||||
|
row_count=1,
|
||||||
|
execution_time=0.12,
|
||||||
|
sql="SELECT id, name FROM users WHERE id = 2",
|
||||||
|
metadata={"columns": ["id", "name"]}
|
||||||
|
),
|
||||||
|
QueryResult(
|
||||||
|
data=[{"count": 100}],
|
||||||
|
row_count=1,
|
||||||
|
execution_time=0.08,
|
||||||
|
sql="SELECT COUNT(*) as count FROM users",
|
||||||
|
metadata={"columns": ["count"]}
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Execute multiple SQL statements separated by semicolons
|
||||||
|
multi_sql = """
|
||||||
|
SELECT id, name FROM users WHERE id = 1;
|
||||||
|
SELECT id, name FROM users WHERE id = 2;
|
||||||
|
SELECT COUNT(*) as count FROM users;
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = await query_executor.execute_sql_for_mcp(multi_sql)
|
||||||
|
|
||||||
|
# Verify the result structure for multiple statements
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["multiple_results"] is True
|
||||||
|
assert "results" in result
|
||||||
|
assert len(result["results"]) == 3
|
||||||
|
|
||||||
|
# Verify first query result
|
||||||
|
assert result["results"][0]["data"] == [{"id": 1, "name": "张三"}]
|
||||||
|
assert result["results"][0]["row_count"] == 1
|
||||||
|
assert result["results"][0]["metadata"]["columns"] == ["id", "name"]
|
||||||
|
assert result["results"][0]["metadata"]["query"] == "SELECT id, name FROM users WHERE id = 1"
|
||||||
|
|
||||||
|
# Verify second query result
|
||||||
|
assert result["results"][1]["data"] == [{"id": 2, "name": "李四"}]
|
||||||
|
assert result["results"][1]["row_count"] == 1
|
||||||
|
assert result["results"][1]["metadata"]["columns"] == ["id", "name"]
|
||||||
|
assert result["results"][1]["metadata"]["query"] == "SELECT id, name FROM users WHERE id = 2"
|
||||||
|
|
||||||
|
# Verify third query result
|
||||||
|
assert result["results"][2]["data"] == [{"count": 100}]
|
||||||
|
assert result["results"][2]["row_count"] == 1
|
||||||
|
assert result["results"][2]["metadata"]["columns"] == ["count"]
|
||||||
|
assert result["results"][2]["metadata"]["query"] == "SELECT COUNT(*) as count FROM users"
|
||||||
|
|
||||||
|
# Verify execute_query was called three times
|
||||||
|
assert mock_execute.call_count == 3
|
||||||
|
|||||||
Reference in New Issue
Block a user