From 5e98e5ba4177f4bd3fbecdc3ccc046b35db3860d Mon Sep 17 00:00:00 2001 From: FreeOnePlus Date: Fri, 6 Jun 2025 14:35:53 +0800 Subject: [PATCH] support Multi-Catalog --- doris_mcp_server/mcp_core.py | 60 +++++--- doris_mcp_server/sse_server.py | 3 +- doris_mcp_server/tools/__init__.py | 6 +- doris_mcp_server/tools/mcp_doris_tools.py | 88 +++++++---- doris_mcp_server/tools/tool_initializer.py | 62 +++++--- doris_mcp_server/utils/schema_extractor.py | 145 +++++++++++++++---- doris_mcp_server/utils/sql_executor_tools.py | 7 +- uv.lock | 92 ++++++------ 8 files changed, 312 insertions(+), 151 deletions(-) diff --git a/doris_mcp_server/mcp_core.py b/doris_mcp_server/mcp_core.py index 72116c4..2ccc509 100644 --- a/doris_mcp_server/mcp_core.py +++ b/doris_mcp_server/mcp_core.py @@ -52,78 +52,85 @@ def run_stdio(): sys.exit(1) # Register Tool: Execute SQL Query -@stdio_mcp.tool("exec_query", description="""[Function Description]: Execute SQL query and return result command (executed by the client).\n +@stdio_mcp.tool("exec_query", description="""[Function Description]: Execute SQL query and return result command with catalog federation support.\n [Parameter Content]:\n -- sql (string) [Required] - SQL statement to execute\n +- sql (string) [Required] - SQL statement to execute. MUST use three-part naming for all table references: 'catalog_name.db_name.table_name'. For internal tables use 'internal.db_name.table_name', for external tables use 'catalog_name.db_name.table_name'\n - db_name (string) [Optional] - Target database name, defaults to the current database\n +- catalog_name (string) [Optional] - Reference catalog name for context, defaults to current catalog\n - max_rows (integer) [Optional] - Maximum number of rows to return, default 100\n - timeout (integer) [Optional] - Query timeout in seconds, default 30\n""") -async def exec_query_tool(sql: str, db_name: str = None, max_rows: int = 100, timeout: int = 30) -> Dict[str, Any]: +async def exec_query_tool(sql: str, db_name: str = None, catalog_name: str = None, max_rows: int = 100, timeout: int = 30) -> Dict[str, Any]: """Wrapper: Execute SQL query and return result command""" from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_exec_query - return await mcp_doris_exec_query(sql=sql, db_name=db_name, max_rows=max_rows, timeout=timeout) + return await mcp_doris_exec_query(sql=sql, db_name=db_name, catalog_name=catalog_name, max_rows=max_rows, timeout=timeout) # Register Tool: Get Table Schema @stdio_mcp.tool("get_table_schema", description="""[Function Description]: Get detailed structure information of the specified table (columns, types, comments, etc.).\n [Parameter Content]:\n - table_name (string) [Required] - Name of the table to query\n -- db_name (string) [Optional] - Target database name, defaults to the current database\n""") -async def get_table_schema_tool(table_name: str, db_name: str = None) -> Dict[str, Any]: +- db_name (string) [Optional] - Target database name, defaults to the current database\n +- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""") +async def get_table_schema_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]: """Wrapper: Get table schema""" from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_table_schema if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]} - return await mcp_doris_get_table_schema(table_name=table_name, db_name=db_name) + return await mcp_doris_get_table_schema(table_name=table_name, db_name=db_name, catalog_name=catalog_name) # Register Tool: Get Database Table List @stdio_mcp.tool("get_db_table_list", description="""[Function Description]: Get a list of all table names in the specified database.\n [Parameter Content]:\n -- db_name (string) [Optional] - Target database name, defaults to the current database\n""") -async def get_db_table_list_tool(db_name: str = None) -> Dict[str, Any]: +- db_name (string) [Optional] - Target database name, defaults to the current database\n +- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""") +async def get_db_table_list_tool(db_name: str = None, catalog_name: str = None) -> Dict[str, Any]: """Wrapper: Get database table list""" from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_db_table_list - return await mcp_doris_get_db_table_list(db_name=db_name) + return await mcp_doris_get_db_table_list(db_name=db_name, catalog_name=catalog_name) # Register Tool: Get Database List @stdio_mcp.tool("get_db_list", description="""[Function Description]: Get a list of all database names on the server.\n [Parameter Content]:\n -- random_string (string) [Required] - Unique identifier for the tool call\n""") -async def get_db_list_tool() -> Dict[str, Any]: +- random_string (string) [Required] - Unique identifier for the tool call\n +- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""") +async def get_db_list_tool(catalog_name: str = None) -> Dict[str, Any]: """Wrapper: Get database list""" from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_db_list - return await mcp_doris_get_db_list() + return await mcp_doris_get_db_list(catalog_name=catalog_name) # Register Tool: Get Table Comment @stdio_mcp.tool("get_table_comment", description="""[Function Description]: Get the comment information for the specified table.\n [Parameter Content]:\n - table_name (string) [Required] - Name of the table to query\n -- db_name (string) [Optional] - Target database name, defaults to the current database\n""") -async def get_table_comment_tool(table_name: str, db_name: str = None) -> Dict[str, Any]: +- db_name (string) [Optional] - Target database name, defaults to the current database\n +- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""") +async def get_table_comment_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]: """Wrapper: Get table comment""" from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_table_comment if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]} - return await mcp_doris_get_table_comment(table_name=table_name, db_name=db_name) + return await mcp_doris_get_table_comment(table_name=table_name, db_name=db_name, catalog_name=catalog_name) # Register Tool: Get Table Column Comments @stdio_mcp.tool("get_table_column_comments", description="""[Function Description]: Get comment information for all columns in the specified table.\n [Parameter Content]:\n - table_name (string) [Required] - Name of the table to query\n -- db_name (string) [Optional] - Target database name, defaults to the current database\n""") -async def get_table_column_comments_tool(table_name: str, db_name: str = None) -> Dict[str, Any]: +- db_name (string) [Optional] - Target database name, defaults to the current database\n +- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""") +async def get_table_column_comments_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]: """Wrapper: Get table column comments""" from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_table_column_comments if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]} - return await mcp_doris_get_table_column_comments(table_name=table_name, db_name=db_name) + return await mcp_doris_get_table_column_comments(table_name=table_name, db_name=db_name, catalog_name=catalog_name) # Register Tool: Get Table Indexes @stdio_mcp.tool("get_table_indexes", description="""[Function Description]: Get index information for the specified table. [Parameter Content]:\n - table_name (string) [Required] - Name of the table to query\n -- db_name (string) [Optional] - Target database name, defaults to the current database\n""") -async def get_table_indexes_tool(table_name: str, db_name: str = None) -> Dict[str, Any]: +- db_name (string) [Optional] - Target database name, defaults to the current database\n +- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""") +async def get_table_indexes_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]: """Wrapper: Get table indexes""" from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_table_indexes if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]} - return await mcp_doris_get_table_indexes(table_name=table_name, db_name=db_name) + return await mcp_doris_get_table_indexes(table_name=table_name, db_name=db_name, catalog_name=catalog_name) # Register Tool: Get Recent Audit Logs @stdio_mcp.tool("get_recent_audit_logs", description="""[Function Description]: Get audit log records for a recent period.\n @@ -140,4 +147,13 @@ async def get_recent_audit_logs_tool(days: int = 7, limit: int = 100) -> Dict[st return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "days and limit parameters must be integers"})}]} return await mcp_doris_get_recent_audit_logs(days=days, limit=limit) +# Register Tool: Get Catalog List +@stdio_mcp.tool("get_catalog_list", description="""[Function Description]: Get a list of all catalog names on the server.\n +[Parameter Content]:\n +- random_string (string) [Required] - Unique identifier for the tool call\n""") +async def get_catalog_list_tool() -> Dict[str, Any]: + """Wrapper: Get catalog list""" + from doris_mcp_server.tools.mcp_doris_tools import mcp_doris_get_catalog_list + return await mcp_doris_get_catalog_list() + # --- Register Tools --- diff --git a/doris_mcp_server/sse_server.py b/doris_mcp_server/sse_server.py index 8a646e6..c155b0f 100644 --- a/doris_mcp_server/sse_server.py +++ b/doris_mcp_server/sse_server.py @@ -1009,7 +1009,8 @@ class DorisMCPSseServer: "get_table_comment": "mcp_doris_get_table_comment", "get_table_column_comments": "mcp_doris_get_table_column_comments", "get_table_indexes": "mcp_doris_get_table_indexes", - "get_recent_audit_logs": "mcp_doris_get_recent_audit_logs" + "get_recent_audit_logs": "mcp_doris_get_recent_audit_logs", + "get_catalog_list": "mcp_doris_get_catalog_list" } # If it's a standard name, convert to MCP name diff --git a/doris_mcp_server/tools/__init__.py b/doris_mcp_server/tools/__init__.py index 4b63b39..2f3fa4c 100644 --- a/doris_mcp_server/tools/__init__.py +++ b/doris_mcp_server/tools/__init__.py @@ -6,7 +6,8 @@ from .mcp_doris_tools import ( mcp_doris_get_table_comment, mcp_doris_get_table_column_comments, mcp_doris_get_table_indexes, - mcp_doris_get_recent_audit_logs + mcp_doris_get_recent_audit_logs, + mcp_doris_get_catalog_list ) # The __all__ list should reflect the registered tool names, @@ -19,5 +20,6 @@ __all__ = [ "get_table_comment", "get_table_column_comments", "get_table_indexes", - "get_recent_audit_logs" + "get_recent_audit_logs", + "get_catalog_list" ] \ No newline at end of file diff --git a/doris_mcp_server/tools/mcp_doris_tools.py b/doris_mcp_server/tools/mcp_doris_tools.py index ccaa9f6..3c93776 100644 --- a/doris_mcp_server/tools/mcp_doris_tools.py +++ b/doris_mcp_server/tools/mcp_doris_tools.py @@ -55,20 +55,31 @@ def _format_response(success: bool, result: Any = None, error: str = None, messa ] } -async def mcp_doris_exec_query(sql: str = None, db_name: str = None, max_rows: int = 100, timeout: int = 30) -> Dict[str, Any]: +async def mcp_doris_exec_query(sql: str = None, db_name: str = None, catalog_name: str = None, max_rows: int = 100, timeout: int = 30) -> Dict[str, Any]: """ - Executes an SQL query and returns the result. + Executes an SQL query and returns the result with catalog federation support. Args: - sql (str): The SQL query to execute. - db_name (str, optional): Target database name. Defaults to the configured default database. + sql (str): The SQL query to execute. MUST use three-part naming for table references: + - Internal tables: internal.db_name.table_name (e.g., "SELECT * FROM internal.ssb.customer") + - External tables: catalog_name.db_name.table_name (e.g., "SELECT * FROM mysql.ssb.customer") + - Cross-catalog queries: "SELECT * FROM mysql.ssb.customer m JOIN internal.ssb.orders o ON m.id = o.customer_id" + + Examples: + - Query internal catalog: "SELECT COUNT(*) FROM internal.ssb.customer" + - Query MySQL catalog: "SELECT COUNT(*) FROM mysql.ssb.customer" + - Cross-catalog join: "SELECT * FROM internal.ssb.customer c JOIN mysql.test.user_info u ON c.id = u.customer_id" + + db_name (str, optional): Target database name. Only used for connection context, table names in SQL must be fully qualified. + catalog_name (str, optional): Reference catalog name for context. Does not affect SQL execution - table names in SQL must be fully qualified. + Available catalogs can be found using get_catalog_list tool. max_rows (int, optional): Maximum number of rows to return. Defaults to 100. timeout (int, optional): Query timeout in seconds. Defaults to 30. Returns: Dict[str, Any]: A dictionary containing the query result or an error. """ - logger.info(f"MCP Tool Call: mcp_doris_exec_query, SQL: {sql}, DB: {db_name}, MaxRows: {max_rows}, Timeout: {timeout}") + logger.info(f"MCP Tool Call: mcp_doris_exec_query, SQL: {sql}, DB: {db_name}, Catalog: {catalog_name}, MaxRows: {max_rows}, Timeout: {timeout}") try: if not sql: return _format_response(success=False, error="SQL statement not provided", message="Please provide the SQL statement to execute") @@ -78,6 +89,7 @@ async def mcp_doris_exec_query(sql: str = None, db_name: str = None, max_rows: i "params": { "sql": sql, "db_name": db_name, + "catalog_name": catalog_name, "max_rows": max_rows, "timeout": timeout } @@ -121,71 +133,71 @@ async def mcp_doris_exec_query(sql: str = None, db_name: str = None, max_rows: i return _format_response(success=False, error=str(e), message="Error executing SQL query") -async def mcp_doris_get_table_schema(table_name: str, db_name: str = None) -> Dict[str, Any]: - logger.info(f"MCP Tool Call: mcp_doris_get_table_schema, Table: {table_name}, DB: {db_name}") +async def mcp_doris_get_table_schema(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]: + logger.info(f"MCP Tool Call: mcp_doris_get_table_schema, Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}") if not table_name: return _format_response(success=False, error="Missing table_name parameter") try: - extractor = MetadataExtractor(db_name=db_name) - schema = extractor.get_table_schema(table_name=table_name, db_name=db_name) + extractor = MetadataExtractor(db_name=db_name, catalog_name=catalog_name) + schema = extractor.get_table_schema(table_name=table_name, db_name=db_name, catalog_name=catalog_name) if not schema: - return _format_response(success=False, error="Table not found or has no columns", message=f"Could not get schema for table {db_name or extractor.db_name}.{table_name}") + return _format_response(success=False, error="Table not found or has no columns", message=f"Could not get schema for table {catalog_name or 'default'}.{db_name or extractor.db_name}.{table_name}") return _format_response(success=True, result=schema) except Exception as e: logger.error(f"MCP tool execution failed mcp_doris_get_table_schema: {str(e)}", exc_info=True) return _format_response(success=False, error=str(e), message="Error getting table schema") -async def mcp_doris_get_db_table_list(db_name: str = None) -> Dict[str, Any]: - logger.info(f"MCP Tool Call: mcp_doris_get_db_table_list, DB: {db_name}") +async def mcp_doris_get_db_table_list(db_name: str = None, catalog_name: str = None) -> Dict[str, Any]: + logger.info(f"MCP Tool Call: mcp_doris_get_db_table_list, DB: {db_name}, Catalog: {catalog_name}") try: - extractor = MetadataExtractor(db_name=db_name) - tables = extractor.get_database_tables(db_name=db_name) + extractor = MetadataExtractor(db_name=db_name, catalog_name=catalog_name) + tables = extractor.get_database_tables(db_name=db_name, catalog_name=catalog_name) return _format_response(success=True, result=tables) except Exception as e: logger.error(f"MCP tool execution failed mcp_doris_get_db_table_list: {str(e)}", exc_info=True) return _format_response(success=False, error=str(e), message="Error getting database table list") -async def mcp_doris_get_db_list() -> Dict[str, Any]: - logger.info(f"MCP Tool Call: mcp_doris_get_db_list") +async def mcp_doris_get_db_list(catalog_name: str = None) -> Dict[str, Any]: + logger.info(f"MCP Tool Call: mcp_doris_get_db_list, Catalog: {catalog_name}") try: - extractor = MetadataExtractor() - databases = extractor.get_all_databases() + extractor = MetadataExtractor(catalog_name=catalog_name) + databases = extractor.get_all_databases(catalog_name=catalog_name) return _format_response(success=True, result=databases) except Exception as e: logger.error(f"MCP tool execution failed mcp_doris_get_db_list: {str(e)}", exc_info=True) return _format_response(success=False, error=str(e), message="Error getting database list") -async def mcp_doris_get_table_comment(table_name: str, db_name: str = None) -> Dict[str, Any]: - logger.info(f"MCP Tool Call: mcp_doris_get_table_comment, Table: {table_name}, DB: {db_name}") +async def mcp_doris_get_table_comment(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]: + logger.info(f"MCP Tool Call: mcp_doris_get_table_comment, Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}") if not table_name: return _format_response(success=False, error="Missing table_name parameter") try: - extractor = MetadataExtractor(db_name=db_name) - comment = extractor.get_table_comment(table_name=table_name, db_name=db_name) + extractor = MetadataExtractor(db_name=db_name, catalog_name=catalog_name) + comment = extractor.get_table_comment(table_name=table_name, db_name=db_name, catalog_name=catalog_name) return _format_response(success=True, result=comment) except Exception as e: logger.error(f"MCP tool execution failed mcp_doris_get_table_comment: {str(e)}", exc_info=True) return _format_response(success=False, error=str(e), message="Error getting table comment") -async def mcp_doris_get_table_column_comments(table_name: str, db_name: str = None) -> Dict[str, Any]: - logger.info(f"MCP Tool Call: mcp_doris_get_table_column_comments, Table: {table_name}, DB: {db_name}") +async def mcp_doris_get_table_column_comments(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]: + logger.info(f"MCP Tool Call: mcp_doris_get_table_column_comments, Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}") if not table_name: return _format_response(success=False, error="Missing table_name parameter") try: - extractor = MetadataExtractor(db_name=db_name) - comments = extractor.get_column_comments(table_name=table_name, db_name=db_name) + extractor = MetadataExtractor(db_name=db_name, catalog_name=catalog_name) + comments = extractor.get_column_comments(table_name=table_name, db_name=db_name, catalog_name=catalog_name) return _format_response(success=True, result=comments) except Exception as e: logger.error(f"MCP tool execution failed mcp_doris_get_table_column_comments: {str(e)}", exc_info=True) return _format_response(success=False, error=str(e), message="Error getting column comments") -async def mcp_doris_get_table_indexes(table_name: str, db_name: str = None) -> Dict[str, Any]: - logger.info(f"MCP Tool Call: mcp_doris_get_table_indexes, Table: {table_name}, DB: {db_name}") +async def mcp_doris_get_table_indexes(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]: + logger.info(f"MCP Tool Call: mcp_doris_get_table_indexes, Table: {table_name}, DB: {db_name}, Catalog: {catalog_name}") if not table_name: return _format_response(success=False, error="Missing table_name parameter") try: - extractor = MetadataExtractor(db_name=db_name) - indexes = extractor.get_table_indexes(table_name=table_name, db_name=db_name) + extractor = MetadataExtractor(db_name=db_name, catalog_name=catalog_name) + indexes = extractor.get_table_indexes(table_name=table_name, db_name=db_name, catalog_name=catalog_name) return _format_response(success=True, result=indexes) except Exception as e: logger.error(f"MCP tool execution failed mcp_doris_get_table_indexes: {str(e)}", exc_info=True) @@ -200,3 +212,19 @@ async def mcp_doris_get_recent_audit_logs(days: int = 7, limit: int = 100) -> Di except Exception as e: logger.error(f"MCP tool execution failed mcp_doris_get_recent_audit_logs: {str(e)}", exc_info=True) return _format_response(success=False, error=str(e), message="Error getting audit logs") + +async def mcp_doris_get_catalog_list() -> Dict[str, Any]: + """ + Get Doris catalog list + + Returns: + Dict[str, Any]: Dictionary containing catalog list or error information + """ + logger.info(f"MCP Tool Call: mcp_doris_get_catalog_list") + try: + extractor = MetadataExtractor() + catalogs = extractor.get_catalog_list() + return _format_response(success=True, result=catalogs, message="Successfully retrieved catalog list") + except Exception as e: + logger.error(f"MCP tool execution failed mcp_doris_get_catalog_list: {str(e)}", exc_info=True) + return _format_response(success=False, error=str(e), message="Error getting catalog list") diff --git a/doris_mcp_server/tools/tool_initializer.py b/doris_mcp_server/tools/tool_initializer.py index cebef12..5e19133 100644 --- a/doris_mcp_server/tools/tool_initializer.py +++ b/doris_mcp_server/tools/tool_initializer.py @@ -26,7 +26,8 @@ from doris_mcp_server.tools.mcp_doris_tools import ( mcp_doris_get_table_comment, mcp_doris_get_table_column_comments, mcp_doris_get_table_indexes, - mcp_doris_get_recent_audit_logs + mcp_doris_get_recent_audit_logs, + mcp_doris_get_catalog_list ) # Get logger @@ -42,79 +43,86 @@ async def register_mcp_tools(mcp): try: # Register Tool: Execute SQL Query (Using long description string including parameters) - @mcp.tool("exec_query", description="""[Function Description]: Execute SQL query and return result command (executed by the client).\n + @mcp.tool("exec_query", description="""[Function Description]: Execute SQL query and return result command with catalog federation support.\n [Parameter Content]:\n - random_string (string) [Required] - Unique identifier for the tool call\n -- sql (string) [Required] - SQL statement to execute\n +- sql (string) [Required] - SQL statement to execute. MUST use three-part naming for all table references: 'catalog_name.db_name.table_name'. For internal tables use 'internal.db_name.table_name', for external tables use 'catalog_name.db_name.table_name'\n - db_name (string) [Optional] - Target database name, defaults to the current database\n +- catalog_name (string) [Optional] - Reference catalog name for context, defaults to current catalog\n - max_rows (integer) [Optional] - Maximum number of rows to return, default 100 - timeout (integer) [Optional] - Query timeout in seconds, default 30""") - async def exec_query_tool(sql: str, db_name: str = None, max_rows: int = 100, timeout: int = 30) -> Dict[str, Any]: + async def exec_query_tool(sql: str, db_name: str = None, catalog_name: str = None, max_rows: int = 100, timeout: int = 30) -> Dict[str, Any]: """Wrapper: Execute SQL query and return result command""" # Note: ctx parameter is no longer needed here as we receive named parameters directly - return await mcp_doris_exec_query(sql=sql, db_name=db_name, max_rows=max_rows, timeout=timeout) + return await mcp_doris_exec_query(sql=sql, db_name=db_name, catalog_name=catalog_name, max_rows=max_rows, timeout=timeout) # Register Tool: Get Table Schema (Keep long description string including parameters) @mcp.tool("get_table_schema", description="""[Function Description]: Get detailed structure information of the specified table (columns, types, comments, etc.).\n [Parameter Content]:\n - random_string (string) [Required] - Unique identifier for the tool call\n - table_name (string) [Required] - Name of the table to query\n -- db_name (string) [Optional] - Target database name, defaults to the current database\n""") - async def get_table_schema_tool(table_name: str, db_name: str = None) -> Dict[str, Any]: +- db_name (string) [Optional] - Target database name, defaults to the current database\n +- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""") + async def get_table_schema_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]: """Wrapper: Get table schema""" if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]} - return await mcp_doris_get_table_schema(table_name=table_name, db_name=db_name) + return await mcp_doris_get_table_schema(table_name=table_name, db_name=db_name, catalog_name=catalog_name) # Register Tool: Get Database Table List (Keep long description string including parameters) @mcp.tool("get_db_table_list", description="""[Function Description]: Get a list of all table names in the specified database.\n [Parameter Content]:\n - random_string (string) [Required] - Unique identifier for the tool call\n -- db_name (string) [Optional] - Target database name, defaults to the current database\n""") - async def get_db_table_list_tool(db_name: str = None) -> Dict[str, Any]: +- db_name (string) [Optional] - Target database name, defaults to the current database\n +- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""") + async def get_db_table_list_tool(db_name: str = None, catalog_name: str = None) -> Dict[str, Any]: """Wrapper: Get database table list""" - return await mcp_doris_get_db_table_list(db_name=db_name) + return await mcp_doris_get_db_table_list(db_name=db_name, catalog_name=catalog_name) # Register Tool: Get Database List (Keep long description string including parameters) # Note: Although the description mentions random_string, the wrapper function signature does not. See how mcp handles this. @mcp.tool("get_db_list", description="""[Function Description]: Get a list of all database names on the server.\n [Parameter Content]:\n -- random_string (string) [Required] - Unique identifier for the tool call\n""") - async def get_db_list_tool() -> Dict[str, Any]: # Function signature has no parameters +- random_string (string) [Required] - Unique identifier for the tool call\n +- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""") + async def get_db_list_tool(catalog_name: str = None) -> Dict[str, Any]: # Function signature has no parameters """Wrapper: Get database list""" - return await mcp_doris_get_db_list() + return await mcp_doris_get_db_list(catalog_name=catalog_name) # Register Tool: Get Table Comment (Keep long description string including parameters) @mcp.tool("get_table_comment", description="""[Function Description]: Get the comment information for the specified table.\n [Parameter Content]:\n - random_string (string) [Required] - Unique identifier for the tool call\n - table_name (string) [Required] - Name of the table to query\n -- db_name (string) [Optional] - Target database name, defaults to the current database\n""") - async def get_table_comment_tool(table_name: str, db_name: str = None) -> Dict[str, Any]: +- db_name (string) [Optional] - Target database name, defaults to the current database\n +- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""") + async def get_table_comment_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]: """Wrapper: Get table comment""" if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]} - return await mcp_doris_get_table_comment(table_name=table_name, db_name=db_name) + return await mcp_doris_get_table_comment(table_name=table_name, db_name=db_name, catalog_name=catalog_name) # Register Tool: Get Table Column Comments (Keep long description string including parameters) @mcp.tool("get_table_column_comments", description="""[Function Description]: Get comment information for all columns in the specified table.\n [Parameter Content]:\n - random_string (string) [Required] - Unique identifier for the tool call\n - table_name (string) [Required] - Name of the table to query\n -- db_name (string) [Optional] - Target database name, defaults to the current database\n""") - async def get_table_column_comments_tool(table_name: str, db_name: str = None) -> Dict[str, Any]: +- db_name (string) [Optional] - Target database name, defaults to the current database\n +- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""") + async def get_table_column_comments_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]: """Wrapper: Get table column comments""" if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]} - return await mcp_doris_get_table_column_comments(table_name=table_name, db_name=db_name) + return await mcp_doris_get_table_column_comments(table_name=table_name, db_name=db_name, catalog_name=catalog_name) # Register Tool: Get Table Indexes (Keep long description string including parameters) @mcp.tool("get_table_indexes", description="""[Function Description]: Get index information for the specified table.\n [Parameter Content]:\n - random_string (string) [Required] - Unique identifier for the tool call\n - table_name (string) [Required] - Name of the table to query\n -- db_name (string) [Optional] - Target database name, defaults to the current database\n""") - async def get_table_indexes_tool(table_name: str, db_name: str = None) -> Dict[str, Any]: +- db_name (string) [Optional] - Target database name, defaults to the current database\n +- catalog_name (string) [Optional] - Target catalog name for federation queries, defaults to current catalog\n""") + async def get_table_indexes_tool(table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, Any]: """Wrapper: Get table indexes""" if not table_name: return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "Missing table_name parameter"})}]} - return await mcp_doris_get_table_indexes(table_name=table_name, db_name=db_name) + return await mcp_doris_get_table_indexes(table_name=table_name, db_name=db_name, catalog_name=catalog_name) # Register Tool: Get Recent Audit Logs (Keep long description string including parameters) @mcp.tool("get_recent_audit_logs", description="""[Function Description]: Get audit log records for a recent period.\n @@ -131,6 +139,14 @@ async def register_mcp_tools(mcp): return {"content": [{"type": "text", "text": json.dumps({"success": False, "error": "days and limit parameters must be integers"})}]} return await mcp_doris_get_recent_audit_logs(days=days, limit=limit) + # Register Tool: Get Catalog List (Keep long description string including parameters) + @mcp.tool("get_catalog_list", description="""[Function Description]: Get a list of all catalog names on the server.\n +[Parameter Content]:\n +- random_string (string) [Required] - Unique identifier for the tool call\n""") + async def get_catalog_list_tool() -> Dict[str, Any]: + """Wrapper: Get catalog list""" + return await mcp_doris_get_catalog_list() + # Get tool count tools_count = len(await mcp.list_tools()) logger.info(f"Registered all MCP tools, total {tools_count} tools") diff --git a/doris_mcp_server/utils/schema_extractor.py b/doris_mcp_server/utils/schema_extractor.py index ec90a28..b9cd244 100644 --- a/doris_mcp_server/utils/schema_extractor.py +++ b/doris_mcp_server/utils/schema_extractor.py @@ -31,15 +31,17 @@ from doris_mcp_server.utils.db import execute_query_df, execute_query class MetadataExtractor: """Apache Doris Metadata Extractor""" - def __init__(self, db_name: str = None): + def __init__(self, db_name: str = None, catalog_name: str = None): """ Initialize the metadata extractor Args: db_name: Default database name, uses the currently connected database if not specified + catalog_name: Default catalog name for federation queries, uses the current catalog if not specified """ # Get configuration from environment variables self.db_name = db_name or os.getenv("DB_DATABASE", "") + self.catalog_name = catalog_name # Store catalog name for federation support self.metadata_db = METADATA_DB_NAME # Use constant # Caching system @@ -118,14 +120,18 @@ class MetadataExtractor: default_patterns = ["^ads_.*$", "^dim_.*$", "^dws_.*$", "^dwd_.*$", "^ods_.*$", "^.*$"] return default_patterns - def get_all_databases(self) -> List[str]: + def get_all_databases(self, catalog_name: str = None) -> List[str]: """ Get a list of all databases + Args: + catalog_name: Catalog name for federation queries, uses instance catalog if None + Returns: List of database names """ - cache_key = "databases" + effective_catalog = catalog_name or self.catalog_name + cache_key = f"databases_{effective_catalog or 'default'}" if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl: return self.metadata_cache[cache_key] @@ -142,13 +148,13 @@ class MetadataExtractor: SCHEMA_NAME """ - result = execute_query(query) + result = self._execute_query_with_catalog(query, self.db_name, effective_catalog) if not result: databases = [] else: databases = [db["SCHEMA_NAME"] for db in result] - logger.info(f"Retrieved database list: {databases}") + logger.info(f"Retrieved database list from catalog {effective_catalog or 'default'}: {databases}") # Update cache self.metadata_cache[cache_key] = databases @@ -205,22 +211,24 @@ class MetadataExtractor: logger.warning(f"Current database {self.db_name} is in the excluded list, metadata retrieval might not work properly") return [self.db_name] if self.db_name else [] - def get_database_tables(self, db_name: Optional[str] = None) -> List[str]: + def get_database_tables(self, db_name: Optional[str] = None, catalog_name: str = None) -> List[str]: """ Get a list of all tables in the database Args: db_name: Database name, uses current database if None + catalog_name: Catalog name for federation queries, uses instance catalog if None Returns: List of table names """ db_name = db_name or self.db_name + effective_catalog = catalog_name or self.catalog_name if not db_name: logger.warning("Database name not specified") return [] - cache_key = f"tables_{db_name}" + cache_key = f"tables_{effective_catalog or 'default'}_{db_name}" if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl: return self.metadata_cache[cache_key] @@ -236,14 +244,14 @@ class MetadataExtractor: AND TABLE_TYPE = 'BASE TABLE' """ - result = execute_query(query, db_name) - logger.info(f"{db_name}.information_schema.tables query result: {result}") + result = self._execute_query_with_catalog(query, db_name, effective_catalog) + logger.info(f"{effective_catalog or 'default'}.{db_name}.information_schema.tables query result: {result}") if not result: tables = [] else: tables = [table['TABLE_NAME'] for table in result] - logger.info(f"Table names retrieved from {db_name}.information_schema.tables: {tables}") + logger.info(f"Table names retrieved from {effective_catalog or 'default'}.{db_name}.information_schema.tables: {tables}") # Sort tables by hierarchy matching (if enabled) if self.enable_table_hierarchy and tables: @@ -385,23 +393,25 @@ class MetadataExtractor: return matches - def get_table_schema(self, table_name: str, db_name: Optional[str] = None) -> Dict[str, Any]: + def get_table_schema(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> Dict[str, Any]: """ Get the schema information for a table Args: table_name: Table name db_name: Database name, uses current database if None + catalog_name: Catalog name for federation queries, uses instance catalog if None Returns: Table schema information, including column names, types, nullability, defaults, comments, etc. """ db_name = db_name or self.db_name + effective_catalog = catalog_name or self.catalog_name if not db_name: logger.warning("Database name not specified") return {} - cache_key = f"schema_{db_name}_{table_name}" + cache_key = f"schema_{effective_catalog or 'default'}_{db_name}_{table_name}" if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl: return self.metadata_cache[cache_key] @@ -426,10 +436,10 @@ class MetadataExtractor: ORDINAL_POSITION """ - result = execute_query(query) + result = self._execute_query_with_catalog(query, db_name, effective_catalog) if not result: - logger.warning(f"Table {db_name}.{table_name} does not exist or has no columns") + logger.warning(f"Table {effective_catalog or 'default'}.{db_name}.{table_name} does not exist or has no columns") return {} # Create structured table schema information @@ -449,7 +459,7 @@ class MetadataExtractor: columns.append(column_info) # Get table comment - table_comment = self.get_table_comment(table_name, db_name) + table_comment = self.get_table_comment(table_name, db_name, effective_catalog) # Build complete structure schema = { @@ -488,23 +498,25 @@ class MetadataExtractor: logger.error(f"Error getting table schema: {str(e)}") return {} - def get_table_comment(self, table_name: str, db_name: Optional[str] = None) -> str: + def get_table_comment(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> str: """ Get the comment for a table Args: table_name: Table name db_name: Database name, uses current database if None + catalog_name: Catalog name for federation queries, uses instance catalog if None Returns: Table comment """ db_name = db_name or self.db_name + effective_catalog = catalog_name or self.catalog_name if not db_name: logger.warning("Database name not specified") return "" - cache_key = f"table_comment_{db_name}_{table_name}" + cache_key = f"table_comment_{effective_catalog or 'default'}_{db_name}_{table_name}" if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl: return self.metadata_cache[cache_key] @@ -520,7 +532,7 @@ class MetadataExtractor: AND TABLE_NAME = '{table_name}' """ - result = execute_query(query) + result = self._execute_query_with_catalog(query, db_name, effective_catalog) if not result or not result[0]: comment = "" @@ -536,23 +548,25 @@ class MetadataExtractor: logger.error(f"Error getting table comment: {str(e)}") return "" - def get_column_comments(self, table_name: str, db_name: Optional[str] = None) -> Dict[str, str]: + def get_column_comments(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> Dict[str, str]: """ Get comments for all columns in a table Args: table_name: Table name db_name: Database name, uses current database if None + catalog_name: Catalog name for federation queries, uses instance catalog if None Returns: Dictionary of column names and comments """ db_name = db_name or self.db_name + effective_catalog = catalog_name or self.catalog_name if not db_name: logger.warning("Database name not specified") return {} - cache_key = f"column_comments_{db_name}_{table_name}" + cache_key = f"column_comments_{effective_catalog or 'default'}_{db_name}_{table_name}" if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl: return self.metadata_cache[cache_key] @@ -571,7 +585,7 @@ class MetadataExtractor: ORDINAL_POSITION """ - result = execute_query(query) + result = self._execute_query_with_catalog(query, db_name, effective_catalog) comments = {} for col in result: @@ -589,28 +603,36 @@ class MetadataExtractor: logger.error(f"Error getting column comments: {str(e)}") return {} - def get_table_indexes(self, table_name: str, db_name: Optional[str] = None) -> List[Dict[str, Any]]: + def get_table_indexes(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> List[Dict[str, Any]]: """ Get the index information for a table Args: table_name: Table name db_name: Database name, uses the database specified during initialization if None + catalog_name: Catalog name for federation queries, uses instance catalog if None Returns: List[Dict[str, Any]]: List of index information """ db_name = db_name or self.db_name + effective_catalog = catalog_name or self.catalog_name if not db_name: logger.error("Database name not specified") return [] - cache_key = f"indexes_{db_name}_{table_name}" + cache_key = f"indexes_{effective_catalog or 'default'}_{db_name}_{table_name}" if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl: return self.metadata_cache[cache_key] try: - query = f"SHOW INDEX FROM `{db_name}`.`{table_name}`" + # Build query with catalog prefix if specified + if effective_catalog: + query = f"SHOW INDEX FROM `{effective_catalog}`.`{db_name}`.`{table_name}`" + logger.info(f"Using three-part naming for index query: {query}") + else: + query = f"SHOW INDEX FROM `{db_name}`.`{table_name}`" + df = execute_query_df(query) # Process results @@ -732,6 +754,52 @@ class MetadataExtractor: logger.error(f"Error getting audit logs: {str(e)}") return pd.DataFrame() + def get_catalog_list(self) -> List[Dict[str, Any]]: + """ + Get a list of all catalogs in Doris with detailed information + + Returns: + List[Dict[str, Any]]: List of catalog information including CatalogId, CatalogName, Type, IsCurrent, CreateTime, LastUpdateTime, Comment + """ + cache_key = "catalogs" + if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl: + return self.metadata_cache[cache_key] + + try: + # Use SHOW CATALOGS command to get catalog list + query = "SHOW CATALOGS" + result = execute_query(query) + + if not result: + catalogs = [] + else: + # Extract catalog information from the result + # SHOW CATALOGS returns: CatalogId, CatalogName, Type, IsCurrent, CreateTime, LastUpdateTime, Comment + catalogs = [] + for row in result: + if isinstance(row, dict): + catalog_info = { + "catalog_id": row.get("CatalogId", ""), + "catalog_name": row.get("CatalogName", ""), + "type": row.get("Type", ""), + "is_current": row.get("IsCurrent", ""), + "create_time": row.get("CreateTime", ""), + "last_update_time": row.get("LastUpdateTime", ""), + "comment": row.get("Comment", "") + } + catalogs.append(catalog_info) + + logger.info(f"Retrieved catalog list: {catalogs}") + + # Update cache + self.metadata_cache[cache_key] = catalogs + self.metadata_cache_time[cache_key] = datetime.now() + + return catalogs + except Exception as e: + logger.error(f"Error getting catalog list: {str(e)}") + return [] + def extract_sql_comments(self, sql: str) -> str: """ Extract comments from SQL @@ -1010,4 +1078,31 @@ class MetadataExtractor: return partition_info except Exception as e: logger.error(f"Error getting partition information for table {db_name}.{table_name}: {str(e)}") - return {} \ No newline at end of file + return {} + + def _execute_query_with_catalog(self, query: str, db_name: str = None, catalog_name: str = None): + """ + Execute query with catalog-aware metadata operations using three-part naming + + Args: + query: SQL query to execute + db_name: Database name to use + catalog_name: Catalog name for three-part naming + + Returns: + Query result + """ + try: + # If catalog_name is specified, modify the query to use three-part naming + # for information_schema queries + if catalog_name and 'information_schema' in query.lower(): + # Replace 'information_schema' with 'catalog_name.information_schema' + modified_query = query.replace('information_schema', f'{catalog_name}.information_schema') + logger.info(f"Modified query for catalog {catalog_name}: {modified_query}") + return execute_query(modified_query, db_name) + else: + # Execute the original query + return execute_query(query, db_name) + except Exception as e: + logger.error(f"Error executing query with catalog: {str(e)}") + raise \ No newline at end of file diff --git a/doris_mcp_server/utils/sql_executor_tools.py b/doris_mcp_server/utils/sql_executor_tools.py index 1b6346b..aded854 100644 --- a/doris_mcp_server/utils/sql_executor_tools.py +++ b/doris_mcp_server/utils/sql_executor_tools.py @@ -42,6 +42,7 @@ async def execute_sql_query(ctx) -> Dict[str, Any]: sql = params.get("sql") db_name = params.get("db_name", os.getenv("DB_DATABASE", "")) + catalog_name = params.get("catalog_name", None) # Add catalog parameter support max_rows = params.get("max_rows", 1000) # Maximum number of rows to return timeout = params.get("timeout", 30) # Timeout in seconds @@ -103,6 +104,9 @@ async def execute_sql_query(ctx) -> Dict[str, Any]: # Execute query try: + # For federation queries, SQL must use three-part naming: catalog_name.db_name.table_name + # This is enforced at the tool description level + result = execute_query(sql, db_name) # Calculate execution time @@ -264,8 +268,6 @@ async def _check_sql_security(sql: str) -> Dict[str, Any]: (r'\bexec\b', "EXECUTE stored procedure"), (r'\bxp_', "Extended stored procedure, potential security risk"), (r'\bshutdown\b', "SHUTDOWN database operation"), - (r'\bunion\s+all\s+select\b', "UNION statement, potential SQL injection"), - (r'\bunion\s+select\b', "UNION statement, potential SQL injection"), (r'\binto\s+outfile\b', "Write to file operation"), (r'\bload_file\b', "Load file operation") ] @@ -315,6 +317,7 @@ async def _check_sql_security(sql: str) -> Dict[str, Any]: "security_issues": security_issues } + def _serialize_row_data(row_data: Dict[str, Any]) -> Dict[str, Any]: """ Convert special types in row data (like date, time, Decimal) to JSON serializable format diff --git a/uv.lock b/uv.lock index d847d22..735ff53 100644 --- a/uv.lock +++ b/uv.lock @@ -123,6 +123,52 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277 }, ] +[[package]] +name = "doris-mcp" +version = "0.2.0" +source = { editable = "." } +dependencies = [ + { name = "fastapi" }, + { name = "mcp", extra = ["cli"] }, + { name = "numpy" }, + { name = "openai" }, + { name = "pandas" }, + { name = "pydantic" }, + { name = "pymysql" }, + { name = "python-dotenv" }, + { name = "requests" }, + { name = "scikit-learn" }, + { name = "simplejson" }, + { name = "uvicorn" }, +] + +[package.optional-dependencies] +dev = [ + { name = "black" }, + { name = "isort" }, + { name = "pytest" }, +] + +[package.metadata] +requires-dist = [ + { name = "black", marker = "extra == 'dev'", specifier = ">=23.0.0" }, + { name = "fastapi", specifier = ">=0.95.0" }, + { name = "isort", marker = "extra == 'dev'", specifier = ">=5.12.0" }, + { name = "mcp", extras = ["cli"], specifier = ">=1.0.0" }, + { name = "numpy", specifier = ">=1.20.0" }, + { name = "openai", specifier = ">=1.66.3" }, + { name = "pandas", specifier = ">=1.5.0" }, + { name = "pydantic", specifier = ">=1.10.0" }, + { name = "pymysql", specifier = ">=1.0.2" }, + { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0.0" }, + { name = "python-dotenv", specifier = ">=0.19.0" }, + { name = "requests", specifier = ">=2.28.0" }, + { name = "scikit-learn", specifier = ">=1.0.0" }, + { name = "simplejson", specifier = ">=3.17.0" }, + { name = "uvicorn", specifier = ">=0.21.0" }, +] +provides-extras = ["dev"] + [[package]] name = "fastapi" version = "0.115.12" @@ -291,52 +337,6 @@ cli = [ { name = "typer" }, ] -[[package]] -name = "mcp-doris" -version = "0.1.0" -source = { editable = "." } -dependencies = [ - { name = "fastapi" }, - { name = "mcp", extra = ["cli"] }, - { name = "numpy" }, - { name = "openai" }, - { name = "pandas" }, - { name = "pydantic" }, - { name = "pymysql" }, - { name = "python-dotenv" }, - { name = "requests" }, - { name = "scikit-learn" }, - { name = "simplejson" }, - { name = "uvicorn" }, -] - -[package.optional-dependencies] -dev = [ - { name = "black" }, - { name = "isort" }, - { name = "pytest" }, -] - -[package.metadata] -requires-dist = [ - { name = "black", marker = "extra == 'dev'", specifier = ">=23.0.0" }, - { name = "fastapi", specifier = ">=0.95.0" }, - { name = "isort", marker = "extra == 'dev'", specifier = ">=5.12.0" }, - { name = "mcp", extras = ["cli"], specifier = ">=1.0.0" }, - { name = "numpy", specifier = ">=1.20.0" }, - { name = "openai", specifier = ">=1.66.3" }, - { name = "pandas", specifier = ">=1.5.0" }, - { name = "pydantic", specifier = ">=1.10.0" }, - { name = "pymysql", specifier = ">=1.0.2" }, - { name = "pytest", marker = "extra == 'dev'", specifier = ">=7.0.0" }, - { name = "python-dotenv", specifier = ">=0.19.0" }, - { name = "requests", specifier = ">=2.28.0" }, - { name = "scikit-learn", specifier = ">=1.0.0" }, - { name = "simplejson", specifier = ">=3.17.0" }, - { name = "uvicorn", specifier = ">=0.21.0" }, -] -provides-extras = ["dev"] - [[package]] name = "mdurl" version = "0.1.2"