Skip to content

Commit

Permalink
OOMptimizer: bucketing batch size profiles to make GPUs go 🔥 (NVIDIA#…
Browse files Browse the repository at this point in the history
…9763)

* Initial working draft of the OOMptimizer.

Signed-off-by: Piotr Żelasko <[email protected]>

* Support model config. Add bucket merging.

Signed-off-by: Piotr Żelasko <[email protected]>

* fix

Signed-off-by: Piotr Żelasko <[email protected]>

* code review

Signed-off-by: Piotr Żelasko <[email protected]>

* Support bucket_batch_size option for lhotse dataloading

Signed-off-by: Piotr Żelasko <[email protected]>

* Ability to force a memory fraction to be unused in OOMptimizer

Signed-off-by: Piotr Żelasko <[email protected]>

* Ability to force a memory fraction to be unused in OOMptimizer

Signed-off-by: Piotr Żelasko <[email protected]>

* Fix for autocast and configurable dtype

Signed-off-by: Piotr Żelasko <[email protected]>

* Allow token-per-second filtering

Signed-off-by: Piotr Żelasko <[email protected]>

* Fix an issue with canary tokenizer

Signed-off-by: Piotr Żelasko <[email protected]>

* Lift the requirement to use CanaryTokenizer with canary prompt format

* Fixes

Signed-off-by: Piotr Żelasko <[email protected]>

* Initial 2D bucketing draft

Signed-off-by: Piotr Żelasko <[email protected]>

* Separate script for 2D bucket estimation

Signed-off-by: Piotr Żelasko <[email protected]>

* Full 2D bucketing support: estimate_uduration_bins_2d, oomptimizer, training

Signed-off-by: Piotr Żelasko <[email protected]>

* fix

Signed-off-by: Piotr Żelasko <[email protected]>

* fix

Signed-off-by: Piotr Żelasko <[email protected]>

* fix

Signed-off-by: Piotr Żelasko <[email protected]>

* fix

Signed-off-by: Piotr Żelasko <[email protected]>

* fix

Signed-off-by: Piotr Żelasko <[email protected]>

* fix

Signed-off-by: Piotr Żelasko <[email protected]>

* Unit tests for bucket_batch_size and 2D bucketing for audio

Signed-off-by: Piotr Żelasko <[email protected]>

* Docs for 2D estimate duration bins

Signed-off-by: Piotr Żelasko <[email protected]>

* Fixes

Signed-off-by: Piotr Żelasko <[email protected]>

* Preliminary support for prompt format in estimate_duration_bins_2d

Signed-off-by: Piotr Żelasko <[email protected]>

* fixes

Signed-off-by: Piotr Żelasko <[email protected]>

* fix for bucket selection edge case

Signed-off-by: Piotr Żelasko <[email protected]>

* Add more info about the distribution to estimate_duration_bins_2d.py

Signed-off-by: Piotr Żelasko <[email protected]>

* Include CUDA RAM usage tracking in OOMptimizer

Signed-off-by: Piotr Żelasko <[email protected]>

* Track batch_size, num frames/tokens, and their padding ratio for AED multi task models

Signed-off-by: Piotr Żelasko <[email protected]>

* OOMptimizer documentation

Signed-off-by: Piotr Żelasko <[email protected]>

* Resolve TODOs and support any combination of (audio|text)->(audio|text) modalities

Signed-off-by: Piotr Żelasko <[email protected]>

* Add missing property decorator

Signed-off-by: Piotr Żelasko <[email protected]>

* fixes

Signed-off-by: Piotr Żelasko <[email protected]>

* Add docs about 2D bucketing with tokenizer and prompts

Signed-off-by: Piotr Żelasko <[email protected]>

* Fix bucket allocation logic for 2D bucketing

Signed-off-by: Piotr Żelasko <[email protected]>

* Bump lhotse version

Signed-off-by: Piotr Żelasko <[email protected]>

* fix...

Signed-off-by: Piotr Żelasko <[email protected]>

* Reverse bucket iteration order; move oomptimizer_schema to AsrModel

Signed-off-by: Piotr Żelasko <[email protected]>

* Make OOMptimizer compatible with dataclass mini-batches

Signed-off-by: Piotr Żelasko <[email protected]>

* Refine the schema

Signed-off-by: Piotr Żelasko <[email protected]>

* fixes after merging main

Signed-off-by: Piotr Żelasko <[email protected]>

* fix oomptimizer with pretrained models; verified canary, parakeet tdt and ctc

Signed-off-by: Piotr Żelasko <[email protected]>

* Disable concurrent bucketing to prevent spawning extra threads in tests

Signed-off-by: Piotr Żelasko <[email protected]>

* fix tests and make life more colorful

Signed-off-by: Piotr Żelasko <[email protected]>

* formatting

Signed-off-by: Piotr Żelasko <[email protected]>

* more reasonable starting batch size settings

Signed-off-by: Piotr Żelasko <[email protected]>

* Disable clearing of cuda memory cache

Signed-off-by: Piotr Żelasko <[email protected]>

* Even more conservative profile by incorporating DDP overhead simulation

Signed-off-by: Piotr Żelasko <[email protected]>

* Bucket selection fix and an extended unit test

* Refactor registered_prompt_format_fn to enable prompt formatting before Sampler

Signed-off-by: Piotr Żelasko <[email protected]>

* porting fix

Signed-off-by: Piotr Żelasko <[email protected]>

* Fixes, move fast-path to prompted dataset

Signed-off-by: Piotr Żelasko <[email protected]>

* Changes from Daniel's review

Signed-off-by: Piotr Żelasko <[email protected]>

* OOMptimizer tests + fixes for 1D bucketing case

Signed-off-by: Piotr Żelasko <[email protected]>

* estimate duration bins tests

Signed-off-by: Piotr Żelasko <[email protected]>

* address Daniel's review

Signed-off-by: Piotr Żelasko <[email protected]>

* fix CPU unit test

Signed-off-by: Piotr Żelasko <[email protected]>

* try to fix CI test

Signed-off-by: Piotr Żelasko <[email protected]>

* Apply suggestions from code review

Co-authored-by: oliver könig <[email protected]>
Signed-off-by: Piotr Żelasko <[email protected]>

* Disable 2D bucketing test with prompt due to quoting issue

Signed-off-by: Piotr Żelasko <[email protected]>

---------

Signed-off-by: Piotr Żelasko <[email protected]>
Signed-off-by: Piotr Żelasko <[email protected]>
Co-authored-by: oliver könig <[email protected]>
Signed-off-by: adityavavre <[email protected]>
  • Loading branch information
2 people authored and adityavavre committed Sep 15, 2024
1 parent 428c805 commit 8bca86e
Show file tree
Hide file tree
Showing 24 changed files with 1,756 additions and 184 deletions.
61 changes: 61 additions & 0 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,67 @@ jobs:
AFTER_SCRIPT: |
rm -rf examples/asr/speech_to_text_adapters_mha_results
# L2: OOMptimizer
L2_Speech_Estimate_Duration_Bins:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
with:
RUNNER: self-hosted-azure
SCRIPT: |
set -x
# 1D buckets [SSL, CTC]
python scripts/speech_recognition/estimate_duration_bins.py \
/home/TestData/an4_dataset/an4_train.json \
--buckets 5
# 2D buckets [CTC, RNNT, TDT] / with tokenizer
python scripts/speech_recognition/estimate_duration_bins_2d.py \
/home/TestData/an4_dataset/an4_train_lang.json \
--tokenizer /home/TestData/asr_tokenizers/canary/en/tokenizer_spe_bpe_v1024_max_4/tokenizer.model \
--buckets 5 \
--sub-buckets 2
# TODO(pzelasko): Figure out how to quote the value in the test properly for CI to accept it...
# 2D buckets with prompt [AED/Canary, SpeechLM] / with aggregate tokenizer + prompt format
# python scripts/speech_recognition/estimate_duration_bins_2d.py \
# /home/TestData/an4_dataset/an4_train_lang.json \
# --tokenizer /home/TestData/asr_tokenizers/canary/canary_spl_tokenizer_v32/tokenizer.model \
# /home/TestData/asr_tokenizers/canary/en/tokenizer_spe_bpe_v1024_max_4/tokenizer.model \
# /home/TestData/asr_tokenizers/canary/es/tokenizer_spe_bpe_v1024_max_4/tokenizer.model \
# --langs spl_tokens en es \
# --prompt-format canary \
# --prompt '[{"role":"user","slots":{"source_lang":"en","target_lang":"en","task":"asr","pnc":"yes"}}]' \
# --buckets 5 \
# --sub-buckets 2
# L2: OOMptimizer
L2_Speech_Batch_Size_OOMptimizer:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
with:
RUNNER: self-hosted-azure
SCRIPT: |
# 1D bucketing
python scripts/speech_recognition/oomptimizer.py \
-c /home/TestData/oomptimizer/fast-conformer_ctc_bpe.yaml \
-m nemo.collections.asr.models.EncDecCTCModelBPE \
-b "[5.0,10.0]"
# 2D bucketing
python scripts/speech_recognition/oomptimizer.py \
-c /home/TestData/oomptimizer/fast-conformer_ctc_bpe.yaml \
-m nemo.collections.asr.models.EncDecCTCModelBPE \
-b "[[5.0,30],[5.0,45],[10.0,57],[10.0,71]]"
# L2: OOMptimizer Canary (has a different batch schema)
L2_Speech_Batch_Size_OOMptimizer_Canary:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
with:
RUNNER: self-hosted-azure
SCRIPT: |
python scripts/speech_recognition/oomptimizer.py \
-c /home/TestData/oomptimizer/fast-conformer_aed.yaml \
-m nemo.collections.asr.models.EncDecMultiTaskModel \
-b "[[5.0,30],[5.0,45],[10.0,57],[10.0,71]]"
# L2: Speech Transcription
L2_Speech_Transcription_Speech_to_Text_Transcribe:
needs: [cicd-test-container-setup]
Expand Down
134 changes: 133 additions & 1 deletion docs/source/asr/datasets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -803,21 +803,153 @@ The following script may be used:
.. code-block:: bash
$ python scripts/speech_recognition/estimate_duration_bins.py -b 30 manifest.json
# The script's output:
Use the following options in your config:
num_buckets=30
bucket_duration_bins=[1.78,2.34,2.69,...
<other diagnostic information about the dataset>
For multi-dataset setups, one may provide multiple manifests and even their weights:
For multi-dataset setups, one may provide a dataset config directly:
.. code-block:: bash
$ python scripts/speech_recognition/estimate_duration_bins.py -b 30 input_cfg.yaml
# The script's output:
Use the following options in your config:
num_buckets=30
bucket_duration_bins=[1.91,3.02,3.56,...
<other diagnostic information about the dataset>
It's also possible to manually specify the list of data manifests (optionally together with weights):
.. code-block:: bash
$ python scripts/speech_recognition/estimate_duration_bins.py -b 30 [[manifest.json,0.7],[other.json,0.3]]
# The script's output:
Use the following options in your config:
num_buckets=30
bucket_duration_bins=[1.91,3.02,3.56,...
<other diagnostic information about the dataset>
2D bucketing
~~~~~~~~~~~~
To achieve maximum training efficiency for some classes of models it is necessary to stratify the sampling
both on the input sequence lengths and the output sequence lengths.
One such example are attention encoder-decoder models, where the overall GPU memory usage can be factorized
into two main components: input-sequence-length bound (encoder activations) and output-sequence-length bound
(decoder activations).
Classical bucketing techniques only stratify on the input sequence length (e.g. duration in speech),
which leverages encoder effectively but leads to excessive padding on on decoder's side.
To amend this we support a 2D bucketing technique which estimates the buckets in two stages.
The first stage is identical to 1D bucketing, i.e. we determine the input-sequence bucket bins so that
every bin holds roughly an equal duration of audio.
In the second stage, we use a tokenizer and optionally a prompt formatter (for prompted models) to
estimate the total number of tokens in each duration bin, and sub-divide it into several sub-buckets,
where each sub-bucket again holds roughly an equal number of tokens.
To run 2D bucketing with 30 buckets sub-divided into 5 sub-buckets each (150 buckets total), use the following script:
.. code-block:: bash
$ python scripts/speech_recognition/estimate_duration_bins_2d.py \
--tokenizer path/to/tokenizer.model \
--buckets 30 \
--sub-buckets 5 \
input_cfg.yaml
# The script's output:
Use the following options in your config:
num_buckets=30
bucket_duration_bins=[[1.91,10],[1.91,17],[1.91,25],...
max_duration=...
max_tps=...
<other diagnostic information about the dataset>
Note that the output in ``bucket_duration_bins`` is a nested list, where every bin specifies
the maximum duration and the maximum number of tokens that go into the bucket.
Passing this option to Lhotse dataloader will automatically enable 2D bucketing.
Note the presence of ``max_duration`` and ``max_tps`` (token-per-second) options:
these need to be included in dataloader's configuration to ensure we can use the buckets correctly at runtime
in case of outliers.
In general, if you change your data in training, it is highly advisable to re-estimate the duration bins.
Note that reasonable values for tokens-per-second rarely exceed 12tps with reasonably good tokenizers.
If you find your dataset's TPS is much higher than that, you may have some bad data outliers.
In that case you may specify ``--max_tps`` option to discard those both in bin estimation and dataloading.
We also support aggregate tokenizers for 2D bucketing estimation:
.. code-block:: bash
$ python scripts/speech_recognition/estimate_duration_bins_2d.py \
--tokenizer path/to/en/tokenizer.model path/to/pl/tokenizer1.model \
--langs en pl \
--buckets 30 \
--sub-buckets 5 \
input_cfg.yaml
To estimate 2D buckets for a prompted model such as Canary-1B, provide prompt format name and an example prompt.
For Canary-1B, we'll also provide the special tokens tokenizer. Example:
.. code-block:: bash
$ python scripts/speech_recognition/estimate_duration_bins_2d.py \
--prompt-format canary \
--prompt "[{'role':'user','slots':{'source_lang':'en','target_lang':'de','task':'ast','pnc':'yes'}}]" \
--tokenizer path/to/spl_tokens/tokenizer.model path/to/en/tokenizer.model path/to/de/tokenizer1.model \
--langs spl_tokens en de \
--buckets 30 \
--sub-buckets 5 \
input_cfg.yaml
Pushing GPU utilization to the limits with bucketing and OOMptimizer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The default approach of specifying a ``batch_duration``, ``bucket_duration_bins`` and ``quadratic_duration``
is quite flexible, but is not maximally efficient. We observed that in practice it often leads to under-utilization
of GPU memory and compute for most buckets (especially those with shorter durations).
While it is impossible to estimate GPU memory usage up-front, we can determine it empirically with a bit of search.
OOMptimizer is an approach that given a NeMo model, optimizer, and a list of buckets (1D or 2D)
estimates the maximum possible batch size to use for each bucket.
It performs a binary search over batch sizes that succeed or lead to CUDA OOM until convergence.
We find that the resulting bucketing batch size profiles enable full GPU utilization in training,
while it only takes a couple of minutes to complete the search.
In order to run OOMptimizer, you only need the bucketing bins (from previous sections) and a model configuration:
.. code-block:: bash
$ python scripts/speech_recognition/oomptimizer.py \
--config-path fast-conformer_aed.yaml \
--module-name nemo.collections.asr.models.EncDecMultiTaskModel \
--buckets '[[3.975,30],[3.975,48],[4.97,37],[4.97,60],[5.851,42],[5.851,71],[6.563,46],[6.563,79],[7.32,49],[7.32,88],[8.19,54],[8.19,99],[8.88,61],[8.88,107],[9.75,66],[9.75,117],[10.55,72],[10.55,127],[11.21,76],[11.21,135],[11.87,79],[11.87,143],[12.54,82],[12.54,151],[13.08,87],[13.08,157],[13.62,91],[13.62,164],[14.16,93],[14.16,170],[14.7,96],[14.7,177],[15.19,99],[15.19,183],[15.67,101],[15.67,189],[16.13,103],[16.13,194],[16.66,105],[16.66,200],[17.2,108],[17.2,207],[17.73,111],[17.73,213],[18.2,114],[18.2,219],[18.69,117],[18.69,225],[19.15,120],[19.15,230],[19.62,123],[19.62,236],[20.264,122],[20.264,244],[32.547,173],[32.547,391],[36.587,227],[36.587,440],[40.0,253],[40.0,480]]'
# The script's output:
<output logs from the search>
The final profile is:
bucket_duration_bins=[[3.975,30],[3.975,48],[4.97,37],[4.97,60],[5.851,42],[5.851,71],[6.563,46],[6.563,79],[7.32,49],[7.32,88],[8.19,54],[8.19,99],[8.88,61],[8.88,107],[9.75,66],[9.75,117],[10.55,72],[10.55,127],[11.21,76],[11.21,135],[11.87,79],[11.87,143],[12.54,82],[12.54,151],[13.08,87],[13.08,157],[13.62,91],[13.62,164],[14.16,93],[14.16,170],[14.7,96],[14.7,177],[15.19,99],[15.19,183],[15.67,101],[15.67,189],[16.13,103],[16.13,194],[16.66,105],[16.66,200],[17.2,108],[17.2,207],[17.73,111],[17.73,213],[18.2,114],[18.2,219],[18.69,117],[18.69,225],[19.15,120],[19.15,230],[19.62,123],[19.62,236],[20.264,122],[20.264,244],[32.547,173],[32.547,391],[36.587,227],[36.587,440],[40.0,253],[40.0,480]]
bucket_batch_size=[352,308,280,245,245,206,206,180,186,163,168,142,151,132,136,119,126,106,116,98,110,92,104,88,99,83,94,79,90,76,86,72,86,72,81,68,80,65,78,63,74,60,72,58,70,58,68,54,66,52,65,52,62,50,37,28,31,24,28,21]
max_tps=12.0
max_duration=40.0
Use the resulting options in your training configuration (typically under namespace ``model.train_ds``) to apply the profile.
It's also possible to run OOMptimizer using a pretrained model's name and bucket bins corresponding
to your fine-tuning data:
$ python scripts/speech_recognition/oomptimizer.py \
--pretrained-name nvidia/canary-1b \
--buckets '[2.0,3.1,5.6,6.6,...]'
Note that your training script can perform some additional actions using GPU RAM that cannot be anticipated by the OOMptimizer.
By default, we let the script use up to 90% of GPU's RAM for this estimation to account for that.
In the unlikely case you run into an OutOfMemoryError during training, you can try re-estimating the profile with the option ``--memory-fraction 0.75`` (or another value) that will further cap OOMptimizer's available GPU RAM.
Seeds and randomness
~~~~~~~~~~~~~~~~~~~~
Expand Down
14 changes: 13 additions & 1 deletion nemo/collections/asr/data/audio_to_text_lhotse.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,19 @@ def __init__(self, tokenizer):

def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]:
audio, audio_lens, cuts = self.load_audio(cuts)
tokens = [torch.as_tensor(self.tokenizer(c.supervisions[0].text, c.supervisions[0].language)) for c in cuts]
tokens = [
torch.as_tensor(
sum(
(
# Supervisions may come pre-tokenized from the dataloader.
s.tokens if hasattr(s, "tokens") else self.tokenizer(s.text, s.language)
for s in c.supervisions
),
start=[],
)
)
for c in cuts
]
token_lens = torch.tensor([t.size(0) for t in tokens], dtype=torch.long)
tokens = collate_vectors(tokens, padding_value=0)
return audio, audio_lens, tokens, token_lens
Expand Down
Loading

0 comments on commit 8bca86e

Please sign in to comment.