Skip to content
15 changes: 12 additions & 3 deletions src/lightning/pytorch/callbacks/throughput_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def _start(self, trainer: "Trainer") -> None:
stage = trainer.state.stage
assert stage is not None

if stage not in self._samples:
reset_needed = trainer.state.fn == TrainerFn.FITTING or stage not in self._samples

if reset_needed:
self._throughputs[stage].reset()
self._lengths[stage] = 0
self._samples[stage] = 0
Expand Down Expand Up @@ -202,10 +204,17 @@ def on_validation_batch_end(
def on_validation_end(self, trainer: "Trainer", *_: Any) -> None:
if trainer.sanity_checking or trainer.state.fn != TrainerFn.FITTING:
return

train_times = self._throughputs[RunningStage.TRAINING]._time
val_times = self._throughputs[RunningStage.VALIDATING]._time

train_elapsed = train_times[-1] if train_times else 0.0
val_elapsed = val_times[-1] if val_times else 0.0

# add the validation time to the training time before continuing to avoid sinking the training throughput
training_finished = self._t0s[RunningStage.TRAINING] + sum(self._throughputs[RunningStage.TRAINING]._time)
training_finished = self._t0s[RunningStage.TRAINING] + train_elapsed
time_between_train_and_val = self._t0s[RunningStage.VALIDATING] - training_finished
val_time = sum(self._throughputs[RunningStage.VALIDATING]._time)
val_time = val_elapsed
self._t0s[RunningStage.TRAINING] += time_between_train_and_val + val_time

@override
Expand Down
62 changes: 62 additions & 0 deletions tests/tests_pytorch/callbacks/test_throughput_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,65 @@ def variable_batch_size_fn(batch):
train_samples.append(metrics["train/samples"])
elif "train|samples" in metrics:
train_samples.append(metrics["train|samples"])


def test_throughput_monitor_validation_with_many_epochs(tmp_path):
"""Ensure ThroughputMonitor handles many epochs with validation and time increases monotonically."""

logger_mock = Mock()
logger_mock.save_dir = tmp_path
monitor = ThroughputMonitor(batch_size_fn=lambda x: 1)
model = BoringModel()
model.flops_per_batch = 10
num_epochs = 100

trainer = Trainer(
devices=1,
logger=logger_mock,
callbacks=[monitor],
max_epochs=num_epochs,
limit_train_batches=2,
limit_val_batches=1,
log_every_n_steps=1,
enable_checkpointing=False,
enable_model_summary=False,
enable_progress_bar=False,
)

timings = []
t = 0.0
for _ in range(num_epochs):
timings += [
t, # train batch 1 start
t + 3.0, # train batch 1 end and start batch 2
t + 6.0, # train batch 2 end
t + 7.0, # val start
t + 8.0, # val end
]
t += 10.0

with mock.patch("time.perf_counter", side_effect=timings):
try:
trainer.fit(model)
except Exception as e:
pytest.fail(f"ThroughputMonitor raised an unexpected exception: {e}")

start_train_timings_idx, end_train_timings_idx = 0, 1
batch_num = 1
cur_train = timings[end_train_timings_idx] - timings[start_train_timings_idx]
for c in logger_mock.log_metrics.mock_calls:
metrics = getattr(c, "kwargs", None) or {}
metrics = metrics.get("metrics", metrics)
for k, v in metrics.items():
if k.endswith("train/time"):
assert v == cur_train, f"Expected train/time {cur_train}, got {v}"
if batch_num == 1:
start_train_timings_idx += 1
end_train_timings_idx += 1
batch_num = 2
else:
start_train_timings_idx += 3
end_train_timings_idx += 3
batch_num = 1
if end_train_timings_idx < len(timings):
cur_train += timings[end_train_timings_idx] - timings[start_train_timings_idx]
Loading