-
Notifications
You must be signed in to change notification settings - Fork 515
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multinode support in torchtune #2301
base: main
Are you sure you want to change the base?
Changes from all commits
bbd81fd
c04ebaf
e02d39b
c558f27
454536c
78bb2ae
66b06e1
0d5aeb4
afc9c2e
c4748a5
94440f9
deffeca
f441721
9ba9e24
b36325a
fc9afbd
373e0c0
4659938
693b8cb
3d8d73d
c0345a5
a3aaeb4
8e20394
427a290
b56b6be
63205da
76ea872
63eb274
4d027b0
34aa18b
30b7366
c7fdc21
900d643
9e230ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
.. _multinode_tutorial: | ||
|
||
===================== | ||
Multi-node finetuning | ||
===================== | ||
|
||
Congratulations! You've finally escaped the struggles of being "GPU poor" and now have access to a multi-node setup. | ||
You can bid farewell to the days of sweating over memory-efficient optimizations, but get ready for new challenges as you navigate the complexities of distributed computing. | ||
|
||
.. grid:: 2 | ||
|
||
.. grid-item-card:: :octicon:`mortar-board;1em;` You will learn: | ||
|
||
* Why multi-node training is useful | ||
* How to set up the torchtune package on a SLURM cluster | ||
* How to fine-tune a Llama3.3 70B model w/ full parameter updates (not LoRA) | ||
|
||
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites | ||
|
||
* Be familiar with distributed training in torchtune | ||
* Already know basic SLURM commands | ||
|
||
.. _advantages_multi_node_label: | ||
|
||
Advantages of multi-node training | ||
--------------------------------- | ||
|
||
More machines means more memory! This is cool for several reasons: | ||
|
||
1. **Bigger models**: With more memory, you can train larger models such as `Llama3.1 405B <https://ai.meta.com/blog/meta-llama-3-1/>`_, `Deepseek-V3 <https://www.deepseek.com/>`_, and more. | ||
2. **Longer data**: For many fine-tuning tasks like writing code, it's helpful to have long context lengths; however longer context length means more memory needed for activations. | ||
3. **Higher quality**: With more memory, you can do full parameter updates (not LoRA) and use optimizers like `AdamW <https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html>`_ (not low-precision optimizers), both of which can potentially improve the quality of your training. | ||
4. **Faster training**: With the ability to fit more data in memory, you can use higher batch sizes *and* turn off memory optimizations like :ref:`activation checkpointing<glossary_act_ckpt>` thereby decreasing the time it takes for training to complete. | ||
|
||
.. note:: | ||
|
||
**Low inter-node bandwidth & FSDP** We utilize PyTorch's **Fully Sharded Data Parallel** to distribute models over multiple devices. In order to distribute training, FSDP runs an `all-gather <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allgather>`_ operation | ||
for each forward pass and an all-gather (usually) plus a `reduce-scatter <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#reducescatter>`_ operation for each backwards pass. These operations (usually) block training from continuing until completed and with a slow | ||
inter-node connection, training speed may be reduced. For more on this, please refer to `this Github Issue <https://github.com/pytorch/pytorch/issues/102434>`_. | ||
|
||
Training Llama3.3 70B on 2 nodes | ||
-------------------------------- | ||
|
||
Let's get training! We'll be utilizing a common cluster workflow manager called `SLURM <https://slurm.schedmd.com/documentation.html>`_ and assume you have a decent working knowledge of SLURM for this tutorial. | ||
First, we need to install torchtune. Although pretty much as straightforward as the :ref:`normal install instructions<install_label>`, | ||
it's recommended that you install the package into a virtual environment that is accessible from all nodes in your cluster like a shared filesystem. | ||
|
||
Next, we need to download the `Llama3.3 70B <https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct>`_ model to your shared filesystem. You'll need to make sure you have the correct credentials following the steps | ||
outlined :ref:`here<tune_download_label>`. | ||
|
||
.. code-block:: bash | ||
|
||
$ tune download meta-llama/Llama-3.3-70B-Instruct --ignore-patterns "consolidated/*.pth" --output-dir SHARED_FS/Llama-3.3-70B-Instruct | ||
|
||
Now that we have a downloaded model, let's check out our example SLURM bash script. | ||
|
||
.. literalinclude:: ../../../recipes/full_finetune_multinode.slurm | ||
|
||
**There's a lot of information in this script but here are the high-level parts:** | ||
|
||
* We utilize SLURM specific commands like number of nodes, tasks, CPUs available, etc. | ||
* We are using `torchrun <https://pytorch.org/docs/stable/elastic/run.html>`_ and the `full_finetune_distributed <https://github.com/pytorch/torchtune/blob/main/recipes/full_finetune_distributed.py>`_ recipe to train just like on single node | ||
* You can consider several cluster-specific environment variables (``NCCL_BUFFSIZE``, ``NCCL_DEBUG``, ``FI_PROVIDER``, etc.) in order to maximize GPU utilization, debug, and more. | ||
|
||
.. note:: | ||
|
||
We may need to explicitly set the network interface for distributed backends. You can read more about `PyTorch distributed backends here <https://pytorch.org/docs/stable/distributed.html#common-environment-variables>`_ | ||
but it's also helpful to know that you can find your network interface by running `ipconfig <https://en.wikipedia.org/wiki/Ipconfig#:~:text=ipconfig%20(standing%20for%20%22Internet%20Protocol,ipconfig>`_ from a specific node. | ||
|
||
After we update the shared filesystem in the bash script, we can launch using `sbatch <https://slurm.schedmd.com/sbatch.html>`_. | ||
|
||
.. code-block:: bash | ||
|
||
sbatch full_finetune_multinode.slurm | ||
|
||
And the output of `squeue <https://slurm.schedmd.com/squeue.html>`_ should show our job running: | ||
|
||
.. code-block:: bash | ||
|
||
$ squeue | ||
JOBID PARTITION NAME USER ST TIME NODES NODELIST(REASON) | ||
1 train torchtun slurm R 0:03 2 slurm-worker-[1-2] | ||
|
||
Once training has completed, which should take roughly seven minutes in total (880 tok/s) with the default config, we can follow the :ref:`instructions here<use_model_in_wild>` in order to upload our beautiful new model to the Hugging Face Hub! | ||
|
||
Future development | ||
------------------ | ||
|
||
We've covered the basics of how to launch a fine-tuning job with SLURM on two nodes with FSDP. There's still more things we're cooking up, | ||
including... | ||
|
||
**2D parallelism**: Utilizing both FSDP *and* tensor parallelism in what is commonly referred to as `2D parallelism <https://pytorch.org/tutorials/intermediate/TP_tutorial.html>`_ will decrease memory requirements even further, allowing us to lean even harder | ||
into the advantages listed :ref:`above<advantages_multi_node_label>`. | ||
|
||
**Longer context (ring attention, etc)**: More memory and more machines means we can train on longer sequences and tag advantage of neat tricks like ring attention, where tokens are split across | ||
GPUs. You can read more about our plans for torchtune in `this Github RFC <https://github.com/pytorch/torchtune/issues/1244>`_. | ||
|
||
**Want other optimizations?** Feel free to let us know by `opening up a Github Issue <https://github.com/pytorch/torchtune/issues/new?q=sort%3Aupdated-desc+is%3Aissue+is%3Aopen&template=Blank+issue>`_ on our repo or `dropping us a line in Discord <https://discord.gg/Zsf8xgT7>`_! |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# Config for multi-node full finetuning in full_finetune_distributed.py | ||
# using a Llama3.3 70B Instruct model | ||
# | ||
# This config assumes that you've run the following command before launching: | ||
# tune download meta-llama/Llama-3.3-70B-Instruct --ignore-patterns "original/consolidated*" --output-dir SHARED_CLUSTER_FS | ||
# | ||
# To launch on 2 nodes w/ 8 devices on a SLURM cluster, run the following command: | ||
# sbatch full_finetune_multinode.slurm | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there a way to make the full_finetune_multinode.slurm takes in an argument to specify which config/model to run, instead of creating a new config for mutlinode? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Arghh, this would be a good idea. I'm leaning towards just trying to get this up there as an example of how to run since you'll really need to modify the SLURM file itself in order to set the correct number of nodes, etc. Open to thoughts though. cc @ebsmothers @pbontrager There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah I agree with @acisseJZhong's suggestion. Also I think the concept of recipes + configs breaks down a bit here. I think we should either very explicitly say "this is just a demo and is not a real recipe" (i.e. we don't even list it in recipes), or we should properly integrate with tune run -- i.e. if one specifies There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Bleh There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make them copy it? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay will make them copy it for now and not add to recipe registry, but I will keep the script there. |
||
# | ||
# This config is only tested on 2 nodes w/ 8 H100 machines. | ||
|
||
output_dir: /tmp/torchtune/llama3_3_70B/full | ||
|
||
# Tokenizer | ||
tokenizer: | ||
_component_: torchtune.models.llama3.llama3_tokenizer | ||
path: /tmp/Llama-3.3-70B-Instruct/original/tokenizer.model | ||
max_seq_len: 1024 | ||
|
||
# Dataset | ||
dataset: | ||
_component_: torchtune.datasets.alpaca_dataset | ||
packed: True # True increases speed | ||
seed: null | ||
shuffle: True | ||
|
||
# Model Arguments | ||
model: | ||
_component_: torchtune.models.llama3_3.llama3_3_70b | ||
|
||
checkpointer: | ||
_component_: torchtune.training.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Llama-3.3-70B-Instruct/ | ||
checkpoint_files: | ||
filename_format: model-{}-of-{}.safetensors | ||
max_filename: "00030" | ||
recipe_checkpoint: null | ||
output_dir: ${output_dir} | ||
model_type: LLAMA3 | ||
resume_from_checkpoint: False | ||
|
||
# Fine-tuning arguments | ||
batch_size: 4 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can u go bigger though |
||
epochs: 1 | ||
|
||
optimizer: | ||
_component_: torch.optim.AdamW | ||
lr: 2e-5 | ||
fused: True | ||
|
||
loss: | ||
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss | ||
max_steps_per_epoch: null | ||
gradient_accumulation_steps: 1 # Use to increase effective batch size | ||
|
||
|
||
# Training env | ||
device: cuda | ||
|
||
# Memory management | ||
enable_activation_checkpointing: True # True reduces memory | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we disable on two nodes? |
||
enable_activation_offloading: False # True reduces memory | ||
custom_sharded_layers: ['tok_embeddings', 'output'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed. | ||
fsdp_cpu_offload: False | ||
clip_grad_norm: null | ||
compile: True # torch.compile the model + loss, True increases speed + decreases memory | ||
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 | ||
|
||
# Reduced precision | ||
dtype: bf16 | ||
|
||
# Logging | ||
metric_logger: | ||
_component_: torchtune.training.metric_logging.DiskLogger | ||
log_dir: ${output_dir}/logs | ||
log_every_n_steps: 1 | ||
log_peak_memory_stats: True | ||
|
||
# Profiler (disabled) | ||
profiler: | ||
_component_: torchtune.training.setup_torch_profiler | ||
enabled: False | ||
|
||
#Output directory of trace artifacts | ||
output_dir: ${output_dir}/profiling_outputs | ||
|
||
#`torch.profiler.ProfilerActivity` types to trace | ||
cpu: True | ||
cuda: True | ||
|
||
#trace options passed to `torch.profiler.profile` | ||
profile_memory: False | ||
with_stack: False | ||
record_shapes: True | ||
with_flops: False | ||
|
||
# `torch.profiler.schedule` options: | ||
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat | ||
wait_steps: 5 | ||
warmup_steps: 3 | ||
active_steps: 2 | ||
num_cycles: 1 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -118,31 +118,40 @@ class FullFinetuneRecipeDistributed(FTRecipeInterface): | |
""" | ||
|
||
def __init__(self, cfg: DictConfig) -> None: | ||
self._device = utils.get_device(device=cfg.device) | ||
device_type = cfg.device | ||
self._device = utils.get_device(device=device_type) | ||
self._dtype = training.get_dtype(cfg.dtype, device=self._device) | ||
|
||
if self._dtype == torch.float16: | ||
raise ValueError( | ||
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." | ||
) | ||
|
||
# logging attributes | ||
# Set up the backend for distributed training (NCCL, GLOO, etc.) | ||
self._enable_async_checkpointing = cfg.get("enable_async_checkpointing", False) | ||
self.fsdp_cpu_offload = cfg.get("fsdp_cpu_offload", False) | ||
self.distributed_backend = training.get_distributed_backend( | ||
device_type, | ||
offload_ops_to_cpu=self.fsdp_cpu_offload | ||
or self._enable_async_checkpointing, | ||
) | ||
init_process_group(self.distributed_backend) | ||
_, rank = utils.get_world_size_and_rank() | ||
self._is_rank_zero = rank == 0 | ||
|
||
# Logging attributes | ||
self._output_dir = cfg.output_dir | ||
self._log_every_n_steps = cfg.get("log_every_n_steps", 1) | ||
self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) | ||
|
||
if self._log_peak_memory_stats and self._device.type != "cuda": | ||
if self._log_peak_memory_stats and device_type != "cuda": | ||
log.info( | ||
"log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False." | ||
) | ||
self._log_peak_memory_stats = False | ||
|
||
_, rank = utils.get_world_size_and_rank() | ||
self._is_rank_zero = rank == 0 | ||
|
||
# Training cfg | ||
self._resume_from_checkpoint = cfg.resume_from_checkpoint | ||
self._enable_async_checkpointing = cfg.get("enable_async_checkpointing", False) | ||
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps | ||
self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False) | ||
self._clip_grad_norm = cfg.get("clip_grad_norm", None) | ||
|
@@ -169,7 +178,7 @@ def __init__(self, cfg: DictConfig) -> None: | |
"enable_activation_offloading", False | ||
) | ||
if self._enable_activation_offloading: | ||
if self._device.type != "cuda": | ||
if device_type != "cuda": | ||
raise RuntimeError( | ||
"enable_activation_offloading should only be True when training on CUDA" | ||
) | ||
|
@@ -240,9 +249,13 @@ def setup(self, cfg: DictConfig) -> None: | |
Setup the recipe. This includes training state (if resume_from_checkpoint is True), | ||
model, tokenizer, loss, optimizer, lr scheduler, sampler, and dataloader. | ||
""" | ||
if self.fsdp_cpu_offload: | ||
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x | ||
# speed up when benchmarking fused AdamW on CPU | ||
training.set_torch_num_threads() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we always want to set this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a good point. Looks like this was added by Rohan, so not sure who to follow up with here. Let me dig into it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah this is a heuristic for fused Adam on CPU when CPU offload is enabled. I don't think it's optimal, but I do think that without it CPU offload training may be much slower There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should it be set for async offload too? Or pure CPU training? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Afaik it shouldn't matter for async offload and mostly has to do with fused optimizer. For pure CPU training I guess the optimizer step also happens on CPU so in that case we would potentially want it |
||
|
||
if self._is_rank_zero: | ||
self._metric_logger = config.instantiate(cfg.metric_logger) | ||
|
||
# log config with parameter override | ||
self._metric_logger.log_config(cfg) | ||
|
||
|
@@ -255,7 +268,7 @@ def setup(self, cfg: DictConfig) -> None: | |
enable_activation_checkpointing=self._enable_activation_checkpointing, | ||
enable_activation_offloading=self._enable_activation_offloading, | ||
custom_sharded_layers=cfg.get("custom_sharded_layers", None), | ||
fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), | ||
fsdp_cpu_offload=self.fsdp_cpu_offload, | ||
reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), | ||
model_state_dict=checkpoint_dict[training.MODEL_KEY], | ||
ac_mode=cfg.get("ac_mode", None), | ||
|
@@ -890,19 +903,7 @@ def recipe_main(cfg: DictConfig) -> None: | |
- Parameters specified in config (see available configs through ``tune ls``) | ||
- Overwritten by arguments from the command-line | ||
""" | ||
if not training.is_distributed(): | ||
raise RuntimeError( | ||
"Distributed finetune recipe should be run via a distributed launcher." | ||
"If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" | ||
) | ||
init_process_group("cuda:nccl,cpu:gloo") | ||
if cfg.get("fsdp_cpu_offload", False): | ||
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x | ||
# speed up when benchmarking fused AdamW on CPU | ||
training.set_torch_num_threads() | ||
|
||
config.log_config(recipe_name="FullFinetuneRecipeDistributed", cfg=cfg) | ||
|
||
recipe = FullFinetuneRecipeDistributed(cfg=cfg) | ||
recipe.setup(cfg=cfg) | ||
recipe.train() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
#!/bin/bash | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
|
||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# ---------- SBATCH commands ---------- # | ||
#SBATCH --job-name=torchtune-multi-node | ||
#SBATCH --ntasks=2 | ||
#SBATCH --nodes=2 | ||
#SBATCH --gpus-per-task=8 | ||
#SBATCH --cpus-per-task=96 | ||
#SBATCH --partition=train | ||
|
||
# ---------- Set env variables ---------- # | ||
# Grab the IP for head node: | ||
# You may need to set this to the fully qualified domain name of your head node | ||
nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) | ||
nodes_array=($nodes) | ||
head_node=${nodes_array[0]} | ||
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) | ||
echo Node IP: $head_node_ip | ||
|
||
# You might need to explicitly set the network interface for distributed backends: | ||
# export NCCL_SOCKET_IFNAME=... | ||
# export GLOO_SOCKET_IFNAME=... | ||
|
||
export TORCH_DIST_INIT_BARRIER=1 | ||
export LOGLEVEL=INFO | ||
|
||
# ---------- Launch training ---------- # | ||
# You probably want to load in a virtual env w/ conda... | ||
# module load conda | ||
# conda activate torchtune | ||
# ...or venv | ||
# source torchtune/bin/activate | ||
Comment on lines
+33
to
+37
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this commented out? Is it because we don't know the user's venv/conda env? I remember wasting a bunch of time myself on this kinda stuff before, might be worth explicitly calling it out in the tutorial There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, we can't make any assumptions about how they initialize their virtual env There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sidenote: |
||
|
||
SHARED_FS=/mnt/slurm # <-- Replace w/ your filesystem | ||
CHECKPOINT_DIR="$SHARED_FS/Llama-3.3-70B-Instruct" | ||
OUTPUT_DIR="$SHARED_FS/Llama3.3-70B-fft-output" | ||
|
||
# Adjust sbatch --ntasks and sbatch --nodes above and --nnodes below to your specific node count | ||
srun tune run --nnodes 2 --nproc_per_node 8 --rdzv_id 101 --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:29500" \ | ||
full_finetune_distributed --config llama3_3/70B_full_multinode checkpoint_dir=$CHECKPOINT_DIR output_dir=$OUTPUT_DIR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, didn't know you could do this. But one nit is that it includes the license, which looks a little weird in the docs imo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can take it out of the recipes and just have people copy and paste from the tutorial? Less findable from Github tho.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed offline, this is fine as is