Skip to content

Commit

Permalink
Merge pull request #159 from Modalities/make_activation_checkpoininti…
Browse files Browse the repository at this point in the history
…ng_felixble

Make Activation Checkpointing Configurable
  • Loading branch information
mali-git authored Jun 28, 2024
2 parents 4aa2e88 + 8605dbe commit f810fcc
Show file tree
Hide file tree
Showing 16 changed files with 50 additions and 26 deletions.
2 changes: 1 addition & 1 deletion config_files/training/config_example_coca.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ settings:
global_evaluation_interval_in_steps: 2
global_num_training_samples: 12
global_num_seen_steps: 0
do_apply_activation_checkpointing: true
activation_checkpointing_modules: []
gradient_acc_steps: 1
local_train_micro_batch_size: 3
sequence_length: 256
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ settings:
callback_interval_in_samples: 2048
global_num_training_samples: 2048
global_num_seen_samples: 0
do_apply_activation_checkpointing: true
activation_checkpointing_modules: []
gradient_acc_steps: 1
local_train_micro_batch_size: 1
sequence_length: 4096
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ settings:
global_checkpointing_interval_in_steps: 128
global_evaluation_interval_in_steps: 64
global_num_seen_steps: 0
do_apply_activation_checkpointing: false
activation_checkpointing_modules: []
gradient_acc_steps: 1
local_train_micro_batch_size: 16
sequence_length: 2048
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ settings:
global_checkpointing_interval_in_steps: 128
global_evaluation_interval_in_steps: 64
global_num_seen_steps: 0
do_apply_activation_checkpointing: false
activation_checkpointing_modules: []
gradient_acc_steps: 1
local_train_micro_batch_size: 16
sequence_length: 2048
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ settings:
global_checkpointing_interval_in_steps: 8192
global_evaluation_interval_in_steps: 1024
global_num_seen_steps: 0
do_apply_activation_checkpointing: false
activation_checkpointing_modules: []
gradient_acc_steps: 1
local_train_micro_batch_size: 16
sequence_length: 2048
Expand Down
2 changes: 1 addition & 1 deletion config_files/training/config_lorem_ipsum.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ settings:
global_checkpointing_interval_in_steps: 3
global_evaluation_interval_in_steps: 2
global_num_seen_steps: 0
do_apply_activation_checkpointing: true
activation_checkpointing_modules: [GPT2Block]
gradient_acc_steps: 1
local_train_micro_batch_size: 1
sequence_length: 256
Expand Down
2 changes: 1 addition & 1 deletion config_files/training/config_mem_map_mamba.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ settings:
callback_interval_in_samples: 32768
global_num_training_samples: 2048
global_num_seen_samples: 0
do_apply_activation_checkpointing: false
activation_checkpointing_modules: []
gradient_acc_steps: 1
local_train_micro_batch_size: 16
sequence_length: 4096
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ settings:
global_checkpointing_interval_in_steps: 1000
global_evaluation_interval_in_steps: 64
global_num_seen_steps: 0
do_apply_activation_checkpointing: false
activation_checkpointing_modules: []
gradient_acc_steps: 1
local_train_micro_batch_size: 4
sequence_length: 2048
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ settings:
global_checkpointing_interval_in_steps: 1000
global_evaluation_interval_in_steps: 64
global_num_seen_steps: 0
do_apply_activation_checkpointing: false
activation_checkpointing_modules: []
gradient_acc_steps: 1
local_train_micro_batch_size: 4
sequence_length: 0 # TODO: Is sequence_length used in training?
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ dependencies = [
"datasets",
"protobuf",
"SentencePiece",
"accelerate",
"rich",
"omegaconf",
"pydantic",
Expand Down
7 changes: 5 additions & 2 deletions src/modalities/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,11 @@ def run(self, components: TrainingComponentsInstantiationModel):
wrapped_model = components.wrapped_model
logging.info(f"Training model with {compute_number_of_trainable_parameters(wrapped_model)} parameters.")

if components.settings.training.do_apply_activation_checkpointing:
apply_activation_checkpointing_inplace(wrapped_model)
if len(components.settings.training.activation_checkpointing_modules) > 0:
apply_activation_checkpointing_inplace(
model=wrapped_model,
activation_checkpointing_modules=components.settings.training.activation_checkpointing_modules,
)

gym.run(
train_data_loader=components.train_dataloader,
Expand Down
18 changes: 12 additions & 6 deletions src/modalities/activation_checkpointing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import partial
from typing import List

import torch
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
Expand All @@ -8,17 +9,22 @@
)
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP

from modalities.models.gpt2.gpt2_model import GPT2Block

import torch
from typing import List

from modalities.util import get_module_class_from_name

def is_module_to_apply_activation_checkpointing(submodule: torch.nn.Module) -> bool:
return isinstance(submodule, GPT2Block)
def is_module_to_apply_activation_checkpointing(submodule: torch.nn.Module, activation_checkpointing_modules: List[type]) -> bool:
return isinstance(submodule, tuple(activation_checkpointing_modules))


def apply_activation_checkpointing_inplace(model: torch.nn.Module):
assert isinstance(model, FSDP), "activation checkpointing can only be applied to FSDP wrapped models!"
def apply_activation_checkpointing_inplace(model: torch.nn.Module, activation_checkpointing_modules: List[str]):
activation_checkpointing_module_types = [get_module_class_from_name(model, m) for m in activation_checkpointing_modules]
if not isinstance(model, FSDP):
raise ValueError("activation checkpointing can only be applied to FSDP wrapped models!")
non_reentrant_wrapper = partial(checkpoint_wrapper, checkpoint_impl=CheckpointImpl.NO_REENTRANT, debug=False)

apply_activation_checkpointing(
model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=is_module_to_apply_activation_checkpointing
model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=lambda submodule: is_module_to_apply_activation_checkpointing(submodule, activation_checkpointing_module_types)
)
2 changes: 1 addition & 1 deletion src/modalities/config/instantiation_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class Training(BaseModel):
global_training_log_interval_in_steps: Annotated[int, Field(strict=True, ge=1)]
global_checkpointing_interval_in_steps: Annotated[int, Field(strict=True, ge=1)]
global_evaluation_interval_in_steps: Annotated[int, Field(strict=True, ge=1)]
do_apply_activation_checkpointing: bool
activation_checkpointing_modules: Optional[List[str]] = Field(default_factory=list)
gradient_acc_steps: Annotated[int, Field(strict=True, ge=1)]
local_train_micro_batch_size: Annotated[int, Field(strict=True, ge=1)]
sequence_length: Annotated[int, Field(strict=True, ge=1)]
Expand Down
8 changes: 2 additions & 6 deletions src/modalities/running_env/fsdp/fsdp_auto_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
from typing import Callable, List

import torch.nn as nn
from accelerate import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

from modalities.config.lookup_enum import LookupEnum
from modalities.util import get_module_class_from_name


class FSDPAutoWrapFactoryIF(ABC):
Expand All @@ -28,12 +28,8 @@ def _get_fsdp_blocks_from_block_names(model: nn.Module, block_names: List[str])
for cls_block_name in block_names:
# TODO FullyShardedDataParallelPlugin from Accelerate uses string matching to find the correct
# block class. In the long-term we should implmement this ourselves in a robuster fashion.
try:
block_type = FullyShardedDataParallelPlugin.get_module_class_from_name(model, cls_block_name)
except AttributeError:
from accelerate.utils.dataclasses import get_module_class_from_name
block_type = get_module_class_from_name(model, cls_block_name)

block_type = get_module_class_from_name(model, cls_block_name)
if block_type is None:
raise ValueError(f"Could not find block with name {cls_block_name} in model")
fsdp_block_types.append(block_type)
Expand Down
20 changes: 20 additions & 0 deletions src/modalities/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,23 @@ def get_all_reduced_value(
post_processing_fun=postprocessing_fun, # lambda t: t[0] / t[1],
)
return value

def get_module_class_from_name(module: torch.nn.Module, name:str) -> Type[torch.nn.Module] | None:
""" From Accelerate source code
(https://github.com/huggingface/accelerate/blob/1f7a79b428749f45187ec69485f2c966fe21926e/src/accelerate/utils/dataclasses.py#L1902)
Gets a class from a module by its name.
Args:
module (`torch.nn.Module`): The module to get the class from.
name (`str`): The name of the class.
"""
modules_children = list(module.children())
if module.__class__.__name__ == name:
return module.__class__
elif len(modules_children) == 0:
return
else:
for child_module in modules_children:
module_class = get_module_class_from_name(child_module, name)
if module_class is not None:
return module_class
2 changes: 1 addition & 1 deletion tests/models/coca/coca_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ text_decoder_config:
n_embd: 768
dropout: 0.0
bias: true
activation: fused_swiglu
activation: swiglu
epsilon: 1e-5
n_pool_head: 8
n_vision_queries: 256
Expand Down

0 comments on commit f810fcc

Please sign in to comment.