diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 99af24fe80b69..c75c82127afac 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -58,6 +58,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `LightningCLI` loading of hyperparameters from `ckpt_path` failing for subclass model mode ([#21246](https://github.com/Lightning-AI/pytorch-lightning/pull/21246)) +- Fixed how `ThroughputMonitor` calculated training time ([#21291](https://github.com/Lightning-AI/pytorch-lightning/pull/21291)) + + --- ## [2.5.5] - 2025-09-05 diff --git a/src/lightning/pytorch/callbacks/throughput_monitor.py b/src/lightning/pytorch/callbacks/throughput_monitor.py index b88ee90bde38d..d38928d33de75 100644 --- a/src/lightning/pytorch/callbacks/throughput_monitor.py +++ b/src/lightning/pytorch/callbacks/throughput_monitor.py @@ -109,7 +109,9 @@ def _start(self, trainer: "Trainer") -> None: stage = trainer.state.stage assert stage is not None - if stage not in self._samples: + reset_needed = trainer.state.fn == TrainerFn.FITTING or stage not in self._samples + + if reset_needed: self._throughputs[stage].reset() self._lengths[stage] = 0 self._samples[stage] = 0 @@ -202,10 +204,17 @@ def on_validation_batch_end( def on_validation_end(self, trainer: "Trainer", *_: Any) -> None: if trainer.sanity_checking or trainer.state.fn != TrainerFn.FITTING: return + + train_times = self._throughputs[RunningStage.TRAINING]._time + val_times = self._throughputs[RunningStage.VALIDATING]._time + + train_elapsed = train_times[-1] if train_times else 0.0 + val_elapsed = val_times[-1] if val_times else 0.0 + # add the validation time to the training time before continuing to avoid sinking the training throughput - training_finished = self._t0s[RunningStage.TRAINING] + sum(self._throughputs[RunningStage.TRAINING]._time) + training_finished = self._t0s[RunningStage.TRAINING] + train_elapsed time_between_train_and_val = self._t0s[RunningStage.VALIDATING] - training_finished - val_time = sum(self._throughputs[RunningStage.VALIDATING]._time) + val_time = val_elapsed self._t0s[RunningStage.TRAINING] += time_between_train_and_val + val_time @override diff --git a/tests/tests_pytorch/callbacks/test_throughput_monitor.py b/tests/tests_pytorch/callbacks/test_throughput_monitor.py index 7dda7875a43c7..93bfe4e844c3a 100644 --- a/tests/tests_pytorch/callbacks/test_throughput_monitor.py +++ b/tests/tests_pytorch/callbacks/test_throughput_monitor.py @@ -420,3 +420,65 @@ def variable_batch_size_fn(batch): train_samples.append(metrics["train/samples"]) elif "train|samples" in metrics: train_samples.append(metrics["train|samples"]) + + +def test_throughput_monitor_validation_with_many_epochs(tmp_path): + """Ensure ThroughputMonitor handles many epochs with validation and time increases monotonically.""" + + logger_mock = Mock() + logger_mock.save_dir = tmp_path + monitor = ThroughputMonitor(batch_size_fn=lambda x: 1) + model = BoringModel() + model.flops_per_batch = 10 + num_epochs = 100 + + trainer = Trainer( + devices=1, + logger=logger_mock, + callbacks=[monitor], + max_epochs=num_epochs, + limit_train_batches=2, + limit_val_batches=1, + log_every_n_steps=1, + enable_checkpointing=False, + enable_model_summary=False, + enable_progress_bar=False, + ) + + timings = [] + t = 0.0 + for _ in range(num_epochs): + timings += [ + t, # train batch 1 start + t + 3.0, # train batch 1 end and start batch 2 + t + 6.0, # train batch 2 end + t + 7.0, # val start + t + 8.0, # val end + ] + t += 10.0 + + with mock.patch("time.perf_counter", side_effect=timings): + try: + trainer.fit(model) + except Exception as e: + pytest.fail(f"ThroughputMonitor raised an unexpected exception: {e}") + + start_train_timings_idx, end_train_timings_idx = 0, 1 + batch_num = 1 + cur_train = timings[end_train_timings_idx] - timings[start_train_timings_idx] + for c in logger_mock.log_metrics.mock_calls: + metrics = getattr(c, "kwargs", None) or {} + metrics = metrics.get("metrics", metrics) + for k, v in metrics.items(): + if k.endswith("train/time"): + assert v == cur_train, f"Expected train/time {cur_train}, got {v}" + if batch_num == 1: + start_train_timings_idx += 1 + end_train_timings_idx += 1 + batch_num = 2 + else: + start_train_timings_idx += 3 + end_train_timings_idx += 3 + batch_num = 1 + if end_train_timings_idx < len(timings): + cur_train += timings[end_train_timings_idx] - timings[start_train_timings_idx]