From 0d287e6f0617382e5fdbe1213e7e08cc75f1668c Mon Sep 17 00:00:00 2001 From: Michael Marchetti Date: Fri, 24 Feb 2023 09:50:16 -0500 Subject: [PATCH 1/2] Fixes await management for AsyncContentsManager. --- rsconnect_jupyter/__init__.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/rsconnect_jupyter/__init__.py b/rsconnect_jupyter/__init__.py index e0bde7c4..dd532bd4 100644 --- a/rsconnect_jupyter/__init__.py +++ b/rsconnect_jupyter/__init__.py @@ -1,4 +1,5 @@ import hashlib +import inspect import json import os import sys @@ -73,7 +74,7 @@ def md5(s): # https://github.com/jupyter/notebook/blob/master/notebook/base/handlers.py class EndpointHandler(APIHandler): @web.authenticated - def post(self, action): + async def post(self, action): data = self.get_json_body() if action == "verify_server": @@ -161,6 +162,11 @@ def post(self, action): hide_tagged_input = data.get("hide_tagged_input", False) model = self.contents_manager.get(path=nb_path) + if inspect.isawaitable(model): + # The default ContentsManager is now async, + # but we handle both cases. + model = await model + if model["type"] != "notebook": # not a notebook raise web.HTTPError(400, "Not a notebook: %s" % nb_path) From 51a6ca023e7957630b98d4dae063facf9ec23cd8 Mon Sep 17 00:00:00 2001 From: tdstein Date: Mon, 27 Feb 2023 15:06:27 -0500 Subject: [PATCH 2/2] Moves ContextManager model fetching to a module. Moves ContextManager model fetching to a seperate module (`rsconnect_jupyter.managers`) to faciliate isolated unit testing. --- rsconnect_jupyter/__init__.py | 9 +++------ rsconnect_jupyter/managers.py | 21 +++++++++++++++++++++ setup.cfg | 5 +++++ tests/test_managers.py | 30 ++++++++++++++++++++++++++++++ 4 files changed, 59 insertions(+), 6 deletions(-) create mode 100644 rsconnect_jupyter/managers.py create mode 100644 tests/test_managers.py diff --git a/rsconnect_jupyter/__init__.py b/rsconnect_jupyter/__init__.py index dd532bd4..b90d8848 100644 --- a/rsconnect_jupyter/__init__.py +++ b/rsconnect_jupyter/__init__.py @@ -1,5 +1,4 @@ import hashlib -import inspect import json import os import sys @@ -30,6 +29,8 @@ from ssl import SSLError +from rsconnect_jupyter.managers import get_model + try: from rsconnect_jupyter.version import version as __version__ # noqa except ImportError: @@ -161,11 +162,7 @@ async def post(self, action): hide_all_input = data.get("hide_all_input", False) hide_tagged_input = data.get("hide_tagged_input", False) - model = self.contents_manager.get(path=nb_path) - if inspect.isawaitable(model): - # The default ContentsManager is now async, - # but we handle both cases. - model = await model + model = await get_model(self.contents_manager, nb_path) if model["type"] != "notebook": # not a notebook diff --git a/rsconnect_jupyter/managers.py b/rsconnect_jupyter/managers.py new file mode 100644 index 00000000..823d1d35 --- /dev/null +++ b/rsconnect_jupyter/managers.py @@ -0,0 +1,21 @@ +from inspect import isawaitable +from typing import Union, Awaitable + +from jupyter_server.services.contents.manager import ContentsManager + + +async def get_model(manager: ContentsManager, path: str) -> dict: + """ + Gets the model via the ContentsManager. + + If the ContentsManager is async (e.g., AsyncContentsManager), then an await is issued. Otherwise, + the model is returned under synchronous expectations. + + :param manager: A Jupyter ContentsManager + :param path: The model path + :return: The model + """ + model: Union[dict, Awaitable[dict]] = manager.get(path) + if isawaitable(model): + model = await model + return model diff --git a/setup.cfg b/setup.cfg index 9a604451..6f8e37da 100644 --- a/setup.cfg +++ b/setup.cfg @@ -28,3 +28,8 @@ packages = rsconnect_jupyter python_requires = >=3.7 include_package_data = true zip_safe = false + +[options.extras_require] +test = + black + pytest \ No newline at end of file diff --git a/tests/test_managers.py b/tests/test_managers.py new file mode 100644 index 00000000..78883349 --- /dev/null +++ b/tests/test_managers.py @@ -0,0 +1,30 @@ +from unittest import TestCase +from unittest.mock import Mock, MagicMock, AsyncMock + +from rsconnect_jupyter.managers import ContentsManager, get_model, isawaitable + + +class GetModelTestCase(TestCase): + async def test_synchronous(self): + model = AsyncMock() + manager = MagicMock(spec=ContentsManager) + manager.get = Mock(return_value=model) + path = "path" + spy = Mock(wraps=isawaitable, return_value=False) + res = await get_model(manager, path) + self.assertEqual(res, model) + model.assert_not_awaited() + manager.get.assert_called_once_with(path) + spy.assert_called_once_with(model) + + async def test_asynchronous(self): + model = AsyncMock() + manager = MagicMock(spec=ContentsManager) + manager.get = Mock(return_value=model) + path = "path" + spy = Mock(wraps=isawaitable, return_value=True) + res = await get_model(manager, path) + self.assertEqual(res, model) + model.assert_awaited() + manager.get.assert_called_once_with(path) + spy.assert_called_once_with(model)