From b961d4eb499071c0c60e24f429c20d1e6a908a32 Mon Sep 17 00:00:00 2001 From: Gonzalo Gasca Meza Date: Fri, 12 Jul 2024 09:42:06 -0700 Subject: [PATCH] Pass session_id during Websocket connect (#1440) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- jupyter_server/gateway/connections.py | 2 ++ tests/test_gateway.py | 36 +++++++++++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/jupyter_server/gateway/connections.py b/jupyter_server/gateway/connections.py index 8027a822cc..d4dde730fa 100644 --- a/jupyter_server/gateway/connections.py +++ b/jupyter_server/gateway/connections.py @@ -47,6 +47,8 @@ async def connect(self): url_escape(self.kernel_id), "channels", ) + if self.session_id: + ws_url += f"?session_id={url_escape(self.session_id)}" self.log.info(f"Connecting to {ws_url}") kwargs: dict[str, Any] = {} kwargs = GatewayClient.instance().load_connection_args(**kwargs) diff --git a/tests/test_gateway.py b/tests/test_gateway.py index f0033c278e..569268d833 100644 --- a/tests/test_gateway.py +++ b/tests/test_gateway.py @@ -200,6 +200,7 @@ async def mock_gateway_request(url, **kwargs): mocked_gateway = patch("jupyter_server.gateway.managers.gateway_request", mock_gateway_request) +mock_gateway_ws_url = "ws://mock-gateway-server:8889" mock_gateway_url = "http://mock-gateway-server:8889" mock_http_user = "alice" @@ -733,6 +734,41 @@ async def test_websocket_connection_closed(init_gateway, jp_serverapp, jp_fetch, pytest.fail(f"Logs contain an error: {message}") +@patch("tornado.websocket.websocket_connect", mock_websocket_connect()) +async def test_websocket_connection_with_session_id(init_gateway, jp_serverapp, jp_fetch, caplog): + # Create the session and kernel and get the kernel manager... + kernel_id = await create_kernel(jp_fetch, "kspec_foo") + km: GatewayKernelManager = jp_serverapp.kernel_manager.get_kernel(kernel_id) + + # Create the KernelWebsocketHandler... + request = HTTPServerRequest("foo", "GET") + request.connection = MagicMock() + handler = KernelWebsocketHandler(jp_serverapp.web_app, request) + # Create the GatewayWebSocketConnection and attach it to the handler... + with mocked_gateway: + conn = GatewayWebSocketConnection(parent=km, websocket_handler=handler) + handler.connection = conn + await conn.connect() + assert conn.session_id != None + expected_ws_url = ( + f"{mock_gateway_ws_url}/api/kernels/{kernel_id}/channels?session_id={conn.session_id}" + ) + assert ( + expected_ws_url in caplog.text + ), "WebSocket URL does not contain the expected session_id." + + # Processing websocket messages happens in separate coroutines and any + # errors in that process will show up in logs, but not bubble up to the + # caller. + # + # To check for these, we wait for the server to stop and then check the + # logs for errors. + await jp_serverapp._cleanup() + for _, level, message in caplog.record_tuples: + if level >= logging.ERROR: + pytest.fail(f"Logs contain an error: {message}") + + # # Test methods below... #