From 5a0b9bcb23cec0b1f363137e950fbdc10c8319ce Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 4 Aug 2021 14:53:02 +0800 Subject: [PATCH] Refactoring (#4) * Fix an error in TDNN-LSTM training. * WIP: Refactoring * Refactor transformer.py * Remove unused code. * Minor fixes. --- .gitignore | 1 + .../ASR/conformer_ctc/conformer.py | 16 +- egs/librispeech/ASR/conformer_ctc/decode.py | 63 +- .../ASR/conformer_ctc/subsampling.py | 144 +++ .../ASR/conformer_ctc/test_subsampling.py | 33 + .../ASR/conformer_ctc/test_transformer.py | 89 ++ egs/librispeech/ASR/conformer_ctc/train.py | 25 +- .../ASR/conformer_ctc/transformer.py | 859 ++++++++---------- egs/librispeech/ASR/local/compile_hlg.py | 36 +- .../ASR/local/compute_fbank_librispeech.py | 28 +- .../ASR/local/compute_fbank_musan.py | 22 +- egs/librispeech/ASR/local/download_lm.py | 52 +- egs/librispeech/ASR/local/prepare_lang.py | 10 +- egs/librispeech/ASR/local/prepare_lang_bpe.py | 16 +- egs/librispeech/ASR/local/train_bpe_model.py | 9 +- egs/librispeech/ASR/prepare.sh | 103 ++- egs/librispeech/ASR/shared | 1 + egs/librispeech/ASR/tdnn_lstm_ctc/decode.py | 6 +- egs/librispeech/ASR/tdnn_lstm_ctc/train.py | 7 +- icefall/decode.py | 39 +- icefall/lexicon.py | 30 +- .../local => icefall/shared}/parse_options.sh | 0 icefall/utils.py | 19 +- 23 files changed, 964 insertions(+), 644 deletions(-) create mode 100644 egs/librispeech/ASR/conformer_ctc/subsampling.py create mode 100755 egs/librispeech/ASR/conformer_ctc/test_subsampling.py create mode 100644 egs/librispeech/ASR/conformer_ctc/test_transformer.py create mode 120000 egs/librispeech/ASR/shared rename {egs/librispeech/ASR/local => icefall/shared}/parse_options.sh (100%) diff --git a/.gitignore b/.gitignore index 6cb9f22997..839a1c34a3 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ path.sh exp exp*/ *.pt +download/ diff --git a/egs/librispeech/ASR/conformer_ctc/conformer.py b/egs/librispeech/ASR/conformer_ctc/conformer.py index 1e82eff2fa..a00664a992 100644 --- a/egs/librispeech/ASR/conformer_ctc/conformer.py +++ b/egs/librispeech/ASR/conformer_ctc/conformer.py @@ -84,20 +84,26 @@ def __init__( # and throws an error without this change. self.after_norm = identity - def encode( + def run_encoder( self, x: Tensor, supervisions: Optional[Supervisions] = None ) -> Tuple[Tensor, Optional[Tensor]]: """ Args: - x: Tensor of dimension (batch_size, num_features, input_length). - supervisions : Supervison in lhotse format, i.e., batch['supervisions'] + x: + The model input. Its shape is [N, T, C]. + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute encoder padding mask, which is used as memory key padding + mask for the decoder. Returns: Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). Tensor: Mask tensor of dimension (batch_size, input_length) """ - x = x.permute(0, 2, 1) # (B, F, T) -> (B, T, F) - x = self.encoder_embed(x) x, pos_emb = self.encoder_pos(x) x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index d1cbc14de9..889a0a4744 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -15,6 +15,7 @@ import torch.nn as nn from conformer import Conformer +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import average_checkpoints, load_checkpoint from icefall.dataset.librispeech import LibriSpeechAsrDataModule from icefall.decode import ( @@ -62,7 +63,7 @@ def get_params() -> AttributeDict: params = AttributeDict( { "exp_dir": Path("conformer_ctc/exp"), - "lang_dir": Path("data/lang/bpe"), + "lang_dir": Path("data/lang_bpe"), "lm_dir": Path("data/lm"), "feature_dim": 80, "nhead": 8, @@ -85,7 +86,7 @@ def get_params() -> AttributeDict: # - whole-lattice-rescoring # - attention-decoder # "method": "whole-lattice-rescoring", - "method": "1best", + "method": "attention-decoder", # num_paths is used when method is "nbest", "nbest-rescoring", # and attention-decoder "num_paths": 100, @@ -100,6 +101,8 @@ def decode_one_batch( HLG: k2.Fsa, batch: dict, lexicon: Lexicon, + sos_id: int, + eos_id: int, G: Optional[k2.Fsa] = None, ) -> Dict[str, List[List[int]]]: """Decode one batch and return the result in a dict. The dict has the @@ -133,6 +136,10 @@ def decode_one_batch( for the format of the `batch`. lexicon: It contains word symbol table. + sos_id: + The token ID of the SOS. + eos_id: + The token ID of the EOS. G: An LM. It is not None when params.method is "nbest-rescoring" or "whole-lattice-rescoring". In general, the G in HLG @@ -147,15 +154,10 @@ def decode_one_batch( feature = feature.to(device) # at entry, feature is [N, T, C] - feature = feature.permute(0, 2, 1) # now feature is [N, C, T] - supervisions = batch["supervisions"] nnet_output, memory, memory_key_padding_mask = model(feature, supervisions) - # nnet_output is [N, C, T] - - nnet_output = nnet_output.permute(0, 2, 1) - # now nnet_output is [N, T, C] + # nnet_output is [N, T, C] supervision_segments = torch.stack( ( @@ -227,6 +229,8 @@ def decode_one_batch( model=model, memory=memory, memory_key_padding_mask=memory_key_padding_mask, + sos_id=sos_id, + eos_id=eos_id, ) else: assert False, f"Unsupported decoding method: {params.method}" @@ -245,6 +249,8 @@ def decode_dataset( model: nn.Module, HLG: k2.Fsa, lexicon: Lexicon, + sos_id: int, + eos_id: int, G: Optional[k2.Fsa] = None, ) -> Dict[str, List[Tuple[List[int], List[int]]]]: """Decode dataset. @@ -260,6 +266,10 @@ def decode_dataset( The decoding graph. lexicon: It contains word symbol table. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. G: An LM. It is not None when params.method is "nbest-rescoring" or "whole-lattice-rescoring". In general, the G in HLG @@ -287,6 +297,8 @@ def decode_dataset( batch=batch, lexicon=lexicon, G=G, + sos_id=sos_id, + eos_id=eos_id, ) for lm_scale, hyps in hyps_dict.items(): @@ -314,20 +326,31 @@ def save_results( test_set_name: str, results_dict: Dict[str, List[Tuple[List[int], List[int]]]], ): + if params.method == "attention-decoder": + # Set it to False since there are too many logs. + enable_log = False + else: + enable_log = True test_set_wers = dict() for key, results in results_dict.items(): recog_path = params.exp_dir / f"recogs-{test_set_name}-{key}.txt" store_transcripts(filename=recog_path, texts=results) - logging.info(f"The transcripts are stored in {recog_path}") + if enable_log: + logging.info(f"The transcripts are stored in {recog_path}") # The following prints out WERs, per-word error statistics and aligned # ref/hyp pairs. errs_filename = params.exp_dir / f"errs-{test_set_name}-{key}.txt" with open(errs_filename, "w") as f: - wer = write_error_stats(f, f"{test_set_name}-{key}", results) + wer = write_error_stats( + f, f"{test_set_name}-{key}", results, enable_log=enable_log + ) test_set_wers[key] = wer - logging.info("Wrote detailed error stats to {}".format(errs_filename)) + if enable_log: + logging.info( + "Wrote detailed error stats to {}".format(errs_filename) + ) test_set_wers = sorted(test_set_wers.items(), key=lambda x: x[1]) errs_info = params.exp_dir / f"wer-summary-{test_set_name}.txt" @@ -367,15 +390,22 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict(torch.load(f"{params.lm_dir}/HLG_bpe.pt")) + graph_compiler = BpeCtcTrainingGraphCompiler( + params.lang_dir, + device=device, + sos_token="", + eos_token="", + ) + sos_id = graph_compiler.sos_id + eos_id = graph_compiler.eos_id + + HLG = k2.Fsa.from_dict(torch.load(f"{params.lang_dir}/HLG.pt")) HLG = HLG.to(device) assert HLG.requires_grad is False if not hasattr(HLG, "lm_scores"): HLG.lm_scores = HLG.scores.clone() - # HLG = k2.ctc_topo(4999).to(device) - if params.method in ( "nbest-rescoring", "whole-lattice-rescoring", @@ -461,6 +491,8 @@ def main(): HLG=HLG, lexicon=lexicon, G=G, + sos_id=sos_id, + eos_id=eos_id, ) save_results( @@ -470,5 +502,8 @@ def main(): logging.info("Done!") +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + if __name__ == "__main__": main() diff --git a/egs/librispeech/ASR/conformer_ctc/subsampling.py b/egs/librispeech/ASR/conformer_ctc/subsampling.py new file mode 100644 index 0000000000..5c3e1222ef --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/subsampling.py @@ -0,0 +1,144 @@ +import torch +import torch.nn as nn + + +class Conv2dSubsampling(nn.Module): + """Convolutional 2D subsampling (to 1/4 length). + + Convert an input of shape [N, T, idim] to an output + with shape [N, T', odim], where + T' = ((T-1)//2 - 1)//2, which approximates T' == T//4 + + It is based on + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py # noqa + """ + + def __init__(self, idim: int, odim: int) -> None: + """ + Args: + idim: + Input dim. The input shape is [N, T, idim]. + Caution: It requires: T >=7, idim >=7 + odim: + Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim] + """ + assert idim >= 7 + super().__init__() + self.conv = nn.Sequential( + nn.Conv2d( + in_channels=1, out_channels=odim, kernel_size=3, stride=2 + ), + nn.ReLU(), + nn.Conv2d( + in_channels=odim, out_channels=odim, kernel_size=3, stride=2 + ), + nn.ReLU(), + ) + self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is [N, T, idim]. + + Returns: + Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim] + """ + # On entry, x is [N, T, idim] + x = x.unsqueeze(1) # [N, T, idim] -> [N, 1, T, idim] i.e., [N, C, H, W] + x = self.conv(x) + # Now x is of shape [N, odim, ((T-1)//2 - 1)//2, ((idim-1)//2 - 1)//2] + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + # Now x is of shape [N, ((T-1)//2 - 1))//2, odim] + return x + + +class VggSubsampling(nn.Module): + """Trying to follow the setup described in the following paper: + https://arxiv.org/pdf/1910.09799.pdf + + This paper is not 100% explicit so I am guessing to some extent, + and trying to compare with other VGG implementations. + + Convert an input of shape [N, T, idim] to an output + with shape [N, T', odim], where + T' = ((T-1)//2 - 1)//2, which approximates T' = T//4 + """ + + def __init__(self, idim: int, odim: int) -> None: + """Construct a VggSubsampling object. + + This uses 2 VGG blocks with 2 Conv2d layers each, + subsampling its input by a factor of 4 in the time dimensions. + + Args: + idim: + Input dim. The input shape is [N, T, idim]. + Caution: It requires: T >=7, idim >=7 + odim: + Output dim. The output shape is [N, ((T-1)//2 - 1)//2, odim] + """ + super().__init__() + + cur_channels = 1 + layers = [] + block_dims = [32, 64] + + # The decision to use padding=1 for the 1st convolution, then padding=0 + # for the 2nd and for the max-pooling, and ceil_mode=True, was driven by + # a back-compatibility concern so that the number of frames at the + # output would be equal to: + # (((T-1)//2)-1)//2. + # We can consider changing this by using padding=1 on the + # 2nd convolution, so the num-frames at the output would be T//4. + for block_dim in block_dims: + layers.append( + torch.nn.Conv2d( + in_channels=cur_channels, + out_channels=block_dim, + kernel_size=3, + padding=1, + stride=1, + ) + ) + layers.append(torch.nn.ReLU()) + layers.append( + torch.nn.Conv2d( + in_channels=block_dim, + out_channels=block_dim, + kernel_size=3, + padding=0, + stride=1, + ) + ) + layers.append( + torch.nn.MaxPool2d( + kernel_size=2, stride=2, padding=0, ceil_mode=True + ) + ) + cur_channels = block_dim + + self.layers = nn.Sequential(*layers) + + self.out = nn.Linear( + block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Subsample x. + + Args: + x: + Its shape is [N, T, idim]. + + Returns: + Return a tensor of shape [N, ((T-1)//2 - 1)//2, odim] + """ + x = x.unsqueeze(1) + x = self.layers(x) + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) + return x diff --git a/egs/librispeech/ASR/conformer_ctc/test_subsampling.py b/egs/librispeech/ASR/conformer_ctc/test_subsampling.py new file mode 100755 index 0000000000..937845d779 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/test_subsampling.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python3 + +from subsampling import Conv2dSubsampling +from subsampling import VggSubsampling +import torch + + +def test_conv2d_subsampling(): + N = 3 + odim = 2 + + for T in range(7, 19): + for idim in range(7, 20): + model = Conv2dSubsampling(idim=idim, odim=odim) + x = torch.empty(N, T, idim) + y = model(x) + assert y.shape[0] == N + assert y.shape[1] == ((T - 1) // 2 - 1) // 2 + assert y.shape[2] == odim + + +def test_vgg_subsampling(): + N = 3 + odim = 2 + + for T in range(7, 19): + for idim in range(7, 20): + model = VggSubsampling(idim=idim, odim=odim) + x = torch.empty(N, T, idim) + y = model(x) + assert y.shape[0] == N + assert y.shape[1] == ((T - 1) // 2 - 1) // 2 + assert y.shape[2] == odim diff --git a/egs/librispeech/ASR/conformer_ctc/test_transformer.py b/egs/librispeech/ASR/conformer_ctc/test_transformer.py new file mode 100644 index 0000000000..08e6806074 --- /dev/null +++ b/egs/librispeech/ASR/conformer_ctc/test_transformer.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python3 + +import torch +from transformer import ( + Transformer, + encoder_padding_mask, + generate_square_subsequent_mask, + decoder_padding_mask, + add_sos, + add_eos, +) + +from torch.nn.utils.rnn import pad_sequence + + +def test_encoder_padding_mask(): + supervisions = { + "sequence_idx": torch.tensor([0, 1, 2]), + "start_frame": torch.tensor([0, 0, 0]), + "num_frames": torch.tensor([18, 7, 13]), + } + + max_len = ((18 - 1) // 2 - 1) // 2 + mask = encoder_padding_mask(max_len, supervisions) + expected_mask = torch.tensor( + [ + [False, False, False], # ((18 - 1)//2 - 1)//2 = 3, + [False, True, True], # ((7 - 1)//2 - 1)//2 = 1, + [False, False, True], # ((13 - 1)//2 - 1)//2 = 2, + ] + ) + assert torch.all(torch.eq(mask, expected_mask)) + + +def test_transformer(): + num_features = 40 + num_classes = 87 + model = Transformer(num_features=num_features, num_classes=num_classes) + + N = 31 + + for T in range(7, 30): + x = torch.rand(N, T, num_features) + y, _, _ = model(x) + assert y.shape == (N, (((T - 1) // 2) - 1) // 2, num_classes) + + +def test_generate_square_subsequent_mask(): + s = 5 + mask = generate_square_subsequent_mask(s) + inf = float("inf") + expected_mask = torch.tensor( + [ + [0.0, -inf, -inf, -inf, -inf], + [0.0, 0.0, -inf, -inf, -inf], + [0.0, 0.0, 0.0, -inf, -inf], + [0.0, 0.0, 0.0, 0.0, -inf], + [0.0, 0.0, 0.0, 0.0, 0.0], + ] + ) + assert torch.all(torch.eq(mask, expected_mask)) + + +def test_decoder_padding_mask(): + x = [torch.tensor([1, 2]), torch.tensor([3]), torch.tensor([2, 5, 8])] + y = pad_sequence(x, batch_first=True, padding_value=-1) + mask = decoder_padding_mask(y, ignore_id=-1) + expected_mask = torch.tensor( + [ + [False, False, True], + [False, True, True], + [False, False, False], + ] + ) + assert torch.all(torch.eq(mask, expected_mask)) + + +def test_add_sos(): + x = [[1, 2], [3], [2, 5, 8]] + y = add_sos(x, sos_id=0) + expected_y = [[0, 1, 2], [0, 3], [0, 2, 5, 8]] + assert y == expected_y + + +def test_add_eos(): + x = [[1, 2], [3], [2, 5, 8]] + y = add_eos(x, eos_id=0) + expected_y = [[1, 2, 0], [3, 0], [2, 5, 8, 0]] + assert y == expected_y diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 40d3cf7fbb..552db81ecc 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -125,7 +125,7 @@ def get_params() -> AttributeDict: params = AttributeDict( { "exp_dir": Path("conformer_ctc/exp"), - "lang_dir": Path("data/lang/bpe"), + "lang_dir": Path("data/lang_bpe"), "feature_dim": 80, "weight_decay": 0.0, "subsampling_factor": 4, @@ -275,15 +275,13 @@ def compute_loss( device = graph_compiler.device feature = batch["inputs"] # at entry, feature is [N, T, C] - feature = feature.permute(0, 2, 1) # now feature is [N, C, T] assert feature.ndim == 3 feature = feature.to(device) supervisions = batch["supervisions"] with torch.set_grad_enabled(is_training): nnet_output, encoder_memory, memory_mask = model(feature, supervisions) - # nnet_output is [N, C, T] - nnet_output = nnet_output.permute(0, 2, 1) # [N, C, T] -> [N, T, C] + # nnet_output is [N, T, C] # NOTE: We need `encode_supervisions` to sort sequences with # different duration in decreasing order, required by @@ -536,6 +534,22 @@ def train_one_epoch( f" best valid loss: {params.best_valid_loss:.4f} " f"best valid epoch: {params.best_valid_epoch}" ) + if tb_writer is not None: + tb_writer.add_scalar( + "train/valid_ctc_loss", + params.valid_ctc_loss, + params.batch_idx_train, + ) + tb_writer.add_scalar( + "train/valid_att_loss", + params.valid_att_loss, + params.batch_idx_train, + ) + tb_writer.add_scalar( + "train/valid_loss", + params.valid_loss, + params.batch_idx_train, + ) params.train_loss = tot_loss / tot_frames @@ -675,5 +689,8 @@ def main(): run(rank=0, world_size=1, args=args) +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + if __name__ == "__main__": main() diff --git a/egs/librispeech/ASR/conformer_ctc/transformer.py b/egs/librispeech/ASR/conformer_ctc/transformer.py index 1df16e3467..a974be4e02 100644 --- a/egs/librispeech/ASR/conformer_ctc/transformer.py +++ b/egs/librispeech/ASR/conformer_ctc/transformer.py @@ -1,6 +1,4 @@ -#!/usr/bin/env python3 - -# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) +# Copyright (c) 2021 University of Chinese Academy of Sciences (author: Han Zhu) # Apache 2.0 import math @@ -8,30 +6,17 @@ import k2 import torch -from torch import Tensor, nn +import torch.nn as nn +from subsampling import Conv2dSubsampling, VggSubsampling from icefall.utils import get_texts +from torch.nn.utils.rnn import pad_sequence # Note: TorchScript requires Dict/List/etc. to be fully typed. -Supervisions = Dict[str, Tensor] +Supervisions = Dict[str, torch.Tensor] class Transformer(nn.Module): - """ - Args: - num_features (int): Number of input features - num_classes (int): Number of output classes - subsampling_factor (int): subsampling factor of encoder (the convolution layers before transformers) - d_model (int): attention dimension - nhead (int): number of head - dim_feedforward (int): feedforward dimention - num_encoder_layers (int): number of encoder layers - num_decoder_layers (int): number of decoder layers - dropout (float): dropout rate - normalize_before (bool): whether to use layer_norm before the first block. - vgg_frontend (bool): whether to use vgg frontend. - """ - def __init__( self, num_features: int, @@ -48,6 +33,36 @@ def __init__( mmi_loss: bool = True, use_feat_batchnorm: bool = False, ) -> None: + """ + Args: + num_features: + The input dimension of the model. + num_classes: + The output dimension of the model. + subsampling_factor: + Number of output frames is num_in_frames // subsampling_factor. + Currently, subsampling_factor MUST be 4. + d_model: + Attention dimension. + nhead: + Number of heads in multi-head attention. + Must satisfy d_model // nhead == 0. + dim_feedforward: + The output dimension of the feedforward layers in encoder/decoder. + num_encoder_layers: + Number of encoder layers. + num_decoder_layers: + Number of decoder layers. + dropout: + Dropout in encoder/decoder. + normalize_before: + If True, use pre-layer norm; False to use post-layer norm. + vgg_frontend: + True to use vgg style frontend for subsampling. + mmi_loss: + use_feat_batchnorm: + True to use batchnorm for the input layer. + """ super().__init__() self.use_feat_batchnorm = use_feat_batchnorm if use_feat_batchnorm: @@ -59,18 +74,23 @@ def __init__( if subsampling_factor != 4: raise NotImplementedError("Support only 'subsampling_factor=4'.") - self.encoder_embed = ( - VggSubsampling(num_features, d_model) - if vgg_frontend - else Conv2dSubsampling(num_features, d_model) - ) + # self.encoder_embed converts the input of shape [N, T, num_classes] + # to the shape [N, T//subsampling_factor, d_model]. + # That is, it does two things simultaneously: + # (1) subsampling: T -> T//subsampling_factor + # (2) embedding: num_classes -> d_model + if vgg_frontend: + self.encoder_embed = VggSubsampling(num_features, d_model) + else: + self.encoder_embed = Conv2dSubsampling(num_features, d_model) + self.encoder_pos = PositionalEncoding(d_model, dropout) encoder_layer = TransformerEncoderLayer( - d_model, - nhead, - dim_feedforward, - dropout, + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, normalize_before=normalize_before, ) @@ -80,9 +100,12 @@ def __init__( encoder_norm = None self.encoder = nn.TransformerEncoder( - encoder_layer, num_encoder_layers, encoder_norm + encoder_layer=encoder_layer, + num_layers=num_encoder_layers, + norm=encoder_norm, ) + # TODO(fangjun): remove dropout self.encoder_output_layer = nn.Sequential( nn.Dropout(p=dropout), nn.Linear(d_model, num_classes) ) @@ -97,14 +120,16 @@ def __init__( self.num_classes ) # bpe model already has sos/eos symbol - self.decoder_embed = nn.Embedding(self.decoder_num_class, d_model) + self.decoder_embed = nn.Embedding( + num_embeddings=self.decoder_num_class, embedding_dim=d_model + ) self.decoder_pos = PositionalEncoding(d_model, dropout) decoder_layer = TransformerDecoderLayer( - d_model, - nhead, - dim_feedforward, - dropout, + d_model=d_model, + nhead=nhead, + dim_feedforward=dim_feedforward, + dropout=dropout, normalize_before=normalize_before, ) @@ -114,7 +139,9 @@ def __init__( decoder_norm = None self.decoder = nn.TransformerDecoder( - decoder_layer, num_decoder_layers, decoder_norm + decoder_layer=decoder_layer, + num_layers=num_decoder_layers, + norm=decoder_norm, ) self.decoder_output_layer = torch.nn.Linear( @@ -126,128 +153,143 @@ def __init__( self.decoder_criterion = None def forward( - self, x: Tensor, supervision: Optional[Supervisions] = None - ) -> Tuple[Tensor, Tensor, Optional[Tensor]]: + self, x: torch.Tensor, supervision: Optional[Supervisions] = None + ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """ Args: - x: Tensor of dimension (batch_size, num_features, input_length). - supervision: Supervison in lhotse format, get from batch['supervisions'] + x: + The input tensor. Its shape is [N, T, C]. + supervision: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + (CAUTION: It contains length information, i.e., start and number of + frames, before subsampling) Returns: - Tensor: After log-softmax tensor of dimension (batch_size, number_of_classes, input_length). - Tensor: Before linear layer tensor of dimension (input_length, batch_size, d_model). - Optional[Tensor]: Mask tensor of dimension (batch_size, input_length) or None. - + Return a tuple containing 3 tensors: + - CTC output for ctc decoding. Its shape is [N, T, C] + - Encoder output with shape [T, N, C]. It can be used as key and + value for the decoder. + - Encoder output padding mask. It can be used as + memory_key_padding_mask for the decoder. Its shape is [N, T]. + It is None if `supervision` is None. """ if self.use_feat_batchnorm: + x = x.permute(0, 2, 1) # [N, T, C] -> [N, C, T] x = self.feat_batchnorm(x) - encoder_memory, memory_mask = self.encode(x, supervision) - x = self.encoder_output(encoder_memory) - return x, encoder_memory, memory_mask + x = x.permute(0, 2, 1) # [N, C, T] -> [N, T, C] + encoder_memory, memory_key_padding_mask = self.run_encoder( + x, supervision + ) + x = self.ctc_output(encoder_memory) + return x, encoder_memory, memory_key_padding_mask - def encode( - self, x: Tensor, supervisions: Optional[Supervisions] = None - ) -> Tuple[Tensor, Optional[Tensor]]: - """ - Args: - x: Tensor of dimension (batch_size, num_features, input_length). - supervisions : Supervison in lhotse format, i.e., batch['supervisions'] + def run_encoder( + self, x: torch.Tensor, supervisions: Optional[Supervisions] = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Run the transformer encoder. + Args: + x: + The model input. Its shape is [N, T, C]. + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + CAUTION: It contains length information, i.e., start and number of + frames, before subsampling + It is read directly from the batch, without any sorting. It is used + to compute the encoder padding mask, which is used as memory key + padding mask for the decoder. Returns: - Tensor: Predictor tensor of dimension (input_length, batch_size, d_model). - Optional[Tensor]: Mask tensor of dimension (batch_size, input_length) or None. + Return a tuple with two tensors: + - The encoder output, with shape [T, N, C] + - encoder padding mask, with shape [N, T]. + The mask is None if `supervisions` is None. + It is used as memory key padding mask in the decoder. """ - x = x.permute(0, 2, 1) # (B, F, T) -> (B, T, F) - x = self.encoder_embed(x) x = self.encoder_pos(x) - x = x.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) mask = encoder_padding_mask(x.size(0), supervisions) - mask = mask.to(x.device) if mask != None else None - x = self.encoder(x, src_key_padding_mask=mask) # (T, B, F) + mask = mask.to(x.device) if mask is not None else None + x = self.encoder(x, src_key_padding_mask=mask) # (T, N, C) return x, mask - def encoder_output(self, x: Tensor) -> Tensor: + def ctc_output(self, x: torch.Tensor) -> torch.Tensor: """ Args: - x: Tensor of dimension (input_length, batch_size, d_model). + x: + The output tensor from the transformer encoder. + Its shape is [T, N, C] Returns: - Tensor: After log-softmax tensor of dimension (batch_size, number_of_classes, input_length). + Return a tensor that can be used for CTC decoding. + Its shape is [N, T, C] """ - x = self.encoder_output_layer(x).permute( - 1, 2, 0 - ) # (T, B, F) ->(B, F, T) - x = nn.functional.log_softmax(x, dim=1) # (B, F, T) + x = self.encoder_output_layer(x) + x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + x = nn.functional.log_softmax(x, dim=-1) # (N, T, C) return x def decoder_forward( self, - x: Tensor, - encoder_mask: Tensor, - supervision: Supervisions = None, - graph_compiler: object = None, - token_ids: List[int] = None, - sos_id: Optional[int] = None, - eos_id: Optional[int] = None, - ) -> Tensor: + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, + token_ids: List[List[int]], + sos_id: int, + eos_id: int, + ) -> torch.Tensor: """ Args: - x: Tensor of dimension (input_length, batch_size, d_model). - encoder_mask: Mask tensor of dimension (batch_size, input_length) - supervision: Supervison in lhotse format, get from batch['supervisions'] - graph_compiler: use graph_compiler.L_inv (Its labels are words, while its aux_labels are phones) - , graph_compiler.words and graph_compiler.oov - sos_id: sos token id - eos_id: eos token id + memory: + It's the output of the encoder with shape [T, N, C] + memory_key_padding_mask: + The padding mask from the encoder. + token_ids: + A list-of-list IDs. Each sublist contains IDs for an utterance. + The IDs can be either phone IDs or word piece IDs. + sos_id: + sos token id + eos_id: + eos token id Returns: - Tensor: Decoder loss. + A scalar, the **sum** of label smoothing loss over utterances + in the batch without any normalization. """ - if supervision is not None and graph_compiler is not None: - batch_text = get_normal_transcripts( - supervision, graph_compiler.lexicon.words, graph_compiler.oov - ) - ys_in_pad, ys_out_pad = add_sos_eos( - batch_text, graph_compiler.L_inv, sos_id, eos_id, - ) - elif token_ids is not None: - _sos = torch.tensor([sos_id]) - _eos = torch.tensor([eos_id]) - ys_in = [ - torch.cat([_sos, torch.tensor(y)], dim=0) for y in token_ids - ] - ys_out = [ - torch.cat([torch.tensor(y), _eos], dim=0) for y in token_ids - ] - ys_in_pad = pad_list(ys_in, eos_id) - ys_out_pad = pad_list(ys_out, -1) + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) - else: - raise ValueError("Invalid input for decoder self attetion") + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1) - ys_in_pad = ys_in_pad.to(x.device) - ys_out_pad = ys_out_pad.to(x.device) + device = memory.device + ys_in_pad = ys_in_pad.to(device) + ys_out_pad = ys_out_pad.to(device) tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - x.device + device ) + # TODO: Use eos_id as ignore_id. + # tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad) - tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) + tgt = self.decoder_embed(ys_in_pad) # (N, T) -> (N, T, C) tgt = self.decoder_pos(tgt) - tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F) + tgt = tgt.permute(1, 0, 2) # (N, T, C) -> (T, N, C) pred_pad = self.decoder( tgt=tgt, - memory=x, + memory=memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=encoder_mask, - ) # (T, B, F) - pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F) - pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F) + memory_key_padding_mask=memory_key_padding_mask, + ) # (T, N, C) + pred_pad = pred_pad.permute(1, 0, 2) # (T, N, C) -> (N, T, C) + pred_pad = self.decoder_output_layer(pred_pad) # (N, T, C) decoder_loss = self.decoder_criterion(pred_pad, ys_out_pad) @@ -255,44 +297,50 @@ def decoder_forward( def decoder_nll( self, - x: Tensor, - encoder_mask: Tensor, + memory: torch.Tensor, + memory_key_padding_mask: torch.Tensor, token_ids: List[List[int]], sos_id: int, eos_id: int, - ) -> Tensor: + ) -> torch.Tensor: """ Args: - x: encoder-output, Tensor of dimension (input_length, batch_size, d_model). - encoder_mask: Mask tensor of dimension (batch_size, input_length) - token_ids: n-best list extracted from lattice before rescore - + memory: + It's the output of the encoder with shape [T, N, C] + memory_key_padding_mask: + The padding mask from the encoder. + token_ids: + A list-of-list IDs (e.g., word piece IDs). + Each sublist represents an utterance. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. Returns: - Tensor: negative log-likelihood. + A 2-D tensor of shape (len(token_ids), max_token_length) + representing the cross entropy loss (i.e., negative log-likelihood). """ - # The common part between this fuction and decoder_forward could be - # extracted as a seperated function. - if token_ids is not None: - _sos = torch.tensor([sos_id]) - _eos = torch.tensor([eos_id]) - ys_in = [ - torch.cat([_sos, torch.tensor(y)], dim=0) for y in token_ids - ] - ys_out = [ - torch.cat([torch.tensor(y), _eos], dim=0) for y in token_ids - ] - ys_in_pad = pad_list(ys_in, eos_id) - ys_out_pad = pad_list(ys_out, -1) - else: - raise ValueError("Invalid input for decoder self attetion") + # The common part between this function and decoder_forward could be + # extracted as a separate function. + + ys_in = add_sos(token_ids, sos_id=sos_id) + ys_in = [torch.tensor(y) for y in ys_in] + ys_in_pad = pad_sequence(ys_in, batch_first=True, padding_value=eos_id) - ys_in_pad = ys_in_pad.to(x.device, dtype=torch.int64) - ys_out_pad = ys_out_pad.to(x.device, dtype=torch.int64) + ys_out = add_eos(token_ids, eos_id=eos_id) + ys_out = [torch.tensor(y) for y in ys_out] + ys_out_pad = pad_sequence(ys_out, batch_first=True, padding_value=-1) + + device = memory.device + ys_in_pad = ys_in_pad.to(device, dtype=torch.int64) + ys_out_pad = ys_out_pad.to(device, dtype=torch.int64) tgt_mask = generate_square_subsequent_mask(ys_in_pad.shape[-1]).to( - x.device + device ) + # TODO: Use eos_id as ignore_id. + # tgt_key_padding_mask = decoder_padding_mask(ys_in_pad, ignore_id=eos_id) tgt_key_padding_mask = decoder_padding_mask(ys_in_pad) tgt = self.decoder_embed(ys_in_pad) # (B, T) -> (B, T, F) @@ -300,10 +348,10 @@ def decoder_nll( tgt = tgt.permute(1, 0, 2) # (B, T, F) -> (T, B, F) pred_pad = self.decoder( tgt=tgt, - memory=x, + memory=memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask, - memory_key_padding_mask=encoder_mask, + memory_key_padding_mask=memory_key_padding_mask, ) # (T, B, F) pred_pad = pred_pad.permute(1, 0, 2) # (T, B, F) -> (B, T, F) pred_pad = self.decoder_output_layer(pred_pad) # (B, T, F) @@ -322,16 +370,24 @@ def decoder_nll( class TransformerEncoderLayer(nn.Module): """ - Modified from torch.nn.TransformerEncoderLayer. Add support of normalize_before, + Modified from torch.nn.TransformerEncoderLayer. + Add support of normalize_before, i.e., use layer_norm before the first block. Args: - d_model: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - dim_feedforward: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - activation: the activation function of intermediate layer, relu or gelu (default=relu). - normalize_before: whether to use layer_norm before the first block. + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + activation: + the activation function of intermediate layer, relu or + gelu (default=relu). + normalize_before: + whether to use layer_norm before the first block. Examples:: >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) @@ -371,23 +427,24 @@ def __setstate__(self, state): def forward( self, - src: Tensor, - src_mask: Optional[Tensor] = None, - src_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: + src: torch.Tensor, + src_mask: Optional[torch.Tensor] = None, + src_key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: """ Pass the input through the encoder layer. Args: src: the sequence to the encoder layer (required). src_mask: the mask for the src sequence (optional). - src_key_padding_mask: the mask for the src keys per batch (optional). + src_key_padding_mask: the mask for the src keys per batch (optional) Shape: src: (S, N, E). src_mask: (S, S). src_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + S is the source sequence length, T is the target sequence length, + N is the batch size, E is the feature number """ residual = src if self.normalize_before: @@ -415,15 +472,22 @@ def forward( class TransformerDecoderLayer(nn.Module): """ - Modified from torch.nn.TransformerDecoderLayer. Add support of normalize_before, + Modified from torch.nn.TransformerDecoderLayer. + Add support of normalize_before, i.e., use layer_norm before the first block. Args: - d_model: the number of expected features in the input (required). - nhead: the number of heads in the multiheadattention models (required). - dim_feedforward: the dimension of the feedforward network model (default=2048). - dropout: the dropout value (default=0.1). - activation: the activation function of intermediate layer, relu or gelu (default=relu). + d_model: + the number of expected features in the input (required). + nhead: + the number of heads in the multiheadattention models (required). + dim_feedforward: + the dimension of the feedforward network model (default=2048). + dropout: + the dropout value (default=0.1). + activation: + the activation function of intermediate layer, relu or + gelu (default=relu). Examples:: >>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) @@ -467,22 +531,28 @@ def __setstate__(self, state): def forward( self, - tgt: Tensor, - memory: Tensor, - tgt_mask: Optional[Tensor] = None, - memory_mask: Optional[Tensor] = None, - tgt_key_padding_mask: Optional[Tensor] = None, - memory_key_padding_mask: Optional[Tensor] = None, - ) -> Tensor: + tgt: torch.Tensor, + memory: torch.Tensor, + tgt_mask: Optional[torch.Tensor] = None, + memory_mask: Optional[torch.Tensor] = None, + tgt_key_padding_mask: Optional[torch.Tensor] = None, + memory_key_padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: """Pass the inputs (and mask) through the decoder layer. Args: - tgt: the sequence to the decoder layer (required). - memory: the sequence from the last layer of the encoder (required). - tgt_mask: the mask for the tgt sequence (optional). - memory_mask: the mask for the memory sequence (optional). - tgt_key_padding_mask: the mask for the tgt keys per batch (optional). - memory_key_padding_mask: the mask for the memory keys per batch (optional). + tgt: + the sequence to the decoder layer (required). + memory: + the sequence from the last layer of the encoder (required). + tgt_mask: + the mask for the tgt sequence (optional). + memory_mask: + the mask for the memory sequence (optional). + tgt_key_padding_mask: + the mask for the tgt keys per batch (optional). + memory_key_padding_mask: + the mask for the memory keys per batch (optional). Shape: tgt: (T, N, E). @@ -491,7 +561,8 @@ def forward( memory_mask: (T, S). tgt_key_padding_mask: (N, T). memory_key_padding_mask: (N, S). - S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number + S is the source sequence length, T is the target sequence length, + N is the batch size, E is the feature number """ residual = tgt if self.normalize_before: @@ -542,164 +613,55 @@ def _get_activation_fn(activation: str): ) -class Conv2dSubsampling(nn.Module): - """Convolutional 2D subsampling (to 1/4 length). - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/subsampling.py - - Args: - idim: Input dimension. - odim: Output dimension. - - """ - - def __init__(self, idim: int, odim: int) -> None: - """Construct a Conv2dSubsampling object.""" - super(Conv2dSubsampling, self).__init__() - self.conv = nn.Sequential( - nn.Conv2d( - in_channels=1, out_channels=odim, kernel_size=3, stride=2 - ), - nn.ReLU(), - nn.Conv2d( - in_channels=odim, out_channels=odim, kernel_size=3, stride=2 - ), - nn.ReLU(), - ) - self.out = nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim) - - def forward(self, x: Tensor) -> Tensor: - """Subsample x. - - Args: - x: Input tensor of dimension (batch_size, input_length, num_features). (#batch, time, idim). - - Returns: - torch.Tensor: Subsampled tensor of dimension (batch_size, input_length, d_model). - where time' = time // 4. - - """ - x = x.unsqueeze(1) # (b, c, t, f) - x = self.conv(x) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) - return x +class PositionalEncoding(nn.Module): + """This class implements the positional encoding + proposed in the following paper: + - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf -class VggSubsampling(nn.Module): - """Trying to follow the setup described here https://arxiv.org/pdf/1910.09799.pdf - This paper is not 100% explicit so I am guessing to some extent, - and trying to compare with other VGG implementations. + PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) + PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) - Args: - idim: Input dimension. - odim: Output dimension. + Note:: + 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) + = exp(-1* 2i / d_model * log(100000)) + = exp(2i * -(log(10000) / d_model)) """ - def __init__(self, idim: int, odim: int) -> None: - """Construct a VggSubsampling object. This uses 2 VGG blocks with 2 - Conv2d layers each, subsampling its input by a factor of 4 in the - time dimensions. - - Args: - idim: Number of features at input, e.g. 40 or 80 for MFCC - (will be treated as the image height). - odim: Output dimension (number of features), e.g. 256 + def __init__(self, d_model: int, dropout: float = 0.1) -> None: """ - super(VggSubsampling, self).__init__() - - cur_channels = 1 - layers = [] - block_dims = [32, 64] - - # The decision to use padding=1 for the 1st convolution, then padding=0 - # for the 2nd and for the max-pooling, and ceil_mode=True, was driven by - # a back-compatibility concern so that the number of frames at the - # output would be equal to: - # (((T-1)//2)-1)//2. - # We can consider changing this by using padding=1 on the 2nd convolution, - # so the num-frames at the output would be T//4. - for block_dim in block_dims: - layers.append( - torch.nn.Conv2d( - in_channels=cur_channels, - out_channels=block_dim, - kernel_size=3, - padding=1, - stride=1, - ) - ) - layers.append(torch.nn.ReLU()) - layers.append( - torch.nn.Conv2d( - in_channels=block_dim, - out_channels=block_dim, - kernel_size=3, - padding=0, - stride=1, - ) - ) - layers.append( - torch.nn.MaxPool2d( - kernel_size=2, stride=2, padding=0, ceil_mode=True - ) - ) - cur_channels = block_dim - - self.layers = nn.Sequential(*layers) - - self.out = nn.Linear( - block_dims[-1] * (((idim - 1) // 2 - 1) // 2), odim - ) - - def forward(self, x: Tensor) -> Tensor: - """Subsample x. - Args: - x: Input tensor of dimension (batch_size, input_length, num_features). (#batch, time, idim). - - Returns: - torch.Tensor: Subsampled tensor of dimension (batch_size, input_length', d_model). - where input_length' == (((input_length - 1) // 2) - 1) // 2 - + d_model: + Embedding dimension. + dropout: + Dropout probability to be applied to the output of this module. """ - x = x.unsqueeze(1) # (b, c, t, f) - x = self.layers(x) - b, c, t, f = x.size() - x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f)) - return x - - -class PositionalEncoding(nn.Module): - """ - Positional encoding. - - Args: - d_model: Embedding dimension. - dropout: Dropout rate. - max_len: Maximum input length. - - """ - - def __init__( - self, d_model: int, dropout: float = 0.1, max_len: int = 5000 - ) -> None: - """Construct an PositionalEncoding object.""" - super(PositionalEncoding, self).__init__() + super().__init__() self.d_model = d_model self.xscale = math.sqrt(self.d_model) self.dropout = nn.Dropout(p=dropout) self.pe = None - self.extend_pe(torch.tensor(0.0).expand(1, max_len)) - def extend_pe(self, x: Tensor) -> None: - """Reset the positional encodings.""" + def extend_pe(self, x: torch.Tensor) -> None: + """Extend the time t in the positional encoding if required. + + The shape of `self.pe` is [1, T1, d_model]. The shape of the input x + is [N, T, d_model]. If T > T1, then we change the shape of self.pe + to [N, T, d_model]. Otherwise, nothing is done. + + Args: + x: + It is a tensor of shape [N, T, C]. + Returns: + Return None. + """ if self.pe is not None: if self.pe.size(1) >= x.size(1): if self.pe.dtype != x.dtype or self.pe.device != x.device: self.pe = self.pe.to(dtype=x.dtype, device=x.device) return - pe = torch.zeros(x.size(1), self.d_model) + pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) div_term = torch.exp( torch.arange(0, self.d_model, 2, dtype=torch.float32) @@ -708,34 +670,44 @@ def extend_pe(self, x: Tensor) -> None: pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0) + # Now pe is of shape [1, T, d_model], where T is x.size(1) self.pe = pe.to(device=x.device, dtype=x.dtype) - def forward(self, x: Tensor) -> Tensor: + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Add positional encoding. Args: - x: Input tensor of dimention (batch_size, input_length, d_model). + x: + Its shape is [N, T, C] Returns: - torch.Tensor: Encoded tensor of dimention (batch_size, input_length, d_model). - + Return a tensor of shape [N, T, C] """ self.extend_pe(x) - x = x * self.xscale + self.pe[:, : x.size(1)] + x = x * self.xscale + self.pe[:, : x.size(1), :] return self.dropout(x) class Noam(object): """ - Implements Noam optimizer. Proposed in "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py + Implements Noam optimizer. + + Proposed in + "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf + + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa Args: - params (iterable): iterable of parameters to optimize or dicts defining parameter groups - model_size: attention dimension of the transformer model - factor: learning rate factor - warm_step: warmup steps + params: + iterable of parameters to optimize or dicts defining parameter groups + model_size: + attention dimension of the transformer model + factor: + learning rate factor + warm_step: + warmup steps """ def __init__( @@ -808,7 +780,8 @@ class LabelSmoothingLoss(nn.Module): """ Label-smoothing loss. KL-divergence between q_{smoothed ground truth prob.}(w) and p_{prob. computed by model}(w) is minimized. - Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py + Modified from + https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/label_smoothing_loss.py # noqa Args: size: the number of class @@ -837,19 +810,23 @@ def __init__( self.true_dist = None self.normalize_length = normalize_length - def forward(self, x: Tensor, target: Tensor) -> Tensor: + def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Compute loss between x and target. Args: - x: prediction of dimention (batch_size, input_length, number_of_classes). - target: target masked with self.padding_id of dimention (batch_size, input_length). + x: + prediction of dimension + (batch_size, input_length, number_of_classes). + target: + target masked with self.padding_id of + dimension (batch_size, input_length). Returns: - torch.Tensor: scalar float value + A scalar tensor containing the loss without normalization. """ assert x.size(2) == self.size - batch_size = x.size(0) + # batch_size = x.size(0) x = x.view(-1, self.size) target = target.view(-1) with torch.no_grad(): @@ -867,12 +844,23 @@ def forward(self, x: Tensor, target: Tensor) -> Tensor: def encoder_padding_mask( max_len: int, supervisions: Optional[Supervisions] = None -) -> Optional[Tensor]: - """Make mask tensor containing indices of padded part. +) -> Optional[torch.Tensor]: + """Make mask tensor containing indexes of padded part. + + TODO:: + This function **assumes** that the model uses + a subsampling factor of 4. We should remove that + assumption later. Args: - max_len: maximum length of input features - supervisions : Supervison in lhotse format, i.e., batch['supervisions'] + max_len: + Maximum length of input features. + CAUTION: It is the length after subsampling. + supervisions: + Supervision in lhotse format. + See https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 # noqa + (CAUTION: It contains length information, i.e., start and number of + frames, before subsampling) Returns: Tensor: Mask tensor of dimension (batch_size, input_length), True denote the masked indices. @@ -912,59 +900,44 @@ def encoder_padding_mask( return mask -def decoder_padding_mask(ys_pad: Tensor, ignore_id: int = -1) -> Tensor: - """Generate a length mask for input. The masked position are filled with bool(True), - Unmasked positions are filled with bool(False). +def decoder_padding_mask( + ys_pad: torch.Tensor, ignore_id: int = -1 +) -> torch.Tensor: + """Generate a length mask for input. + + The masked position are filled with True, + Unmasked positions are filled with False. Args: - ys_pad: padded tensor of dimension (batch_size, input_length). - ignore_id: the ignored number (the padding number) in ys_pad + ys_pad: + padded tensor of dimension (batch_size, input_length). + ignore_id: + the ignored number (the padding number) in ys_pad Returns: - Tensor: a mask tensor of dimension (batch_size, input_length). + Tensor: + a bool tensor of the same shape as the input tensor. """ ys_mask = ys_pad == ignore_id return ys_mask -def get_normal_transcripts( - supervision: Supervisions, words: k2.SymbolTable, oov: str = "" -) -> List[List[int]]: - """Get normal transcripts (1 input recording has 1 transcript) from lhotse cut format. - Achieved by concatenate the transcripts corresponding to the same recording. - - Args: - supervision : Supervison in lhotse format, i.e., batch['supervisions'] - words: The word symbol table. - oov: Out of vocabulary word. - - Returns: - List[List[int]]: List of concatenated transcripts, length is batch_size - """ - - texts = [ - [token if token in words else oov for token in text.split(" ")] - for text in supervision["text"] - ] - texts_ids = [[words[token] for token in text] for text in texts] - - batch_text = [ - [] for _ in range(int(supervision["sequence_idx"].max().item()) + 1) - ] - for sequence_idx, text in zip(supervision["sequence_idx"], texts_ids): - batch_text[sequence_idx] = batch_text[sequence_idx] + text - return batch_text +def generate_square_subsequent_mask(sz: int) -> torch.Tensor: + """Generate a square mask for the sequence. The masked positions are + filled with float('-inf'). Unmasked positions are filled with float(0.0). + The mask can be used for masked self-attention. + For instance, if sz is 3, it returns:: -def generate_square_subsequent_mask(sz: int) -> Tensor: - """Generate a square mask for the sequence. The masked positions are filled with float('-inf'). - Unmasked positions are filled with float(0.0). + tensor([[0., -inf, -inf], + [0., 0., -inf], + [0., 0., 0]]) Args: - sz: mask size + sz: mask size Returns: - Tensor: a square mask of dimension (sz, sz) + A square mask of dimension (sz, sz) """ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) mask = ( @@ -975,121 +948,41 @@ def generate_square_subsequent_mask(sz: int) -> Tensor: return mask -def add_sos_eos( - ys: List[List[int]], - lexicon: k2.Fsa, - sos_id: int, - eos_id: int, - ignore_id: int = -1, -) -> Tuple[Tensor, Tensor]: - """Add and labels. +def add_sos(token_ids: List[List[int]], sos_id: int) -> List[List[int]]: + """Prepend sos_id to each utterance. Args: - ys: batch of unpadded target sequences - lexicon: Its labels are words, while its aux_labels are phones. - sos_id: index of - eos_id: index of - ignore_id: index of padding - - Returns: - Tensor: Input of transformer decoder. Padded tensor of dimention (batch_size, max_length). - Tensor: Output of transformer decoder. padded tensor of dimention (batch_size, max_length). + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + sos_id: + The ID of the SOS token. + + Return: + Return a new list-of-list, where each sublist starts + with SOS ID. """ + ans = [] + for utt in token_ids: + ans.append([sos_id] + utt) + return ans - _sos = torch.tensor([sos_id]) - _eos = torch.tensor([eos_id]) - ys = get_hierarchical_targets(ys, lexicon) - ys_in = [torch.cat([_sos, y], dim=0) for y in ys] - ys_out = [torch.cat([y, _eos], dim=0) for y in ys] - return pad_list(ys_in, eos), pad_list(ys_out, ignore_id) - -def pad_list(ys: List[Tensor], pad_value: float) -> Tensor: - """Perform padding for the list of tensors. - - Args: - ys: List of tensors. len(ys) = batch_size. - pad_value: Value for padding. - - Returns: - Tensor: Padded tensor (batch_size, max_length, `*`). - - Examples: - >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] - >>> x - [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] - >>> pad_list(x, 0) - tensor([[1., 1., 1., 1.], - [1., 1., 0., 0.], - [1., 0., 0., 0.]]) - - """ - n_batch = len(ys) - max_len = max(x.size(0) for x in ys) - pad = ys[0].new_full((n_batch, max_len, *ys[0].size()[1:]), pad_value) - - for i in range(n_batch): - pad[i, : ys[i].size(0)] = ys[i] - - return pad - - -def get_hierarchical_targets( - ys: List[List[int]], lexicon: k2.Fsa -) -> List[Tensor]: - """Get hierarchical transcripts (i.e., phone level transcripts) from transcripts (i.e., word level transcripts). +def add_eos(token_ids: List[List[int]], eos_id: int) -> List[List[int]]: + """Append eos_id to each utterance. Args: - ys: Word level transcripts. - lexicon: Its labels are words, while its aux_labels are phones. - - Returns: - List[Tensor]: Phone level transcripts. - + token_ids: + A list-of-list of token IDs. Each sublist contains + token IDs (e.g., word piece IDs) of an utterance. + eos_id: + The ID of the EOS token. + + Return: + Return a new list-of-list, where each sublist ends + with EOS ID. """ - - if lexicon is None: - return ys - else: - L_inv = lexicon - - n_batch = len(ys) - device = L_inv.device - - transcripts = k2.create_fsa_vec( - [k2.linear_fsa(x, device=device) for x in ys] - ) - transcripts_with_self_loops = k2.add_epsilon_self_loops(transcripts) - - transcripts_lexicon = k2.intersect( - L_inv, transcripts_with_self_loops, treat_epsilons_specially=False - ) - # Don't call invert_() above because we want to return phone IDs, - # which is the `aux_labels` of transcripts_lexicon - transcripts_lexicon = k2.remove_epsilon(transcripts_lexicon) - transcripts_lexicon = k2.top_sort(transcripts_lexicon) - - transcripts_lexicon = k2.shortest_path( - transcripts_lexicon, use_double_scores=True - ) - - ys = get_texts(transcripts_lexicon) - ys = [torch.tensor(y) for y in ys] - - return ys - - -def test_transformer(): - t = Transformer(40, 1281) - T = 200 - f = torch.rand(31, 40, T) - g, _, _ = t(f) - assert g.shape == (31, 1281, (((T - 1) // 2) - 1) // 2) - - -def main(): - test_transformer() - - -if __name__ == "__main__": - main() + ans = [] + for utt in token_ids: + ans.append(utt + [eos_id]) + return ans diff --git a/egs/librispeech/ASR/local/compile_hlg.py b/egs/librispeech/ASR/local/compile_hlg.py index 605d72daed..b304021616 100755 --- a/egs/librispeech/ASR/local/compile_hlg.py +++ b/egs/librispeech/ASR/local/compile_hlg.py @@ -26,7 +26,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: """ Args: lang_dir: - The language directory, e.g., data/lang or data/lang/bpe. + The language directory, e.g., data/lang_phone or data/lang_bpe. Return: An FSA representing HLG. @@ -45,7 +45,7 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: logging.info("Loading G_3_gram.fst.txt") with open("data/lm/G_3_gram.fst.txt") as f: G = k2.Fsa.from_openfst(f.read(), acceptor=False) - torch.save(G.as_dict(), "G_3_gram.pt") + torch.save(G.as_dict(), "data/lm/G_3_gram.pt") first_token_disambig_id = lexicon.token_table["#0"] first_word_disambig_id = lexicon.word_table["#0"] @@ -103,30 +103,18 @@ def compile_HLG(lang_dir: str) -> k2.Fsa: return HLG -def phone_based_HLG(): - if Path("data/lm/HLG.pt").is_file(): - return - - logging.info("Compiling phone based HLG") - HLG = compile_HLG("data/lang") - - logging.info("Saving HLG.pt to data/lm") - torch.save(HLG.as_dict(), "data/lm/HLG.pt") - - -def bpe_based_HLG(): - if Path("data/lm/HLG_bpe.pt").is_file(): - return - - logging.info("Compiling BPE based HLG") - HLG = compile_HLG("data/lang/bpe") - logging.info("Saving HLG_bpe.pt to data/lm") - torch.save(HLG.as_dict(), "data/lm/HLG_bpe.pt") +def main(): + for d in ["data/lang_phone", "data/lang_bpe"]: + d = Path(d) + logging.info(f"Processing {d}") + if (d / "HLG.pt").is_file(): + logging.info(f"{d}/HLG.pt already exists - skipping") + continue -def main(): - phone_based_HLG() - bpe_based_HLG() + HLG = compile_HLG(d) + logging.info(f"Saving HLG.pt to {d}") + torch.save(HLG.as_dict(), f"{d}/HLG.pt") if __name__ == "__main__": diff --git a/egs/librispeech/ASR/local/compute_fbank_librispeech.py b/egs/librispeech/ASR/local/compute_fbank_librispeech.py index 947d9f8d9d..d81096070f 100755 --- a/egs/librispeech/ASR/local/compute_fbank_librispeech.py +++ b/egs/librispeech/ASR/local/compute_fbank_librispeech.py @@ -1,19 +1,28 @@ #!/usr/bin/env python3 """ -This file computes fbank features of the librispeech dataset. -Its looks for manifests in the directory data/manifests -and generated fbank features are saved in data/fbank. +This file computes fbank features of the LibriSpeech dataset. +Its looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. """ +import logging import os from pathlib import Path +import torch from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor +# Torch's multithreaded behavior needs to be disabled or it wastes a lot of CPU and +# slow things down. Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + def compute_fbank_librispeech(): src_dir = Path("data/manifests") @@ -40,12 +49,11 @@ def compute_fbank_librispeech(): with get_executor() as ex: # Initialize the executor only once. for partition, m in manifests.items(): if (output_dir / f"cuts_{partition}.json.gz").is_file(): - print(f"{partition} already exists - skipping.") + logging.info(f"{partition} already exists - skipping.") continue - print("Processing", partition) + logging.info(f"Processing {partition}") cut_set = CutSet.from_manifests( - recordings=m["recordings"], - supervisions=m["supervisions"], + recordings=m["recordings"], supervisions=m["supervisions"], ) if "train" in partition: cut_set = ( @@ -65,4 +73,10 @@ def compute_fbank_librispeech(): if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + compute_fbank_librispeech() diff --git a/egs/librispeech/ASR/local/compute_fbank_musan.py b/egs/librispeech/ASR/local/compute_fbank_musan.py index d63131da89..0fc515d8c2 100755 --- a/egs/librispeech/ASR/local/compute_fbank_musan.py +++ b/egs/librispeech/ASR/local/compute_fbank_musan.py @@ -2,18 +2,27 @@ """ This file computes fbank features of the musan dataset. -Its looks for manifests in the directory data/manifests -and generated fbank features are saved in data/fbank. +Its looks for manifests in the directory data/manifests. + +The generated fbank features are saved in data/fbank. """ +import logging import os from pathlib import Path +import torch from lhotse import CutSet, Fbank, FbankConfig, LilcomHdf5Writer, combine from lhotse.recipes.utils import read_manifests_if_cached from icefall.utils import get_executor +# Torch's multithreaded behavior needs to be disabled or it wastes a lot of CPU and +# slow things down. Do this outside of main() in case it needs to take effect +# even when we are not invoking the main (e.g. when spawning subprocesses). +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + def compute_fbank_musan(): src_dir = Path("data/manifests") @@ -34,10 +43,10 @@ def compute_fbank_musan(): musan_cuts_path = output_dir / "cuts_musan.json.gz" if musan_cuts_path.is_file(): - print(f"{musan_cuts_path} already exists - skipping") + logging.info(f"{musan_cuts_path} already exists - skipping") return - print("Extracting features for Musan") + logging.info("Extracting features for Musan") extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) @@ -63,4 +72,9 @@ def compute_fbank_musan(): if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) compute_fbank_musan() diff --git a/egs/librispeech/ASR/local/download_lm.py b/egs/librispeech/ASR/local/download_lm.py index 0bdc2935ba..5c9e2a6751 100755 --- a/egs/librispeech/ASR/local/download_lm.py +++ b/egs/librispeech/ASR/local/download_lm.py @@ -2,10 +2,25 @@ # Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) """ -This file downloads librispeech LM files to data/lm +This file downloads the following LibriSpeech LM files: + + - 3-gram.pruned.1e-7.arpa.gz + - 4-gram.arpa.gz + - librispeech-vocab.txt + - librispeech-lexicon.txt + +from http://www.openslr.org/resources/11 +and save them in the user provided directory. + +Files are not re-downloaded if they already exist. + +Usage: + ./local/download_lm.py --out-dir ./download/lm """ +import argparse import gzip +import logging import os import shutil from pathlib import Path @@ -14,9 +29,17 @@ from tqdm.auto import tqdm -def download_lm(): +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--out-dir", type=str, help="Output directory.") + + args = parser.parse_args() + return args + + +def main(out_dir: str): url = "http://www.openslr.org/resources/11" - target_dir = Path("data/lm") + out_dir = Path(out_dir) files_to_download = ( "3-gram.pruned.1e-7.arpa.gz", @@ -26,7 +49,7 @@ def download_lm(): ) for f in tqdm(files_to_download, desc="Downloading LibriSpeech LM files"): - filename = target_dir / f + filename = out_dir / f if filename.is_file() is False: urlretrieve_progress( f"{url}/{f}", @@ -34,17 +57,26 @@ def download_lm(): desc=f"Downloading {filename}", ) else: - print(f"{filename} already exists - skipping") + logging.info(f"{filename} already exists - skipping") if ".gz" in str(filename): - unzip_file = Path(os.path.splitext(filename)[0]) - if unzip_file.is_file() is False: + unzipped = Path(os.path.splitext(filename)[0]) + if unzipped.is_file() is False: with gzip.open(filename, "rb") as f_in: - with open(unzip_file, "wb") as f_out: + with open(unzipped, "wb") as f_out: shutil.copyfileobj(f_in, f_out) else: - print(f"{unzip_file} already exist - skipping") + logging.info(f"{unzipped} already exist - skipping") if __name__ == "__main__": - download_lm() + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + + logging.basicConfig(format=formatter, level=logging.INFO) + + args = get_args() + logging.info(f"out_dir: {args.out_dir}") + + main(out_dir=args.out_dir) diff --git a/egs/librispeech/ASR/local/prepare_lang.py b/egs/librispeech/ASR/local/prepare_lang.py index b9d13f5bb4..f7fde7796f 100755 --- a/egs/librispeech/ASR/local/prepare_lang.py +++ b/egs/librispeech/ASR/local/prepare_lang.py @@ -3,7 +3,7 @@ # Copyright (c) 2021 Xiaomi Corporation (authors: Fangjun Kuang) """ -This script takes as input a lexicon file "data/lang/lexicon.txt" +This script takes as input a lexicon file "data/lang_phone/lexicon.txt" consisting of words and tokens (i.e., phones) and does the following: 1. Add disambiguation symbols to the lexicon and generate lexicon_disambig.txt @@ -20,8 +20,6 @@ 5. Generate L_disambig.pt, in k2 format. """ import math -import re -import sys from collections import defaultdict from pathlib import Path from typing import Any, Dict, List, Tuple @@ -284,7 +282,9 @@ def lexicon_to_fst( disambig_token = token2id["#0"] disambig_word = word2id["#0"] arcs = add_self_loops( - arcs, disambig_token=disambig_token, disambig_word=disambig_word, + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, ) final_state = next_state @@ -301,7 +301,7 @@ def lexicon_to_fst( def main(): - out_dir = Path("data/lang") + out_dir = Path("data/lang_phone") lexicon_filename = out_dir / "lexicon.txt" sil_token = "SIL" sil_prob = 0.5 diff --git a/egs/librispeech/ASR/local/prepare_lang_bpe.py b/egs/librispeech/ASR/local/prepare_lang_bpe.py index 0c3e9ede54..e31220d9b2 100755 --- a/egs/librispeech/ASR/local/prepare_lang_bpe.py +++ b/egs/librispeech/ASR/local/prepare_lang_bpe.py @@ -5,10 +5,10 @@ """ This script takes as inputs the following two files: - - data/lang/bpe/bpe.model, - - data/lang/bpe/words.txt + - data/lang_bpe/bpe.model, + - data/lang_bpe/words.txt -and generates the following files in the directory data/lang/bpe: +and generates the following files in the directory data/lang_bpe: - lexicon.txt - lexicon_disambig.txt @@ -88,7 +88,9 @@ def lexicon_to_fst_no_sil( disambig_token = token2id["#0"] disambig_word = word2id["#0"] arcs = add_self_loops( - arcs, disambig_token=disambig_token, disambig_word=disambig_word, + arcs, + disambig_token=disambig_token, + disambig_word=disambig_word, ) final_state = next_state @@ -140,7 +142,7 @@ def generate_lexicon( def main(): - lang_dir = Path("data/lang/bpe") + lang_dir = Path("data/lang_bpe") model_file = lang_dir / "bpe.model" word_sym_table = k2.SymbolTable.from_file(lang_dir / "words.txt") @@ -173,7 +175,9 @@ def main(): write_lexicon(lang_dir / "lexicon_disambig.txt", lexicon_disambig) L = lexicon_to_fst_no_sil( - lexicon, token2id=token_sym_table, word2id=word_sym_table, + lexicon, + token2id=token_sym_table, + word2id=word_sym_table, ) L_disambig = lexicon_to_fst_no_sil( diff --git a/egs/librispeech/ASR/local/train_bpe_model.py b/egs/librispeech/ASR/local/train_bpe_model.py index b5c6c7541a..59746ad9a6 100755 --- a/egs/librispeech/ASR/local/train_bpe_model.py +++ b/egs/librispeech/ASR/local/train_bpe_model.py @@ -14,18 +14,17 @@ # # Please install a version >=0.1.96 +import shutil from pathlib import Path import sentencepiece as spm -import shutil - def main(): model_type = "unigram" vocab_size = 5000 - model_prefix = f"data/lang/bpe/{model_type}_{vocab_size}" - train_text = "data/lang/bpe/train.txt" + model_prefix = f"data/lang_bpe/{model_type}_{vocab_size}" + train_text = "data/lang_bpe/train.txt" character_coverage = 1.0 input_sentence_size = 100000000 @@ -53,7 +52,7 @@ def main(): sp = spm.SentencePieceProcessor(model_file=str(model_file)) vocab_size = sp.vocab_size() - shutil.copyfile(model_file, "data/lang/bpe/bpe.model") + shutil.copyfile(model_file, "data/lang_bpe/bpe.model") if __name__ == "__main__": diff --git a/egs/librispeech/ASR/prepare.sh b/egs/librispeech/ASR/prepare.sh index 406527b713..ae676b199b 100755 --- a/egs/librispeech/ASR/prepare.sh +++ b/egs/librispeech/ASR/prepare.sh @@ -6,8 +6,38 @@ nj=15 stage=-1 stop_stage=100 -. local/parse_options.sh || exit 1 - +# We assume dl_dir (download dir) contains the following +# directories and files. If not, they will be downloaded +# by this script automatically. +# +# - $dl_dir/LibriSpeech +# You can find BOOKS.TXT, test-clean, train-clean-360, etc, inside it. +# You can download them from https://www.openslr.org/12 +# +# - $dl_dir/lm +# This directory contains the following files downloaded from +# http://www.openslr.org/resources/11 +# +# - 3-gram.pruned.1e-7.arpa.gz +# - 3-gram.pruned.1e-7.arpa +# - 4-gram.arpa.gz +# - 4-gram.arpa +# - librispeech-vocab.txt +# - librispeech-lexicon.txt +# +# - $do_dir/musan +# This directory contains the following directories downloaded from +# http://www.openslr.org/17/ +# +# - music +# - noise +# - speech +dl_dir=$PWD/download + +. shared/parse_options.sh || exit 1 + + +# All generated files by this script are saved in "data" mkdir -p data log() { @@ -16,10 +46,11 @@ log() { echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" } +log "dl_dir: $dl_dir" + if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then log "stage -1: Download LM" - mkdir -p data/lm - ./local/download_lm.py + ./local/download_lm.py --out-dir=$dl_dir/lm fi if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then @@ -28,38 +59,28 @@ if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then # If you have pre-downloaded it to /path/to/LibriSpeech, # you can create a symlink # - # ln -sfv /path/to/LibriSpeech data/ - # - # The script checks that if - # - # data/LibriSpeech/test-clean/.completed exists, + # ln -sfv /path/to/LibriSpeech $dl_dir/LibriSpeech # - # it will not re-download it. - # - # The same goes for dev-clean, dev-other, test-other, train-clean-100 - # train-clean-360, and train-other-500 - - mkdir -p data/LibriSpeech - lhotse download librispeech --full data + if [ ! -d $dl_dir/LibriSpeech/train-other-500 ]; then + lhotse download librispeech --full $dl_dir + fi # If you have pre-downloaded it to /path/to/musan, # you can create a symlink # - # ln -sfv /path/to/musan data/ + # ln -sfv /path/to/musan $dl_dir/ # - # and create a file data/.musan_completed - # to avoid downloading it again - if [ ! -f data/.musan_completed ]; then - lhotse download musan data + if [ ! -d $dl_dir/musan ]; then + lhotse download musan $dl_dir fi fi if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then - log "Stage 1: Prepare librispeech manifest" - # We assume that you have downloaded the librispeech corpus - # to data/LibriSpeech + log "Stage 1: Prepare LibriSpeech manifest" + # We assume that you have downloaded the LibriSpeech corpus + # to $dl_dir/LibriSpeech mkdir -p data/manifests - lhotse prepare librispeech -j $nj data/LibriSpeech data/manifests + lhotse prepare librispeech -j $nj $dl_dir/LibriSpeech data/manifests fi if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then @@ -67,7 +88,7 @@ if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then # We assume that you have downloaded the musan corpus # to data/musan mkdir -p data/manifests - lhotse prepare musan data/musan data/manifests + lhotse prepare musan $dl_dir/musan data/manifests fi if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then @@ -84,24 +105,25 @@ fi if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then log "Stage 5: Prepare phone based lang" - # TODO: add BPE based lang - mkdir -p data/lang + mkdir -p data/lang_phone (echo '!SIL SIL'; echo ' SPN'; echo ' SPN'; ) | - cat - data/lm/librispeech-lexicon.txt | - sort | uniq > data/lang/lexicon.txt + cat - $dl_dir/lm/librispeech-lexicon.txt | + sort | uniq > data/lang_phone/lexicon.txt - if [ ! -f data/lang/L_disambig.pt ]; then + if [ ! -f data/lang_phone/L_disambig.pt ]; then ./local/prepare_lang.py fi fi if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then log "State 6: Prepare BPE based lang" - mkdir -p data/lang/bpe - cp data/lang/words.txt data/lang/bpe/ + mkdir -p data/lang_bpe + # We reuse words.txt from phone based lexicon + # so that the two can share G.pt later. + cp data/lang_phone/words.txt data/lang_bpe/ - if [ ! -f data/lang/bpe/train.txt ]; then + if [ ! -f data/lang_bpe/train.txt ]; then log "Generate data for BPE training" files=$( find "data/LibriSpeech/train-clean-100" -name "*.trans.txt" @@ -110,12 +132,12 @@ if [ $stage -le 6 ] && [ $stop_stage -ge 6 ]; then ) for f in ${files[@]}; do cat $f | cut -d " " -f 2- - done > data/lang/bpe/train.txt + done > data/lang_bpe/train.txt fi python3 ./local/train_bpe_model.py - if [ ! -f data/lang/bpe/L_disambig.pt ]; then + if [ ! -f data/lang_bpe/L_disambig.pt ]; then ./local/prepare_lang_bpe.py fi fi @@ -125,22 +147,23 @@ if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then # We assume you have install kaldilm, if not, please install # it using: pip install kaldilm + mkdir -p data/lm if [ ! -f data/lm/G_3_gram.fst.txt ]; then # It is used in building HLG python3 -m kaldilm \ - --read-symbol-table="data/lang/words.txt" \ + --read-symbol-table="data/lang_phone/words.txt" \ --disambig-symbol='#0' \ --max-order=3 \ - data/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt + $dl_dir/lm/3-gram.pruned.1e-7.arpa > data/lm/G_3_gram.fst.txt fi if [ ! -f data/lm/G_4_gram.fst.txt ]; then # It is used for LM rescoring python3 -m kaldilm \ - --read-symbol-table="data/lang/words.txt" \ + --read-symbol-table="data/lang_phone/words.txt" \ --disambig-symbol='#0' \ --max-order=4 \ - data/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt + $dl_dir/lm/4-gram.arpa > data/lm/G_4_gram.fst.txt fi fi diff --git a/egs/librispeech/ASR/shared b/egs/librispeech/ASR/shared new file mode 120000 index 0000000000..4c5e91438c --- /dev/null +++ b/egs/librispeech/ASR/shared @@ -0,0 +1 @@ +../../../icefall/shared/ \ No newline at end of file diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py index 2c45b4e317..137fa795c0 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/decode.py @@ -58,7 +58,7 @@ def get_params() -> AttributeDict: params = AttributeDict( { "exp_dir": Path("tdnn_lstm_ctc/exp/"), - "lang_dir": Path("data/lang"), + "lang_dir": Path("data/lang_phone"), "lm_dir": Path("data/lm"), "feature_dim": 80, "subsampling_factor": 3, @@ -328,7 +328,7 @@ def main(): logging.info(f"device: {device}") - HLG = k2.Fsa.from_dict(torch.load("data/lm/HLG.pt")) + HLG = k2.Fsa.from_dict(torch.load("data/lang_phone/HLG.pt")) HLG = HLG.to(device) assert HLG.requires_grad is False @@ -340,7 +340,7 @@ def main(): logging.info("Loading G_4_gram.fst.txt") logging.warning("It may take 8 minutes.") with open(params.lm_dir / "G_4_gram.fst.txt") as f: - first_word_disambig_id = lexicon.words["#0"] + first_word_disambig_id = lexicon.word_table["#0"] G = k2.Fsa.from_openfst(f.read(), acceptor=False) # G.aux_labels is not needed in later computations, so diff --git a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py index d94a2f7258..dbb9f64ecf 100755 --- a/egs/librispeech/ASR/tdnn_lstm_ctc/train.py +++ b/egs/librispeech/ASR/tdnn_lstm_ctc/train.py @@ -127,7 +127,7 @@ def get_params() -> AttributeDict: params = AttributeDict( { "exp_dir": Path("tdnn_lstm_ctc/exp"), - "lang_dir": Path("data/lang"), + "lang_dir": Path("data/lang_phone"), "lr": 1e-3, "feature_dim": 80, "weight_decay": 5e-4, @@ -501,8 +501,9 @@ def run(rank, world_size, args): ) scheduler = StepLR(optimizer, step_size=8, gamma=0.1) - optimizer.load_state_dict(checkpoints["optimizer"]) - scheduler.load_state_dict(checkpoints["scheduler"]) + if checkpoints: + optimizer.load_state_dict(checkpoints["optimizer"]) + scheduler.load_state_dict(checkpoints["scheduler"]) librispeech = LibriSpeechAsrDataModule(args) train_dl = librispeech.train_dataloaders() diff --git a/icefall/decode.py b/icefall/decode.py index ed08405fa0..0e9baf2e46 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -555,24 +555,31 @@ def rescore_with_attention_decoder( model: nn.Module, memory: torch.Tensor, memory_key_padding_mask: torch.Tensor, + sos_id: int, + eos_id: int, ) -> Dict[str, k2.Fsa]: """This function extracts n paths from the given lattice and uses an attention decoder to rescore them. The path with the highest score is used as the decoding output. - lattice: - An FsaVec. It can be the return value of :func:`get_lattice`. - num_paths: - Number of paths to extract from the given lattice for rescoring. - model: - A transformer model. See the class "Transformer" in - conformer_ctc/transformer.py for its interface. - memory: - The encoder memory of the given model. It is the output of - the last torch.nn.TransformerEncoder layer in the given model. - Its shape is `[T, N, C]`. - memory_key_padding_mask: - The padding mask for memory with shape [N, T]. + Args: + lattice: + An FsaVec. It can be the return value of :func:`get_lattice`. + num_paths: + Number of paths to extract from the given lattice for rescoring. + model: + A transformer model. See the class "Transformer" in + conformer_ctc/transformer.py for its interface. + memory: + The encoder memory of the given model. It is the output of + the last torch.nn.TransformerEncoder layer in the given model. + Its shape is `[T, N, C]`. + memory_key_padding_mask: + The padding mask for memory with shape [N, T]. + sos_id: + The token ID for SOS. + eos_id: + The token ID for EOS. Returns: A dict of FsaVec, whose key contains a string ngram_lm_scale_attention_scale and the value is the @@ -661,7 +668,11 @@ def rescore_with_attention_decoder( # TODO: pass the sos_token_id and eos_token_id via function arguments nll = model.decoder_nll( - expanded_memory, expanded_memory_key_padding_mask, token_ids, 1, 1 + memory=expanded_memory, + memory_key_padding_mask=expanded_memory_key_padding_mask, + token_ids=token_ids, + sos_id=sos_id, + eos_id=eos_id, ) assert nll.ndim == 2 assert nll.shape[0] == num_word_seqs diff --git a/icefall/lexicon.py b/icefall/lexicon.py index 3b52c70c92..89747b11b0 100644 --- a/icefall/lexicon.py +++ b/icefall/lexicon.py @@ -1,7 +1,8 @@ import logging import re +import sys from pathlib import Path -from typing import List, Tuple, Union +from typing import List, Tuple import k2 import torch @@ -31,13 +32,19 @@ def read_lexicon(filename: str) -> List[Tuple[str, List[str]]]: continue if len(a) < 2: - print(f"Found bad line {line} in lexicon file {filename}") - print("Every line is expected to contain at least 2 fields") + logging.info( + f"Found bad line {line} in lexicon file {filename}" + ) + logging.info( + "Every line is expected to contain at least 2 fields" + ) sys.exit(1) word = a[0] if word == "": - print(f"Found bad line {line} in lexicon file {filename}") - print(" should not be a valid word") + logging.info( + f"Found bad line {line} in lexicon file {filename}" + ) + logging.info(" should not be a valid word") sys.exit(1) tokens = a[1:] @@ -61,13 +68,12 @@ def write_lexicon(filename: str, lexicon: List[Tuple[str, List[str]]]) -> None: class Lexicon(object): - """Phone based lexicon. - - TODO: Add BpeLexicon for BPE models. - """ + """Phone based lexicon.""" def __init__( - self, lang_dir: Path, disambig_pattern: str = re.compile(r"^#\d+$"), + self, + lang_dir: Path, + disambig_pattern: str = re.compile(r"^#\d+$"), ): """ Args: @@ -121,7 +127,9 @@ def tokens(self) -> List[int]: class BpeLexicon(Lexicon): def __init__( - self, lang_dir: Path, disambig_pattern: str = re.compile(r"^#\d+$"), + self, + lang_dir: Path, + disambig_pattern: str = re.compile(r"^#\d+$"), ): """ Refer to the help information in Lexicon.__init__. diff --git a/egs/librispeech/ASR/local/parse_options.sh b/icefall/shared/parse_options.sh similarity index 100% rename from egs/librispeech/ASR/local/parse_options.sh rename to icefall/shared/parse_options.sh diff --git a/icefall/utils.py b/icefall/utils.py index 1f2cf95f34..3d48badfef 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -225,7 +225,10 @@ def store_transcripts( def write_error_stats( - f: TextIO, test_set_name: str, results: List[Tuple[str, str]] + f: TextIO, + test_set_name: str, + results: List[Tuple[str, str]], + enable_log: bool = True, ) -> float: """Write statistics based on predicted results and reference transcripts. @@ -255,6 +258,9 @@ def write_error_stats( results: An iterable of tuples. The first element is the reference transcript while the second element is the predicted result. + enable_log: + If True, also print detailed WER to the console. + Otherwise, it is written only to the given file. Returns: Return None. """ @@ -290,11 +296,12 @@ def write_error_stats( tot_errs = sub_errs + ins_errs + del_errs tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) - logging.info( - f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} " - f"[{tot_errs} / {ref_len}, {ins_errs} ins, " - f"{del_errs} del, {sub_errs} sub ]" - ) + if enable_log: + logging.info( + f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} " + f"[{tot_errs} / {ref_len}, {ins_errs} ins, " + f"{del_errs} del, {sub_errs} sub ]" + ) print(f"%WER = {tot_err_rate}", file=f) print(