[Performance]Add complete Token, JWT, OAuth authentication system (#52)
* 0.5.1 Version * fix 0.5.1 schema async bug * fix security bug * fix security bug * Add complete Token, JWT, OAuth authentication system * Add complete Token, JWT, OAuth authentication system * Add complete Token, JWT, OAuth authentication system * Add complete Token, JWT, OAuth authentication system
This commit is contained in:
@@ -246,6 +246,40 @@ class DorisServer:
|
||||
self.logger = get_logger(f"{__name__}.DorisServer")
|
||||
self._setup_handlers()
|
||||
|
||||
async def _extract_auth_info_from_scope(self, scope, headers):
|
||||
"""Extract authentication information from ASGI scope and headers"""
|
||||
auth_info = {}
|
||||
|
||||
# Extract client IP
|
||||
client = scope.get("client")
|
||||
if client:
|
||||
auth_info["client_ip"] = client[0]
|
||||
else:
|
||||
auth_info["client_ip"] = "unknown"
|
||||
|
||||
# Extract token from Authorization header
|
||||
authorization = headers.get(b'authorization', b'').decode('utf-8')
|
||||
if authorization:
|
||||
if authorization.startswith('Bearer '):
|
||||
auth_info["token"] = authorization[7:]
|
||||
auth_info["authorization"] = authorization
|
||||
elif authorization.startswith('Token '):
|
||||
auth_info["token"] = authorization[6:]
|
||||
auth_info["authorization"] = authorization
|
||||
|
||||
# Extract token from query parameters (for compatibility)
|
||||
query_string = scope.get("query_string", b"").decode('utf-8')
|
||||
if query_string and "token=" in query_string:
|
||||
import urllib.parse
|
||||
query_params = urllib.parse.parse_qs(query_string)
|
||||
if "token" in query_params:
|
||||
auth_info["token"] = query_params["token"][0]
|
||||
|
||||
# If no token found, this will be handled by the authentication system
|
||||
# (either return anonymous context if auth disabled, or raise error if auth enabled)
|
||||
|
||||
return auth_info
|
||||
|
||||
def _get_mcp_capabilities(self):
|
||||
"""Get MCP capabilities with version compatibility"""
|
||||
try:
|
||||
@@ -390,6 +424,10 @@ class DorisServer:
|
||||
self.logger.info("Starting Doris MCP Server (stdio mode)")
|
||||
|
||||
try:
|
||||
# Initialize security manager first (includes JWT setup if enabled)
|
||||
await self.security_manager.initialize()
|
||||
self.logger.info("Security manager initialization completed")
|
||||
|
||||
# Ensure connection manager is initialized
|
||||
await self.connection_manager.initialize()
|
||||
self.logger.info("Connection manager initialization completed")
|
||||
@@ -456,6 +494,10 @@ class DorisServer:
|
||||
self.logger.info(f"Starting Doris MCP Server (Streamable HTTP mode) - {host}:{port}, workers: {workers}")
|
||||
|
||||
try:
|
||||
# Initialize security manager first (includes JWT setup if enabled)
|
||||
await self.security_manager.initialize()
|
||||
self.logger.info("Security manager initialization completed")
|
||||
|
||||
# Ensure connection manager is initialized
|
||||
await self.connection_manager.initialize()
|
||||
|
||||
@@ -482,6 +524,44 @@ class DorisServer:
|
||||
async def health_check(request):
|
||||
return JSONResponse({"status": "healthy", "service": "doris-mcp-server"})
|
||||
|
||||
# OAuth endpoints
|
||||
from .auth.oauth_handlers import OAuthHandlers
|
||||
oauth_handlers = OAuthHandlers(self.security_manager)
|
||||
|
||||
async def oauth_login(request):
|
||||
return await oauth_handlers.handle_login(request)
|
||||
|
||||
async def oauth_callback(request):
|
||||
return await oauth_handlers.handle_callback(request)
|
||||
|
||||
async def oauth_provider_info(request):
|
||||
return await oauth_handlers.handle_provider_info(request)
|
||||
|
||||
async def oauth_demo(request):
|
||||
return await oauth_handlers.handle_demo_page(request)
|
||||
|
||||
# Token management endpoints
|
||||
from .auth.token_handlers import TokenHandlers
|
||||
token_handlers = TokenHandlers(self.security_manager)
|
||||
|
||||
async def token_create(request):
|
||||
return await token_handlers.handle_create_token(request)
|
||||
|
||||
async def token_revoke(request):
|
||||
return await token_handlers.handle_revoke_token(request)
|
||||
|
||||
async def token_list(request):
|
||||
return await token_handlers.handle_list_tokens(request)
|
||||
|
||||
async def token_stats(request):
|
||||
return await token_handlers.handle_token_stats(request)
|
||||
|
||||
async def token_cleanup(request):
|
||||
return await token_handlers.handle_cleanup_tokens(request)
|
||||
|
||||
async def token_demo(request):
|
||||
return await token_handlers.handle_demo_page(request)
|
||||
|
||||
# Lifecycle manager - simplified since we manage session_manager externally
|
||||
@contextlib.asynccontextmanager
|
||||
async def lifespan(app: Starlette) -> AsyncIterator[None]:
|
||||
@@ -497,6 +577,18 @@ class DorisServer:
|
||||
debug=True,
|
||||
routes=[
|
||||
Route("/health", health_check, methods=["GET"]),
|
||||
# OAuth endpoints
|
||||
Route("/auth/login", oauth_login, methods=["GET"]),
|
||||
Route("/auth/callback", oauth_callback, methods=["GET"]),
|
||||
Route("/auth/provider", oauth_provider_info, methods=["GET"]),
|
||||
Route("/auth/demo", oauth_demo, methods=["GET"]),
|
||||
# Token management endpoints
|
||||
Route("/token/create", token_create, methods=["GET", "POST"]),
|
||||
Route("/token/revoke", token_revoke, methods=["GET", "DELETE"]),
|
||||
Route("/token/list", token_list, methods=["GET"]),
|
||||
Route("/token/stats", token_stats, methods=["GET"]),
|
||||
Route("/token/cleanup", token_cleanup, methods=["GET", "POST"]),
|
||||
Route("/token/demo", token_demo, methods=["GET"]),
|
||||
],
|
||||
lifespan=lifespan,
|
||||
)
|
||||
@@ -514,8 +606,10 @@ class DorisServer:
|
||||
self.logger.info(f"Received request for path: {path}")
|
||||
|
||||
try:
|
||||
# Handle health check
|
||||
if path.startswith("/health"):
|
||||
# Handle health check, auth, and token management endpoints
|
||||
if (path.startswith("/health") or
|
||||
path.startswith("/auth/") or
|
||||
path.startswith("/token/")):
|
||||
await starlette_app(scope, receive, send)
|
||||
return
|
||||
|
||||
@@ -528,6 +622,29 @@ class DorisServer:
|
||||
self.logger.info(f"MCP Request - Method: {method}")
|
||||
self.logger.info(f"MCP Request - Headers: {headers}")
|
||||
|
||||
# Authentication check for MCP requests
|
||||
try:
|
||||
# Extract authentication information
|
||||
auth_info = await self._extract_auth_info_from_scope(scope, headers)
|
||||
|
||||
# Authenticate the request
|
||||
auth_context = await self.security_manager.authenticate_request(auth_info)
|
||||
self.logger.info(f"MCP request authenticated: token_id={auth_context.token_id}, client_ip={auth_context.client_ip}")
|
||||
|
||||
# Store auth context in scope for potential use by tools/resources
|
||||
scope["auth_context"] = auth_context
|
||||
|
||||
except Exception as auth_error:
|
||||
self.logger.error(f"MCP authentication failed: {auth_error}")
|
||||
# Return 401 Unauthorized
|
||||
from starlette.responses import JSONResponse
|
||||
response = JSONResponse(
|
||||
{"error": "Authentication required", "message": str(auth_error)},
|
||||
status_code=401
|
||||
)
|
||||
await response(scope, receive, send)
|
||||
return
|
||||
|
||||
# Handle Dify compatibility for GET requests
|
||||
if method == "GET":
|
||||
accept_header = headers.get(b'accept', b'').decode('utf-8')
|
||||
@@ -619,6 +736,10 @@ class DorisServer:
|
||||
"""Shutdown server"""
|
||||
self.logger.info("Shutting down Doris MCP Server")
|
||||
try:
|
||||
# Shutdown security manager first (includes JWT cleanup)
|
||||
await self.security_manager.shutdown()
|
||||
self.logger.info("Security manager shutdown completed")
|
||||
|
||||
await self.connection_manager.close()
|
||||
self.logger.info("Doris MCP Server has been shut down")
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user