diff --git a/src/modalities/activation_checkpointing.py b/src/modalities/activation_checkpointing.py index cd3bbbb8..288dc09f 100644 --- a/src/modalities/activation_checkpointing.py +++ b/src/modalities/activation_checkpointing.py @@ -11,14 +11,14 @@ from modalities.models.gpt2.gpt2_model import GPT2Block -def is_module_to_apply_activation_checkpointing(submodule: torch.nn.Module): +def is_module_to_apply_activation_checkpointing(submodule: torch.nn.Module) -> bool: return isinstance(submodule, GPT2Block) -def apply_activation_checkpointing_inplace(model: torch.nn.Module) -> None: +def apply_activation_checkpointing_inplace(model: torch.nn.Module): assert isinstance(model, FSDP), "activation checkpointing can only be applied to FSDP wrapped models!" - non_reentrant_wrapper = partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT, debug=True) + non_reentrant_wrapper = partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT, debug=False) - return apply_activation_checkpointing( + apply_activation_checkpointing( model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=is_module_to_apply_activation_checkpointing ) diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index 707d544d..12e8193c 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -58,6 +58,7 @@ def train( optimizer, loss_fun: Loss, callback_interval_in_batches: int, + # TODO: remove epoch_done_callback: Callable[[int], None], local_sample_id_to_global_sample_id: Callable[[int], int], ): @@ -67,12 +68,14 @@ def train( # batch loop batch: DatasetBatch + # TODO: why do we need a barrier here? dist.barrier() forward_backward_time_recorder = TimeRecorder() forward_backward_time_recorder.start() for batch_id, batch in enumerate(train_loader): + # Because we might resume training, we add the starting batch id of the data loader local_train_batch_id = batch_id + train_loader.fast_forward_batch_id - # train single batch + # Train single batch batch_loss = self._train_batch( batch=batch, model=model, @@ -82,7 +85,7 @@ def train( data_loader=train_loader, ) forward_backward_time_recorder.stop() - # save the batch loss + # Save the batch loss cummulated_loss[0] += batch_loss.item() cummulated_loss[1] += len(batch) batch_length_tensor = torch.tensor(len(batch)).to(torch.device(self.local_rank))