0.3.0 Release Version
This commit is contained in:
322
doris_mcp_client/README.md
Normal file
322
doris_mcp_client/README.md
Normal file
@@ -0,0 +1,322 @@
|
||||
# Doris Unified MCP Client
|
||||
|
||||
This is a unified Doris MCP client that supports both **stdio** and **Streamable HTTP** transport modes, providing complete MCP protocol support.
|
||||
|
||||
## 🚀 Features
|
||||
|
||||
- ✅ **Dual Mode Support**: Both stdio and HTTP transport methods
|
||||
- ✅ **Complete MCP Support**: Resources, Tools, and Prompts primitives
|
||||
- ✅ **Unified API**: Same interface for different transport modes
|
||||
- ✅ **Asynchronous Design**: High-performance async client based on asyncio
|
||||
- ✅ **Enterprise Features**: Connection pooling, error handling, logging
|
||||
- ✅ **Convenience Methods**: High-level wrappers for common database operations
|
||||
|
||||
## 📦 Install Dependencies
|
||||
|
||||
```bash
|
||||
pip install mcp
|
||||
```
|
||||
|
||||
## 🎯 Quick Start
|
||||
|
||||
### 1. stdio Mode
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from client import create_stdio_client
|
||||
|
||||
async def main():
|
||||
# Create stdio client
|
||||
client = await create_stdio_client(
|
||||
"python",
|
||||
["-m", "doris_mcp_server.main", "--transport", "stdio"]
|
||||
)
|
||||
|
||||
async def test_client(client):
|
||||
# Get database list
|
||||
db_result = await client.get_database_list()
|
||||
print(f"Databases: {db_result}")
|
||||
|
||||
# Execute SQL query
|
||||
query_result = await client.execute_sql("SELECT 1 as test")
|
||||
print(f"Query result: {query_result}")
|
||||
|
||||
await client.connect_and_run(test_client)
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
### 2. HTTP Mode
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from unified_client import create_http_client
|
||||
|
||||
async def main():
|
||||
# Create HTTP client
|
||||
client = await create_http_client("http://localhost:3000/mcp")
|
||||
|
||||
async def test_client(client):
|
||||
# Get all tools
|
||||
tools = await client.list_all_tools()
|
||||
print(f"Available tools: {len(tools)}")
|
||||
|
||||
# Execute query
|
||||
result = await client.execute_sql(
|
||||
"SELECT COUNT(*) FROM internal.ssb.lineorder LIMIT 1"
|
||||
)
|
||||
print(f"Query result: {result}")
|
||||
|
||||
await client.connect_and_run(test_client)
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
## 🔧 API Reference
|
||||
|
||||
### Client Creation
|
||||
|
||||
```python
|
||||
# stdio mode
|
||||
client = await create_stdio_client(command, args)
|
||||
|
||||
# HTTP mode
|
||||
client = await create_http_client(server_url, timeout=60)
|
||||
```
|
||||
|
||||
### Basic Operations
|
||||
|
||||
```python
|
||||
async def test_client(client):
|
||||
# Get server capabilities
|
||||
tools = await client.list_all_tools()
|
||||
resources = await client.list_all_resources()
|
||||
prompts = await client.list_all_prompts()
|
||||
|
||||
# Call tool
|
||||
result = await client.call_tool("tool_name", {"param": "value"})
|
||||
|
||||
# Read resource
|
||||
content = await client.read_resource("resource://uri")
|
||||
|
||||
# Get prompt
|
||||
prompt = await client.get_prompt("prompt_name", {"param": "value"})
|
||||
```
|
||||
|
||||
### Advanced Database Operations
|
||||
|
||||
```python
|
||||
async def database_operations(client):
|
||||
# Execute SQL query
|
||||
result = await client.execute_sql("SELECT * FROM table LIMIT 10")
|
||||
|
||||
# Get database list
|
||||
databases = await client.get_database_list()
|
||||
|
||||
# Get table schema
|
||||
schema = await client.get_table_schema("table_name", "db_name")
|
||||
|
||||
# Column data analysis
|
||||
analysis = await client.analyze_column("table", "column", "basic")
|
||||
```
|
||||
|
||||
## 🧪 Testing
|
||||
|
||||
### Run Test Suite
|
||||
|
||||
```bash
|
||||
# Interactive testing
|
||||
python test_unified_client.py
|
||||
|
||||
# Test stdio mode
|
||||
python test_unified_client.py stdio
|
||||
|
||||
# Test HTTP mode
|
||||
python test_unified_client.py http
|
||||
|
||||
# Test both modes
|
||||
python test_unified_client.py both
|
||||
|
||||
# Performance benchmark
|
||||
python test_unified_client.py benchmark
|
||||
```
|
||||
|
||||
### Test Output Example
|
||||
|
||||
```
|
||||
🎯 Doris Unified Client Test Suite
|
||||
============================================================
|
||||
|
||||
🚀 Testing HTTP Mode
|
||||
==================================================
|
||||
📋 Getting server capabilities...
|
||||
✅ Found 11 tools
|
||||
✅ Found 0 resources
|
||||
✅ Found 0 prompts
|
||||
|
||||
🔧 Available tools:
|
||||
1. get_db_list: Get database list
|
||||
2. get_table_list: Get table list for specified database
|
||||
3. get_table_schema: Get table structure information
|
||||
4. exec_query: Execute SQL query
|
||||
5. column_analysis: Analyze column data distribution and statistics
|
||||
...
|
||||
|
||||
🧪 Testing basic functionality...
|
||||
1️⃣ Getting database list...
|
||||
✅ Success: 3 databases
|
||||
2️⃣ Executing simple query...
|
||||
✅ Query successful
|
||||
3️⃣ Executing SSB data query...
|
||||
✅ SSB query successful
|
||||
4️⃣ Getting table structure...
|
||||
✅ Table structure retrieved successfully
|
||||
5️⃣ Column data analysis...
|
||||
✅ Column analysis successful
|
||||
|
||||
✅ HTTP mode testing completed!
|
||||
```
|
||||
|
||||
## 🏗️ Architecture Design
|
||||
|
||||
### Unified Client Architecture
|
||||
|
||||
```
|
||||
DorisUnifiedClient
|
||||
├── DorisResourceClient # Resource management
|
||||
├── DorisToolsClient # Tool invocation
|
||||
├── DorisPromptClient # Prompt management
|
||||
└── Transport Layer
|
||||
├── stdio mode # Standard input/output
|
||||
└── HTTP mode # Streamable HTTP
|
||||
```
|
||||
|
||||
### Key Features
|
||||
|
||||
1. **Unified Interface**: Same API for different transport modes
|
||||
2. **Async Context**: Proper resource management and connection cleanup
|
||||
3. **Error Handling**: Comprehensive exception handling and error recovery
|
||||
4. **Performance Optimization**: Connection reuse and request caching
|
||||
|
||||
## 📚 Usage Examples
|
||||
|
||||
### Complete Example
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from client import DorisUnifiedClient, DorisClientConfig
|
||||
|
||||
async def comprehensive_example():
|
||||
# Create configuration
|
||||
config = DorisClientConfig.stdio(
|
||||
"python",
|
||||
["-m", "doris_mcp_server.main"]
|
||||
)
|
||||
|
||||
client = DorisUnifiedClient(config)
|
||||
|
||||
async def demo_operations(client):
|
||||
print("🔍 Discovering server capabilities...")
|
||||
|
||||
# List all available tools
|
||||
tools = await client.list_all_tools()
|
||||
print(f"Available tools: {[tool.name for tool in tools]}")
|
||||
|
||||
# Get database list
|
||||
print("\n📊 Getting database information...")
|
||||
db_result = await client.get_database_list()
|
||||
print(f"Databases: {db_result}")
|
||||
|
||||
# Execute queries
|
||||
print("\n🔍 Executing queries...")
|
||||
|
||||
# Simple query
|
||||
result1 = await client.execute_sql("SELECT 1 as test_column")
|
||||
print(f"Simple query result: {result1}")
|
||||
|
||||
# Get table schema
|
||||
schema_result = await client.get_table_schema("lineorder", "ssb")
|
||||
print(f"Table schema: {schema_result}")
|
||||
|
||||
# Column analysis
|
||||
analysis_result = await client.analyze_column(
|
||||
"lineorder", "lo_orderkey", "basic"
|
||||
)
|
||||
print(f"Column analysis: {analysis_result}")
|
||||
|
||||
await client.connect_and_run(demo_operations)
|
||||
|
||||
# Run the example
|
||||
asyncio.run(comprehensive_example())
|
||||
```
|
||||
|
||||
### Error Handling
|
||||
|
||||
```python
|
||||
async def error_handling_example(client):
|
||||
try:
|
||||
# This might fail
|
||||
result = await client.execute_sql("INVALID SQL")
|
||||
except Exception as e:
|
||||
print(f"SQL execution failed: {e}")
|
||||
|
||||
# Check result status
|
||||
result = await client.get_database_list()
|
||||
if result.get("success", True):
|
||||
print("Operation successful")
|
||||
else:
|
||||
print(f"Operation failed: {result.get('error')}")
|
||||
```
|
||||
|
||||
## 🔧 Configuration
|
||||
|
||||
### Client Configuration Options
|
||||
|
||||
```python
|
||||
# stdio mode with custom arguments
|
||||
config = DorisClientConfig(
|
||||
transport="stdio",
|
||||
server_command="python",
|
||||
server_args=["-m", "doris_mcp_server.main", "--debug"],
|
||||
timeout=30
|
||||
)
|
||||
|
||||
# HTTP mode with custom timeout
|
||||
config = DorisClientConfig(
|
||||
transport="http",
|
||||
server_url="http://localhost:8080/mcp",
|
||||
timeout=60
|
||||
)
|
||||
```
|
||||
|
||||
### Environment Variables
|
||||
|
||||
```bash
|
||||
# Set default server URL
|
||||
export DORIS_MCP_SERVER_URL="http://localhost:8080"
|
||||
|
||||
# Set default timeout
|
||||
export DORIS_MCP_TIMEOUT=60
|
||||
|
||||
# Enable debug logging
|
||||
export DORIS_MCP_DEBUG=true
|
||||
```
|
||||
|
||||
## 🚀 Performance Tips
|
||||
|
||||
1. **Connection Reuse**: Use the same client instance for multiple operations
|
||||
2. **Batch Operations**: Group related queries together
|
||||
3. **Async Context**: Always use proper async context management
|
||||
4. **Error Recovery**: Implement retry logic for transient failures
|
||||
|
||||
## 🤝 Contributing
|
||||
|
||||
1. Fork the repository
|
||||
2. Create a feature branch
|
||||
3. Make your changes
|
||||
4. Add tests
|
||||
5. Submit a pull request
|
||||
|
||||
## 📄 License
|
||||
|
||||
This project is licensed under the Apache 2.0 License.
|
||||
9
doris_mcp_client/__init__.py
Normal file
9
doris_mcp_client/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
Doris MCP Client Package
|
||||
|
||||
Unified MCP client supporting both stdio and HTTP transport modes
|
||||
"""
|
||||
|
||||
from .client import DorisUnifiedClient, DorisClientConfig
|
||||
|
||||
__all__ = ["DorisUnifiedClient", "DorisClientConfig"]
|
||||
497
doris_mcp_client/client.py
Normal file
497
doris_mcp_client/client.py
Normal file
@@ -0,0 +1,497 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Unified Doris MCP Client - Supports both stdio and Streamable HTTP modes
|
||||
|
||||
Combines the correct HTTP implementation from http_client.py and the complete architecture from client.py
|
||||
Provides complete support for the three major primitives: Resources, Tools, and Prompts
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Callable
|
||||
from datetime import timedelta
|
||||
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.stdio import stdio_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from mcp import StdioServerParameters
|
||||
from mcp.types import (
|
||||
Prompt,
|
||||
Resource,
|
||||
Tool,
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DorisClientConfig:
|
||||
"""Doris client configuration class"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: str = "stdio",
|
||||
server_command: str | None = None,
|
||||
server_args: list[str] | None = None,
|
||||
server_url: str | None = None,
|
||||
timeout: int = 60,
|
||||
):
|
||||
self.transport = transport
|
||||
self.server_command = server_command
|
||||
self.server_args = server_args or []
|
||||
self.server_url = server_url
|
||||
self.timeout = timeout
|
||||
|
||||
@classmethod
|
||||
def stdio(cls, command: str, args: list[str] = None) -> "DorisClientConfig":
|
||||
"""Create stdio connection configuration"""
|
||||
return cls(
|
||||
transport="stdio",
|
||||
server_command=command,
|
||||
server_args=args or []
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def http(cls, url: str, timeout: int = 60) -> "DorisClientConfig":
|
||||
"""Create HTTP connection configuration"""
|
||||
return cls(
|
||||
transport="http",
|
||||
server_url=url,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
|
||||
class DorisResourceClient:
|
||||
"""Doris resource client - Handles Resources related operations"""
|
||||
|
||||
def __init__(self, session: ClientSession):
|
||||
self.session = session
|
||||
self.logger = logging.getLogger(f"{__name__}.DorisResourceClient")
|
||||
self._resources_cache = None
|
||||
|
||||
async def list_resources(self) -> list[Resource]:
|
||||
"""Get list of all available resources"""
|
||||
try:
|
||||
self.logger.info("Getting resource list")
|
||||
response = await self.session.list_resources()
|
||||
resources = response.resources if hasattr(response, "resources") else []
|
||||
self._resources_cache = resources
|
||||
self.logger.info(f"Retrieved {len(resources)} resources")
|
||||
return resources
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to get resource list: {e}")
|
||||
return []
|
||||
|
||||
async def read_resource(self, uri: str) -> str | None:
|
||||
"""Read specified resource content"""
|
||||
try:
|
||||
self.logger.info(f"Reading resource: {uri}")
|
||||
response = await self.session.read_resource(uri)
|
||||
|
||||
if hasattr(response, "contents") and response.contents:
|
||||
# Merge all content
|
||||
content_parts = []
|
||||
for content in response.contents:
|
||||
if hasattr(content, "text"):
|
||||
content_parts.append(content.text)
|
||||
content = "\n".join(content_parts)
|
||||
self.logger.info(f"Successfully read resource content: {len(content)} characters")
|
||||
return content
|
||||
elif hasattr(response, "content"):
|
||||
return str(response.content)
|
||||
else:
|
||||
self.logger.warning(f"Resource {uri} returned no content")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to read resource {uri}: {e}")
|
||||
return None
|
||||
|
||||
async def filter_resources_by_type(self, resource_type: str) -> list[Resource]:
|
||||
"""Filter resources by type"""
|
||||
if not self._resources_cache:
|
||||
await self.list_resources()
|
||||
|
||||
if resource_type == "table":
|
||||
return [r for r in self._resources_cache if "table" in r.uri]
|
||||
elif resource_type == "view":
|
||||
return [r for r in self._resources_cache if "view" in r.uri]
|
||||
elif resource_type == "database":
|
||||
return [
|
||||
r for r in self._resources_cache
|
||||
if "database" in r.uri and "table" not in r.uri
|
||||
]
|
||||
else:
|
||||
return self._resources_cache
|
||||
|
||||
|
||||
class DorisToolsClient:
|
||||
"""Doris tools client - Handles Tools related operations"""
|
||||
|
||||
def __init__(self, session: ClientSession):
|
||||
self.session = session
|
||||
self.logger = logging.getLogger(f"{__name__}.DorisToolsClient")
|
||||
self._tools_cache = None
|
||||
|
||||
async def list_tools(self) -> list[Tool]:
|
||||
"""Get list of all available tools"""
|
||||
try:
|
||||
self.logger.info("Getting tool list")
|
||||
response = await self.session.list_tools()
|
||||
tools = response.tools if hasattr(response, "tools") else []
|
||||
self._tools_cache = tools
|
||||
self.logger.info(f"Retrieved {len(tools)} tools")
|
||||
return tools
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to get tool list: {e}")
|
||||
return []
|
||||
|
||||
async def call_tool(self, name: str, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Call specified tool"""
|
||||
try:
|
||||
self.logger.info(f"Calling tool: {name}")
|
||||
self.logger.debug(f"Tool arguments: {arguments}")
|
||||
|
||||
response = await self.session.call_tool(name, arguments)
|
||||
|
||||
if hasattr(response, "content") and response.content:
|
||||
# Parse response content
|
||||
result_text = ""
|
||||
for content in response.content:
|
||||
if hasattr(content, "text"):
|
||||
result_text += content.text
|
||||
|
||||
# Try to parse as JSON
|
||||
try:
|
||||
result = json.loads(result_text)
|
||||
self.logger.info(f"Tool call successful: {name}")
|
||||
return result
|
||||
except json.JSONDecodeError:
|
||||
# If not JSON format, return text directly
|
||||
return {"success": True, "data": result_text}
|
||||
|
||||
self.logger.warning(f"Tool {name} returned no content")
|
||||
return {"success": False, "error": "No response content"}
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Tool call failed {name}: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
async def get_tool_by_name(self, name: str) -> Tool | None:
|
||||
"""Get tool definition by name"""
|
||||
if not self._tools_cache:
|
||||
await self.list_tools()
|
||||
|
||||
for tool in self._tools_cache:
|
||||
if tool.name == name:
|
||||
return tool
|
||||
return None
|
||||
|
||||
async def get_tools_by_category(self, category: str) -> list[Tool]:
|
||||
"""Filter tools by category"""
|
||||
if not self._tools_cache:
|
||||
await self.list_tools()
|
||||
|
||||
category_lower = category.lower()
|
||||
return [
|
||||
tool for tool in self._tools_cache
|
||||
if category_lower in tool.description.lower()
|
||||
or category_lower in tool.name.lower()
|
||||
]
|
||||
|
||||
|
||||
class DorisPromptClient:
|
||||
"""Doris prompt client - Handles Prompts related operations"""
|
||||
|
||||
def __init__(self, session: ClientSession):
|
||||
self.session = session
|
||||
self.logger = logging.getLogger(f"{__name__}.DorisPromptClient")
|
||||
self._prompts_cache = None
|
||||
|
||||
async def list_prompts(self) -> list[Prompt]:
|
||||
"""Get list of all available prompts"""
|
||||
try:
|
||||
self.logger.info("Getting prompt list")
|
||||
response = await self.session.list_prompts()
|
||||
prompts = response.prompts if hasattr(response, "prompts") else []
|
||||
self._prompts_cache = prompts
|
||||
self.logger.info(f"Retrieved {len(prompts)} prompts")
|
||||
return prompts
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to get prompt list: {e}")
|
||||
return []
|
||||
|
||||
async def get_prompt(self, name: str, arguments: dict[str, Any]) -> str | None:
|
||||
"""Get specified prompt content"""
|
||||
try:
|
||||
self.logger.info(f"Getting prompt: {name}")
|
||||
self.logger.debug(f"Prompt arguments: {arguments}")
|
||||
|
||||
response = await self.session.get_prompt(name, arguments)
|
||||
|
||||
if hasattr(response, "messages") and response.messages:
|
||||
# Merge all message content
|
||||
content_parts = []
|
||||
for message in response.messages:
|
||||
if hasattr(message, "content"):
|
||||
if hasattr(message.content, "text"):
|
||||
content_parts.append(message.content.text)
|
||||
else:
|
||||
content_parts.append(str(message.content))
|
||||
|
||||
content = "\n".join(content_parts)
|
||||
self.logger.info(f"Successfully retrieved prompt content: {len(content)} characters")
|
||||
return content
|
||||
|
||||
self.logger.warning(f"Prompt {name} returned no content")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to get prompt {name}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class DorisUnifiedClient:
|
||||
"""Unified Doris MCP client - Provides complete MCP functionality"""
|
||||
|
||||
def __init__(self, config: DorisClientConfig):
|
||||
self.config = config
|
||||
self.logger = logging.getLogger(f"{__name__}.DorisUnifiedClient")
|
||||
self.session = None
|
||||
self.resources = None
|
||||
self.tools = None
|
||||
self.prompts = None
|
||||
|
||||
async def connect_and_run(self, callback_func: Callable):
|
||||
"""Connect to server and execute callback function"""
|
||||
if self.config.transport == "stdio":
|
||||
await self._run_stdio_mode(callback_func)
|
||||
elif self.config.transport == "http":
|
||||
await self._run_http_mode(callback_func)
|
||||
else:
|
||||
raise ValueError(f"Unsupported transport type: {self.config.transport}")
|
||||
|
||||
async def _run_stdio_mode(self, callback_func: Callable):
|
||||
"""Run in stdio mode"""
|
||||
try:
|
||||
self.logger.info(f"Starting stdio client: {self.config.server_command}")
|
||||
|
||||
server_params = StdioServerParameters(
|
||||
command=self.config.server_command,
|
||||
args=self.config.server_args,
|
||||
)
|
||||
|
||||
async with stdio_client(server_params) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
self.session = session
|
||||
self._init_sub_clients()
|
||||
|
||||
# Initialize server
|
||||
await session.initialize()
|
||||
self.logger.info("Server initialized successfully")
|
||||
|
||||
# Execute callback function
|
||||
await callback_func(self)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"stdio mode execution failed: {e}")
|
||||
raise
|
||||
|
||||
async def _run_http_mode(self, callback_func: Callable):
|
||||
"""Run in HTTP mode"""
|
||||
try:
|
||||
self.logger.info(f"Starting HTTP client: {self.config.server_url}")
|
||||
|
||||
async with streamablehttp_client(
|
||||
self.config.server_url,
|
||||
timeout=timedelta(seconds=self.config.timeout)
|
||||
) as (read, write):
|
||||
async with ClientSession(read, write) as session:
|
||||
self.session = session
|
||||
self._init_sub_clients()
|
||||
|
||||
# Initialize server
|
||||
await session.initialize()
|
||||
self.logger.info("Server initialized successfully")
|
||||
|
||||
# Execute callback function
|
||||
await callback_func(self)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"HTTP mode execution failed: {e}")
|
||||
raise
|
||||
|
||||
def _init_sub_clients(self):
|
||||
"""Initialize sub-clients"""
|
||||
self.resources = DorisResourceClient(self.session)
|
||||
self.tools = DorisToolsClient(self.session)
|
||||
self.prompts = DorisPromptClient(self.session)
|
||||
|
||||
# Convenience methods
|
||||
async def list_all_resources(self) -> list[Resource]:
|
||||
"""Get all resources"""
|
||||
return await self.resources.list_resources()
|
||||
|
||||
async def list_all_tools(self) -> list[Tool]:
|
||||
"""Get all tools"""
|
||||
return await self.tools.list_tools()
|
||||
|
||||
async def list_all_prompts(self) -> list[Prompt]:
|
||||
"""Get all prompts"""
|
||||
return await self.prompts.list_prompts()
|
||||
|
||||
async def call_tool(self, name: str, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Call tool"""
|
||||
return await self.tools.call_tool(name, arguments)
|
||||
|
||||
async def read_resource(self, uri: str) -> str | None:
|
||||
"""Read resource"""
|
||||
return await self.resources.read_resource(uri)
|
||||
|
||||
async def get_prompt(self, name: str, arguments: dict[str, Any]) -> str | None:
|
||||
"""Get prompt"""
|
||||
return await self.prompts.get_prompt(name, arguments)
|
||||
|
||||
# Smart tool finding methods
|
||||
async def _find_tool_by_pattern(self, patterns: list[str]) -> str | None:
|
||||
"""Find tool by name pattern"""
|
||||
tools = await self.list_all_tools()
|
||||
for pattern in patterns:
|
||||
for tool in tools:
|
||||
if pattern in tool.name:
|
||||
return tool.name
|
||||
return None
|
||||
|
||||
async def _find_tool_by_function(self, function_keywords: list[str]) -> str | None:
|
||||
"""Find tool by function keywords"""
|
||||
tools = await self.list_all_tools()
|
||||
for tool in tools:
|
||||
tool_desc = tool.description.lower()
|
||||
tool_name = tool.name.lower()
|
||||
for keyword in function_keywords:
|
||||
if keyword.lower() in tool_desc or keyword.lower() in tool_name:
|
||||
return tool.name
|
||||
return None
|
||||
|
||||
# High-level business methods
|
||||
async def execute_sql(self, sql: str, **kwargs) -> dict[str, Any]:
|
||||
"""Execute SQL query"""
|
||||
tool_name = await self._find_tool_by_pattern(["exec_query", "execute", "query"])
|
||||
if not tool_name:
|
||||
return {"success": False, "error": "SQL execution tool not found"}
|
||||
|
||||
arguments = {"sql": sql, **kwargs}
|
||||
return await self.call_tool(tool_name, arguments)
|
||||
|
||||
async def get_table_schema(self, table_name: str, db_name: str = None, **kwargs) -> dict[str, Any]:
|
||||
"""Get table schema"""
|
||||
tool_name = await self._find_tool_by_pattern(["get_table_schema", "table_schema", "schema"])
|
||||
if not tool_name:
|
||||
return {"success": False, "error": "Table schema tool not found"}
|
||||
|
||||
arguments = {"table_name": table_name}
|
||||
if db_name:
|
||||
arguments["db_name"] = db_name
|
||||
arguments.update(kwargs)
|
||||
|
||||
return await self.call_tool(tool_name, arguments)
|
||||
|
||||
async def get_database_list(self, **kwargs) -> dict[str, Any]:
|
||||
"""Get database list"""
|
||||
tool_name = await self._find_tool_by_pattern(["get_db_list", "database_list", "db_list"])
|
||||
if not tool_name:
|
||||
return {"success": False, "error": "Database list tool not found"}
|
||||
|
||||
return await self.call_tool(tool_name, kwargs)
|
||||
|
||||
async def analyze_column(self, table_name: str, column_name: str, analysis_type: str = "basic", **kwargs) -> dict[str, Any]:
|
||||
"""Analyze column"""
|
||||
tool_name = await self._find_tool_by_pattern(["column_analysis", "analyze_column", "column"])
|
||||
if not tool_name:
|
||||
return {"success": False, "error": "Column analysis tool not found"}
|
||||
|
||||
arguments = {
|
||||
"table_name": table_name,
|
||||
"column_name": column_name,
|
||||
"analysis_type": analysis_type,
|
||||
**kwargs
|
||||
}
|
||||
return await self.call_tool(tool_name, arguments)
|
||||
|
||||
async def call_tool_by_function(self, function_description: str, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Call tool by function description"""
|
||||
# Try to find appropriate tool based on function description
|
||||
function_keywords = function_description.lower().split()
|
||||
tool_name = await self._find_tool_by_function(function_keywords)
|
||||
|
||||
if not tool_name:
|
||||
return {
|
||||
"success": False,
|
||||
"error": f"No tool found for function: {function_description}"
|
||||
}
|
||||
|
||||
return await self.call_tool(tool_name, arguments)
|
||||
|
||||
|
||||
# Convenience factory functions
|
||||
async def create_stdio_client(command: str, args: list[str] = None) -> DorisUnifiedClient:
|
||||
"""Create stdio client"""
|
||||
config = DorisClientConfig.stdio(command, args)
|
||||
return DorisUnifiedClient(config)
|
||||
|
||||
|
||||
async def create_http_client(server_url: str, timeout: int = 60) -> DorisUnifiedClient:
|
||||
"""Create HTTP client"""
|
||||
config = DorisClientConfig.http(server_url, timeout)
|
||||
return DorisUnifiedClient(config)
|
||||
|
||||
|
||||
# Example usage
|
||||
async def example_stdio():
|
||||
"""stdio mode example"""
|
||||
client = await create_stdio_client("python", ["doris_mcp_server/main.py"])
|
||||
|
||||
async def test_client(client: DorisUnifiedClient):
|
||||
# Get server capabilities
|
||||
resources = await client.list_all_resources()
|
||||
tools = await client.list_all_tools()
|
||||
prompts = await client.list_all_prompts()
|
||||
|
||||
print(f"Resources: {len(resources)}")
|
||||
print(f"Tools: {len(tools)}")
|
||||
print(f"Prompts: {len(prompts)}")
|
||||
|
||||
# Test SQL execution
|
||||
result = await client.execute_sql("SELECT 1 as test")
|
||||
print(f"SQL execution result: {result}")
|
||||
|
||||
await client.connect_and_run(test_client)
|
||||
|
||||
|
||||
async def example_http():
|
||||
"""HTTP mode example"""
|
||||
client = await create_http_client("http://localhost:8080")
|
||||
|
||||
async def test_client(client: DorisUnifiedClient):
|
||||
# Get server capabilities
|
||||
resources = await client.list_all_resources()
|
||||
tools = await client.list_all_tools()
|
||||
|
||||
print(f"Resources: {len(resources)}")
|
||||
print(f"Tools: {len(tools)}")
|
||||
|
||||
# Test database list
|
||||
result = await client.get_database_list()
|
||||
print(f"Database list: {result}")
|
||||
|
||||
await client.connect_and_run(test_client)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run stdio example
|
||||
asyncio.run(example_stdio())
|
||||
|
||||
# Run HTTP example
|
||||
# asyncio.run(example_http())
|
||||
Reference in New Issue
Block a user