diff --git a/doris_mcp_server/utils/db.py b/doris_mcp_server/utils/db.py index 7d2acb4..8ddf3f3 100644 --- a/doris_mcp_server/utils/db.py +++ b/doris_mcp_server/utils/db.py @@ -59,6 +59,7 @@ class QueryResult: metadata: dict[str, Any] execution_time: float row_count: int + sql: str class DorisConnection: @@ -132,6 +133,7 @@ class DorisConnection: metadata=metadata, execution_time=execution_time, row_count=row_count, + sql=sql ) except Exception as e: diff --git a/doris_mcp_server/utils/query_executor.py b/doris_mcp_server/utils/query_executor.py index cf2fc28..1e15c2c 100644 --- a/doris_mcp_server/utils/query_executor.py +++ b/doris_mcp_server/utils/query_executor.py @@ -33,6 +33,8 @@ from datetime import datetime, timedelta, date from typing import Any, Dict from decimal import Decimal +import sqlparse + from .db import DorisConnectionManager, QueryResult from .logger import get_logger from .sql_security_utils import get_auth_context @@ -468,6 +470,51 @@ class DorisQueryExecutor: f"Slow query detected: {execution_time:.2f}s (threshold: {self.slow_query_threshold}s)" ) + async def execute_batch_sqls_for_mcp( + self, sqls: list[str], + timeout: int = 30, + session_id: str = "mcp_session", + user_id: str = "mcp_user", + auth_context=None + ) -> dict[str, Any]: + """Execute multiple sqls in batch""" + if not sqls: + return { + "success": False, + "error": "SQL query is required", + "data": None + } + query_requests = [ + QueryRequest( + sql=sql, + session_id=session_id, + user_id=user_id, + timeout=timeout, + cache_enabled=False + ) + for sql in sqls + ] + query_results = await self.execute_batch_queries(query_requests, auth_context) + # Serialize data for JSON response + results = [ + { + "data": [self._serialize_row_data(data) for data in result.data], + "row_count": result.row_count, + "execution_time": result.execution_time, + "metadata": { + "columns": result.metadata.get("columns", []), + "query": result.sql + } + } + for result in query_results + ] + + return { + "success": True, + "multiple_results": True, + "results": results + } + async def execute_batch_queries( self, query_requests: list[QueryRequest], auth_context=None ) -> list[QueryResult]: @@ -485,13 +532,16 @@ class DorisQueryExecutor: self.execute_query(request, auth_context) for request in query_requests ] - try: - results = await asyncio.gather(*tasks, return_exceptions=True) - except Exception as e: - self.logger.error(f"Batch query execution failed: {e}") - raise + query_results = [] + results = await asyncio.gather(*tasks, return_exceptions=True) + for result in results: + if isinstance(result, Exception): + self.logger.error(f"Batch query execution failed: {result}") + raise result + else: + query_results.append(result) - return results + return query_results async def explain_query(self, sql: str, session_id: str) -> dict[str, Any]: """Get query execution plan""" @@ -635,6 +685,15 @@ class DorisQueryExecutor: sql = sql[:-1] sql = f"{sql} LIMIT {limit}" + all_statements = [ + s.strip() + for s in sqlparse.split(sql) + if s.strip() + ] + if len(all_statements) > 1: + return await self.execute_batch_sqls_for_mcp(sqls=all_statements, timeout=timeout, + session_id=session_id, user_id=user_id, + auth_context=auth_context) # Create query request query_request = QueryRequest( sql=sql, diff --git a/test/utils/test_query_executor.py b/test/utils/test_query_executor.py index a9a9b79..ef3c958 100644 --- a/test/utils/test_query_executor.py +++ b/test/utils/test_query_executor.py @@ -201,3 +201,73 @@ class TestDorisQueryExecutor: if result["success"]: assert "data" in result assert "row_count" in result + + @pytest.mark.asyncio + async def test_execute_multi_sql_statements(self, query_executor): + """Test execution of multiple SQL statements""" + from doris_mcp_server.utils.query_executor import QueryResult + + # Disable security check for this test + query_executor.connection_manager.config.security.enable_security_check = False + + with patch.object(query_executor, 'execute_query') as mock_execute: + # Mock results for three SQL statements + mock_execute.side_effect = [ + QueryResult( + data=[{"id": 1, "name": "张三"}], + row_count=1, + execution_time=0.1, + sql="SELECT id, name FROM users WHERE id = 1", + metadata={"columns": ["id", "name"]} + ), + QueryResult( + data=[{"id": 2, "name": "李四"}], + row_count=1, + execution_time=0.12, + sql="SELECT id, name FROM users WHERE id = 2", + metadata={"columns": ["id", "name"]} + ), + QueryResult( + data=[{"count": 100}], + row_count=1, + execution_time=0.08, + sql="SELECT COUNT(*) as count FROM users", + metadata={"columns": ["count"]} + ) + ] + + # Execute multiple SQL statements separated by semicolons + multi_sql = """ + SELECT id, name FROM users WHERE id = 1; + SELECT id, name FROM users WHERE id = 2; + SELECT COUNT(*) as count FROM users; + """ + + result = await query_executor.execute_sql_for_mcp(multi_sql) + + # Verify the result structure for multiple statements + assert result["success"] is True + assert result["multiple_results"] is True + assert "results" in result + assert len(result["results"]) == 3 + + # Verify first query result + assert result["results"][0]["data"] == [{"id": 1, "name": "张三"}] + assert result["results"][0]["row_count"] == 1 + assert result["results"][0]["metadata"]["columns"] == ["id", "name"] + assert result["results"][0]["metadata"]["query"] == "SELECT id, name FROM users WHERE id = 1" + + # Verify second query result + assert result["results"][1]["data"] == [{"id": 2, "name": "李四"}] + assert result["results"][1]["row_count"] == 1 + assert result["results"][1]["metadata"]["columns"] == ["id", "name"] + assert result["results"][1]["metadata"]["query"] == "SELECT id, name FROM users WHERE id = 2" + + # Verify third query result + assert result["results"][2]["data"] == [{"count": 100}] + assert result["results"][2]["row_count"] == 1 + assert result["results"][2]["metadata"]["columns"] == ["count"] + assert result["results"][2]["metadata"]["query"] == "SELECT COUNT(*) as count FROM users" + + # Verify execute_query was called three times + assert mock_execute.call_count == 3