Skip to content
15 changes: 12 additions & 3 deletions src/lightning/pytorch/callbacks/throughput_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def _start(self, trainer: "Trainer") -> None:
stage = trainer.state.stage
assert stage is not None

if stage not in self._samples:
reset_needed = trainer.state.fn == TrainerFn.FITTING or stage not in self._samples

if reset_needed:
self._throughputs[stage].reset()
self._lengths[stage] = 0
self._samples[stage] = 0
Expand Down Expand Up @@ -202,10 +204,17 @@ def on_validation_batch_end(
def on_validation_end(self, trainer: "Trainer", *_: Any) -> None:
if trainer.sanity_checking or trainer.state.fn != TrainerFn.FITTING:
return

train_times = self._throughputs[RunningStage.TRAINING]._time
val_times = self._throughputs[RunningStage.VALIDATING]._time

train_elapsed = train_times[-1] if train_times else 0.0
val_elapsed = val_times[-1] if val_times else 0.0

# add the validation time to the training time before continuing to avoid sinking the training throughput
training_finished = self._t0s[RunningStage.TRAINING] + sum(self._throughputs[RunningStage.TRAINING]._time)
training_finished = self._t0s[RunningStage.TRAINING] + train_elapsed
time_between_train_and_val = self._t0s[RunningStage.VALIDATING] - training_finished
val_time = sum(self._throughputs[RunningStage.VALIDATING]._time)
val_time = val_elapsed
self._t0s[RunningStage.TRAINING] += time_between_train_and_val + val_time

@override
Expand Down
23 changes: 23 additions & 0 deletions tests/tests_pytorch/callbacks/test_throughput_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,26 @@ def variable_batch_size_fn(batch):
train_samples.append(metrics["train/samples"])
elif "train|samples" in metrics:
train_samples.append(metrics["train|samples"])


def test_throughput_monitor_validation_sum_overflow_real(tmp_path):
logger_mock = Mock()
logger_mock.save_dir = tmp_path
monitor = ThroughputMonitor(batch_size_fn=lambda x: 1)
model = BoringModel()
model.flops_per_batch = 10

trainer = Trainer(
devices=1,
logger=logger_mock,
callbacks=[monitor],
max_epochs=100,
enable_checkpointing=False,
enable_model_summary=False,
enable_progress_bar=False,
)

try:
trainer.fit(model)
except Exception as e:
pytest.fail(f"ThroughputMonitor raised an unexpected exception: {e}")
Loading