diff --git a/src/nbmake/nb_run.py b/src/nbmake/nb_run.py index 697e559..867274d 100644 --- a/src/nbmake/nb_run.py +++ b/src/nbmake/nb_run.py @@ -14,6 +14,10 @@ NB_VERSION = 4 +class CellImportError(Exception): + pass + + class NotebookRun: filename: Path verbose: bool @@ -24,11 +28,13 @@ def __init__( default_timeout: int, verbose: bool = False, kernel: Optional[str] = None, + find_import_errors: bool = False, ) -> None: self.filename = filename self.verbose = verbose self.default_timeout = default_timeout self.kernel = kernel + self.find_import_errors = find_import_errors def execute( self, @@ -56,7 +62,8 @@ def execute( c = NotebookClient( nb, timeout=timeout, - allow_errors=allow_errors, + allow_errors=allow_errors or self.find_import_errors, + interrupt_on_timeout=self.find_import_errors, record_timing=True, **extra_kwargs, ) @@ -68,6 +75,11 @@ async def apply_mocks( if any(o["output_type"] == "error" for o in cell["outputs"]): execute_reply["content"]["status"] = "error" + if "ename" in execute_reply["content"]: + if execute_reply["content"]["ename"] == "ModuleNotFoundError": + if self.find_import_errors: + raise CellImportError() + if c.kc is None: raise Exception("there is no kernelclient") mocks: Dict[str, Any] = ( @@ -85,6 +97,8 @@ async def apply_mocks( c.on_cell_executed = apply_mocks c.execute(cwd=self.filename.parent) + except CellImportError: + error = self._get_error(nb) except CellExecutionError: error = self._get_error(nb) except CellTimeoutError as err: diff --git a/src/nbmake/pytest_items.py b/src/nbmake/pytest_items.py index 01d0b4c..ed69d9b 100644 --- a/src/nbmake/pytest_items.py +++ b/src/nbmake/pytest_items.py @@ -47,6 +47,7 @@ def runtest(self): option.nbmake_timeout, verbose=bool(option.verbose), kernel=option.nbmake_kernel, + find_import_errors=option.nbmake_find_import_errors, ) res: NotebookResult = run.execute() diff --git a/src/nbmake/pytest_plugin.py b/src/nbmake/pytest_plugin.py index e818205..eafa323 100644 --- a/src/nbmake/pytest_plugin.py +++ b/src/nbmake/pytest_plugin.py @@ -42,6 +42,13 @@ def pytest_addoption(parser: Any): type=str, ) + group.addoption( + "--nbmake-find-import-errors", + action="store_true", + help="Runs all cells, only reports import errors", + default=False, + ) + def pytest_collect_file(path: str, parent: Any) -> Optional[Any]: opt = parent.config.option diff --git a/tests/resources/import_errs.ipynb b/tests/resources/import_errs.ipynb new file mode 100644 index 0000000..ef7b904 --- /dev/null +++ b/tests/resources/import_errs.ipynb @@ -0,0 +1,61 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import itertools\n", + "\n", + "1/0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import time\n", + "time.sleep(600)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import lkjlkj" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pickle" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.10.8 64-bit", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.8" + }, + "vscode": { + "interpreter": { + "hash": "949777d72b0d2535278d3dc13498b2535136f6dfe0678499012e853ee9abcab1" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/test_nb_run.py b/tests/test_nb_run.py index 9d7f7ed..38afe68 100644 --- a/tests/test_nb_run.py +++ b/tests/test_nb_run.py @@ -133,3 +133,10 @@ def test_when_empty_then_succeeds(self, testdir2: Never): run = NotebookRun(nb, 300) res: NotebookResult = run.execute() assert res.error is None + + def test_when_import_error_then_fails(self, testdir2: Never): + nb = Path(__file__).parent / "resources" / "import_errs.ipynb" + run = NotebookRun(nb, 1, find_import_errors=True) + res: NotebookResult = run.execute() + assert res.error is not None + assert "ModuleNotFoundError" in res.error.summary diff --git a/tests/test_pytest_plugin.py b/tests/test_pytest_plugin.py index 22f22c9..ad6f363 100644 --- a/tests/test_pytest_plugin.py +++ b/tests/test_pytest_plugin.py @@ -184,3 +184,14 @@ def test_when_kernel_passed_then_override(pytester: Pytester, testdir2: Never): hook_recorder = pytester.inline_run("--nbmake", "--nbmake-kernel=python3") assert hook_recorder.ret == ExitCode.OK + + +def test_when_no_import_errs_then_pass(pytester: Pytester, testdir2: Never): + write_nb( + ["import itertools", "1/0", "import pickle"], + Path(pytester.path) / "a.ipynb", + ) + + hook_recorder = pytester.inline_run("--nbmake", "--nbmake-find-import-errors") + + assert hook_recorder.ret == ExitCode.OK