Releases: Lightning-AI/pytorch-lightning
Patch release v2.3.1
Includes minor bugfixes and stability improvements.
Full Changelog: 2.3.0...2.3.1
Lightning v2.3: Tensor Parallelism and 2D Parallelism
Lightning AI is excited to announce the release of Lightning 2.3 ⚡
Did you know? The Lightning philosophy extends beyond a boilerplate-free deep learning framework: We've been hard at work bringing you Lightning Studio. Code together, prototype, train, deploy, host AI web apps. All from your browser, with zero setup.
This release introduces experimental support for Tensor Parallelism and 2D Parallelism, PyTorch 2.3 support, and several bugfixes and stability improvements.
Highlights
Tensor Parallelism (beta)
Tensor parallelism (TP) is a technique that splits up the computation of selected layers across GPUs to save memory and speed up distributed models. To enable TP as well as other forms of parallelism, we introduce a ModelParallelStrategy
for both Lightning Trainer and Fabric. Under the hood, TP is enabled through new experimental PyTorch APIs like DTensor and torch.distributed.tensor.parallel
.
PyTorch Lightning
Enabling TP in a model with PyTorch Lightning requires you to implement the LightningModule.configure_model()
method where you convert selected layers of a model to paralellized layers. This is an advanced feature, because it requires a deep understanding of the model architecture. Open the tutorial Studio to learn the basics of Tensor Parallelism.
import lightning as L
from lightning.pytorch.strategies import ModelParallelStrategy
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module
# 1. Implement the `configure_model()` method in LightningModule
class LitModel(L.LightningModule):
def __init__(self):
super().__init__()
self.model = FeedForward(8192, 8192)
def configure_model(self):
# Lightning will set up a `self.device_mesh` for you
tp_mesh = self.device_mesh["tensor_parallel"]
# Use PyTorch's distributed tensor APIs to parallelize the model
plan = {
"w1": ColwiseParallel(),
"w2": RowwiseParallel(),
"w3": ColwiseParallel(),
}
parallelize_module(self.model, tp_mesh, plan)
def training_step(self, batch):
...
# 2. Create the strategy
strategy = ModelParallelStrategy()
# 3. Configure devices and set the strategy in Trainer
trainer = L.Trainer(accelerator="cuda", devices=2, strategy=strategy)
trainer.fit(...)
Full training example (requires at least 2 GPUs).
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module
import lightning as L
from lightning.pytorch.demos.boring_classes import RandomDataset
from lightning.pytorch.strategies import ModelParallelStrategy
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class LitModel(L.LightningModule):
def __init__(self):
super().__init__()
self.model = FeedForward(8192, 8192)
def configure_model(self):
if self.device_mesh is None:
return
# Lightning will set up a `self.device_mesh` for you
tp_mesh = self.device_mesh["tensor_parallel"]
# Use PyTorch's distributed tensor APIs to parallelize the model
plan = {
"w1": ColwiseParallel(),
"w2": RowwiseParallel(),
"w3": ColwiseParallel(),
}
parallelize_module(self.model, tp_mesh, plan)
def training_step(self, batch):
output = self.model(batch)
loss = output.sum()
return loss
def configure_optimizers(self):
return torch.optim.AdamW(self.model.parameters(), lr=3e-3)
def train_dataloader(self):
# Trainer configures the sampler automatically for you such that
# all batches in a tensor-parallel group are identical
dataset = RandomDataset(8192, 64)
return torch.utils.data.DataLoader(dataset, batch_size=8, num_workers=2)
strategy = ModelParallelStrategy()
trainer = L.Trainer(
accelerator="cuda",
devices=2,
strategy=strategy,
max_epochs=1,
)
model = LitModel()
trainer.fit(model)
trainer.print(f"Peak memory usage: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
Lightning Fabric
Applying TP in a model with Fabric requires you to implement a special function where you convert selected layers of a model to paralellized layers. This is an advanced feature, because it requires a deep understanding of the model architecture. Open the tutorial Studio to learn the basics of Tensor Parallelism.
import lightning as L
from lightning.fabric.strategies import ModelParallelStrategy
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module
# 1. Implement the parallelization function for your model
def parallelize_feedforward(model, device_mesh):
# Lightning will set up a device mesh for you
tp_mesh = device_mesh["tensor_parallel"]
# Use PyTorch's distributed tensor APIs to parallelize the model
plan = {
"w1": ColwiseParallel(),
"w2": RowwiseParallel(),
"w3": ColwiseParallel(),
}
parallelize_module(model, tp_mesh, plan)
return model
# 2. Pass the parallelization function to the strategy
strategy = ModelParallelStrategy(parallelize_fn=parallelize_feedforward)
# 3. Configure devices and set the strategy in Fabric
fabric = L.Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()
Full training example (requires at least 2 GPUs).
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel
from torch.distributed.tensor.parallel import parallelize_module
import lightning as L
from lightning.pytorch.demos.boring_classes import RandomDataset
from lightning.fabric.strategies import ModelParallelStrategy
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
def parallelize_feedforward(model, device_mesh):
# Lightning will set up a device mesh for you
tp_mesh = device_mesh["tensor_parallel"]
# Use PyTorch's distributed tensor APIs to parallelize the model
plan = {
"w1": ColwiseParallel(),
"w2": RowwiseParallel(),
"w3": ColwiseParallel(),
}
parallelize_module(model, tp_mesh, plan)
return model
strategy = ModelParallelStrategy(parallelize_fn=parallelize_feedforward)
fabric = L.Fabric(accelerator="cuda", devices=2, strategy=strategy)
fabric.launch()
# Initialize the model
model = FeedForward(8192, 8192)
model = fabric.setup(model)
# Define the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-3)
optimizer = fabric.setup_optimizers(optimizer)
# Define dataset/dataloader
dataset = RandomDataset(8192, 64)
dataloader = torch.utils.data.DataLoader(dataset, batch_si...
Patch release v2.2.5
PyTorch Lightning + Fabric
Fixed
- Fixed a matrix shape mismatch issue when running a model loaded from a quantized checkpoint (bitsandbytes) (#19886)
Full Changelog: 2.2.4...2.2.5
Patch release v2.2.4
App
Fixed
- Fixed HTTPClient retry for flow/work queue (#19837)
PyTorch
No Changes.
Fabric
No Changes.
Full Changelog: 2.2.3...2.2.4
Patch release v2.2.3
PyTorch
Fixed
- Fixed
WandbLogger.log_hyperparameters()
raising an error if hyperparameters are not JSON serializable (#19769)
Fabric
No Changes.
Full Changelog: 2.2.2...2.2.3
Patch release v2.2.2
PyTorch
Fixed
- Fixed an issue causing a TypeError when using
torch.compile
as a decorator (#19627) - Fixed a KeyError when saving a FSDP sharded checkpoint and setting
save_weights_only=True
(#19524)
Fabric
Fixed
- Fixed an issue causing a TypeError when using
torch.compile
as a decorator (#19627) - Fixed issue where some model methods couldn't be monkeypatched after being Fabric wrapped (#19705)
- Fixed an issue causing weights to be reset in
Fabric.setup()
when using FSDP (#19755)
Full Changelog: 2.2.1...2.2.2
Contributors
@ankitgola005 @awaelchli @Borda @carmocca @dmitsf @dvoytan-spark @fnhirwa
Patch release v2.2.1
PyTorch
Fixed
- Fixed an issue with CSVLogger trying to append to file from a previous run when the version is set manually (#19446)
- Fixed the divisibility check for
Trainer.accumulate_grad_batches
andTrainer.log_every_n_steps
in ThroughputMonitor (#19470) - Fixed support for Remote Stop and Remote Abort with NeptuneLogger (#19130)
- Fixed infinite recursion error in precision plugin graveyard (#19542)
Fabric
Fixed
- Fixed an issue with CSVLogger trying to append to file from a previous run when the version is set manually (#19446)
Full Changelog: 2.2.0post...2.2.1
Contributors
@Raalsky @awaelchli @carmocca @Borda
If we forgot someone due to not matching commit email with GitHub account, let us know :]
Minor release correction
Full Changelog: 2.2.0...2.2.0.post0
Lightning v2.2
Lightning AI is excited to announce the release of Lightning 2.2 ⚡
Did you know? The Lightning philosophy extends beyond a boilerplate-free deep learning framework: We've been hard at work bringing you Lightning Studio. Code together, prototype, train, deploy, host AI web apps. All from your browser, with zero setup.
While our previous release was packed with many big new features, this time around we're rolling out mainly improvements based on feedback from the community. And of course, as the name implies, this release fully supports the latest PyTorch 2.2 🎉
Highlights
Monitoring Throughput
Lightning now has built-in utilities to measure throughput metrics such as batches/sec, samples/sec and Model FLOP Utilization (MFU) (#18848).
Trainer:
For the Trainer, this comes in form of a ThroughputMonitor
callback. In order to track samples/sec, you need to provide a function to tell the monitor how to extract the batch dimension from your input. Furthermore, if you want to track MFU, you can provide a sample forward pass and the ThroughputMonitor
will automatically estimate the utilization based on the hardware you are running on:
import lightning as L
from lightning.pytorch.callbacks import ThroughputMonitor
from lightning.fabric.utilities.throughput import measure_flops
class MyModel(LightningModule):
def setup(self, stage):
with torch.device("meta"):
model = MyModel()
def sample_forward():
batch = torch.randn(..., device="meta")
return model(batch)
self.flops_per_batch = measure_flops(model, sample_forward, loss_fn=torch.Tensor.sum)
throughput = ThroughputMonitor(
batch_size_fn=lambda batch: batch.size(0),
# optional, if your samples have a length (like number of tokens)
sample_fn=lambda batch: batch.size(1)
)
trainer = L.Trainer(log_every_n_steps=10, callbacks=throughput, logger=...)
model = MyModel()
trainer.fit(model)
The results get automatically sent to the logger if one is configured on the Trainer.
Fabric:
For Fabric, the ThroughputMonitor
is a simple utility object on which you call .update()
and compute_and_log()
during the training loop:
import lightning as L
from lightning.fabric.utilities import ThroughputMonitor
fabric = L.Fabric(logger=...)
throughput = ThroughputMonitor(fabric)
t0 = time()
for batch_idx, batch in enumerate(train_dataloader):
do_work()
torch.cuda.synchronize() # required or else time() won't be correct
throughput.update(
time=(time() - t0),
batches=batch_idx,
samples=(batch_idx * batch_size)
)
if batch_idx % 10 == 0:
throughput.compute_and_log(step=batch_idx)
Check out our TinyLlama LLM pretraining script for a full example using Fabric's ThroughputMonitor
.
The troughput utilities can report:
- batches per second (per process and across process)
- samples per second (per process and across process)
- items per second (e.g. tokens) (per process and across process)
- flops per second (per process and across process)
- model flops utilization (MFU) (per process)
- total time, total samples, total batches, and total items (per process)
Improved Handling of Evaluation Mode
When you train a model and have validation enabled, the Trainer automatically calls .eval()
when transitioning to the validation loop, and .train()
when validation ends. Until now, this had the unfortunate side effect that any submodules in your LightningModule that were in evaluation mode get reset to train mode. In Lightning 2.2, the Trainer now captures the mode of every submodule before switching to validation, and restores the mode the modules were in when validation ends (#18951, #18951, #18951). This improvement will help users avoid silent correctness bugs and removes boilerplate code for managing frozen layers.
import lightning as L
class LitModel(L.LightningModule):
def __init__(self):
super().__init__()
self.trainable_module = ...
# This will now stay in eval mode
self.frozen_module = ...
self.frozen_module.eval()
def training_step(self, batch):
# Previously, modules were all in train mode
# Now: Modules are in mode they were set up with
assert self.trainable_module.training
assert not self.frozen_module.training
...
def validation_step(self, batch):
# All modules are in eval mode
...
model = LitModel()
trainer = L.Trainer()
trainer.fit(model)
If you have overridden any of the LightningModule.on_{validation,test,predict}_model_{eval,train}
hooks, they will still get called and execute your custom logic, but they are no longer required if you added them to preserve the eval mode of frozen modules.
Important
In some libraries, for example HuggingFace, models are created in evaluation mode by default (e.g. HFModel.from_pretrained(...)
). Starting from 2.2, you will have to set .train()
on these models if you intend to train them.
Converting FSDP Checkpoints
In the previous release, we introduced distributed checkpointing with FSDP to speed up saving and loading checkpoints for big models. These checkpoints are in a special format saved in a folder with shards from each GPU in a separate file. While these checkpoints can be loaded back with Lightning Trainer or Fabric very easily, they aren't easy to load or process externally. In Lightning 2.2, we introduced a CLI utility that lets you consolidate the checkpoint folder to a single file that can be loaded in raw PyTorch with torch.load()
for example (#19213).
Given you saved a distributed checkpoint, you can then convert it like so:
# For Trainer checkpoints:
python -m lightning.pytorch.utilities.consolidate_checkpoint path/to/my/checkpoint
# For Fabric checkpoints:
python -m lightning.fabric.utilities.consolidate_checkpoint path/to/my/checkpoint
Read more about distributed checkpointing in our documentation: Trainer, Fabric.
Improvements to Compiling DDP/FSDP in Fabric
PyTorch 2.0+ introduced torch.compile
, a powerful tool to speed up your models without changing the code.
We now added a comprehensive guide how to use torch.compile
correctly with tips and tricks to help you troubleshoot common issues. On top of that, Fabric.setup()
will now reapply torch.compile
on top of DDP/FSDP if you are enabling these strategies (#19280).
import lightning as L
# Select a distributed strategy (DDP, FSDP, ...)
fabric = L.Fabric(strategy="ddp", devices=8)
# Compile your model before `.setup()`
model = torch.compile(model)
# Now automatically handles compiling also over DDP/FSDP
model = fabric.setup(model)
# You can opt-out if it is causing trouble
model = fabric.setup(model, _reapply_compile=False)
You might see fewer graph breaks, but there won't be any significant speed-ups with this. We introduced this mainly to make Fabric ready for future improvements from PyTorch to optimizing distributed operations.
Saving and Loading DataLoader State
If you use a dataloader/iterable that implements the .state_dict()
and .load_state_dict()
interface, the Trainer will now automatically save and load their state in the checkpoint (#19361).
import lightning as L
class MyDataLoa...
Lightning 2.2 Release Candidate
This is a preview release for Lightning 2.2.0.