From 5e68f792e23b217936f8d0ea3a280a0f673291db Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Wed, 2 Oct 2024 02:11:00 +0000 Subject: [PATCH 1/2] add activation cktp --- src/zeroband/train.py | 5 +++++ src/zeroband/utils/activation_ckpt.py | 16 ++++++++++++++++ tests/test_torchrun/test_train.py | 5 +++++ 3 files changed, 26 insertions(+) create mode 100644 src/zeroband/utils/activation_ckpt.py diff --git a/src/zeroband/train.py b/src/zeroband/train.py index a48f0d83..0df0464f 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -22,6 +22,7 @@ from zeroband.comms import ElasticDeviceMesh from zeroband.utils import GPUMemoryMonitor, PerfCounter, get_module_signature, get_sharding_strategy +from zeroband.utils.activation_ckpt import apply_ac_ckpt from zeroband.utils.monitor import WandbMonitor, DummyMonitor from zeroband.data import TEST_VOCAB_SIZE, get_dataloader from zeroband.models.llama import get_model @@ -51,6 +52,7 @@ class TrainConfig(BaseConfig): micro_bs: int torch_compile: bool = True sharding_strategy: str = "SHARD_GRAD_OP" + ac_ckpt: bool = False log_model_hash: bool = False @@ -137,6 +139,9 @@ def train(config: Config): config.data.seq_length, ) + if config.train.ac_ckpt: + apply_ac_ckpt(model) + elastic_device_mesh = ElasticDeviceMesh() model = FSDP( diff --git a/src/zeroband/utils/activation_ckpt.py b/src/zeroband/utils/activation_ckpt.py new file mode 100644 index 00000000..cd9ff957 --- /dev/null +++ b/src/zeroband/utils/activation_ckpt.py @@ -0,0 +1,16 @@ +from zeroband.models.llama.model import Transformer + +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper + +from zeroband.utils.logging import get_logger + + +def apply_ac_ckpt(model: Transformer): + """Apply activation checkpointing to the model.""" + logger = get_logger() + + for layer_id, transformer_block in model.layers.named_children(): + transformer_block = checkpoint_wrapper(transformer_block, preserve_rng_state=False) + model.layers.register_module(layer_id, transformer_block) + + logger.info("Applied activation checkpointing to the model") diff --git a/tests/test_torchrun/test_train.py b/tests/test_torchrun/test_train.py index 79bbc61c..d5d80161 100644 --- a/tests/test_torchrun/test_train.py +++ b/tests/test_torchrun/test_train.py @@ -71,3 +71,8 @@ def test_multi_gpu_diloco_non_full_shard(strategy): # we don't test 1,1 and 2,1 because 1 solo gpu failed with fsdp num_gpus = [2, 2] _test_multi_gpu(num_gpus, "debug/diloco.toml", extra_args=["--train.sharding_strategy", strategy]) + + +def test_act_ckpt(): + num_gpus = [1, 2] + _test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=["--train.ac_ckpt"]) From 96a520187fe7fc5a60928541207e51e0ce425a9b Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Wed, 2 Oct 2024 02:30:29 +0000 Subject: [PATCH 2/2] allow to precise an int as ac ckpt --- src/zeroband/train.py | 6 +++--- src/zeroband/utils/activation_ckpt.py | 18 +++++++++++++----- tests/test_torchrun/test_train.py | 5 +++++ 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 0df0464f..43c580ef 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -52,8 +52,7 @@ class TrainConfig(BaseConfig): micro_bs: int torch_compile: bool = True sharding_strategy: str = "SHARD_GRAD_OP" - ac_ckpt: bool = False - + ac_ckpt: bool | int = False log_model_hash: bool = False @@ -140,7 +139,8 @@ def train(config: Config): ) if config.train.ac_ckpt: - apply_ac_ckpt(model) + num = 1 if isinstance(config.train.ac_ckpt, bool) else config.train.ac_ckpt + apply_ac_ckpt(model, num) elastic_device_mesh = ElasticDeviceMesh() diff --git a/src/zeroband/utils/activation_ckpt.py b/src/zeroband/utils/activation_ckpt.py index cd9ff957..18410b28 100644 --- a/src/zeroband/utils/activation_ckpt.py +++ b/src/zeroband/utils/activation_ckpt.py @@ -5,12 +5,20 @@ from zeroband.utils.logging import get_logger -def apply_ac_ckpt(model: Transformer): - """Apply activation checkpointing to the model.""" +def apply_ac_ckpt(model: Transformer, num: int): + """Apply activation checkpointing to the model. + Apply to layers multiple of `num`. + + Example if `num=2` only half of the layers are checkpointed. + """ logger = get_logger() + layers_ckpt = 0 + for layer_id, transformer_block in model.layers.named_children(): - transformer_block = checkpoint_wrapper(transformer_block, preserve_rng_state=False) - model.layers.register_module(layer_id, transformer_block) + if layers_ckpt % num == 0: + transformer_block = checkpoint_wrapper(transformer_block, preserve_rng_state=False) + model.layers.register_module(layer_id, transformer_block) + layers_ckpt += 1 - logger.info("Applied activation checkpointing to the model") + logger.info(f"Applied activation checkpointing to {layers_ckpt} layers") diff --git a/tests/test_torchrun/test_train.py b/tests/test_torchrun/test_train.py index d5d80161..b6047270 100644 --- a/tests/test_torchrun/test_train.py +++ b/tests/test_torchrun/test_train.py @@ -76,3 +76,8 @@ def test_multi_gpu_diloco_non_full_shard(strategy): def test_act_ckpt(): num_gpus = [1, 2] _test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=["--train.ac_ckpt"]) + + +def test_act_ckpt_num(): + num_gpus = [1, 2] + _test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=["--train.ac_ckpt", "2"])