diff --git a/src/zeroband/train.py b/src/zeroband/train.py index a48f0d83..43c580ef 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -22,6 +22,7 @@ from zeroband.comms import ElasticDeviceMesh from zeroband.utils import GPUMemoryMonitor, PerfCounter, get_module_signature, get_sharding_strategy +from zeroband.utils.activation_ckpt import apply_ac_ckpt from zeroband.utils.monitor import WandbMonitor, DummyMonitor from zeroband.data import TEST_VOCAB_SIZE, get_dataloader from zeroband.models.llama import get_model @@ -51,7 +52,7 @@ class TrainConfig(BaseConfig): micro_bs: int torch_compile: bool = True sharding_strategy: str = "SHARD_GRAD_OP" - + ac_ckpt: bool | int = False log_model_hash: bool = False @@ -137,6 +138,10 @@ def train(config: Config): config.data.seq_length, ) + if config.train.ac_ckpt: + num = 1 if isinstance(config.train.ac_ckpt, bool) else config.train.ac_ckpt + apply_ac_ckpt(model, num) + elastic_device_mesh = ElasticDeviceMesh() model = FSDP( diff --git a/src/zeroband/utils/activation_ckpt.py b/src/zeroband/utils/activation_ckpt.py new file mode 100644 index 00000000..18410b28 --- /dev/null +++ b/src/zeroband/utils/activation_ckpt.py @@ -0,0 +1,24 @@ +from zeroband.models.llama.model import Transformer + +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper + +from zeroband.utils.logging import get_logger + + +def apply_ac_ckpt(model: Transformer, num: int): + """Apply activation checkpointing to the model. + Apply to layers multiple of `num`. + + Example if `num=2` only half of the layers are checkpointed. + """ + logger = get_logger() + + layers_ckpt = 0 + + for layer_id, transformer_block in model.layers.named_children(): + if layers_ckpt % num == 0: + transformer_block = checkpoint_wrapper(transformer_block, preserve_rng_state=False) + model.layers.register_module(layer_id, transformer_block) + layers_ckpt += 1 + + logger.info(f"Applied activation checkpointing to {layers_ckpt} layers") diff --git a/tests/test_torchrun/test_train.py b/tests/test_torchrun/test_train.py index 79bbc61c..b6047270 100644 --- a/tests/test_torchrun/test_train.py +++ b/tests/test_torchrun/test_train.py @@ -71,3 +71,13 @@ def test_multi_gpu_diloco_non_full_shard(strategy): # we don't test 1,1 and 2,1 because 1 solo gpu failed with fsdp num_gpus = [2, 2] _test_multi_gpu(num_gpus, "debug/diloco.toml", extra_args=["--train.sharding_strategy", strategy]) + + +def test_act_ckpt(): + num_gpus = [1, 2] + _test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=["--train.ac_ckpt"]) + + +def test_act_ckpt_num(): + num_gpus = [1, 2] + _test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=["--train.ac_ckpt", "2"])