mcp接口flask改成fastapi

This commit is contained in:
2025-10-22 09:07:07 +08:00
parent e57c7350fb
commit bdf3f62e36
2 changed files with 106 additions and 100 deletions

View File

@@ -126,12 +126,14 @@ def setup_routes(app: FastAPI):
from api.apps.document_app import router as document_router from api.apps.document_app import router as document_router
from api.apps.file_app import router as file_router from api.apps.file_app import router as file_router
from api.apps.file2document_app import router as file2document_router from api.apps.file2document_app import router as file2document_router
from api.apps.mcp_server_app import router as mcp_router
app.include_router(user_router, prefix=f"/{API_VERSION}/user", tags=["User"]) app.include_router(user_router, prefix=f"/{API_VERSION}/user", tags=["User"])
app.include_router(kb_router, prefix=f"/{API_VERSION}/kb", tags=["KB"]) app.include_router(kb_router, prefix=f"/{API_VERSION}/kb", tags=["KB"])
app.include_router(document_router, prefix=f"/{API_VERSION}/document", tags=["Document"]) app.include_router(document_router, prefix=f"/{API_VERSION}/document", tags=["Document"])
app.include_router(file_router, prefix=f"/{API_VERSION}/file", tags=["File"]) app.include_router(file_router, prefix=f"/{API_VERSION}/file", tags=["File"])
app.include_router(file2document_router, prefix=f"/{API_VERSION}/file2document", tags=["File2Document"]) app.include_router(file2document_router, prefix=f"/{API_VERSION}/file2document", tags=["File2Document"])
app.include_router(mcp_router, prefix=f"/{API_VERSION}/mcp", tags=["MCP"])
def get_current_user_from_token(authorization: str): def get_current_user_from_token(authorization: str):
"""从token获取当前用户""" """从token获取当前用户"""

View File

@@ -17,15 +17,15 @@ from typing import List, Optional, Dict, Any
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from api import settings
from api.db import VALID_MCP_SERVER_TYPES from api.db import VALID_MCP_SERVER_TYPES
from api.db.db_models import MCPServer from api.db.db_models import MCPServer, User
from api.db.services.mcp_server_service import MCPServerService from api.db.services.mcp_server_service import MCPServerService
from api.db.services.user_service import TenantService from api.db.services.user_service import TenantService
from api.settings import RetCode from api.settings import RetCode
from api.utils import get_uuid from api.utils import get_uuid
from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, \ from api.utils.api_utils import get_data_error_result, get_json_result, server_error_response, get_mcp_tools
get_mcp_tools
from api.utils.web_utils import get_float, safe_json_parse from api.utils.web_utils import get_float, safe_json_parse
from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions from rag.utils.mcp_tool_call_conn import MCPToolCallSession, close_multiple_mcp_toolcall_sessions
from pydantic import BaseModel from pydantic import BaseModel
@@ -52,7 +52,7 @@ class UpdateMCPRequest(BaseModel):
server_type: Optional[str] = None server_type: Optional[str] = None
headers: Optional[Dict[str, Any]] = None headers: Optional[Dict[str, Any]] = None
variables: Optional[Dict[str, Any]] = None variables: Optional[Dict[str, Any]] = None
timeout: Optional[float] = None timeout: Optional[float] = 10
class RemoveMCPRequest(BaseModel): class RemoveMCPRequest(BaseModel):
mcp_ids: List[str] mcp_ids: List[str]
@@ -64,26 +64,30 @@ class ImportMCPRequest(BaseModel):
class ExportMCPRequest(BaseModel): class ExportMCPRequest(BaseModel):
mcp_ids: List[str] mcp_ids: List[str]
class ListToolsRequest(BaseModel): class ListToolsRequest(BaseModel):
mcp_ids: List[str] mcp_ids: List[str]
timeout: Optional[float] = 10 timeout: Optional[float] = 10
class TestToolRequest(BaseModel): class TestToolRequest(BaseModel):
mcp_id: str mcp_id: str
tool_name: str tool_name: str
arguments: Dict[str, Any] arguments: Dict[str, Any]
timeout: Optional[float] = 10 timeout: Optional[float] = 10
class CacheToolsRequest(BaseModel): class CacheToolsRequest(BaseModel):
mcp_id: str mcp_id: str
tools: List[Dict[str, Any]] tools: List[Dict[str, Any]]
class TestMCPRequest(BaseModel): class TestMCPRequest(BaseModel):
url: str url: str
server_type: str server_type: str
timeout: Optional[float] = 10
headers: Optional[Dict[str, Any]] = {} headers: Optional[Dict[str, Any]] = {}
variables: Optional[Dict[str, Any]] = {} variables: Optional[Dict[str, Any]] = {}
timeout: Optional[float] = 10
# Dependency injection # Dependency injection
async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)): async def get_current_user(credentials: HTTPAuthorizationCredentials = Depends(security)):
@@ -153,17 +157,16 @@ router = APIRouter()
@router.post("/list") @router.post("/list")
async def list_mcp( async def list_mcp(
keywords: str = Query(""), request: ListMCPRequest,
page: int = Query(0), keywords: str = Query("", description="Search keywords"),
page_size: int = Query(0), page: int = Query(0, description="Page number"),
orderby: str = Query("create_time"), page_size: int = Query(0, description="Items per page"),
desc: bool = Query(True), orderby: str = Query("create_time", description="Order by field"),
req: ListMCPRequest = None, desc: bool = Query(True, description="Sort descending"),
current_user = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
mcp_ids = req.mcp_ids if req else []
try: try:
servers = MCPServerService.get_servers(current_user.id, mcp_ids, 0, 0, orderby, desc, keywords) or [] servers = MCPServerService.get_servers(current_user.id, request.mcp_ids, 0, 0, orderby, desc, keywords) or []
total = len(servers) total = len(servers)
if page and page_size: if page and page_size:
@@ -176,8 +179,8 @@ async def list_mcp(
@router.get("/detail") @router.get("/detail")
async def detail( async def detail(
mcp_id: str = Query(...), mcp_id: str = Query(..., description="MCP server ID"),
current_user = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
try: try:
mcp_server = MCPServerService.get_or_none(id=mcp_id, tenant_id=current_user.id) mcp_server = MCPServerService.get_or_none(id=mcp_id, tenant_id=current_user.id)
@@ -192,14 +195,14 @@ async def detail(
@router.post("/create") @router.post("/create")
async def create( async def create(
req: CreateMCPRequest, request: CreateMCPRequest,
current_user = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
server_type = req.server_type server_type = request.server_type
if server_type not in VALID_MCP_SERVER_TYPES: if server_type not in VALID_MCP_SERVER_TYPES:
return get_data_error_result(message="Unsupported MCP server type.") return get_data_error_result(message="Unsupported MCP server type.")
server_name = req.name server_name = request.name
if not server_name or len(server_name.encode("utf-8")) > 255: if not server_name or len(server_name.encode("utf-8")) > 255:
return get_data_error_result(message=f"Invalid MCP name or length is {len(server_name)} which is large than 255.") return get_data_error_result(message=f"Invalid MCP name or length is {len(server_name)} which is large than 255.")
@@ -207,20 +210,27 @@ async def create(
if e: if e:
return get_data_error_result(message="Duplicated MCP server name.") return get_data_error_result(message="Duplicated MCP server name.")
url = req.url url = request.url
if not url: if not url:
return get_data_error_result(message="Invalid url.") return get_data_error_result(message="Invalid url.")
headers = safe_json_parse(req.headers) headers = safe_json_parse(request.headers or {})
variables = safe_json_parse(req.variables) variables = safe_json_parse(request.variables or {})
variables.pop("tools", None) variables.pop("tools", None)
timeout = req.timeout timeout = request.timeout or 10
try: try:
req_dict = req.dict() req_data = {
req_dict["id"] = get_uuid() "id": get_uuid(),
req_dict["tenant_id"] = current_user.id "tenant_id": current_user.id,
"name": server_name,
"url": url,
"server_type": server_type,
"headers": headers,
"variables": variables,
"timeout": timeout
}
e, _ = TenantService.get_by_id(current_user.id) e, _ = TenantService.get_by_id(current_user.id)
if not e: if not e:
@@ -234,47 +244,55 @@ async def create(
tools = server_tools[server_name] tools = server_tools[server_name]
tools = {tool["name"]: tool for tool in tools if isinstance(tool, dict) and "name" in tool} tools = {tool["name"]: tool for tool in tools if isinstance(tool, dict) and "name" in tool}
variables["tools"] = tools variables["tools"] = tools
req_dict["variables"] = variables req_data["variables"] = variables
if not MCPServerService.insert(**req_dict): if not MCPServerService.insert(**req_data):
return get_data_error_result("Failed to create MCP server.") return get_data_error_result("Failed to create MCP server.")
return get_json_result(data=req_dict) return get_json_result(data=req_data)
except Exception as e: except Exception as e:
return server_error_response(e) return server_error_response(e)
@router.post("/update") @router.post("/update")
async def update( async def update(
req: UpdateMCPRequest, request: UpdateMCPRequest,
current_user = Depends(get_current_user) current_user: User = Depends(get_current_user)
): ):
mcp_id = req.mcp_id mcp_id = request.mcp_id
e, mcp_server = MCPServerService.get_by_id(mcp_id) e, mcp_server = MCPServerService.get_by_id(mcp_id)
if not e or mcp_server.tenant_id != current_user.id: if not e or mcp_server.tenant_id != current_user.id:
return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}") return get_data_error_result(message=f"Cannot find MCP server {mcp_id} for user {current_user.id}")
server_type = req.server_type if req.server_type is not None else mcp_server.server_type server_type = request.server_type or mcp_server.server_type
if server_type and server_type not in VALID_MCP_SERVER_TYPES: if server_type and server_type not in VALID_MCP_SERVER_TYPES:
return get_data_error_result(message="Unsupported MCP server type.") return get_data_error_result(message="Unsupported MCP server type.")
server_name = req.name if req.name is not None else mcp_server.name
server_name = request.name or mcp_server.name
if server_name and len(server_name.encode("utf-8")) > 255: if server_name and len(server_name.encode("utf-8")) > 255:
return get_data_error_result(message=f"Invalid MCP name or length is {len(server_name)} which is large than 255.") return get_data_error_result(message=f"Invalid MCP name or length is {len(server_name)} which is large than 255.")
url = req.url if req.url is not None else mcp_server.url
url = request.url or mcp_server.url
if not url: if not url:
return get_data_error_result(message="Invalid url.") return get_data_error_result(message="Invalid url.")
headers = safe_json_parse(req.headers if req.headers is not None else mcp_server.headers) headers = safe_json_parse(request.headers or mcp_server.headers)
variables = safe_json_parse(req.variables if req.variables is not None else mcp_server.variables) variables = safe_json_parse(request.variables or mcp_server.variables)
variables.pop("tools", None) variables.pop("tools", None)
timeout = req.timeout if req.timeout is not None else 10 timeout = request.timeout or 10
try: try:
req_dict = req.dict(exclude_unset=True) req_data = {
req_dict["tenant_id"] = current_user.id "tenant_id": current_user.id,
req_dict.pop("mcp_id", None) "id": mcp_id,
req_dict["id"] = mcp_id "name": server_name,
"url": url,
"server_type": server_type,
"headers": headers,
"variables": variables,
"timeout": timeout
}
mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers) mcp_server = MCPServer(id=server_name, name=server_name, url=url, server_type=server_type, variables=variables, headers=headers)
server_tools, err_message = get_mcp_tools([mcp_server], timeout) server_tools, err_message = get_mcp_tools([mcp_server], timeout)
@@ -284,12 +302,12 @@ async def update(
tools = server_tools[server_name] tools = server_tools[server_name]
tools = {tool["name"]: tool for tool in tools if isinstance(tool, dict) and "name" in tool} tools = {tool["name"]: tool for tool in tools if isinstance(tool, dict) and "name" in tool}
variables["tools"] = tools variables["tools"] = tools
req_dict["variables"] = variables req_data["variables"] = variables
if not MCPServerService.filter_update([MCPServer.id == mcp_id, MCPServer.tenant_id == current_user.id], req_dict): if not MCPServerService.filter_update([MCPServer.id == mcp_id, MCPServer.tenant_id == current_user.id], req_data):
return get_data_error_result(message="Failed to updated MCP server.") return get_data_error_result(message="Failed to updated MCP server.")
e, updated_mcp = MCPServerService.get_by_id(req_dict["id"]) e, updated_mcp = MCPServerService.get_by_id(req_data["id"])
if not e: if not e:
return get_data_error_result(message="Failed to fetch updated MCP server.") return get_data_error_result(message="Failed to fetch updated MCP server.")
@@ -298,16 +316,14 @@ async def update(
return server_error_response(e) return server_error_response(e)
@manager.route("/rm", methods=["POST"]) # noqa: F821 @router.post("/rm")
@login_required async def rm(
@validate_request("mcp_ids") request: RemoveMCPRequest,
def rm() -> Response: current_user: User = Depends(get_current_user)
req = request.get_json() ):
mcp_ids = req.get("mcp_ids", []) mcp_ids = request.mcp_ids
try: try:
req["tenant_id"] = current_user.id
if not MCPServerService.delete_by_ids(mcp_ids): if not MCPServerService.delete_by_ids(mcp_ids):
return get_data_error_result(message=f"Failed to delete MCP servers {mcp_ids}") return get_data_error_result(message=f"Failed to delete MCP servers {mcp_ids}")
@@ -316,16 +332,16 @@ def rm() -> Response:
return server_error_response(e) return server_error_response(e)
@manager.route("/import", methods=["POST"]) # noqa: F821 @router.post("/import")
@login_required async def import_multiple(
@validate_request("mcpServers") request: ImportMCPRequest,
def import_multiple() -> Response: current_user: User = Depends(get_current_user)
req = request.get_json() ):
servers = req.get("mcpServers", {}) servers = request.mcpServers
if not servers: if not servers:
return get_data_error_result(message="No MCP servers provided.") return get_data_error_result(message="No MCP servers provided.")
timeout = get_float(req, "timeout", 10) timeout = request.timeout or 10
results = [] results = []
try: try:
@@ -383,12 +399,12 @@ def import_multiple() -> Response:
return server_error_response(e) return server_error_response(e)
@manager.route("/export", methods=["POST"]) # noqa: F821 @router.post("/export")
@login_required async def export_multiple(
@validate_request("mcp_ids") request: ExportMCPRequest,
def export_multiple() -> Response: current_user: User = Depends(get_current_user)
req = request.get_json() ):
mcp_ids = req.get("mcp_ids", []) mcp_ids = request.mcp_ids
if not mcp_ids: if not mcp_ids:
return get_data_error_result(message="No MCP server IDs provided.") return get_data_error_result(message="No MCP server IDs provided.")
@@ -415,16 +431,13 @@ def export_multiple() -> Response:
return server_error_response(e) return server_error_response(e)
@manager.route("/list_tools", methods=["POST"]) # noqa: F821 @router.post("/list_tools")
@login_required async def list_tools(req: ListToolsRequest, current_user: User = Depends(get_current_user)):
@validate_request("mcp_ids") mcp_ids = req.mcp_ids
def list_tools() -> Response:
req = request.get_json()
mcp_ids = req.get("mcp_ids", [])
if not mcp_ids: if not mcp_ids:
return get_data_error_result(message="No MCP server IDs provided.") return get_data_error_result(message="No MCP server IDs provided.")
timeout = get_float(req, "timeout", 10) timeout = req.timeout
results = {} results = {}
tool_call_sessions = [] tool_call_sessions = []
@@ -462,19 +475,16 @@ def list_tools() -> Response:
close_multiple_mcp_toolcall_sessions(tool_call_sessions) close_multiple_mcp_toolcall_sessions(tool_call_sessions)
@manager.route("/test_tool", methods=["POST"]) # noqa: F821 @router.post("/test_tool")
@login_required async def test_tool(req: TestToolRequest, current_user: User = Depends(get_current_user)):
@validate_request("mcp_id", "tool_name", "arguments") mcp_id = req.mcp_id
def test_tool() -> Response:
req = request.get_json()
mcp_id = req.get("mcp_id", "")
if not mcp_id: if not mcp_id:
return get_data_error_result(message="No MCP server ID provided.") return get_data_error_result(message="No MCP server ID provided.")
timeout = get_float(req, "timeout", 10) timeout = req.timeout
tool_name = req.get("tool_name", "") tool_name = req.tool_name
arguments = req.get("arguments", {}) arguments = req.arguments
if not all([tool_name, arguments]): if not all([tool_name, arguments]):
return get_data_error_result(message="Require provide tool name and arguments.") return get_data_error_result(message="Require provide tool name and arguments.")
@@ -495,15 +505,12 @@ def test_tool() -> Response:
return server_error_response(e) return server_error_response(e)
@manager.route("/cache_tools", methods=["POST"]) # noqa: F821 @router.post("/cache_tools")
@login_required async def cache_tool(req: CacheToolsRequest, current_user: User = Depends(get_current_user)):
@validate_request("mcp_id", "tools") mcp_id = req.mcp_id
def cache_tool() -> Response:
req = request.get_json()
mcp_id = req.get("mcp_id", "")
if not mcp_id: if not mcp_id:
return get_data_error_result(message="No MCP server ID provided.") return get_data_error_result(message="No MCP server ID provided.")
tools = req.get("tools", []) tools = req.tools
e, mcp_server = MCPServerService.get_by_id(mcp_id) e, mcp_server = MCPServerService.get_by_id(mcp_id)
if not e or mcp_server.tenant_id != current_user.id: if not e or mcp_server.tenant_id != current_user.id:
@@ -519,22 +526,19 @@ def cache_tool() -> Response:
return get_json_result(data=tools) return get_json_result(data=tools)
@manager.route("/test_mcp", methods=["POST"]) # noqa: F821 @router.post("/test_mcp")
@validate_request("url", "server_type") async def test_mcp(req: TestMCPRequest):
def test_mcp() -> Response: url = req.url
req = request.get_json()
url = req.get("url", "")
if not url: if not url:
return get_data_error_result(message="Invalid MCP url.") return get_data_error_result(message="Invalid MCP url.")
server_type = req.get("server_type", "") server_type = req.server_type
if server_type not in VALID_MCP_SERVER_TYPES: if server_type not in VALID_MCP_SERVER_TYPES:
return get_data_error_result(message="Unsupported MCP server type.") return get_data_error_result(message="Unsupported MCP server type.")
timeout = get_float(req, "timeout", 10) timeout = req.timeout
headers = safe_json_parse(req.get("headers", {})) headers = req.headers
variables = safe_json_parse(req.get("variables", {})) variables = req.variables
mcp_server = MCPServer(id=f"{server_type}: {url}", server_type=server_type, url=url, headers=headers, variables=variables) mcp_server = MCPServer(id=f"{server_type}: {url}", server_type=server_type, url=url, headers=headers, variables=variables)