@@ -13,52 +13,175 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from flask import Response , request
from fl ask_login import current_user , login_required
from typing import List , Optional , Dict , Any
from fastapi import APIRouter , Depends , HTTPException , Query
from fastapi . security import HTTPBearer , HTTPAuthorizationCredentials
from api import settings
from api . db import VALID_MCP_SERVER_TYPES
from api . db . db_models import MCPServer
from api . db . db_models import MCPServer , User
from api . db . services . mcp_server_service import MCPServerService
from api . db . services . user_service import TenantService
from api . settings import RetCode
from api . utils import get_uuid
from api . utils . api_utils import get_data_error_result , get_json_result , server_error_response , validate_request , \
get_mcp_tools
from api . utils . api_utils import get_data_error_result , get_json_result , server_error_response , get_mcp_tools
from api . utils . web_utils import get_float , safe_json_parse
from rag . utils . mcp_tool_call_conn import MCPToolCallSession , close_multiple_mcp_toolcall_sessions
from pydantic import BaseModel
# Security
security = HTTPBearer ( )
# Pydantic models for request/response
class ListMCPRequest ( BaseModel ) :
mcp_ids : List [ str ] = [ ]
class CreateMCPRequest ( BaseModel ) :
name : str
url : str
server_type : str
headers : Optional [ Dict [ str , Any ] ] = { }
variables : Optional [ Dict [ str , Any ] ] = { }
timeout : Optional [ float ] = 10
class UpdateMCPRequest ( BaseModel ) :
mcp_id : str
name : Optional [ str ] = None
url : Optional [ str ] = None
server_type : Optional [ str ] = None
headers : Optional [ Dict [ str , Any ] ] = None
variables : Optional [ Dict [ str , Any ] ] = None
timeout : Optional [ float ] = 10
class RemoveMCPRequest ( BaseModel ) :
mcp_ids : List [ str ]
class ImportMCPRequest ( BaseModel ) :
mcpServers : Dict [ str , Dict [ str , Any ] ]
timeout : Optional [ float ] = 10
class ExportMCPRequest ( BaseModel ) :
mcp_ids : List [ str ]
@manager.route ( " /list " , methods = [ " POST " ] ) # noqa: F821
@login_required
def list_mcp ( ) - > Response :
keywords = request . args . get ( " keywords " , " " )
page_number = int ( request . args . get ( " page " , 0 ) )
items_per_page = int ( request . args . get ( " page_size " , 0 ) )
orderby = request . args . get ( " orderby " , " create_time " )
if request . args . get ( " desc " , " true " ) . lower ( ) == " false " :
desc = False
else :
desc = True
class ListToolsRequest ( BaseModel ) :
mcp_ids : List [ str ]
timeout : Optional [ float ] = 10
req = request . get_json ( )
mcp_ids = req . get ( " mcp_ids " , [ ] )
class TestToolRequest ( BaseModel ) :
mcp_id : str
tool_name : str
arguments : Dict [ str , Any ]
timeout : Optional [ float ] = 10
class CacheToolsRequest ( BaseModel ) :
mcp_id : str
tools : List [ Dict [ str , Any ] ]
class TestMCPRequest ( BaseModel ) :
url : str
server_type : str
timeout : Optional [ float ] = 10
headers : Optional [ Dict [ str , Any ] ] = { }
variables : Optional [ Dict [ str , Any ] ] = { }
# Dependency injection
async def get_current_user ( credentials : HTTPAuthorizationCredentials = Depends ( security ) ) :
""" 获取当前用户 """
from api . db import StatusEnum
from api . db . services . user_service import UserService
from fastapi import HTTPException , status
import logging
try :
servers = MCPServerService . get_servers ( current_user . id , mcp_ids , 0 , 0 , orderby , desc , keywords ) or [ ]
from itsdangerous . url_safe import URLSafeTimedSerializer as Serializer
except ImportError :
# 如果没有itsdangerous, 使用jwt作为替代
import jwt
Serializer = jwt
jwt = Serializer ( secret_key = settings . SECRET_KEY )
authorization = credentials . credentials
if authorization :
try :
access_token = str ( jwt . loads ( authorization ) )
if not access_token or not access_token . strip ( ) :
raise HTTPException (
status_code = status . HTTP_401_UNAUTHORIZED ,
detail = " Authentication attempt with empty access token "
)
# Access tokens should be UUIDs (32 hex characters)
if len ( access_token . strip ( ) ) < 32 :
raise HTTPException (
status_code = status . HTTP_401_UNAUTHORIZED ,
detail = f " Authentication attempt with invalid token format: { len ( access_token ) } chars "
)
user = UserService . query (
access_token = access_token , status = StatusEnum . VALID . value
)
if user :
if not user [ 0 ] . access_token or not user [ 0 ] . access_token . strip ( ) :
raise HTTPException (
status_code = status . HTTP_401_UNAUTHORIZED ,
detail = f " User { user [ 0 ] . email } has empty access_token in database "
)
return user [ 0 ]
else :
raise HTTPException (
status_code = status . HTTP_401_UNAUTHORIZED ,
detail = " Invalid access token "
)
except Exception as e :
logging . warning ( f " load_user got exception { e } " )
raise HTTPException (
status_code = status . HTTP_401_UNAUTHORIZED ,
detail = " Invalid access token "
)
else :
raise HTTPException (
status_code = status . HTTP_401_UNAUTHORIZED ,
detail = " Authorization header required "
)
# Create router
router = APIRouter ( )
@router.post ( " /list " )
async def list_mcp (
request : ListMCPRequest ,
keywords : str = Query ( " " , description = " Search keywords " ) ,
page : int = Query ( 0 , description = " Page number " ) ,
page_size : int = Query ( 0 , description = " Items per page " ) ,
orderby : str = Query ( " create_time " , description = " Order by field " ) ,
desc : bool = Query ( True , description = " Sort descending " ) ,
current_user : User = Depends ( get_current_user )
) :
try :
servers = MCPServerService . get_servers ( current_user . id , request . mcp_ids , 0 , 0 , orderby , desc , keywords ) or [ ]
total = len ( servers )
if page_number and items_per_pag e:
servers = servers [ ( page_number - 1 ) * items_per_page : page_number * items_per_pag e]
if page and page_siz e:
servers = servers [ ( page - 1 ) * page_size : page * page_siz e]
return get_json_result ( data = { " mcp_servers " : servers , " total " : total } )
except Exception as e :
return server_error_response ( e )
@manager. route ( " /detail " , methods = [ " GET " ] ) # noqa: F821
@login_required
def detail ( ) - > R esponse :
mcp_id = request . args [ " mcp_id " ]
@router.get ( " /detail " )
async def detail (
mcp_id : str = Query ( . . . , d escription = " MCP server ID " ) ,
current_user : User = Depends ( get_current_user )
) :
try :
mcp_server = MCPServerService . get_or_none ( id = mcp_id , tenant_id = current_user . id )
@@ -70,17 +193,16 @@ def detail() -> Response:
return server_error_response ( e )
@manager.route ( " /create " , methods = [ " POST " ] ) # noqa: F821
@login_required
@validate_request ( " name " , " url " , " server_type " )
def create ( ) - > Response :
req = request . get_json ( )
server_type = req . get ( " server_type " , " " )
@router.post ( " /create " )
async def create (
request : CreateMCPRequest ,
current_user : User = Depends ( get_current_user )
) :
server_type = request . server_type
if server_type not in VALID_MCP_SERVER_TYPES :
return get_data_error_result ( message = " Unsupported MCP server type. " )
server_name = req . get ( " name " , " " )
server_name = request . name
if not server_name or len ( server_name . encode ( " utf-8 " ) ) > 255 :
return get_data_error_result ( message = f " Invalid MCP name or length is { len ( server_name ) } which is large than 255. " )
@@ -88,20 +210,27 @@ def create() -> Response:
if e :
return get_data_error_result ( message = " Duplicated MCP server name. " )
url = req . get ( " url " , " " )
url = request . url
if not url :
return get_data_error_result ( message = " Invalid url. " )
headers = safe_json_parse ( req . get ( " headers" , { } ) )
req [ " headers " ] = headers
variables = safe_json_parse ( req . get ( " variables " , { } ) )
headers = safe_json_parse ( request . headers or { } )
variables = safe_json_parse ( request . variables or { } )
variables . pop ( " tools " , None )
timeout = get_float ( req , " timeout" , 10 )
timeout = request . timeout or 10
try :
req[ " id " ] = get_uuid ( )
req [ " tenant_id " ] = current_user . id
req_data = {
" id " : get_uuid ( ) ,
" tenant_id " : current_user . id ,
" name " : server_name ,
" url " : url ,
" server_type " : server_type ,
" headers " : headers ,
" variables " : variables ,
" timeout " : timeout
}
e , _ = TenantService . get_by_id ( current_user . id )
if not e :
@@ -115,49 +244,55 @@ def create() -> Response:
tools = server_tools [ server_name ]
tools = { tool [ " name " ] : tool for tool in tools if isinstance ( tool , dict ) and " name " in tool }
variables [ " tools " ] = tools
req [ " variables " ] = variables
req_data [ " variables " ] = variables
if not MCPServerService . insert ( * * req ) :
if not MCPServerService . insert ( * * req_data ) :
return get_data_error_result ( " Failed to create MCP server. " )
return get_json_result ( data = req )
return get_json_result ( data = req_data )
except Exception as e :
return server_error_response ( e )
@manager.route ( " /update " , methods = [ " POST " ] ) # noqa: F821
@login_required
@validate_request ( " mcp_id " )
def update ( ) - > Response :
req = request . get_json ( )
mcp_id = req . get ( " mcp_id " , " " )
@router.post ( " /update " )
async def update (
request : UpdateMCPRequest ,
current_user : User = Depends ( get_current_user )
) :
mcp_id = request . mcp_id
e , mcp_server = MCPServerService . get_by_id ( mcp_id )
if not e or mcp_server . tenant_id != current_user . id :
return get_data_error_result ( message = f " Cannot find MCP server { mcp_id } for user { current_user . id } " )
server_type = req . get ( " server_type" , mcp_server . server_type )
server_type = request . server_type or mcp_server . server_type
if server_type and server_type not in VALID_MCP_SERVER_TYPES :
return get_data_error_result ( message = " Unsupported MCP server type. " )
server_name = req . get ( " name " , mcp_server . name )
server_name = request . name or mcp_server . name
if server_name and len ( server_name . encode ( " utf-8 " ) ) > 255 :
return get_data_error_result ( message = f " Invalid MCP name or length is { len ( server_name ) } which is large than 255. " )
url = req . get ( " url " , mcp_server . url )
url = request . url or mcp_server . url
if not url :
return get_data_error_result ( message = " Invalid url. " )
headers = safe_json_parse ( req . get ( " headers" , mcp_server . headers ) )
req [ " headers " ] = headers
variables = safe_json_parse ( req . get ( " variables " , mcp_server . variables ) )
headers = safe_json_parse ( request . headers or mcp_server . headers )
variables = safe_json_parse ( request . variables or mcp_server . variables )
variables . pop ( " tools " , None )
timeout = get_float ( req , " timeout" , 10 )
timeout = request . timeout or 10
try :
req[ " tenant_id " ] = current_user . id
req . pop ( " mcp _id" , None )
req [ " id " ] = mcp_id
req_data = {
" tenant _id " : current_user . id ,
" id " : mcp_id ,
" name " : server_name ,
" url " : url ,
" server_type " : server_type ,
" headers " : headers ,
" variables " : variables ,
" timeout " : timeout
}
mcp_server = MCPServer ( id = server_name , name = server_name , url = url , server_type = server_type , variables = variables , headers = headers )
server_tools , err_message = get_mcp_tools ( [ mcp_server ] , timeout )
@@ -167,12 +302,12 @@ def update() -> Response:
tools = server_tools [ server_name ]
tools = { tool [ " name " ] : tool for tool in tools if isinstance ( tool , dict ) and " name " in tool }
variables [ " tools " ] = tools
req [ " variables " ] = variables
req_data [ " variables " ] = variables
if not MCPServerService . filter_update ( [ MCPServer . id == mcp_id , MCPServer . tenant_id == current_user . id ] , req ) :
if not MCPServerService . filter_update ( [ MCPServer . id == mcp_id , MCPServer . tenant_id == current_user . id ] , req_data ) :
return get_data_error_result ( message = " Failed to updated MCP server. " )
e , updated_mcp = MCPServerService . get_by_id ( req [ " id " ] )
e , updated_mcp = MCPServerService . get_by_id ( req_data [ " id " ] )
if not e :
return get_data_error_result ( message = " Failed to fetch updated MCP server. " )
@@ -181,16 +316,14 @@ def update() -> Response:
return server_error_response ( e )
@manager.route ( " /rm " , methods = [ " POST " ] ) # noqa: F821
@login_required
@validate_request ( " mcp_ids " )
def rm ( ) - > Response :
req = request . get_json ( )
mcp_ids = req . get ( " mcp_ids " , [ ] )
@router.post ( " /rm " )
async def rm (
request : RemoveMCPRequest ,
current_user : User = Depends ( get_current_user )
) :
mcp_ids = request . mcp_ids
try :
req [ " tenant_id " ] = current_user . id
if not MCPServerService . delete_by_ids ( mcp_ids ) :
return get_data_error_result ( message = f " Failed to delete MCP servers { mcp_ids } " )
@@ -199,16 +332,16 @@ def rm() -> Response:
return server_error_response ( e )
@manager.route ( " /import " , methods = [ " POST " ] ) # noqa: F821
@login_required
@validate_request ( " mcpServers " )
def import_multiple ( ) - > Respon se:
req = request . get_json ( )
servers = req . get ( " mcpServers " , { } )
@router.post ( " /import " )
async def import_multiple (
request : ImportMCPRequest ,
current_user : User = Depends ( get_current_u ser )
) :
servers = request . mcpServers
if not servers :
return get_data_error_result ( message = " No MCP servers provided. " )
timeout = get_float ( req , " timeout" , 10 )
timeout = request . timeout or 10
results = [ ]
try :
@@ -266,12 +399,12 @@ def import_multiple() -> Response:
return server_error_response ( e )
@manager.route ( " /export " , methods = [ " POST " ] ) # noqa: F821
@login_required
@validate_request ( " mcp_ids " )
def export_multiple ( ) - > Respon se:
req = request . get_json ( )
mcp_ids = req . get ( " mcp_ids " , [ ] )
@router.post ( " /export " )
async def export_multiple (
request : ExportMCPRequest ,
current_user : User = Depends ( get_current_u ser )
) :
mcp_ids = request . mcp_ids
if not mcp_ids :
return get_data_error_result ( message = " No MCP server IDs provided. " )
@@ -298,16 +431,13 @@ def export_multiple() -> Response:
return server_error_response ( e )
@manager.route ( " /list_tools " , methods = [ " POST " ] ) # noqa: F821
@login_required
@validate_request ( " mcp_ids" )
def list_tools ( ) - > Response :
req = request . get_json ( )
mcp_ids = req . get ( " mcp_ids " , [ ] )
@router.post ( " /list_tools " )
async def list_tools ( req : ListToolsRequest , current_user : User = Depends ( get_current_user ) ) :
mcp_ids = req . mcp_ids
if not mcp_ids :
return get_data_error_result ( message = " No MCP server IDs provided. " )
timeout = get_float ( req , " timeout " , 10 )
timeout = req . timeout
results = { }
tool_call_sessions = [ ]
@@ -345,19 +475,16 @@ def list_tools() -> Response:
close_multiple_mcp_toolcall_sessions ( tool_call_sessions )
@manager.route ( " /test_tool " , methods = [ " POST " ] ) # noqa: F821
@login_required
@validate_request ( " mcp_id " , " tool_name " , " arguments " )
def test_tool ( ) - > Response :
req = request . get_json ( )
mcp_id = req . get ( " mcp_id " , " " )
@router.post ( " /test_tool " )
async def test_tool ( req : TestToolRequest , current_user : User = Depends ( get_current_user ) ) :
mcp_id = req . mcp_id
if not mcp_id :
return get_data_error_result ( message = " No MCP server ID provided. " )
timeout = get_float ( req , " timeout " , 10 )
timeout = req . timeout
tool_name = req . get ( " tool_name " , " " )
arguments = req . get ( " arguments " , { } )
tool_name = req . tool_name
arguments = req . arguments
if not all ( [ tool_name , arguments ] ) :
return get_data_error_result ( message = " Require provide tool name and arguments. " )
@@ -378,15 +505,12 @@ def test_tool() -> Response:
return server_error_response ( e )
@manager.route ( " /cache_tools " , methods = [ " POST " ] ) # noqa: F821
@login_required
@validate_request ( " mcp_id " , " tools " )
def cache_tool ( ) - > Response :
req = request . get_json ( )
mcp_id = req . get ( " mcp_id " , " " )
@router.post ( " /cache_tools " )
async def cache_tool ( req : CacheToolsRequest , current_user : User = Depends ( get_current_user ) ) :
mcp_id = req . mcp_id
if not mcp_id :
return get_data_error_result ( message = " No MCP server ID provided. " )
tools = req . get ( " tools " , [ ] )
tools = req . tools
e , mcp_server = MCPServerService . get_by_id ( mcp_id )
if not e or mcp_server . tenant_id != current_user . id :
@@ -402,22 +526,19 @@ def cache_tool() -> Response:
return get_json_result ( data = tools )
@manager.route ( " /test_mcp " , methods = [ " POST " ] ) # noqa: F821
@validate_request ( " url " , " server_type " )
def test_mcp ( ) - > Response :
req = request . get_json ( )
url = req . get ( " url " , " " )
@router.post ( " /test_mcp " )
async def test_mcp ( req : TestMCPRequest ) :
url = req . url
if not url :
return get_data_error_result ( message = " Invalid MCP url. " )
server_type = req . get ( " server_type " , " " )
server_type = req . server_type
if server_type not in VALID_MCP_SERVER_TYPES :
return get_data_error_result ( message = " Unsupported MCP server type. " )
timeout = get_float ( req , " timeout " , 10 )
headers = safe_json_parse ( req . get ( " headers " , { } ) )
variables = safe_json_parse ( req . get ( " variables " , { } ) )
timeout = req . timeout
headers = req . headers
variables = req . variables
mcp_server = MCPServer ( id = f " { server_type } : { url } " , server_type = server_type , url = url , headers = headers , variables = variables )