Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full DPO Distributed #2275

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions recipes/configs/llama3_1/8B_full_dpo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Config for multi-device full DPO alignment in full_dpo_distributed.py
# using a Llama3.1 8B model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Meta-Llama-3.1-8B-Instruct --output-dir /tmp/Meta-Llama-3.1-8B-Instruct --ignore-patterns "original/consolidated.00.pth"
#
# To launch on 4 devices, run the following command from root:
# tune run --nnodes 1 --nproc_per_node 4 full_dpo_distributed --config llama3_1/8B_full_dpo
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run --nnodes 1 --nproc_per_node 4 full_dpo_distributed --config llama3_1/8B_full_dpo checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works best when the model is being fine-tuned on 2+ GPUs.
# For single device full DPO alignment please use llama3_1/8B_full_dpo_single_device

output_dir: /tmp/torchtune/llama3_1_8B/full_dpo # /tmp may be deleted by your system. Change it to your preference.

# Model Arguments
model:
_component_: torchtune.models.llama3_1.llama3_1_8b

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
max_seq_len: 1024 # higher increases memory

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
checkpoint_files: [
model-00001-of-00004.safetensors,
model-00002-of-00004.safetensors,
model-00003-of-00004.safetensors,
model-00004-of-00004.safetensors
]
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False

ref_checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3.1-8B-Instruct/
checkpoint_files: [
model-00001-of-00004.safetensors,
model-00002-of-00004.safetensors,
model-00003-of-00004.safetensors,
model-00004-of-00004.safetensors
]
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3

# Dataset and Sampler
dataset:
_component_: torchtune.datasets.stack_exchange_paired_dataset
seed: null
shuffle: True
batch_size: 4

# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.05
lr: 1e-6
lr_scheduler:
_component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
num_warmup_steps: 100

loss:
_component_: torchtune.rlhf.loss.DPOLoss
beta: 0.05
label_smoothing: 0

# Training
epochs: 1
max_steps_per_epoch: 1000
gradient_accumulation_steps: 8 # Use to increase effective batch size
compile: False # torch.compile the model + loss, True increases speed + decreases memory

# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True

# Environment
device: cuda
dtype: bf16

# Memory management
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory
Loading