Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multinode support in torchtune #2301

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
bbd81fd
Remove last references to from training
Jan 27, 2025
c04ebaf
Deprecate and use new function
Jan 27, 2025
e02d39b
Expose
Jan 27, 2025
c558f27
Update API docs
Jan 27, 2025
454536c
Add tests
Jan 27, 2025
78bb2ae
Merge remote-tracking branch 'upstream/main' into multi-node-support
Jan 27, 2025
66b06e1
Lint
Jan 27, 2025
0d5aeb4
Add multinode recipe and sbatch script
Jan 27, 2025
afc9c2e
Update launch commands
Jan 27, 2025
c4748a5
Move env variables around
Jan 27, 2025
94440f9
Multi-node tutorial
Jan 27, 2025
deffeca
Updates
Jan 28, 2025
f441721
Update code block
Jan 28, 2025
9ba9e24
asdf
Jan 28, 2025
b36325a
Fix linting errors
Jan 28, 2025
fc9afbd
Updates
Jan 29, 2025
373e0c0
Lint
Jan 29, 2025
4659938
Pass test
Jan 29, 2025
693b8cb
Updates to tutorial
Jan 29, 2025
3d8d73d
Remove full_finetune_multinode from recipes registry
Jan 29, 2025
c0345a5
Lint
Jan 29, 2025
a3aaeb4
Last link
Jan 29, 2025
8e20394
Merge remote-tracking branch 'upstream/main' into multi-node-support
Jan 29, 2025
427a290
Merge remote-tracking branch 'upstream/main' into multi-node-support
Jan 30, 2025
b56b6be
Evan updates
Jan 31, 2025
63205da
Merge remote-tracking branch 'upstream/main' into multi-node-support
Jan 31, 2025
76ea872
Merge remote-tracking branch 'upstream/main' into multi-node-support
Jan 31, 2025
63eb274
Update comment
Jan 31, 2025
4d027b0
Move process initialization
joecummings Jan 31, 2025
34aa18b
Move init process group to above checkpoint instantiation
joecummings Feb 1, 2025
30b7366
Update intro
joecummings Feb 1, 2025
c7fdc21
Docs r dumb
joecummings Feb 1, 2025
900d643
Wow
joecummings Feb 1, 2025
9e230ca
Rework intro
joecummings Feb 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/api_ref_training.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ Utilities for enabling and working with distributed training.
init_distributed
is_distributed
gather_cpu_state_dict
get_distributed_backend

.. _ac_label:

Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ torchtune tutorials.
tutorials/e2e_flow
tutorials/llama_kd_tutorial
tutorials/memory_optimizations
tutorials/multinode

.. toctree::
:glob:
Expand Down
2 changes: 2 additions & 0 deletions docs/source/tutorials/e2e_flow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ For Llama models, you can run generation directly in torchao on the quantized mo
discussed in `this readme <https://github.com/pytorch/ao/tree/main/torchao/_models/llama>`_. This way you can compare your own results
to those in the previously-linked table.

.. _use_model_in_wild:

Use your model in the wild
--------------------------

Expand Down
100 changes: 100 additions & 0 deletions docs/source/tutorials/multinode.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
.. _multinode_tutorial:

=====================
Multi-node finetuning
=====================

Congratulations! After years of being "GPU poor", you've worked hard, saved your hard earned Bitcoin and graduated to the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry but discussions of crypto are banned on our docs

so-called **"GPU middle class"**. In many ways, your worries of yesteryear are gone (memory efficient training, who??).
But, new problems are on the horizon for you because multi-node is a whole new beast. Come with me as I take you
through your new life, complete with a big backyard, new car, and of course - a nice rack of H100s.

.. grid:: 2

.. grid-item-card:: :octicon:`mortar-board;1em;` You will learn:

* Why multi-node training is useful
* How to set up the torchtune package on a SLURM cluster
* How to fine-tune a Llama3.3 70B model w/ full parameter updates (not LoRA)

.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites

* Be familiar with distributed training in torchtune
* Already know basic SLURM commands

.. _advantages_multi_node_label:

Advantages of multi-node training
---------------------------------

More machines means more memory! This is cool for several reasons:

1. **Bigger models**: With more memory, you can train larger models such as `Llama3.1 405B <https://ai.meta.com/blog/meta-llama-3-1/>`_, `Deepseek-V3 <https://www.deepseek.com/>`_, and more.
2. **Longer data**: More many tasks like writing code, it's helpful to have long context lengths; however longer context length means more memory needed for activations.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separately would be a little bit careful about how we frame this. Like we don't actually have context parallel yet so don't wanna imply that people can continually scale context length with # of nodes.

Suggested change
2. **Longer data**: More many tasks like writing code, it's helpful to have long context lengths; however longer context length means more memory needed for activations.
2. **Longer data**: For many tasks like writing code, it's helpful to have long context lengths; however longer context length means more memory needed for activations.

3. **Higher quality**: With more memory, you can do full parameter updates (not LoRA) and use optimizers like `AdamW <https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html>`_ (not low-precision optimizers), both of which can potentially improve the quality of your training.
4. **Faster training**: With the ability to fit more data in memory, you can use higher batch sizes *and* turn off memory optimizations like :ref:`activation checkpointing<glossary_act_ckpt>` thereby decreasing the time it takes for training to complete.

.. note::

**Low inter-node bandwidth & FSDP** We utilize `Fully Sharded Data Parallel <https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/>`_ to distribute models over multiple devices. In order to distribute training, FSDP runs an `all-gather <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#allgather>`_ operation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I would not point to this FSDP blog post as pretty much all the APIs given there are moot for torchtune's purposes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair.

for each forward pass and an all-gather plus a `scatter-reduce <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#reducescatter>`_ operation for each backwards pass. These operations (usually) block training from continuing until completed and with a slow
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for each forward pass and an all-gather plus a `scatter-reduce <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#reducescatter>`_ operation for each backwards pass. These operations (usually) block training from continuing until completed and with a slow
for each forward pass and an all-gather plus a `reduce-scatter <https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html#reducescatter>`_ operation for each backwards pass. These operations (usually) block training from continuing until completed and with a slow

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've heard it both ways.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ya personally I am just taking NCCL docs as my source of truth here

inter-node connection, training speed may be reduced. For more on this, please refer to `this Github Issue <https://github.com/pytorch/pytorch/issues/102434>`_.

Training Llama3.3 70B on 2 nodes
--------------------------------

Let's get training! We'll be utilizing a common cluster workflow manager called `SLURM <https://slurm.schedmd.com/documentation.html>`_ and assume you have a decent working knowledge of SLURM for this tutorial.
First, we need to install torchtune. Although pretty much as straightforward as the :ref:`normal install instructions<install_label>`,
it's recommended that you install the package into a virtual environment that is accessible from all nodes in your cluster like a shared filesystem.

Next, we need to download the `Llama3.3 70B <https://huggingface.co/meta-llama/Llama-3.3-70B-Instruct>`_ model to your shared filesystem. You'll need to make sure you have the correct credentials following the steps
outlined :ref:`here<tune_download_label>`.

.. code-block:: bash

$ tune download meta-llama/Llama-3.3-70B-Instruct --ignore-patterns "consolidated/*.pth" --output-dir SHARED_FS/Llama-3.3-70B-Instruct

Now that we have a downloaded model, let's check out our example SLURM bash script.

.. literalinclude:: ../../../recipes/full_finetune_multinode.slurm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, didn't know you could do this. But one nit is that it includes the license, which looks a little weird in the docs imo

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can take it out of the recipes and just have people copy and paste from the tutorial? Less findable from Github tho.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline, this is fine as is


**There's a lot of information in this script but here are the high-level parts:**

* We utilize SLURM specific commands like number of nodes, tasks, CPUs available, etc.
* We are using `torchrun <https://pytorch.org/docs/stable/elastic/run.html>`_ and the `full_finetune_distributed <https://github.com/pytorch/torchtune/blob/main/recipes/full_finetune_distributed.py>`_ recipe to train just like on single node
* You should consider several cluster-specific environment variables to maximize GPU utilization
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

V vague


.. note::

We may need to explicitly set the network interface for distributed backends. You can read more about `PyTorch distributed backends here <https://pytorch.org/docs/stable/distributed.html#common-environment-variables>`_
but it's also helpful to know that you can find your network interface by running `ipconfig <https://en.wikipedia.org/wiki/Ipconfig#:~:text=ipconfig%20(standing%20for%20%22Internet%20Protocol,ipconfig>`_ from a specific node.

After we update the shared filesystem in the bash script, we can launch using `sbatch <https://slurm.schedmd.com/sbatch.html>`_.

.. code-block:: bash

sbatch full_finetune_multinode.slurm

And the output of `squeue <https://slurm.schedmd.com/squeue.html>`_ should show our job running:

.. code-block:: bash

$ squeue
JOBID PARTITION NAME USER ST TIME NODES NODELIST(REASON)
1 train torchtun slurm R 0:03 2 slurm-worker-[1-2]

Once training has completed, we can follow the :ref:`instructions here<use_model_in_wild>` in order to upload our beautiful new model to the Hugging Face Hub!
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth going into where the logs can be found, how to monitor status, or any of that? (It's OK to say no, but I feel like these are common points of confusion for people)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to stay far far away from a tutorial on how to use SLURM.


Future development
------------------

We've covered the basics of how to launch a fine-tuning job with SLURM on two nodes with FSDP. There's still more things we're cooking up,
including...

**2D parallelism**: Utilizing both FSDP *and* tensor parallelism in what is commonly referred to as `2D parallelism <https://pytorch.org/tutorials/intermediate/TP_tutorial.html>`_ will decrease memory requirements even further, allowing us to lean even harder
into the advantages listed :ref:`above<advantages_multi_node_label>`.

**Longer context (ring attention, etc)**: More memory and more machines means we can train on longer sequences and tag advantage of neat tricks like ring attention, where tokens are split across
GPUs. You can read more about our plans for torchtune in `this Github RFC <https://github.com/pytorch/torchtune/issues/1244>`_.

**Want other optimizations?** Feel free to let us know by `opening up a Github Issue <https://github.com/pytorch/torchtune/issues/new?q=sort%3Aupdated-desc+is%3Aissue+is%3Aopen&template=Blank+issue>`_ on our repo or `dropping us a line in Discord <https://discord.gg/Zsf8xgT7>`_!
104 changes: 104 additions & 0 deletions recipes/configs/llama3_3/70B_full_multinode.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Config for multi-node full finetuning in full_finetune_distributed.py
# using a Llama3.3 70B Instruct model
#
# This config assumes that you've run the following command before launching:
# tune download meta-llama/Llama-3.3-70B-Instruct --ignore-patterns "original/consolidated*" --output-dir SHARED_CLUSTER_FS
#
# To launch on 2 nodes w/ 8 devices on a SLURM cluster, run the following command:
# sbatch full_finetune_multinode.slurm
Copy link
Contributor

@acisseJZhong acisseJZhong Jan 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a way to make the full_finetune_multinode.slurm takes in an argument to specify which config/model to run, instead of creating a new config for mutlinode?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arghh, this would be a good idea. I'm leaning towards just trying to get this up there as an example of how to run since you'll really need to modify the SLURM file itself in order to set the correct number of nodes, etc.

Open to thoughts though. cc @ebsmothers @pbontrager

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree with @acisseJZhong's suggestion. Also I think the concept of recipes + configs breaks down a bit here. I think we should either very explicitly say "this is just a demo and is not a real recipe" (i.e. we don't even list it in recipes), or we should properly integrate with tune run -- i.e. if one specifies tune run --nnodes {>1} ... we dispatch to a generic slurm script on the backend (this is just one UX.. could also require explicit --slurm arg or something like that)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bleh

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make them copy it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay will make them copy it for now and not add to recipe registry, but I will keep the script there.

#
# This config is only tested on 2 nodes w/ 8 H100 machines.

output_dir: /tmp/torchtune/llama3_3_70B/full

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Llama-3.3-70B-Instruct/original/tokenizer.model
max_seq_len: 1024

# Dataset
dataset:
_component_: torchtune.datasets.alpaca_dataset
packed: True # True increases speed
seed: null
shuffle: True

# Model Arguments
model:
_component_: torchtune.models.llama3_3.llama3_3_70b

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-3.3-70B-Instruct/
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: "00030"
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can u go bigger though

epochs: 1

optimizer:
_component_: torch.optim.AdamW
lr: 2e-5
# Note: highly recommended to use fused=True optimizer flag
# with CPU offload for faster optimizer step.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can be removed, right? Should only be relevant for CPU offload

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah true.

fused: True

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1 # Use to increase effective batch size


# Training env
device: cuda

# Memory management
enable_activation_checkpointing: True # True reduces memory
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we disable on two nodes?

enable_activation_offloading: False # True reduces memory
custom_sharded_layers: ['tok_embeddings', 'output'] # Layers to shard separately (useful for large vocab size models). Lower Memory, but lower speed.
fsdp_cpu_offload: False
clip_grad_norm: null
compile: True # torch.compile the model + loss, True increases speed + decreases memory
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1

# Reduced precision
dtype: bf16

# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True

# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True

#trace options passed to `torch.profiler.profile`
profile_memory: False
with_stack: False
record_shapes: True
with_flops: False

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 3
active_steps: 2
num_cycles: 1
38 changes: 20 additions & 18 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,20 +118,21 @@ class FullFinetuneRecipeDistributed(FTRecipeInterface):
"""

def __init__(self, cfg: DictConfig) -> None:
self._device = utils.get_device(device=cfg.device)
device_type = cfg.device
self._device = utils.get_device(device=device_type)
self._dtype = training.get_dtype(cfg.dtype, device=self._device)

if self._dtype == torch.float16:
raise ValueError(
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
)

# logging attributes
# Logging attributes
self._output_dir = cfg.output_dir
self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False)

if self._log_peak_memory_stats and self._device.type != "cuda":
if self._log_peak_memory_stats and device_type != "cuda":
log.info(
"log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False."
)
Expand All @@ -147,6 +148,12 @@ def __init__(self, cfg: DictConfig) -> None:
self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False)
self._clip_grad_norm = cfg.get("clip_grad_norm", None)
self._checkpoint_client = CheckpointClient(cfg)
self.fsdp_cpu_offload = cfg.get("fsdp_cpu_offload", False)
self.distributed_backend = training.get_distributed_backend(
device_type,
offload_ops_to_cpu=self.fsdp_cpu_offload
or self._enable_async_checkpointing,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you run with async checkpointing? Pretty interested to know how much time it saves on multiple nodes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

K np for now

)

# Optimizer in backward is not compatible with gradient accumulation or gradient clipping
if self._optimizer_in_bwd:
Expand All @@ -169,7 +176,7 @@ def __init__(self, cfg: DictConfig) -> None:
"enable_activation_offloading", False
)
if self._enable_activation_offloading:
if self._device.type != "cuda":
if device_type != "cuda":
raise RuntimeError(
"enable_activation_offloading should only be True when training on CUDA"
)
Expand Down Expand Up @@ -240,9 +247,16 @@ def setup(self, cfg: DictConfig) -> None:
Setup the recipe. This includes training state (if resume_from_checkpoint is True),
model, tokenizer, loss, optimizer, lr scheduler, sampler, and dataloader.
"""
# Set up the backend for distributed training (NCCL, GLOO, etc.)
init_process_group(self.distributed_backend)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious why do we want to move this block from recipe_main to setup?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my mind, this is doing actual setup. Therefore it should belong with the rest of the setup code, not buried at the bottom of the recipe where it's hard to find.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also update generate_v2_distributed recipe?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Follow-up for all distributed recipes :)


if self.fsdp_cpu_offload:
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x
# speed up when benchmarking fused AdamW on CPU
training.set_torch_num_threads()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we always want to set this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point. Looks like this was added by Rohan, so not sure who to follow up with here. Let me dig into it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is a heuristic for fused Adam on CPU when CPU offload is enabled. I don't think it's optimal, but I do think that without it CPU offload training may be much slower

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it be set for async offload too? Or pure CPU training?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Afaik it shouldn't matter for async offload and mostly has to do with fused optimizer. For pure CPU training I guess the optimizer step also happens on CPU so in that case we would potentially want it


if self._is_rank_zero:
self._metric_logger = config.instantiate(cfg.metric_logger)

# log config with parameter override
self._metric_logger.log_config(cfg)

Expand All @@ -255,7 +269,7 @@ def setup(self, cfg: DictConfig) -> None:
enable_activation_checkpointing=self._enable_activation_checkpointing,
enable_activation_offloading=self._enable_activation_offloading,
custom_sharded_layers=cfg.get("custom_sharded_layers", None),
fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False),
fsdp_cpu_offload=self.fsdp_cpu_offload,
reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),
model_state_dict=checkpoint_dict[training.MODEL_KEY],
ac_mode=cfg.get("ac_mode", None),
Expand Down Expand Up @@ -890,19 +904,7 @@ def recipe_main(cfg: DictConfig) -> None:
- Parameters specified in config (see available configs through ``tune ls``)
- Overwritten by arguments from the command-line
"""
if not training.is_distributed():
raise RuntimeError(
"Distributed finetune recipe should be run via a distributed launcher."
"If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]"
)
init_process_group("cuda:nccl,cpu:gloo")
if cfg.get("fsdp_cpu_offload", False):
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x
# speed up when benchmarking fused AdamW on CPU
training.set_torch_num_threads()

config.log_config(recipe_name="FullFinetuneRecipeDistributed", cfg=cfg)

recipe = FullFinetuneRecipeDistributed(cfg=cfg)
recipe.setup(cfg=cfg)
recipe.train()
Expand Down
44 changes: 44 additions & 0 deletions recipes/full_finetune_multinode.slurm
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# ---------- SBATCH commands ---------- #
#SBATCH --job-name=torchtune-multi-node
#SBATCH --ntasks=2
#SBATCH --nodes=2
#SBATCH --gpus-per-task=8
#SBATCH --cpus-per-task=96
#SBATCH --partition=train

# ---------- Set env variables ---------- #
# Grab the IP for head node:
nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
nodes_array=($nodes)
head_node=${nodes_array[0]}
head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
echo Node IP: $head_node_ip

# You might need to explicitly set the network interface for distributed backends:
# export NCCL_SOCKET_IFNAME=...
# export GLOO_SOCKET_IFNAME=...

export TORCH_DIST_INIT_BARRIER=1
export LOGLEVEL=INFO

# ---------- Launch training ---------- #
# You probably want to load in a virtual env w/ conda...
# module load conda
# conda activate torchtune
# ...or venv
# source torchtune/bin/activate
Comment on lines +33 to +37
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this commented out? Is it because we don't know the user's venv/conda env? I remember wasting a bunch of time myself on this kinda stuff before, might be worth explicitly calling it out in the tutorial

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, we can't make any assumptions about how they initialize their virtual env

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sidenote: uv has eliminated all venv pains for me


SHARED_FS=/mnt/slurm # <-- Replace w/ your filesystem
CHECKPOINT_DIR="$SHARED_FS/Llama-3.3-70B-Instruct"
OUTPUT_DIR="$SHARED_FS/Llama3.3-70B-fft-output"

# Adjust sbatch --ntasks and sbatch --nodes above and --nnodes below to your specific node count
srun tune run --nnodes 2 --nproc_per_node 8 --rdzv_id 101 --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:29500" \
full_finetune_distributed --config llama3_3/70B_full_multinode checkpoint_dir=$CHECKPOINT_DIR output_dir=$OUTPUT_DIR
2 changes: 1 addition & 1 deletion recipes/lora_finetune_distributed_multi_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def __init__(self, cfg: DictConfig) -> None:
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
)

_, rank = training.get_world_size_and_rank()
_, rank = utils.get_world_size_and_rank()

self._is_rank_zero = rank == 0

Expand Down
Loading
Loading