Skip to content

Commit

Permalink
add activation cktp
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Oct 2, 2024
1 parent 58a9aca commit 5e68f79
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from zeroband.comms import ElasticDeviceMesh

from zeroband.utils import GPUMemoryMonitor, PerfCounter, get_module_signature, get_sharding_strategy
from zeroband.utils.activation_ckpt import apply_ac_ckpt
from zeroband.utils.monitor import WandbMonitor, DummyMonitor
from zeroband.data import TEST_VOCAB_SIZE, get_dataloader
from zeroband.models.llama import get_model
Expand Down Expand Up @@ -51,6 +52,7 @@ class TrainConfig(BaseConfig):
micro_bs: int
torch_compile: bool = True
sharding_strategy: str = "SHARD_GRAD_OP"
ac_ckpt: bool = False

log_model_hash: bool = False

Expand Down Expand Up @@ -137,6 +139,9 @@ def train(config: Config):
config.data.seq_length,
)

if config.train.ac_ckpt:
apply_ac_ckpt(model)

elastic_device_mesh = ElasticDeviceMesh()

model = FSDP(
Expand Down
16 changes: 16 additions & 0 deletions src/zeroband/utils/activation_ckpt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from zeroband.models.llama.model import Transformer

from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import checkpoint_wrapper

from zeroband.utils.logging import get_logger


def apply_ac_ckpt(model: Transformer):
"""Apply activation checkpointing to the model."""
logger = get_logger()

for layer_id, transformer_block in model.layers.named_children():
transformer_block = checkpoint_wrapper(transformer_block, preserve_rng_state=False)
model.layers.register_module(layer_id, transformer_block)

logger.info("Applied activation checkpointing to the model")
5 changes: 5 additions & 0 deletions tests/test_torchrun/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,8 @@ def test_multi_gpu_diloco_non_full_shard(strategy):
# we don't test 1,1 and 2,1 because 1 solo gpu failed with fsdp
num_gpus = [2, 2]
_test_multi_gpu(num_gpus, "debug/diloco.toml", extra_args=["--train.sharding_strategy", strategy])


def test_act_ckpt():
num_gpus = [1, 2]
_test_multi_gpu(num_gpus, "debug/normal.toml", extra_args=["--train.ac_ckpt"])

0 comments on commit 5e68f79

Please sign in to comment.