Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full DPO Distributed #2275

Merged
merged 39 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
b96255b
full dpo configs, distributed recipe, and integration tests
sam-pi Jan 17, 2025
761b718
disable dropout, ref model setup, minor doc update
sam-pi Jan 23, 2025
753e822
Merge remote-tracking branch 'upstream/main' into HEAD
SalmanMohammadi Jan 28, 2025
0f90093
updating full recipe
SalmanMohammadi Jan 28, 2025
ebed89c
updating recipe
SalmanMohammadi Jan 30, 2025
aff595f
removing 70B full dpo config until multi-node support is available
sam-pi Jan 30, 2025
431f269
minor update to avoid _ref_model self reference
sam-pi Jan 30, 2025
c63e9e8
clean up rank zero logs and ref_checkpointer
sam-pi Jan 30, 2025
ebf288a
remove unncessary save/load test and update to 2 GPUs
sam-pi Jan 30, 2025
ba12bb4
fix: Metrics weren't running and synced across devices
bogdansalyp Jan 30, 2025
6139096
fix: Fixed tokens_per_second_per_gpu
bogdansalyp Jan 31, 2025
2a4ca92
fix: Fixed torch.distributed naming
bogdansalyp Jan 31, 2025
7f94b07
fix: tokens_per_second_pre_gpu fixed for full dpo
bogdansalyp Jan 31, 2025
1a673df
fix: Added running metrics to full_dpo_distributed
bogdansalyp Jan 31, 2025
16821c4
Merge pull request #2 from bogdansalyp/fix/running_metrics_and_sync_l…
sam-pi Feb 1, 2025
d052271
fix: num_tokens all_reduce crash in DPO recipes
bogdansalyp Feb 3, 2025
f9fedc4
Merge pull request #3 from bogdansalyp/fix/num_tokens-tensor-issue
sam-pi Feb 3, 2025
a0ac5aa
delete ref logits and improved default full dpo config
sam-pi Feb 3, 2025
5af4bed
remove 70B full DPO for now
sam-pi Feb 4, 2025
7dd27bb
Update recipes/full_dpo_distributed.py
sam-pi Feb 4, 2025
86e7639
Update recipes/full_dpo_distributed.py
sam-pi Feb 4, 2025
fb228c6
minor docs update and config comment
sam-pi Feb 4, 2025
85ce53d
fix linting
sam-pi Feb 5, 2025
ad7de59
explicitly specify 2 GPUs for full DPO test
sam-pi Feb 6, 2025
4e67853
reduce dpo test VRAM usage
sam-pi Feb 6, 2025
4667adf
Update recipes/configs/llama3_1/8B_full_dpo.yaml
SalmanMohammadi Feb 6, 2025
f2e9f47
Update docs/source/recipes/dpo.rst
SalmanMohammadi Feb 6, 2025
99a87e8
use get_lr, remove ac_mode, add clip_grad_norm
sam-pi Feb 6, 2025
8e46c93
average running loss across ranks
sam-pi Feb 6, 2025
716efca
adding opt in bwd
SalmanMohammadi Feb 6, 2025
eef1b01
fixing test
SalmanMohammadi Feb 6, 2025
46c59ec
removing grad clip
SalmanMohammadi Feb 6, 2025
6781894
Updating test
SalmanMohammadi Feb 6, 2025
4fd18bf
fixing test... round 2
SalmanMohammadi Feb 7, 2025
9c71083
fixing typo
SalmanMohammadi Feb 7, 2025
b293ce3
updating distributed test to correctly resume from checkpoint
SalmanMohammadi Feb 7, 2025
893398e
updating tests and recipe
SalmanMohammadi Feb 7, 2025
abbdf11
revert optimizer_in_bwd update for now
sam-pi Feb 7, 2025
d463e70
activations handling only for policy forward pass
sam-pi Feb 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions recipes/configs/llama3_1/8B_full_dpo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Config for multi-device full DPO alignment in full_dpo_distributed.py
# using a Llama3.1 8B model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth"
#
# To launch on 4 devices, run the following command from root:
# tune run --nnodes 1 --nproc_per_node 4 full_dpo_distributed --config llama3_1/8B_full_dpo
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run --nnodes 1 --nproc_per_node 4 full_dpo_distributed --config llama3_1/8B_full_dpo checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works best when the model is being fine-tuned on 2+ GPUs.
# For single device full DPO alignment please use llama3_1/8B_full_dpo_single_device
SalmanMohammadi marked this conversation as resolved.
Show resolved Hide resolved

output_dir: /tmp/torchtune/llama3_1_8B/full_dpo # /tmp may be deleted by your system. Change it to your preference.

# Model Arguments
model:
_component_: torchtune.models.llama3_1.llama3_1_8b

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
max_seq_len: 1024 # higher increases memory

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
checkpoint_files: [
model-00001-of-00004.safetensors,
model-00002-of-00004.safetensors,
model-00003-of-00004.safetensors,
model-00004-of-00004.safetensors
]
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False

ref_checkpointer:
SalmanMohammadi marked this conversation as resolved.
Show resolved Hide resolved
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
checkpoint_files: [
model-00001-of-00004.safetensors,
model-00002-of-00004.safetensors,
model-00003-of-00004.safetensors,
model-00004-of-00004.safetensors
]
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3

# Dataset and Sampler
dataset:
_component_: torchtune.datasets.stack_exchange_paired_dataset
seed: null
shuffle: True
batch_size: 4

# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.05
lr: 2e-5
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 20

loss:
_component_: torchtune.rlhf.loss.DPOLoss
beta: 0.05
label_smoothing: 0

# Training
epochs: 1
max_steps_per_epoch: 1000
gradient_accumulation_steps: 8 # Use to increase effective batch size
compile: False # torch.compile the model + loss, True increases speed + decreases memory

# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True

# Environment
device: cuda
dtype: bf16

# Memory management
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory
Loading