Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OOMptimizer: bucketing batch size profiles to make GPUs go 🔥 #9763

Merged
merged 70 commits into from
Aug 16, 2024
Merged
Changes from 1 commit
Commits
Show all changes
70 commits
Select commit Hold shift + click to select a range
f71e27e
Initial working draft of the OOMptimizer.
pzelasko Jul 17, 2024
5995dbe
Support model config. Add bucket merging.
pzelasko Jul 18, 2024
561e674
fix
pzelasko Jul 18, 2024
5970c34
code review
pzelasko Jul 18, 2024
b4ab721
Support bucket_batch_size option for lhotse dataloading
pzelasko Jul 19, 2024
9e632e4
Ability to force a memory fraction to be unused in OOMptimizer
pzelasko Jul 19, 2024
4b009bd
Ability to force a memory fraction to be unused in OOMptimizer
pzelasko Jul 19, 2024
a386fa8
Fix for autocast and configurable dtype
pzelasko Jul 19, 2024
a4e2c66
Allow token-per-second filtering
pzelasko Jul 19, 2024
0cdc58d
Fix an issue with canary tokenizer
pzelasko Jul 22, 2024
9c3e625
Lift the requirement to use CanaryTokenizer with canary prompt format
pzelasko Jul 22, 2024
aaa05a5
Fixes
pzelasko Jul 23, 2024
e7556fb
Initial 2D bucketing draft
pzelasko Jul 23, 2024
8497a25
Separate script for 2D bucket estimation
pzelasko Jul 23, 2024
bc60b5f
Full 2D bucketing support: estimate_uduration_bins_2d, oomptimizer, t…
pzelasko Jul 23, 2024
10c2ada
fix
pzelasko Jul 23, 2024
bb0bc4f
fix
pzelasko Jul 23, 2024
5e442bf
fix
pzelasko Jul 23, 2024
97a800c
fix
pzelasko Jul 23, 2024
21588ba
fix
pzelasko Jul 23, 2024
77d2851
fix
pzelasko Jul 23, 2024
5b704a8
Unit tests for bucket_batch_size and 2D bucketing for audio
pzelasko Jul 23, 2024
1155135
Docs for 2D estimate duration bins
pzelasko Jul 23, 2024
2f43313
Fixes
pzelasko Jul 23, 2024
81420df
Preliminary support for prompt format in estimate_duration_bins_2d
pzelasko Jul 23, 2024
572f2be
fixes
pzelasko Jul 24, 2024
ade45ea
fix for bucket selection edge case
pzelasko Jul 24, 2024
8d607e1
Add more info about the distribution to estimate_duration_bins_2d.py
pzelasko Jul 24, 2024
7ffdd96
Include CUDA RAM usage tracking in OOMptimizer
pzelasko Jul 24, 2024
5644f43
Track batch_size, num frames/tokens, and their padding ratio for AED …
pzelasko Jul 24, 2024
3b532a9
OOMptimizer documentation
pzelasko Jul 25, 2024
616036f
Resolve TODOs and support any combination of (audio|text)->(audio|tex…
pzelasko Jul 25, 2024
968a00f
Add missing property decorator
pzelasko Jul 25, 2024
639df62
fixes
pzelasko Jul 25, 2024
ec4206f
Add docs about 2D bucketing with tokenizer and prompts
pzelasko Jul 25, 2024
d64a726
Fix bucket allocation logic for 2D bucketing
pzelasko Jul 25, 2024
0d2cbe5
Merge branch 'main' into oomptimizer
pzelasko Jul 26, 2024
764c3f1
Bump lhotse version
pzelasko Jul 26, 2024
14ed8be
fix...
pzelasko Jul 29, 2024
888c343
Merge branch 'main' into oomptimizer
pzelasko Jul 29, 2024
5c1e096
Reverse bucket iteration order; move oomptimizer_schema to AsrModel
pzelasko Aug 1, 2024
e3aa624
Merge branch 'main' into oomptimizer
pzelasko Aug 1, 2024
41beffd
Make OOMptimizer compatible with dataclass mini-batches
pzelasko Aug 1, 2024
4f6859e
Refine the schema
pzelasko Aug 1, 2024
b237f96
fixes after merging main
pzelasko Aug 1, 2024
c4a25ea
fix oomptimizer with pretrained models; verified canary, parakeet tdt…
pzelasko Aug 1, 2024
eeebf19
Disable concurrent bucketing to prevent spawning extra threads in tests
pzelasko Aug 1, 2024
731fda0
fix tests and make life more colorful
pzelasko Aug 2, 2024
5d5d9e1
formatting
pzelasko Aug 2, 2024
4a41e66
more reasonable starting batch size settings
pzelasko Aug 2, 2024
87d0ea7
Disable clearing of cuda memory cache
pzelasko Aug 2, 2024
f7198bc
Even more conservative profile by incorporating DDP overhead simulation
pzelasko Aug 5, 2024
44ef482
Merge branch 'main' into oomptimizer
pzelasko Aug 7, 2024
fc8a8c7
Bucket selection fix and an extended unit test
pzelasko Aug 7, 2024
2bd282c
Refactor registered_prompt_format_fn to enable prompt formatting befo…
pzelasko Aug 8, 2024
4546a1f
porting fix
pzelasko Aug 8, 2024
88f4d21
Fixes, move fast-path to prompted dataset
pzelasko Aug 8, 2024
3f892b9
Merge branch 'main' into oomptimizer
pzelasko Aug 9, 2024
81288e1
Changes from Daniel's review
pzelasko Aug 13, 2024
0d10556
Merge branch 'main' into oomptimizer
pzelasko Aug 14, 2024
dddc2ef
OOMptimizer tests + fixes for 1D bucketing case
pzelasko Aug 14, 2024
80cb49b
estimate duration bins tests
pzelasko Aug 14, 2024
7a1bf71
address Daniel's review
pzelasko Aug 14, 2024
cbd3da8
fix CPU unit test
pzelasko Aug 14, 2024
9bb2693
Merge branch 'main' into oomptimizer
pzelasko Aug 15, 2024
81b4d92
try to fix CI test
pzelasko Aug 15, 2024
10444f8
Merge branch 'oomptimizer' of https://github.com/nvidia/nemo into oom…
pzelasko Aug 15, 2024
02c88f5
Apply suggestions from code review
pzelasko Aug 15, 2024
2383c93
Merge remote-tracking branch 'origin/main' into oomptimizer
pzelasko Aug 15, 2024
6f066e0
Disable 2D bucketing test with prompt due to quoting issue
pzelasko Aug 15, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 283 additions & 0 deletions scripts/speech_recognition/oomptimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
#!/usr/bin/env python
import math
Fixed Show fixed Hide fixed
Fixed Show fixed Hide fixed

import click
import pytorch_lightning as pl
import torch
from lhotse import compute_num_samples

from nemo.collections.asr.models import ASRModel
from nemo.collections.asr.models.aed_multitask_models import EncDecMultiTaskModel
from nemo.core.neural_types import AudioSignal, LabelsType, LengthsType, NeuralType
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MaskType may also need to be supported at some point, as an alternative to LengthsType. I don't think it's a big deal, though.

from nemo.utils import logging


class ProfilingBatchGenerator:
"""
ProfilingBatchGenerator is used to generate artificial mini-batches for model training
and tracking the progress of batch size optimization.

The high-level usage API is the following::

>>> gen = ProfilingBatchGenerator(schema)
... finished = False
... while not finished:
... batch = gen(input_seq_len, output_seq_len)
... try:
... training_step(model, batch)
... oom = False
... except torch.cuda.OutOfMemoryError:
... oom = True
... finished = gen.advance(oom)
... solution = gen.max_batch_size # The solution of the search problem.
... gen.reset() # Can re-use for other sequence lengths now.


In order to generate mini-batches compatible with a given model, the generator:

* accepts a ``schema`` argument in its constructor, and

* accepts input/output sequence lengths in each call to generate a mini-batch.

``schema`` has the following structure::

>>> [{"type": NeuralType(...) | str, "seq_length": "input|output", "vocab_size": int}, {...}, ...]

Each item in ``schema`` specifies a NeMo NeuralType which needs to have a defined ``elements_type``.
The supported types are ``AudioSignal``, ``LengthsType`` and ``LabelsType``.
If "type" is not a NeuralType, we interpret that as a placeholder tensor that's not relevant but expect by the model.

In addition, ``"seq_length"`` key is used to determine whether we should apply input or output sequence length
to a given tensor, and "vocab_size" is required for ``LabelsType`` so that we can generate proper label values.

"""

def __init__(
self,
schema: list[dict] = None,
start_batch_size: int = 1024,
rel_gap_thresh: float = 0.1,
device: str = "cuda",
):
self.schema = schema
self.start_batch_size = start_batch_size
self.rel_gap_thresh = rel_gap_thresh
self.device = device
self.reset()

def __call__(self, input_seq_len: int, output_seq_len: int):
B = self._current
batch = []
for item in self.schema:
nt = item["type"]
if not isinstance(nt, NeuralType): # placeholder
tnsr = torch.tensor([])
elif isinstance(nt.elements_type, AudioSignal):
tnsr = torch.randn(B, input_seq_len, dtype=torch.float32, device=self.device)
elif isinstance(nt.elements_type, LengthsType):
seq_len = input_seq_len if item["seq_length"] == "input" else output_seq_len
tnsr = torch.ones(B, dtype=torch.long, device=self.device) * seq_len
elif isinstance(nt.elements_type, LabelsType):
tnsr = torch.randint(0, item["vocab_size"], size=(B, output_seq_len), device=self.device)
else:
raise RuntimeError("Unexpected item in oomptimizer schema: {item}")
batch.append(tnsr)
return tuple(batch)

@property
def max_batch_size(self) -> int | None:
if (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs a bit of doc for all the cases

self._max_ok is not None
and self._min_err is not None
and (self.current_rel_gap <= self.rel_gap_thresh or self._min_err - self._max_ok <= 1)
):
return self._max_ok
return None

@property
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does relative gap mean?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added doc to explain

def current_rel_gap(self) -> int | None:
if self._min_err is None or self._max_ok is None:
return None
return (self._min_err - self._max_ok) / self._min_err

def reset(self):
self._current = self.start_batch_size
self._max_ok = None # max batch size that works
self._min_err = None # min batch size that doesn't work

def advance(self, oom: bool) -> bool:
"""
Adjusts the current batch size based on the outcome.
Returns a bool indicating whether the calibration is complete.
"""
if self.max_batch_size is not None:
return True

if oom:
# Training step failed with OOM.
# Update the minimum known batch size that causes an error.
self._min_err = min(float("inf") if self._min_err is None else self._min_err, self._current)
# Training step failed on OOM
if self._max_ok is None:
# We haven't found a batch size that works yet, keep going 2x down.
self._current = round(self._current / 2)
else:
# Try the middle-point between the known extremes.
self._current = round((self._max_ok + self._min_err) / 2)
else:
# Training step successful.
# Update the maximum known batch size that works.
self._max_ok = max(-1 if self._max_ok is None else self._max_ok, self._current)
if self._min_err is None:
# We haven't found a batch size that causes an error yet, keep going 2x higher
self._current *= 2
else:
# Try the middle-point between the known extremes.
self._current = round((self._max_ok + self._min_err) / 2)

return False


class FloatList(click.Option):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this used for ? Might as well use hydra with a dataclass than click

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I went with click out of an old habit. This auto-parses bucket duration bins [1,2,3,4] to list of floats.

name = "list[float]"

def type_cast_value(self, ctx, value):
if isinstance(value, list) and all(isinstance(v, float) for v in value):
return value
try:
import ast

return ast.literal_eval(value)
except ValueError:
raise click.BadParameter(value)


@click.command(context_settings={'show_default': True})
@click.option("-m", "--model-name", type=str, required=True, help="Name of model to use, e.g. 'nvidia/canary-1b'.")
@click.option("-o", "--optimizer-name", type=str, default="adamw", help="Name of optimizer to use.")
@click.option(
"-b",
"--buckets",
cls=FloatList,
default=[5.0, 10.0, 15.0, 20.0, 25.0, 30.0],
help="List of upper-bound bucket bins (i.e. first bucket is [0.0 - item0), second bucket is [item0 - item1), etc.)",
)
@click.option(
"-t",
"--threshold",
type=float,
default=0.1,
help="Search stopping criterion in range [0, 1], lower is more precise. Interpret as the uncerainty gap, i.e. (min_oom_batch_size - max_ok_batch_size) / min_oom_batch_size.",
)
@click.option("-s", "--start-batch-size", type=int, default=1024, help="Initial batch size to start the search from.")
@click.option(
"-l",
"--labels-per-second",
type=int,
default=10,
help="How many labels/second should we simulate. More means longer output text sequences, and can increase memory consumption.",
)
def oomptimizer(
model_name: str,
optimizer_name: str,
buckets: list[float],
threshold: float,
start_batch_size: int,
labels_per_second: int,
):
"""
OOMptimizer finds the optimal batch sizes for training your model with bucketing dataloading.

Dynamic bucketing is notoriously difficult to tune as you risk running into CUDA OOM many steps into the training.
In order to simplify finding the optimal settings, OOMptimizer scans each bucket to find the maximum possible
batch size that doesn't trigger a CUDA OOM.

\b
The suggested workflow is the following:
1) Run scripts/speech_recognition/estimate_duration_bins.py to get the duration distribution of your data.
2) Run OOMptimizer to find the optimal batch sizes for your specific model, optimizer, and GPU.
3) Use these optimal settings in your actual training script and enjoy optimal GPU utilization OOM-free.
"""
logging.setLevel(logging.CRITICAL)
device = "cuda"

print("Intializing ASR model.")
# TODO(pzelasko): This currently only supports "from_pretrained".
# We need to be able to read a model training configuration and instantiate the model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use restore_from(..., return_config=True)...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ended up going with --module-name and --config-path like discussed offline. It works well.

# and figure out all the necessary details from that.
trainer = pl.Trainer(barebones=True)
trainer.log_every_n_steps = 1000000
model = ASRModel.from_pretrained(model_name, trainer=trainer).to(device)

# TODO(pzelasko): ideally move into model @property e.g. "oomptimizer_schema" :D
schema = [
{"type": NeuralType(("B", "T"), AudioSignal()), "seq_length": "input"},
{"type": NeuralType(("B",), LengthsType()), "seq_length": "input"},
{
"type": NeuralType(("B", "T"), LabelsType()),
"seq_length": "output",
"vocab_size": model.tokenizer.vocab_size,
},
{"type": NeuralType(("B",), LengthsType()), "seq_length": "output"},
]
if isinstance(model, EncDecMultiTaskModel):
schema.extend(
[{"type": "dummy"}, {"type": "dummy"}]
) # multi-task has 2 extra tensors not needed for batch size tuning

print("Setting up the optimizers.")
optimizer, _ = model.setup_optimization({"name": optimizer_name, "lr": 1e-7, "weight_decay": 0.0})

def get_max_seq_lens(buckets):
# TODO(pzelasko): support text data inputs.
return [
(
compute_num_samples(d, sampling_rate=16000), # num_samples
math.ceil(labels_per_second * d), # num_labels; might need to go data-driven for optimal tuning
)
for d in buckets
]

print("Starting profiling.")
max_seq_lens = get_max_seq_lens(buckets)
gen = ProfilingBatchGenerator(schema=schema, start_batch_size=start_batch_size, rel_gap_thresh=threshold)
profile = {}

for seq_len_in, seq_len_out in max_seq_lens:
print(f"The current sequence lengths are: input={seq_len_in} output={seq_len_out}.")
gen.reset()
batch_idx = 0

def step():
batch = gen(seq_len_in, seq_len_out)
oom = False
try:
print(f"Current gap: {gen.current_rel_gap}. Attempting shapes: {[b.shape for b in batch]}", end=" ")
optimizer.zero_grad()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is theoretically possible to do these three lines in a cuda stream capture with "relaxed" mode to avoid doing any sort of GPU-side computation. However, it will work only for code that has no data-dependent shapes (like torch.nonzero). Note that I haven't run your code and don't know how slow it is right now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is surprisingly fast - for ~30 buckets the total runtime seems within 1-2 minutes. If CUDA graph "relaxed" mode would be "ok" with skipping NCCL ops then we might even incorporate this as a training time calibration (which we can't do now because these steps trigger NCCL syncs, if one GPU dies and other doesn't, it would hang). But even as-is I think this is a viable approach.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For lots of buckets (i.e. 100+) it takes a while. We should try the "relaxed" CUDA graph trick, and if it works, make a follow up PR.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious how long "a while" is.

The relaxed cuda graph trick definitely won't always work unfortunately... I spoke with someone who works on end-to-end training and he told me that there is a cudaStreamSynchronize() is the torch.amp.GradScaler, which will prevent using relaxed stream capture for models that do gradient scaling in mixed precision training.

Copy link
Collaborator Author

@pzelasko pzelasko Jul 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's around 15 minutes for 150 buckets.

model.training_step(batch, batch_idx)
optimizer.step()
except torch.cuda.OutOfMemoryError:
print(f"- OOM!")
torch.cuda.memory.empty_cache()
oom = True
else:
print(f"- OK!")
return oom

with torch.autocast(device, torch.bfloat16):
oom = step()
while not (finished := gen.advance(oom)):
oom = step()

print(f"Optimal setting for input={seq_len_in} output={seq_len_out} is max_batch_size={gen.max_batch_size}")
profile[(seq_len_in, seq_len_out)] = gen.max_batch_size
gen.start_batch_size = gen.max_batch_size

# TODO(pzelasko): Output the profile as a copy-pastable configuration/CLI option.
print("The final profile is:")
for (seq_len_in, seq_len_out), v in profile.items():
print(f"Optimal setting for input={seq_len_in} output={seq_len_out} is max_batch_size={v}")


if __name__ == "__main__":
oomptimizer()
Loading