From bdf3f62e36ec92ea0baedfe3860a42dde44b4e6d Mon Sep 17 00:00:00 2001 From: dangzerong <429714019@qq.com> Date: Wed, 22 Oct 2025 09:07:07 +0800 Subject: [PATCH] =?UTF-8?q?mcp=E6=8E=A5=E5=8F=A3flask=E6=94=B9=E6=88=90fas?= =?UTF-8?q?tapi?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api/apps/__init___fastapi.py | 2 + api/apps/mcp_server_app.py | 204 ++++++++++++++++++----------------- 2 files changed, 106 insertions(+), 100 deletions(-) diff --git a/api/apps/__init___fastapi.py b/api/apps/__init___fastapi.py index d3af198..2668905 100644 --- a/api/apps/__init___fastapi.py +++ b/api/apps/__init___fastapi.py @@ -126,12 +126,14 @@ def setup_routes(app: FastAPI): from api.apps.document_app import router as document_router from api.apps.file_app import router as file_router from api.apps.file2document_app import router as file2document_router + from api.apps.mcp_server_app import router as mcp_router app.include_router(user_router, prefix=f"/{API_VERSION}/user", tags=["User"]) app.include_router(kb_router, prefix=f"/{API_VERSION}/kb", tags=["KB"]) app.include_router(document_router, prefix=f"/{API_VERSION}/document", tags=["Document"]) app.include_router(file_router, prefix=f"/{API_VERSION}/file", tags=["File"]) app.include_router(file2document_router, prefix=f"/{API_VERSION}/file2document", tags=["File2Document"]) + app.include_router(mcp_router, prefix=f"/{API_VERSION}/mcp", tags=["MCP"]) def get_current_user_from_token(authorization: str): """从token获取当前用户""" diff --git a/api/apps/mcp_server_app.py b/api/apps/mcp_server_app.py index 33eb952..17ad887 100644 --- a/api/apps/mcp_server_app.py +++ b/api/apps/mcp_server_app.py @@ -17,15 +17,15 @@ from typing import List, Optional, Dict, Any from fastapi import APIRouter, Depends, HTTPException, Query from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from api import settings from api.db import VALID_MCP_SERVER_TYPES -from api.db.db_models import MCPServer +from api.db.db_models import MCPServer, User from api.db.services.mcp_server_service import MCPServerService from api.db.services.user_service import TenantService from api.settings import RetCode from api.utils import get_uuid -from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, \ - get_mcp_tools +from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, get_mcp_tools from api.utils.web_utils import get_float, safe_json_parse from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions from pydantic import BaseModel @@ -52,7 +52,7 @@ class UpdateMCPRequest(BaseModel): server_type: Optional[str] = None headers: Optional[Dict[str, Any]] = None variables: Optional[Dict[str, Any]] = None - timeout: Optional[float] = None + timeout: Optional[float] = 10 class RemoveMCPRequest(BaseModel): mcp_ids: List[str] @@ -64,26 +64,30 @@ class ImportMCPRequest(BaseModel): class ExportMCPRequest(BaseModel): mcp_ids: List[str] + class ListToolsRequest(BaseModel): mcp_ids: List[str] timeout: Optional[float] = 10 + class TestToolRequest(BaseModel): mcp_id: str tool_name: str arguments: Dict[str, Any] timeout: Optional[float] = 10 + class CacheToolsRequest(BaseModel): mcp_id: str tools: List[Dict[str, Any]] + class TestMCPRequest(BaseModel): url: str server_type: str + timeout: Optional[float] = 10 headers: Optional[Dict[str, Any]] = {} variables: Optional[Dict[str, Any]] = {} - timeout: Optional[float] = 10 # Dependency injection async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)): @@ -153,17 +157,16 @@ router = APIRouter() @router.post("/list") async def list_mcp( - keywords: str = Query(""), - page: int = Query(0), - page_size: int = Query(0), - orderby: str = Query("create_time"), - desc: bool = Query(True), - req: ListMCPRequest = None, - current_user = Depends(get_current_user) + request: ListMCPRequest, + keywords: str = Query("", description="Search keywords"), + page: int = Query(0, description="Page number"), + page_size: int = Query(0, description="Items per page"), + orderby: str = Query("create_time", description="Order by field"), + desc: bool = Query(True, description="Sort descending"), + current_user: User = Depends(get_current_user) ): - mcp_ids = req.mcp_ids if req else [] try: - servers = MCPServerService.get_servers(current_user.id, mcp_ids, 0, 0, orderby, desc, keywords) or [] + servers = MCPServerService.get_servers(current_user.id, request.mcp_ids, 0, 0, orderby, desc, keywords) or [] total = len(servers) if page and page_size: @@ -176,8 +179,8 @@ async def list_mcp( @router.get("/detail") async def detail( - mcp_id: str = Query(...), - current_user = Depends(get_current_user) + mcp_id: str = Query(..., description="MCP server ID"), + current_user: User = Depends(get_current_user) ): try: mcp_server = MCPServerService.get_or_none(id=mcp_id, tenant_id=current_user.id) @@ -192,14 +195,14 @@ async def detail( @router.post("/create") async def create( - req: CreateMCPRequest, - current_user = Depends(get_current_user) + request: CreateMCPRequest, + current_user: User = Depends(get_current_user) ): - server_type = req.server_type + server_type = request.server_type if server_type not in VALID_MCP_SERVER_TYPES: return get_data_error_result(message="Unsupported MCP server type.") - server_name = req.name + server_name = request.name if not server_name or len(server_name.encode("utf-8")) > 255: return get_data_error_result(message=f"Invalid MCP name or length is {len(server_name)} which is large than 255.") @@ -207,20 +210,27 @@ async def create( if e: return get_data_error_result(message="Duplicated MCP server name.") - url = req.url + url = request.url if not url: return get_data_error_result(message="Invalid url.") - headers = safe_json_parse(req.headers) - variables = safe_json_parse(req.variables) + headers = safe_json_parse(request.headers or {}) + variables = safe_json_parse(request.variables or {}) variables.pop("tools", None) - timeout = req.timeout + timeout = request.timeout or 10 try: - req_dict = req.dict() - req_dict["id"] = get_uuid() - req_dict["tenant_id"] = current_user.id + req_data = { + "id": get_uuid(), + "tenant_id": current_user.id, + "name": server_name, + "url": url, + "server_type": server_type, + "headers": headers, + "variables": variables, + "timeout": timeout + } e, _ = TenantService.get_by_id(current_user.id) if not e: @@ -234,47 +244,55 @@ async def create( tools = server_tools[server_name] tools = {tool["name"]: tool for tool in tools if isinstance(tool, dict) and "name" in tool} variables["tools"] = tools - req_dict["variables"] = variables + req_data["variables"] = variables - if not MCPServerService.insert(**req_dict): + if not MCPServerService.insert(**req_data): return get_data_error_result("Failed to create MCP server.") - return get_json_result(data=req_dict) + return get_json_result(data=req_data) except Exception as e: return server_error_response(e) @router.post("/update") async def update( - req: UpdateMCPRequest, - current_user = Depends(get_current_user) + request: UpdateMCPRequest, + current_user: User = Depends(get_current_user) ): - mcp_id = req.mcp_id + mcp_id = request.mcp_id e, mcp_server = MCPServerService.get_by_id(mcp_id) if not e or mcp_server.tenant_id != current_user.id: return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}") - server_type = req.server_type if req.server_type is not None else mcp_server.server_type + server_type = request.server_type or mcp_server.server_type if server_type and server_type not in VALID_MCP_SERVER_TYPES: return get_data_error_result(message="Unsupported MCP server type.") - server_name = req.name if req.name is not None else mcp_server.name + + server_name = request.name or mcp_server.name if server_name and len(server_name.encode("utf-8")) > 255: return get_data_error_result(message=f"Invalid MCP name or length is {len(server_name)} which is large than 255.") - url = req.url if req.url is not None else mcp_server.url + + url = request.url or mcp_server.url if not url: return get_data_error_result(message="Invalid url.") - headers = safe_json_parse(req.headers if req.headers is not None else mcp_server.headers) - variables = safe_json_parse(req.variables if req.variables is not None else mcp_server.variables) + headers = safe_json_parse(request.headers or mcp_server.headers) + variables = safe_json_parse(request.variables or mcp_server.variables) variables.pop("tools", None) - timeout = req.timeout if req.timeout is not None else 10 + timeout = request.timeout or 10 try: - req_dict = req.dict(exclude_unset=True) - req_dict["tenant_id"] = current_user.id - req_dict.pop("mcp_id", None) - req_dict["id"] = mcp_id + req_data = { + "tenant_id": current_user.id, + "id": mcp_id, + "name": server_name, + "url": url, + "server_type": server_type, + "headers": headers, + "variables": variables, + "timeout": timeout + } mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers) server_tools, err_message = get_mcp_tools([mcp_server], timeout) @@ -284,12 +302,12 @@ async def update( tools = server_tools[server_name] tools = {tool["name"]: tool for tool in tools if isinstance(tool, dict) and "name" in tool} variables["tools"] = tools - req_dict["variables"] = variables + req_data["variables"] = variables - if not MCPServerService.filter_update([MCPServer.id == mcp_id, MCPServer.tenant_id == current_user.id], req_dict): + if not MCPServerService.filter_update([MCPServer.id == mcp_id, MCPServer.tenant_id == current_user.id], req_data): return get_data_error_result(message="Failed to updated MCP server.") - e, updated_mcp = MCPServerService.get_by_id(req_dict["id"]) + e, updated_mcp = MCPServerService.get_by_id(req_data["id"]) if not e: return get_data_error_result(message="Failed to fetch updated MCP server.") @@ -298,16 +316,14 @@ async def update( return server_error_response(e) -@manager.route("/rm", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("mcp_ids") -def rm() -> Response: - req = request.get_json() - mcp_ids = req.get("mcp_ids", []) +@router.post("/rm") +async def rm( + request: RemoveMCPRequest, + current_user: User = Depends(get_current_user) +): + mcp_ids = request.mcp_ids try: - req["tenant_id"] = current_user.id - if not MCPServerService.delete_by_ids(mcp_ids): return get_data_error_result(message=f"Failed to delete MCP servers {mcp_ids}") @@ -316,16 +332,16 @@ def rm() -> Response: return server_error_response(e) -@manager.route("/import", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("mcpServers") -def import_multiple() -> Response: - req = request.get_json() - servers = req.get("mcpServers", {}) +@router.post("/import") +async def import_multiple( + request: ImportMCPRequest, + current_user: User = Depends(get_current_user) +): + servers = request.mcpServers if not servers: return get_data_error_result(message="No MCP servers provided.") - timeout = get_float(req, "timeout", 10) + timeout = request.timeout or 10 results = [] try: @@ -383,12 +399,12 @@ def import_multiple() -> Response: return server_error_response(e) -@manager.route("/export", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("mcp_ids") -def export_multiple() -> Response: - req = request.get_json() - mcp_ids = req.get("mcp_ids", []) +@router.post("/export") +async def export_multiple( + request: ExportMCPRequest, + current_user: User = Depends(get_current_user) +): + mcp_ids = request.mcp_ids if not mcp_ids: return get_data_error_result(message="No MCP server IDs provided.") @@ -415,16 +431,13 @@ def export_multiple() -> Response: return server_error_response(e) -@manager.route("/list_tools", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("mcp_ids") -def list_tools() -> Response: - req = request.get_json() - mcp_ids = req.get("mcp_ids", []) +@router.post("/list_tools") +async def list_tools(req: ListToolsRequest, current_user: User = Depends(get_current_user)): + mcp_ids = req.mcp_ids if not mcp_ids: return get_data_error_result(message="No MCP server IDs provided.") - timeout = get_float(req, "timeout", 10) + timeout = req.timeout results = {} tool_call_sessions = [] @@ -462,19 +475,16 @@ def list_tools() -> Response: close_multiple_mcp_toolcall_sessions(tool_call_sessions) -@manager.route("/test_tool", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("mcp_id", "tool_name", "arguments") -def test_tool() -> Response: - req = request.get_json() - mcp_id = req.get("mcp_id", "") +@router.post("/test_tool") +async def test_tool(req: TestToolRequest, current_user: User = Depends(get_current_user)): + mcp_id = req.mcp_id if not mcp_id: return get_data_error_result(message="No MCP server ID provided.") - timeout = get_float(req, "timeout", 10) + timeout = req.timeout - tool_name = req.get("tool_name", "") - arguments = req.get("arguments", {}) + tool_name = req.tool_name + arguments = req.arguments if not all([tool_name, arguments]): return get_data_error_result(message="Require provide tool name and arguments.") @@ -495,15 +505,12 @@ def test_tool() -> Response: return server_error_response(e) -@manager.route("/cache_tools", methods=["POST"]) # noqa: F821 -@login_required -@validate_request("mcp_id", "tools") -def cache_tool() -> Response: - req = request.get_json() - mcp_id = req.get("mcp_id", "") +@router.post("/cache_tools") +async def cache_tool(req: CacheToolsRequest, current_user: User = Depends(get_current_user)): + mcp_id = req.mcp_id if not mcp_id: return get_data_error_result(message="No MCP server ID provided.") - tools = req.get("tools", []) + tools = req.tools e, mcp_server = MCPServerService.get_by_id(mcp_id) if not e or mcp_server.tenant_id != current_user.id: @@ -519,22 +526,19 @@ def cache_tool() -> Response: return get_json_result(data=tools) -@manager.route("/test_mcp", methods=["POST"]) # noqa: F821 -@validate_request("url", "server_type") -def test_mcp() -> Response: - req = request.get_json() - - url = req.get("url", "") +@router.post("/test_mcp") +async def test_mcp(req: TestMCPRequest): + url = req.url if not url: return get_data_error_result(message="Invalid MCP url.") - server_type = req.get("server_type", "") + server_type = req.server_type if server_type not in VALID_MCP_SERVER_TYPES: return get_data_error_result(message="Unsupported MCP server type.") - timeout = get_float(req, "timeout", 10) - headers = safe_json_parse(req.get("headers", {})) - variables = safe_json_parse(req.get("variables", {})) + timeout = req.timeout + headers = req.headers + variables = req.variables mcp_server = MCPServer(id=f"{server_type}: {url}", server_type=server_type, url=url, headers=headers, variables=variables)