support Multi-Catalog

This commit is contained in:
FreeOnePlus
2025-06-06 14:35:53 +08:00
parent 466fcbdb45
commit 5e98e5ba41
8 changed files with 312 additions and 151 deletions

View File

@@ -31,15 +31,17 @@ from doris_mcp_server.utils.db import execute_query_df, execute_query
class MetadataExtractor:
"""Apache Doris Metadata Extractor"""
def __init__(self, db_name: str = None):
def __init__(self, db_name: str = None, catalog_name: str = None):
"""
Initialize the metadata extractor
Args:
db_name: Default database name, uses the currently connected database if not specified
catalog_name: Default catalog name for federation queries, uses the current catalog if not specified
"""
# Get configuration from environment variables
self.db_name = db_name or os.getenv("DB_DATABASE", "")
self.catalog_name = catalog_name # Store catalog name for federation support
self.metadata_db = METADATA_DB_NAME # Use constant
# Caching system
@@ -118,14 +120,18 @@ class MetadataExtractor:
default_patterns = ["^ads_.*$", "^dim_.*$", "^dws_.*$", "^dwd_.*$", "^ods_.*$", "^.*$"]
return default_patterns
def get_all_databases(self) -> List[str]:
def get_all_databases(self, catalog_name: str = None) -> List[str]:
"""
Get a list of all databases
Args:
catalog_name: Catalog name for federation queries, uses instance catalog if None
Returns:
List of database names
"""
cache_key = "databases"
effective_catalog = catalog_name or self.catalog_name
cache_key = f"databases_{effective_catalog or 'default'}"
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
return self.metadata_cache[cache_key]
@@ -142,13 +148,13 @@ class MetadataExtractor:
SCHEMA_NAME
"""
result = execute_query(query)
result = self._execute_query_with_catalog(query, self.db_name, effective_catalog)
if not result:
databases = []
else:
databases = [db["SCHEMA_NAME"] for db in result]
logger.info(f"Retrieved database list: {databases}")
logger.info(f"Retrieved database list from catalog {effective_catalog or 'default'}: {databases}")
# Update cache
self.metadata_cache[cache_key] = databases
@@ -205,22 +211,24 @@ class MetadataExtractor:
logger.warning(f"Current database {self.db_name} is in the excluded list, metadata retrieval might not work properly")
return [self.db_name] if self.db_name else []
def get_database_tables(self, db_name: Optional[str] = None) -> List[str]:
def get_database_tables(self, db_name: Optional[str] = None, catalog_name: str = None) -> List[str]:
"""
Get a list of all tables in the database
Args:
db_name: Database name, uses current database if None
catalog_name: Catalog name for federation queries, uses instance catalog if None
Returns:
List of table names
"""
db_name = db_name or self.db_name
effective_catalog = catalog_name or self.catalog_name
if not db_name:
logger.warning("Database name not specified")
return []
cache_key = f"tables_{db_name}"
cache_key = f"tables_{effective_catalog or 'default'}_{db_name}"
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
return self.metadata_cache[cache_key]
@@ -236,14 +244,14 @@ class MetadataExtractor:
AND TABLE_TYPE = 'BASE TABLE'
"""
result = execute_query(query, db_name)
logger.info(f"{db_name}.information_schema.tables query result: {result}")
result = self._execute_query_with_catalog(query, db_name, effective_catalog)
logger.info(f"{effective_catalog or 'default'}.{db_name}.information_schema.tables query result: {result}")
if not result:
tables = []
else:
tables = [table['TABLE_NAME'] for table in result]
logger.info(f"Table names retrieved from {db_name}.information_schema.tables: {tables}")
logger.info(f"Table names retrieved from {effective_catalog or 'default'}.{db_name}.information_schema.tables: {tables}")
# Sort tables by hierarchy matching (if enabled)
if self.enable_table_hierarchy and tables:
@@ -385,23 +393,25 @@ class MetadataExtractor:
return matches
def get_table_schema(self, table_name: str, db_name: Optional[str] = None) -> Dict[str, Any]:
def get_table_schema(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> Dict[str, Any]:
"""
Get the schema information for a table
Args:
table_name: Table name
db_name: Database name, uses current database if None
catalog_name: Catalog name for federation queries, uses instance catalog if None
Returns:
Table schema information, including column names, types, nullability, defaults, comments, etc.
"""
db_name = db_name or self.db_name
effective_catalog = catalog_name or self.catalog_name
if not db_name:
logger.warning("Database name not specified")
return {}
cache_key = f"schema_{db_name}_{table_name}"
cache_key = f"schema_{effective_catalog or 'default'}_{db_name}_{table_name}"
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
return self.metadata_cache[cache_key]
@@ -426,10 +436,10 @@ class MetadataExtractor:
ORDINAL_POSITION
"""
result = execute_query(query)
result = self._execute_query_with_catalog(query, db_name, effective_catalog)
if not result:
logger.warning(f"Table {db_name}.{table_name} does not exist or has no columns")
logger.warning(f"Table {effective_catalog or 'default'}.{db_name}.{table_name} does not exist or has no columns")
return {}
# Create structured table schema information
@@ -449,7 +459,7 @@ class MetadataExtractor:
columns.append(column_info)
# Get table comment
table_comment = self.get_table_comment(table_name, db_name)
table_comment = self.get_table_comment(table_name, db_name, effective_catalog)
# Build complete structure
schema = {
@@ -488,23 +498,25 @@ class MetadataExtractor:
logger.error(f"Error getting table schema: {str(e)}")
return {}
def get_table_comment(self, table_name: str, db_name: Optional[str] = None) -> str:
def get_table_comment(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> str:
"""
Get the comment for a table
Args:
table_name: Table name
db_name: Database name, uses current database if None
catalog_name: Catalog name for federation queries, uses instance catalog if None
Returns:
Table comment
"""
db_name = db_name or self.db_name
effective_catalog = catalog_name or self.catalog_name
if not db_name:
logger.warning("Database name not specified")
return ""
cache_key = f"table_comment_{db_name}_{table_name}"
cache_key = f"table_comment_{effective_catalog or 'default'}_{db_name}_{table_name}"
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
return self.metadata_cache[cache_key]
@@ -520,7 +532,7 @@ class MetadataExtractor:
AND TABLE_NAME = '{table_name}'
"""
result = execute_query(query)
result = self._execute_query_with_catalog(query, db_name, effective_catalog)
if not result or not result[0]:
comment = ""
@@ -536,23 +548,25 @@ class MetadataExtractor:
logger.error(f"Error getting table comment: {str(e)}")
return ""
def get_column_comments(self, table_name: str, db_name: Optional[str] = None) -> Dict[str, str]:
def get_column_comments(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> Dict[str, str]:
"""
Get comments for all columns in a table
Args:
table_name: Table name
db_name: Database name, uses current database if None
catalog_name: Catalog name for federation queries, uses instance catalog if None
Returns:
Dictionary of column names and comments
"""
db_name = db_name or self.db_name
effective_catalog = catalog_name or self.catalog_name
if not db_name:
logger.warning("Database name not specified")
return {}
cache_key = f"column_comments_{db_name}_{table_name}"
cache_key = f"column_comments_{effective_catalog or 'default'}_{db_name}_{table_name}"
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
return self.metadata_cache[cache_key]
@@ -571,7 +585,7 @@ class MetadataExtractor:
ORDINAL_POSITION
"""
result = execute_query(query)
result = self._execute_query_with_catalog(query, db_name, effective_catalog)
comments = {}
for col in result:
@@ -589,28 +603,36 @@ class MetadataExtractor:
logger.error(f"Error getting column comments: {str(e)}")
return {}
def get_table_indexes(self, table_name: str, db_name: Optional[str] = None) -> List[Dict[str, Any]]:
def get_table_indexes(self, table_name: str, db_name: Optional[str] = None, catalog_name: str = None) -> List[Dict[str, Any]]:
"""
Get the index information for a table
Args:
table_name: Table name
db_name: Database name, uses the database specified during initialization if None
catalog_name: Catalog name for federation queries, uses instance catalog if None
Returns:
List[Dict[str, Any]]: List of index information
"""
db_name = db_name or self.db_name
effective_catalog = catalog_name or self.catalog_name
if not db_name:
logger.error("Database name not specified")
return []
cache_key = f"indexes_{db_name}_{table_name}"
cache_key = f"indexes_{effective_catalog or 'default'}_{db_name}_{table_name}"
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
return self.metadata_cache[cache_key]
try:
query = f"SHOW INDEX FROM `{db_name}`.`{table_name}`"
# Build query with catalog prefix if specified
if effective_catalog:
query = f"SHOW INDEX FROM `{effective_catalog}`.`{db_name}`.`{table_name}`"
logger.info(f"Using three-part naming for index query: {query}")
else:
query = f"SHOW INDEX FROM `{db_name}`.`{table_name}`"
df = execute_query_df(query)
# Process results
@@ -732,6 +754,52 @@ class MetadataExtractor:
logger.error(f"Error getting audit logs: {str(e)}")
return pd.DataFrame()
def get_catalog_list(self) -> List[Dict[str, Any]]:
"""
Get a list of all catalogs in Doris with detailed information
Returns:
List[Dict[str, Any]]: List of catalog information including CatalogId, CatalogName, Type, IsCurrent, CreateTime, LastUpdateTime, Comment
"""
cache_key = "catalogs"
if cache_key in self.metadata_cache and (datetime.now() - self.metadata_cache_time.get(cache_key, datetime.min)).total_seconds() < self.cache_ttl:
return self.metadata_cache[cache_key]
try:
# Use SHOW CATALOGS command to get catalog list
query = "SHOW CATALOGS"
result = execute_query(query)
if not result:
catalogs = []
else:
# Extract catalog information from the result
# SHOW CATALOGS returns: CatalogId, CatalogName, Type, IsCurrent, CreateTime, LastUpdateTime, Comment
catalogs = []
for row in result:
if isinstance(row, dict):
catalog_info = {
"catalog_id": row.get("CatalogId", ""),
"catalog_name": row.get("CatalogName", ""),
"type": row.get("Type", ""),
"is_current": row.get("IsCurrent", ""),
"create_time": row.get("CreateTime", ""),
"last_update_time": row.get("LastUpdateTime", ""),
"comment": row.get("Comment", "")
}
catalogs.append(catalog_info)
logger.info(f"Retrieved catalog list: {catalogs}")
# Update cache
self.metadata_cache[cache_key] = catalogs
self.metadata_cache_time[cache_key] = datetime.now()
return catalogs
except Exception as e:
logger.error(f"Error getting catalog list: {str(e)}")
return []
def extract_sql_comments(self, sql: str) -> str:
"""
Extract comments from SQL
@@ -1010,4 +1078,31 @@ class MetadataExtractor:
return partition_info
except Exception as e:
logger.error(f"Error getting partition information for table {db_name}.{table_name}: {str(e)}")
return {}
return {}
def _execute_query_with_catalog(self, query: str, db_name: str = None, catalog_name: str = None):
"""
Execute query with catalog-aware metadata operations using three-part naming
Args:
query: SQL query to execute
db_name: Database name to use
catalog_name: Catalog name for three-part naming
Returns:
Query result
"""
try:
# If catalog_name is specified, modify the query to use three-part naming
# for information_schema queries
if catalog_name and 'information_schema' in query.lower():
# Replace 'information_schema' with 'catalog_name.information_schema'
modified_query = query.replace('information_schema', f'{catalog_name}.information_schema')
logger.info(f"Modified query for catalog {catalog_name}: {modified_query}")
return execute_query(modified_query, db_name)
else:
# Execute the original query
return execute_query(query, db_name)
except Exception as e:
logger.error(f"Error executing query with catalog: {str(e)}")
raise

View File

@@ -42,6 +42,7 @@ async def execute_sql_query(ctx) -> Dict[str, Any]:
sql = params.get("sql")
db_name = params.get("db_name", os.getenv("DB_DATABASE", ""))
catalog_name = params.get("catalog_name", None) # Add catalog parameter support
max_rows = params.get("max_rows", 1000) # Maximum number of rows to return
timeout = params.get("timeout", 30) # Timeout in seconds
@@ -103,6 +104,9 @@ async def execute_sql_query(ctx) -> Dict[str, Any]:
# Execute query
try:
# For federation queries, SQL must use three-part naming: catalog_name.db_name.table_name
# This is enforced at the tool description level
result = execute_query(sql, db_name)
# Calculate execution time
@@ -264,8 +268,6 @@ async def _check_sql_security(sql: str) -> Dict[str, Any]:
(r'\bexec\b', "EXECUTE stored procedure"),
(r'\bxp_', "Extended stored procedure, potential security risk"),
(r'\bshutdown\b', "SHUTDOWN database operation"),
(r'\bunion\s+all\s+select\b', "UNION statement, potential SQL injection"),
(r'\bunion\s+select\b', "UNION statement, potential SQL injection"),
(r'\binto\s+outfile\b', "Write to file operation"),
(r'\bload_file\b', "Load file operation")
]
@@ -315,6 +317,7 @@ async def _check_sql_security(sql: str) -> Dict[str, Any]:
"security_issues": security_issues
}
def _serialize_row_data(row_data: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert special types in row data (like date, time, Decimal) to JSON serializable format