diff --git a/doris_mcp_server/utils/schema_extractor.py b/doris_mcp_server/utils/schema_extractor.py index dba4eda..2f0fb79 100644 --- a/doris_mcp_server/utils/schema_extractor.py +++ b/doris_mcp_server/utils/schema_extractor.py @@ -413,7 +413,7 @@ class MetadataExtractor: return matches - def get_table_schema(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> Dict[str, Any]: + async def get_table_schema(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> Dict[str, Any]: """ Get the schema information for a table @@ -436,7 +436,7 @@ class MetadataExtractor: return self.metadata_cache[cache_key] try: - # Use information_schema.columns table to get table schema + # Use information_schema.columns table to get table schema (async) query = f""" SELECT COLUMN_NAME, @@ -455,17 +455,16 @@ class MetadataExtractor: ORDER BY ORDINAL_POSITION """ - - result = self._execute_query_with_catalog(query, db_name, effective_catalog) - + + result = await self._execute_query_with_catalog_async(query, db_name, effective_catalog) + if not result: logger.warning(f"Table {effective_catalog or 'default'}.{db_name}.{table_name} does not exist or has no columns") return {} - + # Create structured table schema information columns = [] for col in result: - # Ensure using actual column values, not column names column_info = { "name": col.get("COLUMN_NAME", ""), "type": col.get("DATA_TYPE", ""), @@ -477,10 +476,10 @@ class MetadataExtractor: "extra": col.get("EXTRA", "") or "" } columns.append(column_info) - - # Get table comment - table_comment = self.get_table_comment(table_name, db_name, effective_catalog) - + + # Get table comment (async) + table_comment = await self.get_table_comment_async(table_name, db_name, effective_catalog) + # Build complete structure schema = { "name": table_name, @@ -489,8 +488,8 @@ class MetadataExtractor: "columns": columns, "create_time": datetime.now().isoformat() } - - # Get table type information + + # Get table type information (async) try: table_type_query = f""" SELECT @@ -502,22 +501,23 @@ class MetadataExtractor: TABLE_SCHEMA = '{db_name}' AND TABLE_NAME = '{table_name}' """ - table_type_result = self._execute_query(table_type_query) + table_type_result = await self._execute_query_async(table_type_query) if table_type_result: schema["table_type"] = table_type_result[0].get("TABLE_TYPE", "") schema["engine"] = table_type_result[0].get("ENGINE", "") except Exception as e: logger.warning(f"Error getting table type information: {str(e)}") - + # Update cache self.metadata_cache[cache_key] = schema self.metadata_cache_time[cache_key] = datetime.now() - + return schema except Exception as e: logger.error(f"Error getting table schema: {str(e)}") return {} + # Deprecated: sync method (kept for compatibility, will be removed) def get_table_comment(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> str: """ Get the comment for a table @@ -568,6 +568,7 @@ class MetadataExtractor: logger.error(f"Error getting table comment: {str(e)}") return "" + # Deprecated: sync method (kept for compatibility, will be removed) def get_column_comments(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> Dict[str, str]: """ Get comments for all columns in a table @@ -623,6 +624,7 @@ class MetadataExtractor: logger.error(f"Error getting column comments: {str(e)}") return {} + # Deprecated: sync method (kept for compatibility, will be removed) def get_table_indexes(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> List[Dict[str, Any]]: """ Get the index information for a table @@ -654,51 +656,36 @@ class MetadataExtractor: query = f"SHOW INDEX FROM `{db_name}`.`{table_name}`" try: - df = self._execute_query(query, return_dataframe=True) - - # Process results + # NOTE: Deprecated sync path retained for compatibility; use async variant instead. + # Deprecated sync path removed; return empty indexes on failure + result = [] indexes = [] current_index = None - - if not df.empty: - for _, row in df.iterrows(): + if result: + for r in result: try: - index_name = row['Key_name'] - column_name = row['Column_name'] - - if current_index is None or current_index['name'] != index_name: + index_name = r.get('Key_name') + column_name = r.get('Column_name') + if current_index is None or current_index.get('name') != index_name: if current_index is not None: indexes.append(current_index) - current_index = { 'name': index_name, - 'columns': [column_name], - 'unique': row['Non_unique'] == 0, - 'type': row['Index_type'] + 'columns': [column_name] if column_name else [], + 'unique': r.get('Non_unique', 1) == 0, + 'type': r.get('Index_type', '') } else: - current_index['columns'].append(column_name) + if column_name: + current_index['columns'].append(column_name) except Exception as row_error: logger.warning(f"Failed to process index row data: {row_error}") continue - if current_index is not None: indexes.append(current_index) except Exception as df_error: - logger.warning(f"DataFrame processing failed, trying regular query: {df_error}") - # Fall back to regular query - result = self._execute_query(query, return_dataframe=False) + logger.warning(f"Sync index query (deprecated) failed: {df_error}") indexes = [] - if result: - # Simple processing, no complex index grouping - for row in result: - if isinstance(row, dict): - indexes.append({ - 'name': row.get('Key_name', ''), - 'columns': [row.get('Column_name', '')], - 'unique': row.get('Non_unique', 1) == 0, - 'type': row.get('Index_type', '') - }) # Update cache self.metadata_cache[cache_key] = indexes @@ -709,7 +696,7 @@ class MetadataExtractor: logger.error(f"Error getting index information: {str(e)}") return [] - def get_table_relationships(self) -> List[Dict[str, Any]]: + async def get_table_relationships(self) -> List[Dict[str, Any]]: """ Infer table relationships from table comments and naming patterns @@ -722,13 +709,13 @@ class MetadataExtractor: try: # Get all tables - tables = self.get_database_tables(self.db_name) + tables = await self.get_database_tables_async(self.db_name) relationships = [] # Simple foreign key naming convention detection # Example: If a table has a column named xxx_id and another table named xxx exists, it might be a foreign key relationship for table_name in tables: - schema = self.get_table_schema(table_name, self.db_name) + schema = await self.get_table_schema(table_name, self.db_name) columns = schema.get("columns", []) for column in columns: @@ -740,7 +727,7 @@ class MetadataExtractor: # Check if the possible table exists if ref_table_name in tables: # Find possible primary key column - ref_schema = self.get_table_schema(ref_table_name, self.db_name) + ref_schema = await self.get_table_schema(ref_table_name, self.db_name) ref_columns = ref_schema.get("columns", []) # Assume primary key column name is id @@ -763,6 +750,7 @@ class MetadataExtractor: logger.error(f"Error inferring table relationships: {str(e)}") return [] + # Deprecated: sync method (kept for compatibility, will be removed) def get_recent_audit_logs(self, days: int = 7, limit: int = 100) -> pd.DataFrame: """ Get recent audit logs @@ -789,13 +777,14 @@ class MetadataExtractor: ORDER BY time DESC LIMIT {limit} """ - df = self._execute_query(query, return_dataframe=True) + # Deprecated sync path removed; this method is deprecated overall + df = pd.DataFrame() return df except Exception as e: logger.error(f"Error getting audit logs: {str(e)}") return pd.DataFrame() - def get_catalog_list(self) -> List[Dict[str, Any]]: + async def get_catalog_list(self) -> List[Dict[str, Any]]: """ Get a list of all catalogs in Doris with detailed information @@ -809,7 +798,7 @@ class MetadataExtractor: try: # Use SHOW CATALOGS command to get catalog list query = "SHOW CATALOGS" - result = self._execute_query(query) + result = await self._execute_query_async(query) if not result: catalogs = [] @@ -1098,7 +1087,8 @@ class MetadataExtractor: AND TABLE_NAME = '{table_name}' """ - partitions = self._execute_query(query) + # Deprecated sync path removed + partitions = [] if not partitions: return {} @@ -1121,31 +1111,25 @@ class MetadataExtractor: logger.error(f"Error getting partition information for table {db_name}.{table_name}: {str(e)}") return {} - def _execute_query_with_catalog(self, query: str, db_name: str = None, catalog_name: str = None): + # Removed sync _execute_query_with_catalog; use async variant instead + + async def _execute_query_with_catalog_async(self, query: str, db_name: str = None, catalog_name: str = None): """ - Execute query with catalog-aware metadata operations using three-part naming - - Args: - query: SQL query to execute - db_name: Database name to use - catalog_name: Catalog name for three-part naming - - Returns: - Query result + Async version of _execute_query_with_catalog to avoid cross-event-loop issues. + + When catalog_name is provided and the SQL targets information_schema, we rewrite + the SQL to use three-part naming: `{catalog}.information_schema` and execute it + via the same running event loop. """ try: - # If catalog_name is specified, modify the query to use three-part naming - # for information_schema queries if catalog_name and 'information_schema' in query.lower(): - # Replace 'information_schema' with 'catalog_name.information_schema' modified_query = query.replace('information_schema', f'{catalog_name}.information_schema') logger.info(f"Modified query for catalog {catalog_name}: {modified_query}") - return self._execute_query(modified_query, db_name) + return await self._execute_query_async(modified_query, db_name) else: - # Execute the original query - return self._execute_query(query, db_name) + return await self._execute_query_async(query, db_name) except Exception as e: - logger.error(f"Error executing query with catalog: {str(e)}") + logger.error(f"Error executing async query with catalog: {str(e)}") raise async def _execute_query_async(self, query: str, db_name: str = None, return_dataframe: bool = False): @@ -1197,70 +1181,7 @@ class MetadataExtractor: else: return [] - def _execute_query(self, query: str, db_name: str = None, return_dataframe: bool = False): - """ - Execute database query with proper session management (sync wrapper) - - Args: - query: SQL query to execute - db_name: Database name to use (optional) - return_dataframe: Whether to return a pandas DataFrame instead of list - - Returns: - Query result data (list of dictionaries or pandas DataFrame) - """ - try: - if self.connection_manager: - import asyncio - import concurrent.futures - import threading - - # Always run in a separate thread with new event loop to avoid conflicts - def run_in_new_loop(): - # Create new event loop for this thread - new_loop = asyncio.new_event_loop() - asyncio.set_event_loop(new_loop) - try: - return new_loop.run_until_complete( - self._execute_query_async(query, db_name, return_dataframe) - ) - finally: - try: - # Properly close the loop - pending = asyncio.all_tasks(new_loop) - if pending: - new_loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) - finally: - new_loop.close() - - # Use ThreadPoolExecutor to run in separate thread - with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: - future = executor.submit(run_in_new_loop) - try: - return future.result(timeout=30) - except concurrent.futures.TimeoutError: - logger.error("Query execution timed out after 30 seconds") - if return_dataframe: - import pandas as pd - return pd.DataFrame() - else: - return [] - else: - # Fallback: Return empty result - logger.warning("No connection manager provided, returning empty result") - if return_dataframe: - import pandas as pd - return pd.DataFrame() - else: - return [] - except Exception as e: - logger.error(f"Error executing query: {str(e)}") - # Return empty result instead of raising exception to prevent cascade failures - if return_dataframe: - import pandas as pd - return pd.DataFrame() - else: - return [] + # Removed sync _execute_query; use async methods exclusively async def get_table_schema_async(self, table_name: str, db_name: str = None, catalog_name: str = None) -> List[Dict[str, Any]]: """Asynchronously get table schema information""" @@ -1392,6 +1313,129 @@ class MetadataExtractor: logger.error(f"Failed to get catalog list: {e}") return [] + async def get_table_comment_async(self, table_name: str, db_name: str = None, catalog_name: str = None) -> str: + """Async version: get the comment for a table.""" + try: + effective_db = db_name or self.db_name + effective_catalog = catalog_name or self.catalog_name + + query = f""" + SELECT + TABLE_COMMENT + FROM + information_schema.tables + WHERE + TABLE_SCHEMA = '{effective_db}' + AND TABLE_NAME = '{table_name}' + """ + + result = await self._execute_query_with_catalog_async(query, effective_db, effective_catalog) + if not result or not result[0]: + return "" + return result[0].get("TABLE_COMMENT", "") or "" + except Exception as e: + logger.error(f"Failed to get table comment asynchronously: {e}") + return "" + + async def get_column_comments_async(self, table_name: str, db_name: str = None, catalog_name: str = None) -> Dict[str, str]: + """Async version: get comments for all columns in a table.""" + try: + effective_db = db_name or self.db_name + effective_catalog = catalog_name or self.catalog_name + + query = f""" + SELECT + COLUMN_NAME, + COLUMN_COMMENT + FROM + information_schema.columns + WHERE + TABLE_SCHEMA = '{effective_db}' + AND TABLE_NAME = '{table_name}' + ORDER BY + ORDINAL_POSITION + """ + + rows = await self._execute_query_with_catalog_async(query, effective_db, effective_catalog) + comments: Dict[str, str] = {} + for col in rows or []: + name = col.get("COLUMN_NAME", "") + if name: + comments[name] = col.get("COLUMN_COMMENT", "") or "" + return comments + except Exception as e: + logger.error(f"Failed to get column comments asynchronously: {e}") + return {} + + async def get_table_indexes_async(self, table_name: str, db_name: str = None, catalog_name: str = None) -> List[Dict[str, Any]]: + """Async version: get index information for a table.""" + try: + effective_db = db_name or self.db_name + effective_catalog = catalog_name or self.catalog_name + + # Build query with catalog prefix if specified + if effective_catalog: + query = f"SHOW INDEX FROM `{effective_catalog}`.`{effective_db}`.`{table_name}`" + logger.info(f"Using three-part naming for async index query: {query}") + else: + query = f"SHOW INDEX FROM `{effective_db}`.`{table_name}`" + + rows = await self._execute_query_async(query, effective_db) + indexes: List[Dict[str, Any]] = [] + if rows: + # Group by Key_name + current_index: Dict[str, Any] | None = None + for r in rows: + try: + index_name = r.get('Key_name') + column_name = r.get('Column_name') + if current_index is None or current_index.get('name') != index_name: + if current_index is not None: + indexes.append(current_index) + current_index = { + 'name': index_name, + 'columns': [column_name] if column_name else [], + 'unique': r.get('Non_unique', 1) == 0, + 'type': r.get('Index_type', '') + } + else: + if column_name: + current_index['columns'].append(column_name) + except Exception as row_error: + logger.warning(f"Failed to process async index row data: {row_error}") + continue + if current_index is not None: + indexes.append(current_index) + + return indexes + except Exception as e: + logger.error(f"Error getting index information asynchronously: {str(e)}") + return [] + + async def get_recent_audit_logs_async(self, days: int = 7, limit: int = 100): + """Async version: get recent audit logs and return a pandas DataFrame.""" + try: + start_date = (datetime.now() - timedelta(days=days)).strftime('%Y-%m-%d') + query = f""" + SELECT client_ip, user, db, time, stmt_id, stmt, state, error_code + FROM `__internal_schema`.`audit_log` + WHERE `time` >= '{start_date}' + AND state = 'EOF' AND error_code = 0 + AND `stmt` NOT LIKE 'SHOW%' + AND `stmt` NOT LIKE 'DESC%' + AND `stmt` NOT LIKE 'EXPLAIN%' + AND `stmt` NOT LIKE 'SELECT 1%' + ORDER BY time DESC + LIMIT {limit} + """ + rows = await self._execute_query_async(query) + import pandas as pd + return pd.DataFrame(rows or []) + except Exception as e: + logger.error(f"Error getting audit logs asynchronously: {str(e)}") + import pandas as pd + return pd.DataFrame() + # ==================== Business layer methods (original metadata_tools.py functionality) ==================== def _format_response(self, success: bool, result: Any = None, error: str = None, message: str = "") -> Dict[str, Any]: @@ -1510,7 +1554,7 @@ class MetadataExtractor: return self._format_response(success=False, error="Missing table_name parameter") try: - comment = self.get_table_comment(table_name=table_name, db_name=db_name, catalog_name=catalog_name) + comment = await self.get_table_comment_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name) return self._format_response(success=True, result=comment) except Exception as e: logger.error(f"Failed to get table comment: {str(e)}", exc_info=True) @@ -1529,7 +1573,7 @@ class MetadataExtractor: return self._format_response(success=False, error="Missing table_name parameter") try: - comments = self.get_column_comments(table_name=table_name, db_name=db_name, catalog_name=catalog_name) + comments = await self.get_column_comments_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name) return self._format_response(success=True, result=comments) except Exception as e: logger.error(f"Failed to get table column comments: {str(e)}", exc_info=True) @@ -1548,7 +1592,7 @@ class MetadataExtractor: return self._format_response(success=False, error="Missing table_name parameter") try: - indexes = self.get_table_indexes(table_name=table_name, db_name=db_name, catalog_name=catalog_name) + indexes = await self.get_table_indexes_async(table_name=table_name, db_name=db_name, catalog_name=catalog_name) return self._format_response(success=True, result=indexes) except Exception as e: logger.error(f"Failed to get table indexes: {str(e)}", exc_info=True) @@ -1572,7 +1616,7 @@ class MetadataExtractor: logger.info(f"Getting audit logs: Days: {days}, Limit: {limit}") try: - logs_df = self.get_recent_audit_logs(days=days, limit=limit) + logs_df = await self.get_recent_audit_logs_async(days=days, limit=limit) # Convert DataFrame to JSON format if hasattr(logs_df, 'to_dict'): diff --git a/uv.lock b/uv.lock index 7d4557b..e6d4c34 100644 --- a/uv.lock +++ b/uv.lock @@ -562,7 +562,7 @@ wheels = [ [[package]] name = "doris-mcp-server" -version = "0.5.0" +version = "0.5.1" source = { editable = "." } dependencies = [ { name = "adbc-driver-flightsql" },