Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,7 @@ def __call__(self, manager: ExtensionsManagerProtocol) -> None:
if self._infer_entries:
# --- update entries ---
self._update_entries(log_report)
for entry in self._entries:
if entry not in df.columns:
df[entry] = None
self._widget.value = df[self._entries].to_html(index=False, na_rep="")
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,37 @@ def test_run_print_report_notebook(tmp_path: pathlib.Path):
pass


@pytest.mark.skipif(
not _ipython_module_available or not _pandas_available,
reason="print report notebook import failed, "
"maybe ipython is not installed",
)
def test_run_print_report_notebook_with_entry(tmp_path: pathlib.Path):
max_epochs = 5
iters_per_epoch = 5
manager = ppe.training.ExtensionsManager(
{},
{},
max_epochs,
iters_per_epoch=iters_per_epoch,
out_dir=str(tmp_path),
)

out = io.StringIO()
log_report = ppe.training.extensions.LogReport()
manager.extend(log_report)
extension = ppe.training.extensions.PrintReportNotebook(
["epoch", "iteration"], out=out
)
manager.extend(extension)

for _ in range(max_epochs):
for _ in range(iters_per_epoch):
with manager.run_iteration():
# Only test it runs without fail
# The value is not tested now...
pass


if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])
Loading