-
Notifications
You must be signed in to change notification settings - Fork 514
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Full DPO Distributed #2275
base: main
Are you sure you want to change the base?
Full DPO Distributed #2275
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2275
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@joecummings Please take a look and let me know if you have feedback! |
Hey @sam-pi! Thanks so much for adding this. I had a quick skim through and it looked good to me. I'll have a closer look soon. First, a couple of high level points. Did you manage to train using these configs? If so, could you attach some evidence of successful runs (e.g. WandB links)? I'm particularly interested in the hardware requirements for the 70B config. We may want to think about offering some additional memory performance improvements for this recipe in particular, such as different parallelization configurations for the reference model (which doesn't need gradients to be sharded), offloading the entire reference model to CPU, etc. |
@SalmanMohammadi Please take a look at my training run screenshots and configs at the bottom of the PR summary (I tried re-uploading the screenshot of my WandB run). I tried showing a comparison of a rank/alpha 256 lora dpo run against a full dpo run (only 100 iterations). |
# You can add specific overrides through the command line. For example | ||
# to override the checkpointer directory while launching training | ||
# you can run: | ||
# tune run --nnodes 1 --nproc_per_node 2 full_dpo_distributed --config llama3_1/70B_full_dpo checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you mentioned you trained on 2 nodes it'd be good to add the command you used here.
Seperately, I'm going to try see if I can find a config that can train on a single node with reasonable speeds.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looked into running this on 1 node and I couldn't find a way to get it to fit - if you do please feel free to update. Otherwise, maybe it's not worth including this 70B_full_dpo.yaml in the PR since technically I only got this working with some custom scripts using sbatch and torchrun with --nnodes 2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know I'm late to this discussion, but at least for now I would leave out any config that cannot run on a single node. Now that #2301 is open, we do have a playbook on how to run our recipes on multiple nodes. At the same time, I don't want us to be in the business of maintaining a bunch of separate slurm scripts for every recipe. So the way I would sequence this is:
- Land this PR without the 70B config (but keep it in our back pocket)
- Figure out whether we can generalize our slurm script to be parametrized by recipe/config (it seems feasible to me, but admittedly I haven't tried and haven't used slurm in a while)
- If (2) works, add in 70B_full_dpo.yaml with similar run instructions to what's in Multinode support in torchtune #2301's 70B_full_multinode.yaml
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Separately, at least for 70B full finetune, we can fit on a single node with CPU offload (see the config fsdp_cpu_offload
). Not sure if it's sufficient here (or the perf implications). There is also optimizer-in-backward and 8-bit optimizers (maybe model quality implications for the latter though). And while I'm leaving random suggestions.. if we are gonna do a 70B Llama model, why not 3.3?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I removed the 70B config for now! Fair point on using 3.3 - I stuck to 3.1 to keep it simple for now and I hope it could be adapted relatively easily to 3.3.
Any updates on merging this to main? Really excited to use it 😄 |
_, | ||
) = self.concatenated_forward(self._ref_model, batch) | ||
|
||
loss, chosen_rewards, rejected_rewards = self._loss_fn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another heads up: we log these below but we're not taking GAS into account.
(lmk if these comments are unhelpful btw and I'll stop 🙂 -- just trying to get this PR to run / verify on our setup and commenting as I find discrepancies)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(lmk if these comments are unhelpful btw and I'll stop 🙂 -- just trying to get this PR to run / verify on our setup and commenting as I find discrepancies)
Not at all, your comments are incredibly helpful and more than welcome! Thanks for taking the time to help review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another heads up: we log these below but we're not taking GAS into account.
noob q: what's GAS?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gradient Accumulation Steps
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah you're totally right. We should update to correct for gradient accumulation steps.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case I assume the same holds for the LoRA DPO recipe too, right?
self._resume_from_checkpoint = cfg.resume_from_checkpoint | ||
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps | ||
self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False) | ||
self._clip_grad_norm = cfg.get("clip_grad_norm", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like we're missing the actual grad clipping logic in the train step.
With the changes I mentioned in the comments I was able to get parity with NeMo's DPO using the same data / hparams. E.g. here's the loss curves: Really awesome work! Pretty excited to use this. |
I'm going to try out investigate some alternative sharding strategies for the reference model, and see if I can get single-node training working for 70B. Will update soon. @sam-pi would you be up for looking into @EugenHotaj's comments above? |
OK so we're not blocking this PR I'm going to leave exploring different parallelism strategies for a follow-up. Let's make the necessary fixes to this recipe and bring it in line with our other distributed recipes. @sam-pi If the 70B config doesn't work on a single node, I'd also suggest we remove it for now and add it back in after patching in the changes from #2301. What do you think? |
# formed by concatenating an equal number of "chosen" and "rejected". | ||
len_chosen = concatenated_input_ids.shape[0] // 2 | ||
|
||
all_logits = model(concatenated_input_ids) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One way to reduce memory and potentially fit this on a single node is to call model(...)
twice. Right now we're effectively doubling the batch size here, and might be causing OOMs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One scenario I can think of is if a user is OOMing and can't go below an effective batch size of 2 (by configuring a batch size of 1). I'd be interested in seeing the tradeoff here vs. the additional computation from two extra model forward passes (both the policy and reference models) - though I have a feeling the memory savings may not be worth it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that's a great idea - but I will plan to leave it out of this PR if that's alright. That seems like a useful addition to all DPO recipes
Thanks, I will look into all these fixes today and also remove the 70B config for now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is looking great, thanks so much for adding this @sam-pi! Aside from my inline comments it'd be good to confirm that various features like compile, optimizer-in-backward, etc are working and doing what we'd expect (we can even add e.g. compile to the recipe test)
# You can add specific overrides through the command line. For example | ||
# to override the checkpointer directory while launching training | ||
# you can run: | ||
# tune run --nnodes 1 --nproc_per_node 2 full_dpo_distributed --config llama3_1/70B_full_dpo checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know I'm late to this discussion, but at least for now I would leave out any config that cannot run on a single node. Now that #2301 is open, we do have a playbook on how to run our recipes on multiple nodes. At the same time, I don't want us to be in the business of maintaining a bunch of separate slurm scripts for every recipe. So the way I would sequence this is:
- Land this PR without the 70B config (but keep it in our back pocket)
- Figure out whether we can generalize our slurm script to be parametrized by recipe/config (it seems feasible to me, but admittedly I haven't tried and haven't used slurm in a while)
- If (2) works, add in 70B_full_dpo.yaml with similar run instructions to what's in Multinode support in torchtune #2301's 70B_full_multinode.yaml
# You can add specific overrides through the command line. For example | ||
# to override the checkpointer directory while launching training | ||
# you can run: | ||
# tune run --nnodes 1 --nproc_per_node 2 full_dpo_distributed --config llama3_1/70B_full_dpo checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Separately, at least for 70B full finetune, we can fit on a single node with CPU offload (see the config fsdp_cpu_offload
). Not sure if it's sufficient here (or the perf implications). There is also optimizer-in-backward and 8-bit optimizers (maybe model quality implications for the latter though). And while I'm leaving random suggestions.. if we are gonna do a 70B Llama model, why not 3.3?
|
||
# Train for two epochs | ||
cmd_1 = f""" | ||
tune run --nnodes 1 --nproc_per_node 1 full_dpo_distributed \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This feels a bit weird.. why are we testing a distributed recipe on a single device (similar comment for the other commands in this file)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah good point, I more or less copied this from the lora DPO testing - I will look into updating it
# epoch_folder = get_largest_iter_folder(tmpdir) | ||
# epoch_folder_minus_one = f"epoch_{int(epoch_folder.split('_')[-1]) - 1}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are these commented out? Can they be removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will remove, thanks for catching that. These are copied from the lora DPO tests
) | ||
|
||
@pytest.mark.integration_test | ||
def test_save_and_load_weights(self, tmpdir, monkeypatch): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test I don't fully understand.. it makes sense why we would do something like this for LoRA where we are merging the weights in the final checkpoint. But here the model arch is the same, right? Do we see it as likely that something will go wrong during save and load (that's not already accounted for by the resume from checkpoint test above)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was guessing it's good practice to make sure save/load works in general, but happy to remove if it doesn't make sense.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah mainly I am cognizant of us not having too too many recipe tests (they take a bit of time to run and run on every PR). In this case I claim save and load is already pretty well-covered by test_training_state_on_resume
self._ref_model = self._setup_reference_model( | ||
cfg_model=cfg.model, | ||
fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), | ||
reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It may be worth running some memory profiling on this recipe (especially since it's already enabled). Like it seems to me that by setting reshard_after_forward=True
here, we never have both the reference model and the policy model weights gathered at the same time. Is that the correct understanding? If so, worth confirming that it happens in practice (especially given the discussion around fitting on a single node)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sam-pi you mentioned you had to set this to True, right? What did you find?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was finding if this was set to False I was getting OOM issues. I didn't yet debug further than note that it just works for me when set to True.
# deleting logits here helps reduce (peak) memory usage - we only need them for metric logging | ||
del policy_chosen_logits, policy_rejected_logits | ||
|
||
with torch.no_grad(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder whether we can just run in inference mode? Cause reference model never has grad updates so stuff like view tracking etc afforded by no_grad
shouldn't be relevant, right? Lmk if I'm way off base, otherwise worth a try imo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought that by running model.eval() we are setting inference mode - am I mistaken?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model.eval mainly disables behaviour like dropout and batch norm. I think setting no grad on the prams is probably overkill though.
I've had issues with inference mode causing additional recompiles in the inlined compiled transformer layers.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've had issues with inference mode causing additional recompiles in the inlined compiled transformer layers.
Ah fair enough. Yeah mainly I was thinking it could potentially give some minor speedups. But if it messes with compile then agree it's not worth it
_, | ||
) = self.concatenated_forward(self._ref_model, batch) | ||
|
||
loss, chosen_rewards, rejected_rewards = self._loss_fn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case I assume the same holds for the LoRA DPO recipe too, right?
|
||
# Step with optimizer | ||
if (idx + 1) % self._gradient_accumulation_steps == 0: | ||
self._optimizer.step() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we've actually enabled optimizer in backward here either
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sam-pi relevant code snippet to enable
if not self._optimizer_in_bwd: |
time_per_step = time.perf_counter() - t0 | ||
log_dict = { | ||
"loss": loss_to_log, | ||
"lr": self._optimizer.param_groups[0]["lr"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we also have this utility now in case that's helpful here
recipes/full_dpo_distributed.py
Outdated
_, | ||
_, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One final heads up: I was getting OOMs on 70B unless I deleted these logits as well. I also had to call gc.collect(); torch.cuda.clear_cache()
otherwise I'd OOM mid-training sometimes.
Since all we're doing with the logits is calling .mean()
below, we could return the means directly from concatenated_forward
. Then you don't need to do any explicit deletion (but maybe still have to call collect(), clear_cache().
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I added in deleting the logits for now at least
…ora_dpo fix: Running metrics and tokens_per_second_per_gpu fixes for DPO recipes
@ebsmothers @EugenHotaj @SalmanMohammadi Please take a look at the updates from @bogdansalyp to fix metric syncing/averaging across ranks and accounting for gradient accumulation in metrics. |
fix: num_tokens all_reduce crash in DPO recipes
name="llama3_1/8B_full_dpo", | ||
file_path="llama3_1/8B_full_dpo.yaml", | ||
), | ||
Config( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This can now be removed.
Context
Adapted from the great work in #1966
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses: relates to #2082
Changelog
What are the changes made in this PR?
Test plan
Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example
Commands and Sample Outputs
Full DPO Config
Lora DPO Config