diff --git a/rsconnect_jupyter/__init__.py b/rsconnect_jupyter/__init__.py index 86eb7df7..45826afd 100644 --- a/rsconnect_jupyter/__init__.py +++ b/rsconnect_jupyter/__init__.py @@ -30,6 +30,8 @@ from ssl import SSLError +from rsconnect_jupyter.managers import get_model + try: from rsconnect_jupyter.version import version as __version__ # noqa except ImportError: @@ -74,7 +76,7 @@ def md5(s): # https://github.com/jupyter/notebook/blob/master/notebook/base/handlers.py class EndpointHandler(APIHandler): @web.authenticated - def post(self, action): + async def post(self, action): data = self.get_json_body() if action == "verify_server": @@ -161,7 +163,8 @@ def post(self, action): hide_all_input = data.get("hide_all_input", False) hide_tagged_input = data.get("hide_tagged_input", False) - model = self.contents_manager.get(path=nb_path) + model = await get_model(self.contents_manager, nb_path) + if model["type"] != "notebook": # not a notebook raise web.HTTPError(400, "Not a notebook: %s" % nb_path) diff --git a/rsconnect_jupyter/managers.py b/rsconnect_jupyter/managers.py new file mode 100644 index 00000000..823d1d35 --- /dev/null +++ b/rsconnect_jupyter/managers.py @@ -0,0 +1,21 @@ +from inspect import isawaitable +from typing import Union, Awaitable + +from jupyter_server.services.contents.manager import ContentsManager + + +async def get_model(manager: ContentsManager, path: str) -> dict: + """ + Gets the model via the ContentsManager. + + If the ContentsManager is async (e.g., AsyncContentsManager), then an await is issued. Otherwise, + the model is returned under synchronous expectations. + + :param manager: A Jupyter ContentsManager + :param path: The model path + :return: The model + """ + model: Union[dict, Awaitable[dict]] = manager.get(path) + if isawaitable(model): + model = await model + return model diff --git a/setup.cfg b/setup.cfg index d3781c2a..7a11e16d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -28,3 +28,8 @@ packages = rsconnect_jupyter python_requires = >=3.7 include_package_data = true zip_safe = false + +[options.extras_require] +test = + black + pytest \ No newline at end of file diff --git a/tests/test_managers.py b/tests/test_managers.py new file mode 100644 index 00000000..78883349 --- /dev/null +++ b/tests/test_managers.py @@ -0,0 +1,30 @@ +from unittest import TestCase +from unittest.mock import Mock, MagicMock, AsyncMock + +from rsconnect_jupyter.managers import ContentsManager, get_model, isawaitable + + +class GetModelTestCase(TestCase): + async def test_synchronous(self): + model = AsyncMock() + manager = MagicMock(spec=ContentsManager) + manager.get = Mock(return_value=model) + path = "path" + spy = Mock(wraps=isawaitable, return_value=False) + res = await get_model(manager, path) + self.assertEqual(res, model) + model.assert_not_awaited() + manager.get.assert_called_once_with(path) + spy.assert_called_once_with(model) + + async def test_asynchronous(self): + model = AsyncMock() + manager = MagicMock(spec=ContentsManager) + manager.get = Mock(return_value=model) + path = "path" + spy = Mock(wraps=isawaitable, return_value=True) + res = await get_model(manager, path) + self.assertEqual(res, model) + model.assert_awaited() + manager.get.assert_called_once_with(path) + spy.assert_called_once_with(model)