diff --git a/egs/librispeech/ASR/finetune.sh b/egs/librispeech/ASR/finetune.sh index 63d0966ed0..bc6357312e 100755 --- a/egs/librispeech/ASR/finetune.sh +++ b/egs/librispeech/ASR/finetune.sh @@ -79,6 +79,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then --use-averaged-model True \ --beam-size 4 \ --exp-dir pruned_transducer_stateless7/exp_giga_finetune \ + --bpe-model icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/data/lang_bpe_500/bpe.model \ --max-duration 400 \ --decoding-method $m done diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py b/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py index 4f64850b65..b0e4be0d1a 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/decode_gigaspeech.py @@ -1,8 +1,9 @@ #!/usr/bin/env python3 # -# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang, +# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang, # Zengwei Yao, -# Xiaoyu Yang) +# Xiaoyu Yang, +# Yifan Yang,) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -20,36 +21,36 @@ """ Usage: (1) greedy search -./pruned_transducer_stateless7/decode.py \ +./pruned_transducer_stateless7/decode_gigaspeech.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ + --exp-dir ./pruned_transducer_stateless7/exp_giga_finetune \ --max-duration 600 \ --decoding-method greedy_search (2) beam search (not recommended) -./pruned_transducer_stateless7/decode.py \ +./pruned_transducer_stateless7/decode_gigaspeech.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ + --exp-dir ./pruned_transducer_stateless7/exp_giga_finetune \ --max-duration 600 \ --decoding-method beam_search \ --beam-size 4 (3) modified beam search -./pruned_transducer_stateless7/decode.py \ +./pruned_transducer_stateless7/decode_gigaspeech.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ + --exp-dir ./pruned_transducer_stateless7/exp_giga_finetune \ --max-duration 600 \ --decoding-method modified_beam_search \ --beam-size 4 (4) fast beam search (one best) -./pruned_transducer_stateless7/decode.py \ +./pruned_transducer_stateless7/decode_gigaspeech.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ + --exp-dir ./pruned_transducer_stateless7/exp_giga_finetune \ --max-duration 600 \ --decoding-method fast_beam_search \ --beam 20.0 \ @@ -57,10 +58,10 @@ --max-states 64 (5) fast beam search (nbest) -./pruned_transducer_stateless7/decode.py \ +./pruned_transducer_stateless7/decode_gigaspeech.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ + --exp-dir ./pruned_transducer_stateless7/exp_giga_finetune \ --max-duration 600 \ --decoding-method fast_beam_search_nbest \ --beam 20.0 \ @@ -70,10 +71,10 @@ --nbest-scale 0.5 (6) fast beam search (nbest oracle WER) -./pruned_transducer_stateless7/decode.py \ +./pruned_transducer_stateless7/decode_gigaspeech.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ + --exp-dir ./pruned_transducer_stateless7/exp_giga_finetune \ --max-duration 600 \ --decoding-method fast_beam_search_nbest_oracle \ --beam 20.0 \ @@ -83,10 +84,10 @@ --nbest-scale 0.5 (7) fast beam search (with LG) -./pruned_transducer_stateless7/decode.py \ +./pruned_transducer_stateless7/decode_gigaspeech.py \ --epoch 28 \ --avg 15 \ - --exp-dir ./pruned_transducer_stateless7/exp \ + --exp-dir ./pruned_transducer_stateless7/exp_giga_finetune \ --max-duration 600 \ --decoding-method fast_beam_search_nbest_LG \ --beam 20.0 \ @@ -187,7 +188,7 @@ def get_parser(): parser.add_argument( "--exp-dir", type=str, - default="pruned_transducer_stateless7/exp", + default="pruned_transducer_stateless7/exp_giga_finetune", help="The experiment dir", ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py index 726a24809e..02e8bd9eb4 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py @@ -1,8 +1,10 @@ #!/usr/bin/env python3 -# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang, +# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang, # Wei Kang, -# Mingshuang Luo,) -# Zengwei Yao) +# Mingshuang Luo, +# Zengwei Yao, +# Xiaoyu Yang, +# Yifan Yang) # # See ../../../../LICENSE for clarification regarding multiple authors # @@ -20,27 +22,23 @@ """ Usage: -export CUDA_VISIBLE_DEVICES="0,1,2,3" +export CUDA_VISIBLE_DEVICES="0,1" -./pruned_transducer_stateless7/train.py \ - --world-size 4 \ - --num-epochs 30 \ - --start-epoch 1 \ - --exp-dir pruned_transducer_stateless7/exp \ - --full-libri 1 \ - --max-duration 300 - -# For mix precision training: - -./pruned_transducer_stateless7/train.py \ - --world-size 4 \ - --num-epochs 30 \ +./pruned_transducer_stateless7/finetune.py \ + --world-size 2 \ + --num-epochs 20 \ --start-epoch 1 \ + --exp-dir pruned_transducer_stateless7/exp_giga_finetune \ + --subset S \ --use-fp16 1 \ - --exp-dir pruned_transducer_stateless7/exp \ - --full-libri 1 \ - --max-duration 550 - + --base-lr 0.005 \ + --lr-epochs 100 \ + --lr-batches 100000 \ + --bpe-model icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/data/lang_bpe_500/bpe.model \ + --do-finetune True \ + --use-mux True \ + --finetune-ckpt icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp/pretrained.pt \ + --max-duration 500 """ @@ -59,9 +57,10 @@ import torch.multiprocessing as mp import torch.nn as nn from decoder import Decoder +from asr_datamodule import LibriSpeechAsrDataModule from gigaspeech import GigaSpeechAsrDataModule from joiner import Joiner -from lhotse.cut import Cut +from lhotse.cut import Cut, CutSet from lhotse.dataset.sampling.base import CutSampler from lhotse.utils import fix_random_seed from model import Transducer @@ -103,7 +102,21 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None: def add_finetune_arguments(parser: argparse.ArgumentParser): - parser.add_argument("--do-finetune", type=str2bool, default=False) + parser.add_argument( + "--do-finetune", + type=str2bool, + default=True, + help="Whether to fine-tune.", + ) + parser.add_argument( + "--use-mux", + type=str2bool, + default=False, + help=""" + Whether to adapt. If true, we will mix 5% of the new data + with 95% of the original data to fine-tune. + """, + ) parser.add_argument( "--init-modules", @@ -907,7 +920,11 @@ def train_one_epoch( # NOTE: We use reduction==sum and loss is computed over utterances # in the batch and there is no normalization to it so far. scaler.scale(loss).backward() - set_batch_count(model, params.batch_idx_train) + # Skip the warmup by adding a huge number to batch_count + if params.do_finetune: + set_batch_count(model, params.batch_idx_train + 100000) + else: + set_batch_count(model, params.batch_idx_train) scheduler.step_batch(params.batch_idx_train) scaler.step(optimizer) @@ -1104,7 +1121,12 @@ def run(rank, world_size, args): parameters_names=parameters_names, ) - scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs) + scheduler = Eden( + optimizer=optimizer, + lr_batches=params.lr_batches, + lr_epochs=params.lr_epochs, + warmup_batches=0, + ) if checkpoints and "optimizer" in checkpoints: logging.info("Loading optimizer state dict") @@ -1129,7 +1151,15 @@ def run(rank, world_size, args): gigaspeech = GigaSpeechAsrDataModule(args) - train_cuts = gigaspeech.train_cuts() + if params.use_mux: + librispeech = LibriSpeechAsrDataModule(args) + train_cuts = CutSet.mux( + librispeech.train_all_shuf_cuts(), + gigaspeech.train_cuts(), + weights=[0.95, 0.05], + ) + else: + train_cuts = gigaspeech.train_cuts() def remove_short_and_long_utt(c: Cut): # Keep only utterances with duration between 1 second and 20 seconds @@ -1141,9 +1171,9 @@ def remove_short_and_long_utt(c: Cut): # an utterance duration distribution for your dataset to select # the threshold if c.duration < 1.0 or c.duration > 20.0: - logging.warning( - f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" - ) + # logging.warning( + # f"Exclude cut with ID {c.id} from training. Duration: {c.duration}" + # ) return False # In pruned RNN-T, we require that T >= S