feat: add batch SQL execution support for MCP (#70)

* feat: add batch SQL execution support for MCP

- Add sql field to QueryResult to track executed query
- Implement execute_batch_sqls_for_mcp for executing multiple SQL
- Use sqlparse to split and execute multiple SQL in single request
- Improve error handling in execute_batch_queries
- Return multiple results format when batch queries are detected

* test: add multi-SQL statements test for query executor
This commit is contained in:
zzzzwc
2025-12-24 12:45:29 +08:00
committed by GitHub
parent e58361e04b
commit 43143f0b30
3 changed files with 137 additions and 6 deletions

View File

@@ -59,6 +59,7 @@ class QueryResult:
metadata: dict[str, Any]
execution_time: float
row_count: int
sql: str
class DorisConnection:
@@ -132,6 +133,7 @@ class DorisConnection:
metadata=metadata,
execution_time=execution_time,
row_count=row_count,
sql=sql
)
except Exception as e:

View File

@@ -33,6 +33,8 @@ from datetime import datetime, timedelta, date
from typing import Any, Dict
from decimal import Decimal
import sqlparse
from .db import DorisConnectionManager, QueryResult
from .logger import get_logger
from .sql_security_utils import get_auth_context
@@ -468,6 +470,51 @@ class DorisQueryExecutor:
f"Slow query detected: {execution_time:.2f}s (threshold: {self.slow_query_threshold}s)"
)
async def execute_batch_sqls_for_mcp(
self, sqls: list[str],
timeout: int = 30,
session_id: str = "mcp_session",
user_id: str = "mcp_user",
auth_context=None
) -> dict[str, Any]:
"""Execute multiple sqls in batch"""
if not sqls:
return {
"success": False,
"error": "SQL query is required",
"data": None
}
query_requests = [
QueryRequest(
sql=sql,
session_id=session_id,
user_id=user_id,
timeout=timeout,
cache_enabled=False
)
for sql in sqls
]
query_results = await self.execute_batch_queries(query_requests, auth_context)
# Serialize data for JSON response
results = [
{
"data": [self._serialize_row_data(data) for data in result.data],
"row_count": result.row_count,
"execution_time": result.execution_time,
"metadata": {
"columns": result.metadata.get("columns", []),
"query": result.sql
}
}
for result in query_results
]
return {
"success": True,
"multiple_results": True,
"results": results
}
async def execute_batch_queries(
self, query_requests: list[QueryRequest], auth_context=None
) -> list[QueryResult]:
@@ -485,13 +532,16 @@ class DorisQueryExecutor:
self.execute_query(request, auth_context) for request in query_requests
]
try:
results = await asyncio.gather(*tasks, return_exceptions=True)
except Exception as e:
self.logger.error(f"Batch query execution failed: {e}")
raise
query_results = []
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, Exception):
self.logger.error(f"Batch query execution failed: {result}")
raise result
else:
query_results.append(result)
return results
return query_results
async def explain_query(self, sql: str, session_id: str) -> dict[str, Any]:
"""Get query execution plan"""
@@ -635,6 +685,15 @@ class DorisQueryExecutor:
sql = sql[:-1]
sql = f"{sql} LIMIT {limit}"
all_statements = [
s.strip()
for s in sqlparse.split(sql)
if s.strip()
]
if len(all_statements) > 1:
return await self.execute_batch_sqls_for_mcp(sqls=all_statements, timeout=timeout,
session_id=session_id, user_id=user_id,
auth_context=auth_context)
# Create query request
query_request = QueryRequest(
sql=sql,