Skip to content

Commit

Permalink
Add adaption recipe for pruned_transducer_stateless7 (#1059)
Browse files Browse the repository at this point in the history
* Add mux for finetune

* Add comments

* Fix for black

* Update finetune.py
  • Loading branch information
yfyeung authored May 17, 2023
1 parent bccd20d commit 562bda9
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 46 deletions.
1 change: 1 addition & 0 deletions egs/librispeech/ASR/finetune.sh
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
--use-averaged-model True \
--beam-size 4 \
--exp-dir pruned_transducer_stateless7/exp_giga_finetune \
--bpe-model icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/data/lang_bpe_500/bpe.model \
--max-duration 400 \
--decoding-method $m
done
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#!/usr/bin/env python3
#
# Copyright 2021-2022 Xiaomi Corporation (Author: Fangjun Kuang,
# Copyright 2021-2023 Xiaomi Corporation (Author: Fangjun Kuang,
# Zengwei Yao,
# Xiaoyu Yang)
# Xiaoyu Yang,
# Yifan Yang,)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
Expand All @@ -20,47 +21,47 @@
"""
Usage:
(1) greedy search
./pruned_transducer_stateless7/decode.py \
./pruned_transducer_stateless7/decode_gigaspeech.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--exp-dir ./pruned_transducer_stateless7/exp_giga_finetune \
--max-duration 600 \
--decoding-method greedy_search
(2) beam search (not recommended)
./pruned_transducer_stateless7/decode.py \
./pruned_transducer_stateless7/decode_gigaspeech.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--exp-dir ./pruned_transducer_stateless7/exp_giga_finetune \
--max-duration 600 \
--decoding-method beam_search \
--beam-size 4
(3) modified beam search
./pruned_transducer_stateless7/decode.py \
./pruned_transducer_stateless7/decode_gigaspeech.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--exp-dir ./pruned_transducer_stateless7/exp_giga_finetune \
--max-duration 600 \
--decoding-method modified_beam_search \
--beam-size 4
(4) fast beam search (one best)
./pruned_transducer_stateless7/decode.py \
./pruned_transducer_stateless7/decode_gigaspeech.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--exp-dir ./pruned_transducer_stateless7/exp_giga_finetune \
--max-duration 600 \
--decoding-method fast_beam_search \
--beam 20.0 \
--max-contexts 8 \
--max-states 64
(5) fast beam search (nbest)
./pruned_transducer_stateless7/decode.py \
./pruned_transducer_stateless7/decode_gigaspeech.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--exp-dir ./pruned_transducer_stateless7/exp_giga_finetune \
--max-duration 600 \
--decoding-method fast_beam_search_nbest \
--beam 20.0 \
Expand All @@ -70,10 +71,10 @@
--nbest-scale 0.5
(6) fast beam search (nbest oracle WER)
./pruned_transducer_stateless7/decode.py \
./pruned_transducer_stateless7/decode_gigaspeech.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--exp-dir ./pruned_transducer_stateless7/exp_giga_finetune \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_oracle \
--beam 20.0 \
Expand All @@ -83,10 +84,10 @@
--nbest-scale 0.5
(7) fast beam search (with LG)
./pruned_transducer_stateless7/decode.py \
./pruned_transducer_stateless7/decode_gigaspeech.py \
--epoch 28 \
--avg 15 \
--exp-dir ./pruned_transducer_stateless7/exp \
--exp-dir ./pruned_transducer_stateless7/exp_giga_finetune \
--max-duration 600 \
--decoding-method fast_beam_search_nbest_LG \
--beam 20.0 \
Expand Down Expand Up @@ -187,7 +188,7 @@ def get_parser():
parser.add_argument(
"--exp-dir",
type=str,
default="pruned_transducer_stateless7/exp",
default="pruned_transducer_stateless7/exp_giga_finetune",
help="The experiment dir",
)

Expand Down
88 changes: 59 additions & 29 deletions egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#!/usr/bin/env python3
# Copyright 2021-2022 Xiaomi Corp. (authors: Fangjun Kuang,
# Copyright 2021-2023 Xiaomi Corp. (authors: Fangjun Kuang,
# Wei Kang,
# Mingshuang Luo,)
# Zengwei Yao)
# Mingshuang Luo,
# Zengwei Yao,
# Xiaoyu Yang,
# Yifan Yang)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
Expand All @@ -20,27 +22,23 @@
"""
Usage:
export CUDA_VISIBLE_DEVICES="0,1,2,3"
export CUDA_VISIBLE_DEVICES="0,1"
./pruned_transducer_stateless7/train.py \
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--exp-dir pruned_transducer_stateless7/exp \
--full-libri 1 \
--max-duration 300
# For mix precision training:
./pruned_transducer_stateless7/train.py \
--world-size 4 \
--num-epochs 30 \
./pruned_transducer_stateless7/finetune.py \
--world-size 2 \
--num-epochs 20 \
--start-epoch 1 \
--exp-dir pruned_transducer_stateless7/exp_giga_finetune \
--subset S \
--use-fp16 1 \
--exp-dir pruned_transducer_stateless7/exp \
--full-libri 1 \
--max-duration 550
--base-lr 0.005 \
--lr-epochs 100 \
--lr-batches 100000 \
--bpe-model icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/data/lang_bpe_500/bpe.model \
--do-finetune True \
--use-mux True \
--finetune-ckpt icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp/pretrained.pt \
--max-duration 500
"""


Expand All @@ -59,9 +57,10 @@
import torch.multiprocessing as mp
import torch.nn as nn
from decoder import Decoder
from asr_datamodule import LibriSpeechAsrDataModule
from gigaspeech import GigaSpeechAsrDataModule
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.cut import Cut, CutSet
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from model import Transducer
Expand Down Expand Up @@ -103,7 +102,21 @@ def set_batch_count(model: Union[nn.Module, DDP], batch_count: float) -> None:


def add_finetune_arguments(parser: argparse.ArgumentParser):
parser.add_argument("--do-finetune", type=str2bool, default=False)
parser.add_argument(
"--do-finetune",
type=str2bool,
default=True,
help="Whether to fine-tune.",
)
parser.add_argument(
"--use-mux",
type=str2bool,
default=False,
help="""
Whether to adapt. If true, we will mix 5% of the new data
with 95% of the original data to fine-tune.
""",
)

parser.add_argument(
"--init-modules",
Expand Down Expand Up @@ -907,7 +920,11 @@ def train_one_epoch(
# NOTE: We use reduction==sum and loss is computed over utterances
# in the batch and there is no normalization to it so far.
scaler.scale(loss).backward()
set_batch_count(model, params.batch_idx_train)
# Skip the warmup by adding a huge number to batch_count
if params.do_finetune:
set_batch_count(model, params.batch_idx_train + 100000)
else:
set_batch_count(model, params.batch_idx_train)
scheduler.step_batch(params.batch_idx_train)

scaler.step(optimizer)
Expand Down Expand Up @@ -1104,7 +1121,12 @@ def run(rank, world_size, args):
parameters_names=parameters_names,
)

scheduler = Eden(optimizer, params.lr_batches, params.lr_epochs)
scheduler = Eden(
optimizer=optimizer,
lr_batches=params.lr_batches,
lr_epochs=params.lr_epochs,
warmup_batches=0,
)

if checkpoints and "optimizer" in checkpoints:
logging.info("Loading optimizer state dict")
Expand All @@ -1129,7 +1151,15 @@ def run(rank, world_size, args):

gigaspeech = GigaSpeechAsrDataModule(args)

train_cuts = gigaspeech.train_cuts()
if params.use_mux:
librispeech = LibriSpeechAsrDataModule(args)
train_cuts = CutSet.mux(
librispeech.train_all_shuf_cuts(),
gigaspeech.train_cuts(),
weights=[0.95, 0.05],
)
else:
train_cuts = gigaspeech.train_cuts()

def remove_short_and_long_utt(c: Cut):
# Keep only utterances with duration between 1 second and 20 seconds
Expand All @@ -1141,9 +1171,9 @@ def remove_short_and_long_utt(c: Cut):
# an utterance duration distribution for your dataset to select
# the threshold
if c.duration < 1.0 or c.duration > 20.0:
logging.warning(
f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
)
# logging.warning(
# f"Exclude cut with ID {c.id} from training. Duration: {c.duration}"
# )
return False

# In pruned RNN-T, we require that T >= S
Expand Down

0 comments on commit 562bda9

Please sign in to comment.