Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full DPO Distributed #2275

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open

Full DPO Distributed #2275

wants to merge 18 commits into from

Conversation

sam-pi
Copy link

@sam-pi sam-pi commented Jan 17, 2025

Context

Adapted from the great work in #1966

What is the purpose of this PR? Is it to

  • add a new feature

Please link to any issues this PR addresses: relates to #2082

Changelog

What are the changes made in this PR?

  • Adds full DPO distributed training configs and recipes, adapting from the lora DPO training
  • Includes integration tests
  • Includes configs for llama3.1 8B and 70B models

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API

Commands and Sample Outputs

Full DPO Config

output_dir: .../Meta-Llama-3.1-8B-Instruct/full_dpo
model:
  _component_: torchtune.models.llama3_1.llama3_1_8b
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: .../Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
  max_seq_len: 1024
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: .../Meta-Llama-3.1-8B-Instruct
  checkpoint_files:
  - model-00001-of-00004.safetensors
  - model-00002-of-00004.safetensors
  - model-00003-of-00004.safetensors
  - model-00004-of-00004.safetensors
  recipe_checkpoint: null
  output_dir: ${output_dir}
  model_type: LLAMA3
resume_from_checkpoint: false
ref_checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: .../Meta-Llama-3.1-8B-Instruct
  checkpoint_files:
  - model-00001-of-00004.safetensors
  - model-00002-of-00004.safetensors
  - model-00003-of-00004.safetensors
  - model-00004-of-00004.safetensors
  recipe_checkpoint: null
  output_dir: ${output_dir}
  model_type: LLAMA3
dataset:
  _component_: torchtune.datasets.stack_exchange_paired_dataset
seed: null
shuffle: true
batch_size: 4
optimizer:
  _component_: torch.optim.AdamW
  fused: true
  weight_decay: 0.05
  lr: 1.0e-06
lr_scheduler:
  _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
  num_warmup_steps: 100
loss:
  _component_: torchtune.rlhf.loss.DPOLoss
  beta: 0.05
  label_smoothing: 0
epochs: 1
max_steps_per_epoch: 2000
gradient_accumulation_steps: 4
compile: false
metric_logger:
  _component_: torchtune.training.metric_logging.WandBLogger
  log_dir: ${output_dir}/logs
  project: torchtune
  name: llama3.1-8B-dpo_3605
log_every_n_steps: 1
log_peak_memory_stats: true
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false

Lora DPO Config

output_dir: .../Meta-Llama-3.1-8B-Instruct/lora_dpo
model:
  _component_: torchtune.models.llama3_1.lora_llama3_1_8b
  lora_attn_modules:
  - q_proj
  - v_proj
  - output_proj
  apply_lora_to_mlp: true
  apply_lora_to_output: false
  lora_rank: 256
  lora_alpha: 256
  lora_dropout: 0.0
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: .../Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
  max_seq_len: 1024
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: .../Meta-Llama-3.1-8B-Instruct
  checkpoint_files:
  - model-00001-of-00004.safetensors
  - model-00002-of-00004.safetensors
  - model-00003-of-00004.safetensors
  - model-00004-of-00004.safetensors
  recipe_checkpoint: null
  output_dir: ${output_dir}
  model_type: LLAMA3
resume_from_checkpoint: false
save_adapter_weights_only: false
dataset:
  _component_: torchtune.datasets.stack_exchange_paired_dataset
seed: null
shuffle: true
batch_size: 4
optimizer:
  _component_: torch.optim.AdamW
  fused: true
  weight_decay: 0.05
  lr: 1.0e-05
lr_scheduler:
  _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
  num_warmup_steps: 100
loss:
  _component_: torchtune.rlhf.loss.DPOLoss
  beta: 0.1
  label_smoothing: 0
epochs: 1
max_steps_per_epoch: 100
gradient_accumulation_steps: 4
compile: false
metric_logger:
  _component_: torchtune.training.metric_logging.WandBLogger
  log_dir: ${output_dir}/logs
  project: torchtune
  name: llama3.1-8Blora-dpo_3603
log_every_n_steps: 1
log_peak_memory_stats: true
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false
Screenshot 2025-01-16 at 12 39 23 PM

Copy link

pytorch-bot bot commented Jan 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2275

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 17, 2025
@sam-pi
Copy link
Author

sam-pi commented Jan 17, 2025

@joecummings Please take a look and let me know if you have feedback!

@SalmanMohammadi
Copy link
Collaborator

SalmanMohammadi commented Jan 20, 2025

Hey @sam-pi! Thanks so much for adding this. I had a quick skim through and it looked good to me. I'll have a closer look soon. First, a couple of high level points.

Did you manage to train using these configs? If so, could you attach some evidence of successful runs (e.g. WandB links)?

I'm particularly interested in the hardware requirements for the 70B config. We may want to think about offering some additional memory performance improvements for this recipe in particular, such as different parallelization configurations for the reference model (which doesn't need gradients to be sharded), offloading the entire reference model to CPU, etc.

@sam-pi
Copy link
Author

sam-pi commented Jan 21, 2025

@SalmanMohammadi Please take a look at my training run screenshots and configs at the bottom of the PR summary (I tried re-uploading the screenshot of my WandB run). I tried showing a comparison of a rank/alpha 256 lora dpo run against a full dpo run (only 100 iterations).
For Llama3.1-70B-Instruct, I was able to run using 2 nodes with 8x H100 GPUs (I think this is just 2x the HW requirements for running a single non-quantized 70B).

@RdoubleA RdoubleA mentioned this pull request Jan 21, 2025
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run --nnodes 1 --nproc_per_node 2 full_dpo_distributed --config llama3_1/70B_full_dpo checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you mentioned you trained on 2 nodes it'd be good to add the command you used here.

Seperately, I'm going to try see if I can find a config that can train on a single node with reasonable speeds.

Copy link
Author

@sam-pi sam-pi Jan 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked into running this on 1 node and I couldn't find a way to get it to fit - if you do please feel free to update. Otherwise, maybe it's not worth including this 70B_full_dpo.yaml in the PR since technically I only got this working with some custom scripts using sbatch and torchrun with --nnodes 2.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know I'm late to this discussion, but at least for now I would leave out any config that cannot run on a single node. Now that #2301 is open, we do have a playbook on how to run our recipes on multiple nodes. At the same time, I don't want us to be in the business of maintaining a bunch of separate slurm scripts for every recipe. So the way I would sequence this is:

  1. Land this PR without the 70B config (but keep it in our back pocket)
  2. Figure out whether we can generalize our slurm script to be parametrized by recipe/config (it seems feasible to me, but admittedly I haven't tried and haven't used slurm in a while)
  3. If (2) works, add in 70B_full_dpo.yaml with similar run instructions to what's in Multinode support in torchtune #2301's 70B_full_multinode.yaml

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separately, at least for 70B full finetune, we can fit on a single node with CPU offload (see the config fsdp_cpu_offload). Not sure if it's sufficient here (or the perf implications). There is also optimizer-in-backward and 8-bit optimizers (maybe model quality implications for the latter though). And while I'm leaving random suggestions.. if we are gonna do a 70B Llama model, why not 3.3?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I removed the 70B config for now! Fair point on using 3.3 - I stuck to 3.1 to keep it simple for now and I hope it could be adapted relatively easily to 3.3.

@EugenHotaj
Copy link
Contributor

Any updates on merging this to main? Really excited to use it 😄

_,
) = self.concatenated_forward(self._ref_model, batch)

loss, chosen_rewards, rejected_rewards = self._loss_fn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another heads up: we log these below but we're not taking GAS into account.

(lmk if these comments are unhelpful btw and I'll stop 🙂 -- just trying to get this PR to run / verify on our setup and commenting as I find discrepancies)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(lmk if these comments are unhelpful btw and I'll stop 🙂 -- just trying to get this PR to run / verify on our setup and commenting as I find discrepancies)

Not at all, your comments are incredibly helpful and more than welcome! Thanks for taking the time to help review.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another heads up: we log these below but we're not taking GAS into account.

noob q: what's GAS?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gradient Accumulation Steps

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah you're totally right. We should update to correct for gradient accumulation steps.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case I assume the same holds for the LoRA DPO recipe too, right?

self._resume_from_checkpoint = cfg.resume_from_checkpoint
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False)
self._clip_grad_norm = cfg.get("clip_grad_norm", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we're missing the actual grad clipping logic in the train step.

@EugenHotaj
Copy link
Contributor

With the changes I mentioned in the comments I was able to get parity with NeMo's DPO using the same data / hparams. E.g. here's the loss curves:

Screenshot 2025-01-29 at 5 06 40 PM

Really awesome work! Pretty excited to use this.

@SalmanMohammadi
Copy link
Collaborator

SalmanMohammadi commented Jan 30, 2025

Any updates on merging this to main? Really excited to use it 😄

I'm going to try out investigate some alternative sharding strategies for the reference model, and see if I can get single-node training working for 70B. Will update soon. @sam-pi would you be up for looking into @EugenHotaj's comments above?

@SalmanMohammadi
Copy link
Collaborator

OK so we're not blocking this PR I'm going to leave exploring different parallelism strategies for a follow-up. Let's make the necessary fixes to this recipe and bring it in line with our other distributed recipes.

@sam-pi If the 70B config doesn't work on a single node, I'd also suggest we remove it for now and add it back in after patching in the changes from #2301. What do you think?

# formed by concatenating an equal number of "chosen" and "rejected".
len_chosen = concatenated_input_ids.shape[0] // 2

all_logits = model(concatenated_input_ids)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One way to reduce memory and potentially fit this on a single node is to call model(...) twice. Right now we're effectively doubling the batch size here, and might be causing OOMs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One scenario I can think of is if a user is OOMing and can't go below an effective batch size of 2 (by configuring a batch size of 1). I'd be interested in seeing the tradeoff here vs. the additional computation from two extra model forward passes (both the policy and reference models) - though I have a feeling the memory savings may not be worth it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's a great idea - but I will plan to leave it out of this PR if that's alright. That seems like a useful addition to all DPO recipes

@sam-pi
Copy link
Author

sam-pi commented Jan 30, 2025

OK so we're not blocking this PR I'm going to leave exploring different parallelism strategies for a follow-up. Let's make the necessary fixes to this recipe and bring it in line with our other distributed recipes.

@sam-pi If the 70B config doesn't work on a single node, I'd also suggest we remove it for now and add it back in after patching in the changes from #2301. What do you think?

Thanks, I will look into all these fixes today and also remove the 70B config for now

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is looking great, thanks so much for adding this @sam-pi! Aside from my inline comments it'd be good to confirm that various features like compile, optimizer-in-backward, etc are working and doing what we'd expect (we can even add e.g. compile to the recipe test)

# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run --nnodes 1 --nproc_per_node 2 full_dpo_distributed --config llama3_1/70B_full_dpo checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know I'm late to this discussion, but at least for now I would leave out any config that cannot run on a single node. Now that #2301 is open, we do have a playbook on how to run our recipes on multiple nodes. At the same time, I don't want us to be in the business of maintaining a bunch of separate slurm scripts for every recipe. So the way I would sequence this is:

  1. Land this PR without the 70B config (but keep it in our back pocket)
  2. Figure out whether we can generalize our slurm script to be parametrized by recipe/config (it seems feasible to me, but admittedly I haven't tried and haven't used slurm in a while)
  3. If (2) works, add in 70B_full_dpo.yaml with similar run instructions to what's in Multinode support in torchtune #2301's 70B_full_multinode.yaml

# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run --nnodes 1 --nproc_per_node 2 full_dpo_distributed --config llama3_1/70B_full_dpo checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separately, at least for 70B full finetune, we can fit on a single node with CPU offload (see the config fsdp_cpu_offload). Not sure if it's sufficient here (or the perf implications). There is also optimizer-in-backward and 8-bit optimizers (maybe model quality implications for the latter though). And while I'm leaving random suggestions.. if we are gonna do a 70B Llama model, why not 3.3?


# Train for two epochs
cmd_1 = f"""
tune run --nnodes 1 --nproc_per_node 1 full_dpo_distributed \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels a bit weird.. why are we testing a distributed recipe on a single device (similar comment for the other commands in this file)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good point, I more or less copied this from the lora DPO testing - I will look into updating it

Comment on lines 113 to 114
# epoch_folder = get_largest_iter_folder(tmpdir)
# epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are these commented out? Can they be removed?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will remove, thanks for catching that. These are copied from the lora DPO tests

)

@pytest.mark.integration_test
def test_save_and_load_weights(self, tmpdir, monkeypatch):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test I don't fully understand.. it makes sense why we would do something like this for LoRA where we are merging the weights in the final checkpoint. But here the model arch is the same, right? Do we see it as likely that something will go wrong during save and load (that's not already accounted for by the resume from checkpoint test above)?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was guessing it's good practice to make sure save/load works in general, but happy to remove if it doesn't make sense.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah mainly I am cognizant of us not having too too many recipe tests (they take a bit of time to run and run on every PR). In this case I claim save and load is already pretty well-covered by test_training_state_on_resume

self._ref_model = self._setup_reference_model(
cfg_model=cfg.model,
fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False),
reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be worth running some memory profiling on this recipe (especially since it's already enabled). Like it seems to me that by setting reshard_after_forward=True here, we never have both the reference model and the policy model weights gathered at the same time. Is that the correct understanding? If so, worth confirming that it happens in practice (especially given the discussion around fitting on a single node)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sam-pi you mentioned you had to set this to True, right? What did you find?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was finding if this was set to False I was getting OOM issues. I didn't yet debug further than note that it just works for me when set to True.

# deleting logits here helps reduce (peak) memory usage - we only need them for metric logging
del policy_chosen_logits, policy_rejected_logits

with torch.no_grad():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder whether we can just run in inference mode? Cause reference model never has grad updates so stuff like view tracking etc afforded by no_grad shouldn't be relevant, right? Lmk if I'm way off base, otherwise worth a try imo.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that by running model.eval() we are setting inference mode - am I mistaken?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model.eval mainly disables behaviour like dropout and batch norm. I think setting no grad on the prams is probably overkill though.

I've had issues with inference mode causing additional recompiles in the inlined compiled transformer layers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've had issues with inference mode causing additional recompiles in the inlined compiled transformer layers.

Ah fair enough. Yeah mainly I was thinking it could potentially give some minor speedups. But if it messes with compile then agree it's not worth it

_,
) = self.concatenated_forward(self._ref_model, batch)

loss, chosen_rewards, rejected_rewards = self._loss_fn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case I assume the same holds for the LoRA DPO recipe too, right?


# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
self._optimizer.step()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we've actually enabled optimizer in backward here either

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sam-pi relevant code snippet to enable

if not self._optimizer_in_bwd:

time_per_step = time.perf_counter() - t0
log_dict = {
"loss": loss_to_log,
"lr": self._optimizer.param_groups[0]["lr"],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we also have this utility now in case that's helpful here

Comment on lines 915 to 916
_,
_,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One final heads up: I was getting OOMs on 70B unless I deleted these logits as well. I also had to call gc.collect(); torch.cuda.clear_cache() otherwise I'd OOM mid-training sometimes.

Since all we're doing with the logits is calling .mean() below, we could return the means directly from concatenated_forward. Then you don't need to do any explicit deletion (but maybe still have to call collect(), clear_cache().

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I added in deleting the logits for now at least

@sam-pi
Copy link
Author

sam-pi commented Feb 1, 2025

@ebsmothers @EugenHotaj @SalmanMohammadi Please take a look at the updates from @bogdansalyp to fix metric syncing/averaging across ranks and accounting for gradient accumulation in metrics.

name="llama3_1/8B_full_dpo",
file_path="llama3_1/8B_full_dpo.yaml",
),
Config(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can now be removed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants