Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full DPO Distributed #2275

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions recipes/configs/llama3_1/70B_full_dpo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Config for multi-device full DPO alignment in full_dpo_distributed.py
# using a Llama3.1 70B model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Meta-Llama-3.1-70B-Instruct --output-dir /tmp/Meta-Llama-3.1-70B-Instruct --ignore-patterns "original/consolidated.00.pth"
#
# To launch on 2 devices, run the following command from root:
# tune run --nnodes 1 --nproc_per_node 2 full_dpo_distributed --config llama3_1/70B_full_dpo
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run --nnodes 1 --nproc_per_node 2 full_dpo_distributed --config llama3_1/70B_full_dpo checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you mentioned you trained on 2 nodes it'd be good to add the command you used here.

Seperately, I'm going to try see if I can find a config that can train on a single node with reasonable speeds.

Copy link
Author

@sam-pi sam-pi Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked into running this on 1 node and I couldn't find a way to get it to fit - if you do please feel free to update. Otherwise, maybe it's not worth including this 70B_full_dpo.yaml in the PR since technically I only got this working with some custom scripts using sbatch and torchrun with --nnodes 2.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know I'm late to this discussion, but at least for now I would leave out any config that cannot run on a single node. Now that #2301 is open, we do have a playbook on how to run our recipes on multiple nodes. At the same time, I don't want us to be in the business of maintaining a bunch of separate slurm scripts for every recipe. So the way I would sequence this is:

  1. Land this PR without the 70B config (but keep it in our back pocket)
  2. Figure out whether we can generalize our slurm script to be parametrized by recipe/config (it seems feasible to me, but admittedly I haven't tried and haven't used slurm in a while)
  3. If (2) works, add in 70B_full_dpo.yaml with similar run instructions to what's in Multinode support in torchtune #2301's 70B_full_multinode.yaml

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separately, at least for 70B full finetune, we can fit on a single node with CPU offload (see the config fsdp_cpu_offload). Not sure if it's sufficient here (or the perf implications). There is also optimizer-in-backward and 8-bit optimizers (maybe model quality implications for the latter though). And while I'm leaving random suggestions.. if we are gonna do a 70B Llama model, why not 3.3?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I removed the 70B config for now! Fair point on using 3.3 - I stuck to 3.1 to keep it simple for now and I hope it could be adapted relatively easily to 3.3.

#
# This config works best when the model is being fine-tuned on 2+ nodes with 8 H100s.

output_dir: /tmp/torchtune/llama3_1_70B/full_dpo # /tmp may be deleted by your system. Change it to your preference.

# Model Arguments
model:
_component_: torchtune.models.llama3_1.llama3_1_70b

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Meta-Llama-3.1-70B-Instruct/original/tokenizer.model
max_seq_len: 1024 # higher increases memory

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: "00030"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False

ref_checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-70B-Instruct/
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: "00030"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3

# Dataset and Sampler
dataset:
_component_: torchtune.datasets.stack_exchange_paired_dataset
seed: null
shuffle: True
batch_size: 4

# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.05
lr: 1e-6
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 100

loss:
_component_: torchtune.rlhf.loss.DPOLoss
beta: 0.05
label_smoothing: 0

# Training
epochs: 1
max_steps_per_epoch: 1000
gradient_accumulation_steps: 8 # Use to increase effective batch size
compile: False # torch.compile the model + loss, True increases speed + decreases memory

# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True

# Environment
device: cuda
dtype: bf16

# Memory management
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory
99 changes: 99 additions & 0 deletions recipes/configs/llama3_1/8B_full_dpo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Config for multi-device full DPO alignment in full_dpo_distributed.py
# using a Llama3.1 8B model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth"
#
# To launch on 4 devices, run the following command from root:
# tune run --nnodes 1 --nproc_per_node 4 full_dpo_distributed --config llama3_1/8B_full_dpo
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run --nnodes 1 --nproc_per_node 4 full_dpo_distributed --config llama3_1/8B_full_dpo checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works best when the model is being fine-tuned on 2+ GPUs.
# For single device full DPO alignment please use llama3_1/8B_full_dpo_single_device

output_dir: /tmp/torchtune/llama3_1_8B/full_dpo # /tmp may be deleted by your system. Change it to your preference.

# Model Arguments
model:
_component_: torchtune.models.llama3_1.llama3_1_8b

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
max_seq_len: 1024 # higher increases memory

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
checkpoint_files: [
model-00001-of-00004.safetensors,
model-00002-of-00004.safetensors,
model-00003-of-00004.safetensors,
model-00004-of-00004.safetensors
]
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False

ref_checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
checkpoint_files: [
model-00001-of-00004.safetensors,
model-00002-of-00004.safetensors,
model-00003-of-00004.safetensors,
model-00004-of-00004.safetensors
]
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3

# Dataset and Sampler
dataset:
_component_: torchtune.datasets.stack_exchange_paired_dataset
seed: null
shuffle: True
batch_size: 4

# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.05
lr: 1e-6
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 100

loss:
_component_: torchtune.rlhf.loss.DPOLoss
beta: 0.05
label_smoothing: 0

# Training
epochs: 1
max_steps_per_epoch: 1000
gradient_accumulation_steps: 8 # Use to increase effective batch size
compile: False # torch.compile the model + loss, True increases speed + decreases memory

# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True

# Environment
device: cuda
dtype: bf16

# Memory management
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory
Loading
Loading