Skip to content

Commit

Permalink
Support model config. Add bucket merging.
Browse files Browse the repository at this point in the history
Signed-off-by: Piotr Żelasko <[email protected]>
  • Loading branch information
pzelasko committed Jul 18, 2024
1 parent f71e27e commit 5995dbe
Showing 1 changed file with 61 additions and 14 deletions.
75 changes: 61 additions & 14 deletions scripts/speech_recognition/oomptimizer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
#!/usr/bin/env python
import importlib
import math
import sys

import click
import pytorch_lightning as pl
import torch
from lhotse import compute_num_samples
from omegaconf import OmegaConf

from nemo.collections.asr.models import ASRModel
from nemo.collections.asr.models.aed_multitask_models import EncDecMultiTaskModel
Expand Down Expand Up @@ -153,7 +156,14 @@ def type_cast_value(self, ctx, value):


@click.command(context_settings={'show_default': True})
@click.option("-m", "--model-name", type=str, required=True, help="Name of model to use, e.g. 'nvidia/canary-1b'.")
@click.option("-n", "--pretrained-name", type=str, help="Name of a pretrained model to use, e.g. 'nvidia/canary-1b'.")
@click.option(
"-m",
"--module-name",
type=str,
help="Full path to NeMo's module corresponding to CONFIG_PATH, e.g. 'nemo.collections.asr.models.EncDecMultiTaskModel'.",
)
@click.option("-c", "--config-path", type=str, help="Path to the training configuration file for MODULE_NAME.")
@click.option("-o", "--optimizer-name", type=str, default="adamw", help="Name of optimizer to use.")
@click.option(
"-b",
Expand All @@ -166,19 +176,21 @@ def type_cast_value(self, ctx, value):
"-t",
"--threshold",
type=float,
default=0.1,
default=0.05,
help="Search stopping criterion in range [0, 1], lower is more precise. Interpret as the uncerainty gap, i.e. (min_oom_batch_size - max_ok_batch_size) / min_oom_batch_size.",
)
@click.option("-s", "--start-batch-size", type=int, default=1024, help="Initial batch size to start the search from.")
@click.option(
"-l",
"--labels-per-second",
type=int,
default=10,
default=15, # conservative estimate towards longer transcripts
help="How many labels/second should we simulate. More means longer output text sequences, and can increase memory consumption.",
)
def oomptimizer(
model_name: str,
pretrained_name: str | None,
module_name: str | None,
config_path: str | None,
optimizer_name: str,
buckets: list[float],
threshold: float,
Expand All @@ -188,6 +200,13 @@ def oomptimizer(
"""
OOMptimizer finds the optimal batch sizes for training your model with bucketing dataloading.
\b
There are two main usage patterns: for using a pretrained model or an untrained model configuration.
The latter is more flexible but requires the user to provide two separate arguments. Examples:
* python oomptimizer.py --pretrained-name nvidia/canary-1b
* python oomptimizer.py --module-name nemo.collections.asr.models.EncDecMultiTaskModel \
--config-path examples/asr/conf/speech_multitask/fast-conformer_aed.yaml
Dynamic bucketing is notoriously difficult to tune as you risk running into CUDA OOM many steps into the training.
In order to simplify finding the optimal settings, OOMptimizer scans each bucket to find the maximum possible
batch size that doesn't trigger a CUDA OOM.
Expand All @@ -198,16 +217,29 @@ def oomptimizer(
2) Run OOMptimizer to find the optimal batch sizes for your specific model, optimizer, and GPU.
3) Use these optimal settings in your actual training script and enjoy optimal GPU utilization OOM-free.
"""
if all(opt is None for opt in (pretrained_name, module_name, config_path)):
click.secho(
"You need to provide either PRETRAINED_NAME or the pair of MODULE_NAME and CONFIG_PATH.", fg="yellow"
)
sys.exit(1)
logging.setLevel(logging.CRITICAL)
device = "cuda"

print("Intializing ASR model.")
# TODO(pzelasko): This currently only supports "from_pretrained".
# We need to be able to read a model training configuration and instantiate the model
# and figure out all the necessary details from that.
trainer = pl.Trainer(barebones=True)
trainer.log_every_n_steps = 1000000
model = ASRModel.from_pretrained(model_name, trainer=trainer).to(device)
if pretrained_name is not None:
assert (
config_path is not None and module_name is not None
), "--pretrained-name cannot be used together with --module-name/--config-path"
print(f"Intializing ASR model from pretrained checkpoint {pretrained_name}.")
model = ASRModel.from_pretrained(pretrained_name, trainer=trainer).to(device)
else:
assert config_path is not None, "--module-name requires --config-path to be specified as well."
assert module_name is not None, "--config-path requires --module-name to be specified as well."
cfg = OmegaConf.load(config_path)
namespace, name = module_name.rsplit('.', maxsplit=1)
model_cls = getattr(importlib.import_module(namespace), name)
model = model_cls(cfg=cfg.model, trainer=trainer).to(device)

# TODO(pzelasko): ideally move into model @property e.g. "oomptimizer_schema" :D
schema = [
Expand Down Expand Up @@ -243,7 +275,7 @@ def get_max_seq_lens(buckets):
gen = ProfilingBatchGenerator(schema=schema, start_batch_size=start_batch_size, rel_gap_thresh=threshold)
profile = {}

for seq_len_in, seq_len_out in max_seq_lens:
for bucket, (seq_len_in, seq_len_out) in zip(buckets, max_seq_lens):
print(f"The current sequence lengths are: input={seq_len_in} output={seq_len_out}.")
gen.reset()
batch_idx = 0
Expand All @@ -270,13 +302,28 @@ def step():
oom = step()

print(f"Optimal setting for input={seq_len_in} output={seq_len_out} is max_batch_size={gen.max_batch_size}")
profile[(seq_len_in, seq_len_out)] = gen.max_batch_size
profile[(bucket, seq_len_in, seq_len_out)] = gen.max_batch_size
gen.start_batch_size = gen.max_batch_size

# TODO(pzelasko): Output the profile as a copy-pastable configuration/CLI option.
print("The 1st stage profile is:")
for (bucket, seq_len_in, seq_len_out), bs in profile.items():
print(f"Bucket={bucket} (input={seq_len_in} output={seq_len_out}) => max_batch_size={bs}")

print("Bucket merging stage...")
final_profile = []
for idx, ((bucket, seq_len_in, seq_len_out), bs) in enumerate(profile.items()):
if idx == 0:
final_profile.append([bucket, bs])
continue
if bs == final_profile[-1][1]:
print(f"Merging bucket {idx} with bucket {idx-1} due to identical batch sizes.")
final_profile[-1][0] = bucket
continue
final_profile.append([bucket, bs])

print("The final profile is:")
for (seq_len_in, seq_len_out), v in profile.items():
print(f"Optimal setting for input={seq_len_in} output={seq_len_out} is max_batch_size={v}")
print("\tbucket_duration_bins=[", ",".join(str(seqlen) for seqlen, bs in final_profile), "]", sep="")
print("\tbucket_batch_size=[", ",".join(str(bs) for seqlen, bs in final_profile), "]", sep="")


if __name__ == "__main__":
Expand Down

0 comments on commit 5995dbe

Please sign in to comment.