diff --git a/pytorch_pfn_extras/training/extensions/print_report_notebook.py b/pytorch_pfn_extras/training/extensions/print_report_notebook.py index fb617f57..a817e5ea 100644 --- a/pytorch_pfn_extras/training/extensions/print_report_notebook.py +++ b/pytorch_pfn_extras/training/extensions/print_report_notebook.py @@ -56,4 +56,7 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None: if self._infer_entries: # --- update entries --- self._update_entries(log_report) + for entry in self._entries: + if entry not in df.columns: + df[entry] = None self._widget.value = df[self._entries].to_html(index=False, na_rep="") diff --git a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_print_report_notebook.py b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_print_report_notebook.py index 1eb8a222..82ae43e6 100644 --- a/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_print_report_notebook.py +++ b/tests/pytorch_pfn_extras_tests/training_tests/extensions_tests/test_print_report_notebook.py @@ -37,5 +37,37 @@ def test_run_print_report_notebook(tmp_path: pathlib.Path): pass +@pytest.mark.skipif( + not _ipython_module_available or not _pandas_available, + reason="print report notebook import failed, " + "maybe ipython is not installed", +) +def test_run_print_report_notebook_with_entry(tmp_path: pathlib.Path): + max_epochs = 5 + iters_per_epoch = 5 + manager = ppe.training.ExtensionsManager( + {}, + {}, + max_epochs, + iters_per_epoch=iters_per_epoch, + out_dir=str(tmp_path), + ) + + out = io.StringIO() + log_report = ppe.training.extensions.LogReport() + manager.extend(log_report) + extension = ppe.training.extensions.PrintReportNotebook( + ["epoch", "iteration"], out=out + ) + manager.extend(extension) + + for _ in range(max_epochs): + for _ in range(iters_per_epoch): + with manager.run_iteration(): + # Only test it runs without fail + # The value is not tested now... + pass + + if __name__ == "__main__": pytest.main([__file__, "-v", "-s"])