Skip to content

Commit 90f7c03

Browse files
committed
Enhance logger metrics retrieval to conditionally include progress bar metrics based on the presence of a progress bar callback.
Signed-off-by: Wil Kong <[email protected]>
1 parent e088694 commit 90f7c03

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

src/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,9 @@ def metrics(self) -> _METRICS:
234234
"""This function returns either batch or epoch metrics."""
235235
on_step = self._first_loop_iter is not None
236236
assert self.trainer._results is not None
237-
return self.trainer._results.metrics(on_step)
237+
# Only include progress bar metrics if a progress bar callback is present
238+
include_pbar_metrics = self.trainer.progress_bar_callback is not None
239+
return self.trainer._results.metrics(on_step, include_pbar_metrics=include_pbar_metrics)
238240

239241
@property
240242
def callback_metrics(self) -> _OUT_DICT:

src/lightning/pytorch/trainer/connectors/logger_connector/result.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ def _forked_name(self, result_metric: _ResultMetric, on_step: bool) -> tuple[str
468468
forked_name += dataloader_suffix
469469
return name, forked_name
470470

471-
def metrics(self, on_step: bool) -> _METRICS:
471+
def metrics(self, on_step: bool, *, include_pbar_metrics: bool = True) -> _METRICS:
472472
metrics = _METRICS(callback={}, log={}, pbar={})
473473

474474
for _, result_metric in self.valid_items():
@@ -489,7 +489,7 @@ def metrics(self, on_step: bool) -> _METRICS:
489489
metrics["callback"][forked_name] = value
490490

491491
# populate progress_bar metrics. convert tensors to numbers
492-
if result_metric.meta.prog_bar:
492+
if result_metric.meta.prog_bar and include_pbar_metrics:
493493
metrics["pbar"][forked_name] = convert_tensors_to_scalars(value)
494494

495495
return metrics

0 commit comments

Comments
 (0)