[Performance]Add complete Token, JWT, OAuth authentication system (#52)

* 0.5.1 Version

* fix 0.5.1 schema async bug

* fix security bug

* fix security bug

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system

* Add complete Token, JWT, OAuth authentication system
This commit is contained in:
Yijia Su
2025-09-02 17:01:43 +08:00
committed by GitHub
parent c1e3b13851
commit c3d487ccdd
17 changed files with 4689 additions and 44 deletions

View File

@@ -246,6 +246,40 @@ class DorisServer:
self.logger = get_logger(f"{__name__}.DorisServer")
self._setup_handlers()
async def _extract_auth_info_from_scope(self, scope, headers):
"""Extract authentication information from ASGI scope and headers"""
auth_info = {}
# Extract client IP
client = scope.get("client")
if client:
auth_info["client_ip"] = client[0]
else:
auth_info["client_ip"] = "unknown"
# Extract token from Authorization header
authorization = headers.get(b'authorization', b'').decode('utf-8')
if authorization:
if authorization.startswith('Bearer '):
auth_info["token"] = authorization[7:]
auth_info["authorization"] = authorization
elif authorization.startswith('Token '):
auth_info["token"] = authorization[6:]
auth_info["authorization"] = authorization
# Extract token from query parameters (for compatibility)
query_string = scope.get("query_string", b"").decode('utf-8')
if query_string and "token=" in query_string:
import urllib.parse
query_params = urllib.parse.parse_qs(query_string)
if "token" in query_params:
auth_info["token"] = query_params["token"][0]
# If no token found, this will be handled by the authentication system
# (either return anonymous context if auth disabled, or raise error if auth enabled)
return auth_info
def _get_mcp_capabilities(self):
"""Get MCP capabilities with version compatibility"""
try:
@@ -390,6 +424,10 @@ class DorisServer:
self.logger.info("Starting Doris MCP Server (stdio mode)")
try:
# Initialize security manager first (includes JWT setup if enabled)
await self.security_manager.initialize()
self.logger.info("Security manager initialization completed")
# Ensure connection manager is initialized
await self.connection_manager.initialize()
self.logger.info("Connection manager initialization completed")
@@ -456,6 +494,10 @@ class DorisServer:
self.logger.info(f"Starting Doris MCP Server (Streamable HTTP mode) - {host}:{port}, workers: {workers}")
try:
# Initialize security manager first (includes JWT setup if enabled)
await self.security_manager.initialize()
self.logger.info("Security manager initialization completed")
# Ensure connection manager is initialized
await self.connection_manager.initialize()
@@ -482,6 +524,44 @@ class DorisServer:
async def health_check(request):
return JSONResponse({"status": "healthy", "service": "doris-mcp-server"})
# OAuth endpoints
from .auth.oauth_handlers import OAuthHandlers
oauth_handlers = OAuthHandlers(self.security_manager)
async def oauth_login(request):
return await oauth_handlers.handle_login(request)
async def oauth_callback(request):
return await oauth_handlers.handle_callback(request)
async def oauth_provider_info(request):
return await oauth_handlers.handle_provider_info(request)
async def oauth_demo(request):
return await oauth_handlers.handle_demo_page(request)
# Token management endpoints
from .auth.token_handlers import TokenHandlers
token_handlers = TokenHandlers(self.security_manager)
async def token_create(request):
return await token_handlers.handle_create_token(request)
async def token_revoke(request):
return await token_handlers.handle_revoke_token(request)
async def token_list(request):
return await token_handlers.handle_list_tokens(request)
async def token_stats(request):
return await token_handlers.handle_token_stats(request)
async def token_cleanup(request):
return await token_handlers.handle_cleanup_tokens(request)
async def token_demo(request):
return await token_handlers.handle_demo_page(request)
# Lifecycle manager - simplified since we manage session_manager externally
@contextlib.asynccontextmanager
async def lifespan(app: Starlette) -> AsyncIterator[None]:
@@ -497,6 +577,18 @@ class DorisServer:
debug=True,
routes=[
Route("/health", health_check, methods=["GET"]),
# OAuth endpoints
Route("/auth/login", oauth_login, methods=["GET"]),
Route("/auth/callback", oauth_callback, methods=["GET"]),
Route("/auth/provider", oauth_provider_info, methods=["GET"]),
Route("/auth/demo", oauth_demo, methods=["GET"]),
# Token management endpoints
Route("/token/create", token_create, methods=["GET", "POST"]),
Route("/token/revoke", token_revoke, methods=["GET", "DELETE"]),
Route("/token/list", token_list, methods=["GET"]),
Route("/token/stats", token_stats, methods=["GET"]),
Route("/token/cleanup", token_cleanup, methods=["GET", "POST"]),
Route("/token/demo", token_demo, methods=["GET"]),
],
lifespan=lifespan,
)
@@ -514,8 +606,10 @@ class DorisServer:
self.logger.info(f"Received request for path: {path}")
try:
# Handle health check
if path.startswith("/health"):
# Handle health check, auth, and token management endpoints
if (path.startswith("/health") or
path.startswith("/auth/") or
path.startswith("/token/")):
await starlette_app(scope, receive, send)
return
@@ -528,6 +622,29 @@ class DorisServer:
self.logger.info(f"MCP Request - Method: {method}")
self.logger.info(f"MCP Request - Headers: {headers}")
# Authentication check for MCP requests
try:
# Extract authentication information
auth_info = await self._extract_auth_info_from_scope(scope, headers)
# Authenticate the request
auth_context = await self.security_manager.authenticate_request(auth_info)
self.logger.info(f"MCP request authenticated: token_id={auth_context.token_id}, client_ip={auth_context.client_ip}")
# Store auth context in scope for potential use by tools/resources
scope["auth_context"] = auth_context
except Exception as auth_error:
self.logger.error(f"MCP authentication failed: {auth_error}")
# Return 401 Unauthorized
from starlette.responses import JSONResponse
response = JSONResponse(
{"error": "Authentication required", "message": str(auth_error)},
status_code=401
)
await response(scope, receive, send)
return
# Handle Dify compatibility for GET requests
if method == "GET":
accept_header = headers.get(b'accept', b'').decode('utf-8')
@@ -619,6 +736,10 @@ class DorisServer:
"""Shutdown server"""
self.logger.info("Shutting down Doris MCP Server")
try:
# Shutdown security manager first (includes JWT cleanup)
await self.security_manager.shutdown()
self.logger.info("Security manager shutdown completed")
await self.connection_manager.close()
self.logger.info("Doris MCP Server has been shut down")
except Exception as e: