diff --git a/jupyter_server/auth/identity.py b/jupyter_server/auth/identity.py index adeb567b5b..12e53937df 100644 --- a/jupyter_server/auth/identity.py +++ b/jupyter_server/auth/identity.py @@ -17,8 +17,10 @@ import uuid from dataclasses import asdict, dataclass from http.cookies import Morsel +from urllib.parse import unquote from tornado import escape, httputil, web +from tornado.websocket import WebSocketHandler from traitlets import Bool, Dict, Type, Unicode, default from traitlets.config import LoggingConfigurable @@ -106,6 +108,9 @@ def _backward_compat_user(got_user: t.Any) -> User: raise ValueError(msg) +_TOKEN_SUBPROTOCOL = "v1.token.websocket.jupyter.org" + + class IdentityProvider(LoggingConfigurable): """ Interface for providing identity management and authentication. @@ -424,6 +429,21 @@ def get_token(self, handler: web.RequestHandler) -> str | None: m = self.auth_header_pat.match(handler.request.headers.get("Authorization", "")) if m: user_token = m.group(2) + if not user_token and isinstance(handler, WebSocketHandler): + subprotocol_header = handler.request.headers.get("Sec-WebSocket-Protocol") + if subprotocol_header: + subprotocols = [s.strip() for s in subprotocol_header.split(",")] + for subprotocol in subprotocols: + if subprotocol.startswith(_TOKEN_SUBPROTOCOL + "."): + user_token = subprotocol[len(_TOKEN_SUBPROTOCOL) + 1 :] + try: + user_token = unquote(user_token) + except ValueError: + # leave tokens that fail to decode + # these won't be accepted, but proceed with validation + pass + break + return user_token async def get_user_token(self, handler: web.RequestHandler) -> User | None: diff --git a/jupyter_server/base/websocket.py b/jupyter_server/base/websocket.py index 8780d7afc4..a2b8812346 100644 --- a/jupyter_server/base/websocket.py +++ b/jupyter_server/base/websocket.py @@ -1,4 +1,7 @@ """Base websocket classes.""" + +from __future__ import annotations + import re import warnings from typing import Optional, no_type_check @@ -164,3 +167,14 @@ def send_ping(self): def on_pong(self, data): """Handle a pong message.""" self.last_pong = ioloop.IOLoop.current().time() + + def select_subprotocol(self, subprotocols: list[str]) -> str | None: + # default subprotocol + # some clients (Chrome) + # require selected subprotocol to match one of the requested subprotocols + # otherwise connection is rejected + token_subprotocol = "v1.token.websocket.jupyter.org" + if token_subprotocol in subprotocols: + return token_subprotocol + else: + return None diff --git a/jupyter_server/services/events/handlers.py b/jupyter_server/services/events/handlers.py index ce580048f2..7672c4aeda 100644 --- a/jupyter_server/services/events/handlers.py +++ b/jupyter_server/services/events/handlers.py @@ -14,6 +14,7 @@ from jupyter_server.auth.decorator import authorized, ws_authenticated from jupyter_server.base.handlers import JupyterHandler +from jupyter_server.base.websocket import WebSocketMixin from ...base.handlers import APIHandler @@ -21,6 +22,7 @@ class SubscribeWebsocket( + WebSocketMixin, JupyterHandler, websocket.WebSocketHandler, ): diff --git a/jupyter_server/services/kernels/websocket.py b/jupyter_server/services/kernels/websocket.py index 374df76f3e..c5682fca7c 100644 --- a/jupyter_server/services/kernels/websocket.py +++ b/jupyter_server/services/kernels/websocket.py @@ -90,6 +90,12 @@ def select_subprotocol(self, subprotocols): preferred_protocol = "v1.kernel.websocket.jupyter.org" elif preferred_protocol == "": preferred_protocol = None - selected_subprotocol = preferred_protocol if preferred_protocol in subprotocols else None + + # super() subprotocol enables token authentication via subprotocol + selected_subprotocol = ( + preferred_protocol + if preferred_protocol in subprotocols + else super().select_subprotocol(subprotocols) + ) # None is the default, "legacy" protocol return selected_subprotocol diff --git a/tests/base/test_websocket.py b/tests/base/test_websocket.py index 17ee009227..690cd763c0 100644 --- a/tests/base/test_websocket.py +++ b/tests/base/test_websocket.py @@ -10,7 +10,7 @@ from tornado.websocket import WebSocketClosedError, WebSocketHandler from jupyter_server.auth import IdentityProvider, User -from jupyter_server.auth.decorator import allow_unauthenticated +from jupyter_server.auth.decorator import allow_unauthenticated, ws_authenticated from jupyter_server.base.handlers import JupyterHandler from jupyter_server.base.websocket import WebSocketMixin from jupyter_server.serverapp import ServerApp @@ -75,6 +75,12 @@ class NoAuthRulesWebsocketHandler(MockJupyterHandler): pass +class AuthenticatedWebsocketHandler(MockJupyterHandler): + @ws_authenticated + def get(self, *args, **kwargs) -> None: + return super().get(*args, **kwargs) + + class PermissiveWebsocketHandler(MockJupyterHandler): @allow_unauthenticated def get(self, *args, **kwargs) -> None: @@ -126,6 +132,31 @@ async def test_websocket_auth_required(jp_serverapp, jp_ws_fetch): assert exception.value.code == 403 +async def test_websocket_token_subprotocol_auth(jp_serverapp, jp_ws_fetch): + app: ServerApp = jp_serverapp + app.web_app.add_handlers( + ".*$", + [ + (url_path_join(app.base_url, "ws"), AuthenticatedWebsocketHandler), + ], + ) + + with pytest.raises(HTTPClientError) as exception: + ws = await jp_ws_fetch("ws", headers={"Authorization": ""}) + assert exception.value.code == 403 + token = jp_serverapp.identity_provider.token + ws = await jp_ws_fetch( + "ws", + headers={ + "Authorization": "", + "Sec-WebSocket-Protocol": "v1.kernel.websocket.jupyter.org, v1.token.websocket.jupyter.org, v1.token.websocket.jupyter.org." + + token, + }, + ) + assert ws.protocol.selected_subprotocol == "v1.token.websocket.jupyter.org" + ws.close() + + class IndiscriminateIdentityProvider(IdentityProvider): async def get_user(self, handler): return User(username="test")