-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OOMptimizer: bucketing batch size profiles to make GPUs go 🔥 #9763
Changes from 40 commits
f71e27e
5995dbe
561e674
5970c34
b4ab721
9e632e4
4b009bd
a386fa8
a4e2c66
0cdc58d
9c3e625
aaa05a5
e7556fb
8497a25
bc60b5f
10c2ada
bb0bc4f
5e442bf
97a800c
21588ba
77d2851
5b704a8
1155135
2f43313
81420df
572f2be
ade45ea
8d607e1
7ffdd96
5644f43
3b532a9
616036f
968a00f
639df62
ec4206f
d64a726
0d2cbe5
764c3f1
14ed8be
888c343
5c1e096
e3aa624
41beffd
4f6859e
b237f96
c4a25ea
eeebf19
731fda0
5d5d9e1
4a41e66
87d0ea7
f7198bc
44ef482
fc8a8c7
2bd282c
4546a1f
88f4d21
3f892b9
81288e1
0d10556
dddc2ef
80cb49b
7a1bf71
cbd3da8
9bb2693
81b4d92
10444f8
02c88f5
2383c93
6f066e0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -808,7 +808,17 @@ The following script may be used: | |
bucket_duration_bins=[1.78,2.34,2.69,... | ||
<other diagnostic information about the dataset> | ||
|
||
For multi-dataset setups, one may provide multiple manifests and even their weights: | ||
For multi-dataset setups, one may provide a dataset config directly: | ||
|
||
.. code-block:: bash | ||
|
||
$ python scripts/speech_recognition/estimate_duration_bins.py -b 30 input_cfg.yaml | ||
Use the following options in your config: | ||
num_buckets=30 | ||
bucket_duration_bins=[1.91,3.02,3.56,... | ||
<other diagnostic information about the dataset> | ||
|
||
It's also possible to manually specify the list of data manifests (optionally together with weights): | ||
|
||
.. code-block:: bash | ||
|
||
|
@@ -818,6 +828,122 @@ For multi-dataset setups, one may provide multiple manifests and even their weig | |
bucket_duration_bins=[1.91,3.02,3.56,... | ||
<other diagnostic information about the dataset> | ||
|
||
2D bucketing | ||
~~~~~~~~~~~~ | ||
|
||
To achieve maximum training efficiency for some classes of models it is necessary to stratify the sampling | ||
both on the input sequence lengths and the output sequence lengths. | ||
One such example are attention encoder-decoder models, where the overall GPU memory usage can be factorized | ||
into two main components: input-sequence-length bound (encoder activations) and output-sequence-length bound | ||
(decoder activations). | ||
Classical bucketing techniques only stratify on the input sequence length (e.g. duration in speech), | ||
which leverages encoder effectively but leads to excessive padding on on decoder's side. | ||
|
||
To amend this we support a 2D bucketing technique which estimates the buckets in two stages. | ||
The first stage is identical to 1D bucketing, i.e. we determine the input-sequence bucket bins so that | ||
every bin holds roughly an equal duration of audio. | ||
In the second stage, we use a tokenizer and optionally a prompt formatter (for prompted models) to | ||
estimate the total number of tokens in each duration bin, and sub-divide it into several sub-buckets, | ||
where each sub-bucket again holds roughly an equal number of tokens. | ||
|
||
To run 2D bucketing with 30 buckets sub-divided into 5 sub-buckets each (150 buckets total), use the following script: | ||
|
||
.. code-block:: bash | ||
|
||
$ python scripts/speech_recognition/estimate_duration_bins_2d.py \ | ||
--tokenizer path/to/tokenizer.model \ | ||
--buckets 30 \ | ||
--sub-buckets 5 \ | ||
input_cfg.yaml | ||
Use the following options in your config: | ||
num_buckets=30 | ||
bucket_duration_bins=[[1.91,10],[1.91,17],[1.91,25],... | ||
max_duration=... | ||
max_tps=... | ||
<other diagnostic information about the dataset> | ||
|
||
Note that the output in ``bucket_duration_bins`` is a nested list, where every bin specifies | ||
the maximum duration and the maximum number of tokens that go into the bucket. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This assume that the buckets are in tar? maybe you can add some-kind of exmaple on how the data looks, it will significantly help to prepare data for using this tool. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It doesn't have any assumptions about the data. These can be regular NeMo manifests or lhotse manifests; tarred/non-tarred doesn't matter. You can see the existing examples above to see how to provide nemo/lhotse inputs or how to build an input config for this script. |
||
Passing this option to Lhotse dataloader will automatically enable 2D bucketing. | ||
Note the presence of ``max_duration`` and ``max_tps`` (token-per-second) options: | ||
these need to be included in dataloader's configuration to ensure we can use the buckets correctly at runtime | ||
in case of outliers. | ||
In general, if you change your data in training, it is highly advisable to re-estimate the duration bins. | ||
|
||
Note that reasonable values for tokens-per-second rarely exceed 12tps with reasonably good tokenizers. | ||
If you find your dataset's TPS is much higher than that, you may have some bad data outliers. | ||
In that case you may specify ``--max_tps`` option to discard those both in bin estimation and dataloading. | ||
|
||
We also support aggregate tokenizers for 2D bucketing estimation: | ||
|
||
.. code-block:: bash | ||
|
||
$ python scripts/speech_recognition/estimate_duration_bins_2d.py \ | ||
--tokenizer path/to/en/tokenizer.model path/to/pl/tokenizer1.model \ | ||
--langs en pl \ | ||
--buckets 30 \ | ||
--sub-buckets 5 \ | ||
input_cfg.yaml | ||
|
||
To estimate 2D buckets for a prompted model such as Canary-1B, provide prompt format name and an example prompt. | ||
For Canary-1B, we'll also provide the special tokens tokenizer. Example: | ||
|
||
.. code-block:: bash | ||
|
||
$ python scripts/speech_recognition/estimate_duration_bins_2d.py \ | ||
--prompt-format canary \ | ||
--prompt "[{'role':'user','slots':{'source_lang':'en','target_lang':'de','task':'ast','pnc':'yes'}}]" \ | ||
--tokenizer path/to/spl_tokens/tokenizer.model path/to/en/tokenizer.model path/to/de/tokenizer1.model \ | ||
--langs spl_tokens en de \ | ||
--buckets 30 \ | ||
--sub-buckets 5 \ | ||
input_cfg.yaml | ||
|
||
Pushing GPU utilization to the limits with bucketing and OOMptimizer | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
||
The default approach of specifying a ``batch_duration``, ``bucket_duration_bins`` and ``quadratic_duration`` | ||
is quite flexible, but is not maximally efficient. We observed that in practice it often leads to under-utilization | ||
of GPU memory and compute for most buckets (especially those with shorter durations). | ||
While it is impossible to estimate GPU memory usage up-front, we can determine it empirically with a bit of search. | ||
|
||
OOMptimizer is an approach that given a NeMo model, optimizer, and a list of buckets (1D or 2D) | ||
estimates the maximum possible batch size to use for each bucket. | ||
It performs a binary search over batch sizes that succeed or lead to CUDA OOM until convergence. | ||
We find that the resulting bucketing batch size profiles enable full GPU utilization in training, | ||
while it only takes a couple of minutes to complete the search. | ||
|
||
In order to run OOMptimizer, you only need the bucketing bins (from previous sections) and a model configuration: | ||
|
||
.. code-block:: bash | ||
|
||
$ python scripts/speech_recognition/oomptimizer.py \ | ||
--config-path fast-conformer_aed.yaml \ | ||
--module-name nemo.collections.asr.models.EncDecMultiTaskModel \ | ||
--buckets '[[3.975,30],[3.975,48],[4.97,37],[4.97,60],[5.851,42],[5.851,71],[6.563,46],[6.563,79],[7.32,49],[7.32,88],[8.19,54],[8.19,99],[8.88,61],[8.88,107],[9.75,66],[9.75,117],[10.55,72],[10.55,127],[11.21,76],[11.21,135],[11.87,79],[11.87,143],[12.54,82],[12.54,151],[13.08,87],[13.08,157],[13.62,91],[13.62,164],[14.16,93],[14.16,170],[14.7,96],[14.7,177],[15.19,99],[15.19,183],[15.67,101],[15.67,189],[16.13,103],[16.13,194],[16.66,105],[16.66,200],[17.2,108],[17.2,207],[17.73,111],[17.73,213],[18.2,114],[18.2,219],[18.69,117],[18.69,225],[19.15,120],[19.15,230],[19.62,123],[19.62,236],[20.264,122],[20.264,244],[32.547,173],[32.547,391],[36.587,227],[36.587,440],[40.0,253],[40.0,480]]' | ||
<output logs from the search> | ||
The final profile is: | ||
bucket_duration_bins=[[3.975,30],[3.975,48],[4.97,37],[4.97,60],[5.851,42],[5.851,71],[6.563,46],[6.563,79],[7.32,49],[7.32,88],[8.19,54],[8.19,99],[8.88,61],[8.88,107],[9.7 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure how this is rendered in the final webpage, but note that 9.75 gets its 5 cut off into the next line here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. right, I merged the lines now |
||
5,66],[9.75,117],[10.55,72],[10.55,127],[11.21,76],[11.21,135],[11.87,79],[11.87,143],[12.54,82],[12.54,151],[13.08,87],[13.08,157],[13.62,91],[13.62,164],[14.16,93],[14.16,170],[14 | ||
.7,96],[14.7,177],[15.19,99],[15.19,183],[15.67,101],[15.67,189],[16.13,103],[16.13,194],[16.66,105],[16.66,200],[17.2,108],[17.2,207],[17.73,111],[17.73,213],[18.2,114],[18.2,219], | ||
[18.69,117],[18.69,225],[19.15,120],[19.15,230],[19.62,123],[19.62,236],[20.264,122],[20.264,244],[32.547,173],[32.547,391],[36.587,227],[36.587,440],[40.0,253],[40.0,480]] | ||
bucket_batch_size=[352,308,280,245,245,206,206,180,186,163,168,142,151,132,136,119,126,106,116,98,110,92,104,88,99,83,94,79,90,76,86,72,86,72,81,68,80,65,78,63,74,60,72,58,7 | ||
0,58,68,54,66,52,65,52,62,50,37,28,31,24,28,21] | ||
max_tps=12.0 | ||
max_duration=40.0 | ||
|
||
Use the resulting options in your training configuration (typically under namespace ``model.train_ds``) to apply the profile. | ||
|
||
It's also possible to run OOMptimizer using a pretrained model's name and bucket bins corresponding | ||
to your fine-tuning data: | ||
|
||
$ python scripts/speech_recognition/oomptimizer.py \ | ||
--pretrained-name nvidia/canary-1b \ | ||
--buckets '[2.0,3.1,5.6,6.6,...]' | ||
|
||
Note that in a rare event, your training script can perform some additional actions using GPU RAM that cannot | ||
be anticipated by the OOMptimizer. In that case, you can try re-estimating the profile with the option | ||
``--memory-fraction 0.75`` (or another value) that will cap OOMptimizer's available GPU RAM. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Specify that the default is 0.9 if you're comfortable with that. |
||
|
||
Seeds and randomness | ||
~~~~~~~~~~~~~~~~~~~~ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -512,6 +512,7 @@ def _setup_dataloader_from_config(self, config: Optional[Dict], inference: bool | |
prompt_format_fn=get_prompt_format_fn(self.prompt_format), | ||
inference=inference, | ||
), | ||
tokenizer=self.tokenizer, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do other subclasses' _setup_dataloader_from_config methods need to be changed similarly now to pass the tokenizer? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great point! Adding that. Will also add integration tests for 2D bucketing estimation, OOMptimizer script, and a training script with 2D buckets once my other PR with new integration tests is merged. |
||
) | ||
|
||
def setup_training_data(self, train_data_config: Optional[DictConfig]): | ||
|
@@ -682,9 +683,18 @@ def training_step(self, batch, batch_nb): | |
|
||
audio_loss = self.loss(log_probs=transf_log_probs, labels=labels) | ||
|
||
num_frames = signal_len.sum() | ||
num_tokens = transcript_len.sum() | ||
tot_frames = signal.numel() | ||
tot_tokens = transcript.numel() | ||
tensorboard_logs = { | ||
'train_loss': audio_loss, | ||
'learning_rate': self._optimizer.param_groups[0]['lr'], | ||
'batch_size': signal.shape[0], | ||
'num_frames': num_frames, | ||
'num_tokens': num_tokens, | ||
'input_to_padding_ratio': num_frames / tot_frames, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI this line and the next will do a GPU to CPU synchronization at some point because num_frames is (likely) a cuda tensor, while tot_frames is an int. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you think of an alternative way to compute this without requiring the sync? I think it's a useful diagnostic. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IDK. If you don't see any large gaps, just leave it for now. I'm not sure at what point lightning will convert this tensorboard_logs variable to a serialized string, at which point it will have to do the GPU to CPU copy. |
||
'output_to_padding_ratio': num_tokens / tot_tokens, | ||
} | ||
|
||
return {'loss': audio_loss, 'log': tensorboard_logs} | ||
|
@@ -1051,6 +1061,26 @@ def predict_step(self, batch, batch_idx=0, dataloader_idx=0, has_processed_signa | |
def adapter_module_names(self) -> List[str]: | ||
return ['', 'encoder', 'transf_encoder', 'transf_decoder'] | ||
|
||
@property | ||
def oomptimizer_schema(self) -> list[dict]: | ||
""" | ||
Return a typing schema for optimal batch size calibration for various | ||
sequence lengths using OOMptimizer. | ||
""" | ||
assert hasattr(self, "tokenizer"), "OOMptimizer currently supports only models that use tokenizers." | ||
return [ | ||
{"type": NeuralType(("B", "T"), AudioSignal()), "seq_length": "input"}, | ||
{"type": NeuralType(("B",), LengthsType()), "seq_length": "input"}, | ||
{ | ||
"type": NeuralType(("B", "T"), LabelsType()), | ||
"seq_length": "output", | ||
"vocab_size": self.tokenizer.vocab_size, | ||
}, | ||
{"type": NeuralType(("B",), LengthsType()), "seq_length": "output"}, | ||
{"type": "dummy"}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are these two lines still needed? It looks like "dummy" entries get filtered out. But I may be missing code in lhotse that might work with this. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, dummy will pass an empty tensor, but it still needs to be there to match the training_step function signature. |
||
{"type": "dummy"}, | ||
] | ||
|
||
|
||
def parse_multitask_prompt(prompt: dict | None) -> list[dict]: | ||
if prompt is None or not prompt: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My complaint about this line is that it isn't obvious to a user that this array of values will be output by running
estimate_duration_bins_2d.py
. (Same problem with estimate_durations_bins.py above)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks, fixed