feat: add batch SQL execution support for MCP (#70)

* feat: add batch SQL execution support for MCP

- Add sql field to QueryResult to track executed query
- Implement execute_batch_sqls_for_mcp for executing multiple SQL
- Use sqlparse to split and execute multiple SQL in single request
- Improve error handling in execute_batch_queries
- Return multiple results format when batch queries are detected

* test: add multi-SQL statements test for query executor
This commit is contained in:
zzzzwc
2025-12-24 12:45:29 +08:00
committed by GitHub
parent e58361e04b
commit 43143f0b30
3 changed files with 137 additions and 6 deletions

View File

@@ -59,6 +59,7 @@ class QueryResult:
metadata: dict[str, Any] metadata: dict[str, Any]
execution_time: float execution_time: float
row_count: int row_count: int
sql: str
class DorisConnection: class DorisConnection:
@@ -132,6 +133,7 @@ class DorisConnection:
metadata=metadata, metadata=metadata,
execution_time=execution_time, execution_time=execution_time,
row_count=row_count, row_count=row_count,
sql=sql
) )
except Exception as e: except Exception as e:

View File

@@ -33,6 +33,8 @@ from datetime import datetime, timedelta, date
from typing import Any, Dict from typing import Any, Dict
from decimal import Decimal from decimal import Decimal
import sqlparse
from .db import DorisConnectionManager, QueryResult from .db import DorisConnectionManager, QueryResult
from .logger import get_logger from .logger import get_logger
from .sql_security_utils import get_auth_context from .sql_security_utils import get_auth_context
@@ -468,6 +470,51 @@ class DorisQueryExecutor:
f"Slow query detected: {execution_time:.2f}s (threshold: {self.slow_query_threshold}s)" f"Slow query detected: {execution_time:.2f}s (threshold: {self.slow_query_threshold}s)"
) )
async def execute_batch_sqls_for_mcp(
self, sqls: list[str],
timeout: int = 30,
session_id: str = "mcp_session",
user_id: str = "mcp_user",
auth_context=None
) -> dict[str, Any]:
"""Execute multiple sqls in batch"""
if not sqls:
return {
"success": False,
"error": "SQL query is required",
"data": None
}
query_requests = [
QueryRequest(
sql=sql,
session_id=session_id,
user_id=user_id,
timeout=timeout,
cache_enabled=False
)
for sql in sqls
]
query_results = await self.execute_batch_queries(query_requests, auth_context)
# Serialize data for JSON response
results = [
{
"data": [self._serialize_row_data(data) for data in result.data],
"row_count": result.row_count,
"execution_time": result.execution_time,
"metadata": {
"columns": result.metadata.get("columns", []),
"query": result.sql
}
}
for result in query_results
]
return {
"success": True,
"multiple_results": True,
"results": results
}
async def execute_batch_queries( async def execute_batch_queries(
self, query_requests: list[QueryRequest], auth_context=None self, query_requests: list[QueryRequest], auth_context=None
) -> list[QueryResult]: ) -> list[QueryResult]:
@@ -485,13 +532,16 @@ class DorisQueryExecutor:
self.execute_query(request, auth_context) for request in query_requests self.execute_query(request, auth_context) for request in query_requests
] ]
try: query_results = []
results = await asyncio.gather(*tasks, return_exceptions=True) results = await asyncio.gather(*tasks, return_exceptions=True)
except Exception as e: for result in results:
self.logger.error(f"Batch query execution failed: {e}") if isinstance(result, Exception):
raise self.logger.error(f"Batch query execution failed: {result}")
raise result
else:
query_results.append(result)
return results return query_results
async def explain_query(self, sql: str, session_id: str) -> dict[str, Any]: async def explain_query(self, sql: str, session_id: str) -> dict[str, Any]:
"""Get query execution plan""" """Get query execution plan"""
@@ -635,6 +685,15 @@ class DorisQueryExecutor:
sql = sql[:-1] sql = sql[:-1]
sql = f"{sql} LIMIT {limit}" sql = f"{sql} LIMIT {limit}"
all_statements = [
s.strip()
for s in sqlparse.split(sql)
if s.strip()
]
if len(all_statements) > 1:
return await self.execute_batch_sqls_for_mcp(sqls=all_statements, timeout=timeout,
session_id=session_id, user_id=user_id,
auth_context=auth_context)
# Create query request # Create query request
query_request = QueryRequest( query_request = QueryRequest(
sql=sql, sql=sql,

View File

@@ -201,3 +201,73 @@ class TestDorisQueryExecutor:
if result["success"]: if result["success"]:
assert "data" in result assert "data" in result
assert "row_count" in result assert "row_count" in result
@pytest.mark.asyncio
async def test_execute_multi_sql_statements(self, query_executor):
"""Test execution of multiple SQL statements"""
from doris_mcp_server.utils.query_executor import QueryResult
# Disable security check for this test
query_executor.connection_manager.config.security.enable_security_check = False
with patch.object(query_executor, 'execute_query') as mock_execute:
# Mock results for three SQL statements
mock_execute.side_effect = [
QueryResult(
data=[{"id": 1, "name": "张三"}],
row_count=1,
execution_time=0.1,
sql="SELECT id, name FROM users WHERE id = 1",
metadata={"columns": ["id", "name"]}
),
QueryResult(
data=[{"id": 2, "name": "李四"}],
row_count=1,
execution_time=0.12,
sql="SELECT id, name FROM users WHERE id = 2",
metadata={"columns": ["id", "name"]}
),
QueryResult(
data=[{"count": 100}],
row_count=1,
execution_time=0.08,
sql="SELECT COUNT(*) as count FROM users",
metadata={"columns": ["count"]}
)
]
# Execute multiple SQL statements separated by semicolons
multi_sql = """
SELECT id, name FROM users WHERE id = 1;
SELECT id, name FROM users WHERE id = 2;
SELECT COUNT(*) as count FROM users;
"""
result = await query_executor.execute_sql_for_mcp(multi_sql)
# Verify the result structure for multiple statements
assert result["success"] is True
assert result["multiple_results"] is True
assert "results" in result
assert len(result["results"]) == 3
# Verify first query result
assert result["results"][0]["data"] == [{"id": 1, "name": "张三"}]
assert result["results"][0]["row_count"] == 1
assert result["results"][0]["metadata"]["columns"] == ["id", "name"]
assert result["results"][0]["metadata"]["query"] == "SELECT id, name FROM users WHERE id = 1"
# Verify second query result
assert result["results"][1]["data"] == [{"id": 2, "name": "李四"}]
assert result["results"][1]["row_count"] == 1
assert result["results"][1]["metadata"]["columns"] == ["id", "name"]
assert result["results"][1]["metadata"]["query"] == "SELECT id, name FROM users WHERE id = 2"
# Verify third query result
assert result["results"][2]["data"] == [{"count": 100}]
assert result["results"][2]["row_count"] == 1
assert result["results"][2]["metadata"]["columns"] == ["count"]
assert result["results"][2]["metadata"]["query"] == "SELECT COUNT(*) as count FROM users"
# Verify execute_query was called three times
assert mock_execute.call_count == 3