[Models] Activation checkpointing from TorchTune#2954
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Nice! When it works, can you also add a few lines in https://huggingface.co/docs/trl/en/reducing_memory_usage? 🙏 |
lewtun
left a comment
There was a problem hiding this comment.
Thanks for adding this @kashif ! Overall it looks great :) In addition to Quentin's comment about the docs, would you mind running a benchmark with e.g. SFT and GRPO so we can have a rough idea of both the memory saved and the impact on throughput?
|
sure! |
| self.activation_offload_context = get_act_offloading_ctx_manager( | ||
| model=self.model, | ||
| enable_activation_offloading=self.args.enable_activation_offloading, | ||
| ) |
There was a problem hiding this comment.
I'd find it more readable like this:
| self.activation_offload_context = get_act_offloading_ctx_manager( | |
| model=self.model, | |
| enable_activation_offloading=self.args.enable_activation_offloading, | |
| ) | |
| if self.args.enable_activation_offloading: | |
| self.activation_offload_context = get_act_offloading_ctx_manager( | |
| model=self.model, | |
| enable_activation_offloading=self.args.enable_activation_offloading, | |
| ) |
and later in the code:
context = self.activation_offload_context if self.args.enable_activation_offloading else nullcontext()There was a problem hiding this comment.
so the get_act_offloading_ctx_manager returns a nullcontext when the flag enable_activation_offloading=False so that is why i do not need to put this in an if statement, or we can remove that logic from the get_act_offloading_ctx_manager method?
There was a problem hiding this comment.
Yes, I find it more explicit to disable outside the function than inside. But it's really not very important.
…e/trl into activation-checkpoint
* Update sft_config.py * Update sft_trainer.py * Update sft_config.py * Update sft_trainer.py * Apply style fixes --------- Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
* ✨ Enhance GRPO logging with configurable completions sampling - Update `GRPOConfig` to replace `log_completions` with `log_completions_steps` - Add `print_prompt_completions_sample()` utility function for rich console logging - Modify `GRPOTrainer` to additionally print 5 random prompt-completion pairs every log_completions_steps steps * GRPO trainer completions logging, move wandb checks together * Add rich availability check and use fallback in print_prompt_completions_sample when rich is not available * Update docstrings on print_prompt_completions_sample Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Revert back to simple log_completions bool * GRPO log completions fully * Remove print fallback from print_prompt_completions_sample * Move accelerator main process check up for grpo log completions * Explicit variable names in print_prompt_completions_sample * Make GRPOConfig docstring match field description * Update log_completions docs again Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> * Update GRPOConfig docs to match field * improve readibility when prompt or completions are multilines * log reward * prevent hanging, don't print without rich, print reward * style --------- Co-authored-by: Robert Veres <robert.veres@languagetool.org> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
#2921) Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
* updated DPO default values for alpha and tau * same for grpo --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
* pin liger-kernel * style
* parameterize enable_prefix_caching * apply review suggestion --------- Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
|
Looks like this PR was parked for now. @kashif did the implementation not work? This is super relevant to me if I am going to use TRL for training long-context reasoners |
|
@casper-hansen So during training, i see a reduction in memory but the memory jumps up to default as soon as the eval steps starts and i am investigating why... |
|
@kashif The implementation relies on |
|
@casper-hansen i have added the documentation as well as a check that disables the activation offloading when the |
In my personal opinion, it's suboptimal to disable in all cases. The only case where it's incompatible is when using the fused linear cross entropy. Maybe this can be fixed in a followup PR. |
|
here is the plot @casper-hansen using: #!/usr/bin/env python
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Script to benchmark GPU memory usage with and without activation offloading.
Example:
python trl/scripts/activation_offloading_benchmark.py \
--model_name_or_path Qwen/Qwen2-0.5B \
--dataset_name trl-lib/Capybara \
--start_length 128 \
--max_length 4096 \
--step_size 128 \
--num_train_steps 5 \
--per_device_train_batch_size 1
"""
import argparse
import os
import gc
import sys
import traceback
from pathlib import Path
import torch
import matplotlib.pyplot as plt
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import SFTConfig, SFTTrainer, TrlParser
def measure_memory():
"""Measure current and peak GPU memory in GB"""
current = torch.cuda.memory_allocated() / (1024 ** 3)
peak = torch.cuda.max_memory_allocated() / (1024 ** 3)
return current, peak
def reset_memory_stats():
"""Reset the peak memory stats"""
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
gc.collect()
def train_with_config(config, model_name, dataset, sequence_length, num_steps):
"""Run training with a specific configuration and measure memory usage"""
try:
# Create model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Set up trainer
trainer = SFTTrainer(
model=model,
args=config,
train_dataset=dataset,
processing_class=tokenizer,
)
# Reset memory stats before training
reset_memory_stats()
# Capture starting memory
_, start_peak = measure_memory()
# Train for a few steps - no parameters needed as max_steps is in config
trainer.train()
# Measure memory after training
_, peak = measure_memory()
# Clean up
del trainer
del model
del tokenizer
reset_memory_stats()
# Return peak memory during training
return peak, True # Success
except Exception as e:
is_oom = "CUDA out of memory" in str(e)
error_msg = "OOM" if is_oom else str(e)
print(f"Error with sequence length {sequence_length}: {error_msg}")
if not is_oom:
traceback.print_exc()
# Clean up - use variable names directly to avoid AttributeError
try:
if 'trainer' in locals():
del trainer
if 'model' in locals():
del model
if 'tokenizer' in locals():
del tokenizer
except:
pass
reset_memory_stats()
return None, False # Failure
def run_benchmark(args):
"""Run the complete benchmark with increasing sequence lengths"""
# Load dataset
dataset = load_dataset(args.dataset_name, split=f"train[:{args.dataset_size}]")
# Results storage
results = {
"with_offloading": {"seq_lengths": [], "memory": []},
"without_offloading": {"seq_lengths": [], "memory": []}
}
# Directory for output
output_dir = Path(args.output_dir)
output_dir.mkdir(exist_ok=True, parents=True)
# Test both with and without activation offloading
for use_offloading in [True, False]:
mode = "with_offloading" if use_offloading else "without_offloading"
print(f"\n{'='*50}\nTesting {mode}\n{'='*50}")
seq_length = args.start_length
while seq_length <= args.max_length:
print(f"\nTesting sequence length: {seq_length}")
# Create config with current sequence length
config = SFTConfig(
output_dir=str(output_dir / f"temp-{mode}-{seq_length}"),
max_length=seq_length,
packing=True,
per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=1,
learning_rate=args.learning_rate,
logging_steps=1,
max_steps=args.num_train_steps,
activation_offloading=use_offloading,
remove_unused_columns=False,
report_to="none",
)
# Train with this config
peak_memory, success = train_with_config(
config, args.model_name_or_path, dataset, seq_length, args.num_train_steps
)
if peak_memory is not None:
results[mode]["seq_lengths"].append(seq_length)
results[mode]["memory"].append(peak_memory)
print(f"Sequence length {seq_length}: Peak memory {peak_memory:.2f} GB")
if not success:
print(f"Failed at sequence length {seq_length}, stopping {mode} tests")
break
# Increase sequence length for next iteration
seq_length += args.step_size
return results
def plot_results(results, output_path):
"""Plot memory usage for both configurations"""
plt.figure(figsize=(10, 6))
# Plot both configurations
for mode, data in results.items():
if data["seq_lengths"]: # Only plot if we have data
label = "With Activation Offloading" if mode == "with_offloading" else "Without Activation Offloading"
plt.plot(data["seq_lengths"], data["memory"], 'o-', label=label)
plt.xlabel('Sequence Length')
plt.ylabel('Peak GPU Memory Usage (GB)')
plt.title('GPU Memory Usage vs Sequence Length')
plt.grid(True)
plt.legend()
# Save the plot
plt.savefig(output_path)
print(f"Plot saved to {output_path}")
def main():
parser = argparse.ArgumentParser(description="Benchmark activation offloading memory usage")
# Model and dataset arguments
parser.add_argument("--model_name_or_path", type=str, required=True, help="Path to pretrained model")
parser.add_argument("--dataset_name", type=str, default="trl-lib/Capybara", help="Dataset to use for training")
parser.add_argument("--dataset_size", type=int, default=100, help="Number of examples to use from dataset")
# Sequence length parameters
parser.add_argument("--start_length", type=int, default=128, help="Starting sequence length")
parser.add_argument("--max_length", type=int, default=4096, help="Maximum sequence length to try")
parser.add_argument("--step_size", type=int, default=128, help="Increment for sequence length per step")
# Training parameters
parser.add_argument("--num_train_steps", type=int, default=5, help="Number of training steps per sequence length")
parser.add_argument("--per_device_train_batch_size", type=int, default=1, help="Batch size per device")
parser.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate")
# Output parameters
parser.add_argument("--output_dir", type=str, default="./activation_offloading_benchmark",
help="Directory to save temporary files and output")
args = parser.parse_args()
if not torch.cuda.is_available():
print("CUDA is not available. This benchmark requires a GPU.")
return 1
# Run the benchmark
results = run_benchmark(args)
# Plot the results
plot_path = os.path.join(args.output_dir, "memory_vs_sequence_length.png")
plot_results(results, plot_path)
# Save the raw results
import json
with open(os.path.join(args.output_dir, "results.json"), "w") as f:
json.dump(results, f, indent=2)
return 0
if __name__ == "__main__":
sys.exit(main()) run with: python trl/scripts/activation_offloading_benchmark.py \
--model_name_or_path Qwen/Qwen2-0.5B \
--dataset_name trl-lib/Capybara \
--start_length 128 \
--max_length 2048 \
--step_size 128 \
--num_train_steps 5 |
lewtun
left a comment
There was a problem hiding this comment.
Very clean PR and integration with TRL @kashif !
Regarding the benchmark, could you run this with a 7B model so we can see when SFT goes OOM for a given sequence length without activation checkpointing. I guess it may also be simpler to use packing in your benchmark so you are guaranteed to have fixed chunks of size max_length.
Note that this isn't a blocking requirement to get this merged once @qgallouedec has approved - I think it would mainly be nice to include in the docs as a plot :)
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
| activation_offloading (`bool`, *optional*, defaults to `False`): | ||
| Whether to offload the activations to the CPU. |
There was a problem hiding this comment.
Could you move this to the section > Parameters that control the training 🙏
| # Disable offloading for any Liger modules | ||
| for name, module in unwrapped_model.named_modules(): | ||
| if "liger" in name.lower(): |
There was a problem hiding this comment.
Does this function support liger or not? If the user is prevented from using it with Liger, perhaps this part is useless?
Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com> Co-authored-by: DanFosing <danfoss12340@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Robert <robert.veres00@gmail.com> Co-authored-by: Robert Veres <robert.veres@languagetool.org> Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: Mathew Shen <datahonor@gmail.com> Co-authored-by: Ishan Kumar <ishankumar216@gmail.com> Co-authored-by: Huazhong Ji <hzji210@gmail.com> Co-authored-by: tpoisonooo <khj.application@aliyun.com> Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

What does this PR do?
Adapt Torchtune's activation checkpointing for HF models