Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

Commit

Permalink
Merge remote-tracking branch 'dan/master' into lm-rescoring
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Apr 11, 2021
2 parents 51e0329 + 4264636 commit 84f37bd
Show file tree
Hide file tree
Showing 19 changed files with 713 additions and 675 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

data*/
!snowfall/data
exp*/

.DS_Store
Expand Down
20 changes: 6 additions & 14 deletions egs/aishell/asr/simple_v1/mmi_att_transformer_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,12 @@
from typing import List
from typing import Union

from lhotse import CutSet
from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler
from snowfall.common import average_checkpoint
from snowfall.common import find_first_disambig_symbol
from snowfall.common import get_texts
from snowfall.common import load_checkpoint
from snowfall.common import setup_logger
from snowfall.data import AishellAsrDataModule
from snowfall.decoding.graph import compile_LG
from snowfall.models import AcousticModel
from snowfall.models.transformer import Transformer
Expand Down Expand Up @@ -199,11 +198,12 @@ def get_parser():


def main():
args = get_parser().parse_args()
parser = get_parser()
AishellAsrDataModule.add_arguments(parser)
args = parser.parse_args()

model_type = args.model_type
epoch = args.epoch
max_frames = args.max_frames
avg = args.avg
att_rate = args.att_rate

Expand Down Expand Up @@ -289,16 +289,8 @@ def main():
LG = k2.Fsa.from_dict(d)

# load dataset
# feature_dir = Path('/export/gpudisk2/data/hegc/audio_workspace/snowfall_aishell1/exp/data')
feature_dir = Path('exp/data')
logging.debug("About to get test cuts")
cuts_test = CutSet.from_json(feature_dir / 'cuts_test.json.gz')

logging.debug("About to create test dataset")
test = K2SpeechRecognitionDataset(cuts_test)
sampler = SingleCutSampler(cuts_test, max_frames=max_frames)
logging.debug("About to create test dataloader")
test_dl = torch.utils.data.DataLoader(test, batch_size=None, sampler=sampler, num_workers=1)
aishell = AishellAsrDataModule(args)
test_dl = aishell.test_dataloaders()

# if not torch.cuda.is_available():
# logging.error('No GPU detected!')
Expand Down
104 changes: 8 additions & 96 deletions egs/aishell/asr/simple_v1/mmi_att_transformer_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,12 @@
from torch.utils.tensorboard import SummaryWriter
from typing import Dict, Optional, Tuple

from lhotse import CutSet
from lhotse.dataset import BucketingSampler, CutConcatenate, CutMix, K2SpeechRecognitionDataset, SingleCutSampler
from lhotse.utils import fix_random_seed
from snowfall.common import describe, str2bool
from snowfall.common import describe
from snowfall.common import load_checkpoint, save_checkpoint
from snowfall.common import save_training_info
from snowfall.common import setup_logger
from snowfall.data.aishell import AishellAsrDataModule
from snowfall.models import AcousticModel
from snowfall.models.transformer import Noam, Transformer
from snowfall.models.conformer import Conformer
Expand Down Expand Up @@ -366,11 +365,6 @@ def get_parser():
type=int,
default=0,
help="Number of start epoch.")
parser.add_argument(
'--max-frames',
type=int,
default=25000,
help="Maximum number of feature frames in a single batch.")
parser.add_argument(
'--warm-step',
type=int,
Expand Down Expand Up @@ -402,51 +396,17 @@ def get_parser():
type=int,
default=256,
help="Number of units in transformer attention layers.")
parser.add_argument(
'--bucketing_sampler',
type=str2bool,
default=False,
help='When enabled, the batches will come from buckets of '
'similar duration (saves padding frames).')
parser.add_argument(
'--num-buckets',
type=int,
default=30,
help='The number of buckets for the BucketingSampler'
'(you might want to increase it for larger datasets).')
parser.add_argument(
'--concatenate-cuts',
type=str2bool,
default=True,
help='When enabled, utterances (cuts) will be concatenated '
'to minimize the amount of padding.')
parser.add_argument(
'--duration-factor',
type=float,
default=1.0,
help='Determines the maximum duration of a concatenated cut '
'relative to the duration of the longest cut in a batch.')
parser.add_argument(
'--gap',
type=float,
default=1.0,
help='The amount of padding (in seconds) inserted between concatenated cuts. '
'This padding is filled with noise when noise augmentation is used.')
parser.add_argument(
'--full-libri',
type=str2bool,
default=False,
help='When enabled, use 960h LibriSpeech.')
return parser


def main():
args = get_parser().parse_args()
parser = get_parser()
AishellAsrDataModule.add_arguments(parser)
args = parser.parse_args()

model_type = args.model_type
start_epoch = args.start_epoch
num_epochs = args.num_epochs
max_frames = args.max_frames
accum_grad = args.accum_grad
den_scale = args.den_scale
att_rate = args.att_rate
Expand Down Expand Up @@ -485,57 +445,9 @@ def main():
P.scores = torch.zeros_like(P.scores)
P = P.to(device)

# load dataset
# feature_dir = Path('/export/gpudisk2/data/hegc/audio_workspace/snowfall_aishell1/exp/data')
feature_dir = Path('exp/data')
logging.info("About to get train cuts")
cuts_train = CutSet.from_json(feature_dir /
'cuts_train.json.gz')
logging.info("About to get dev cuts")
cuts_dev = CutSet.from_json(feature_dir / 'cuts_dev.json.gz')
logging.info("About to get Musan cuts")
cuts_musan = CutSet.from_json(feature_dir / 'cuts_musan.json.gz')

logging.info("About to create train dataset")
transforms = [CutMix(cuts=cuts_musan, prob=0.5, snr=(10, 20))]
if args.concatenate_cuts:
logging.info(f'Using cut concatenation with duration factor {args.duration_factor} and gap {args.gap}.')
# Cut concatenation should be the first transform in the list,
# so that if we e.g. mix noise in, it will fill the gaps between different utterances.
transforms = [CutConcatenate(duration_factor=args.duration_factor, gap=args.gap)] + transforms
train = K2SpeechRecognitionDataset(cuts_train, cut_transforms=transforms)
if args.bucketing_sampler:
logging.info('Using BucketingSampler.')
train_sampler = BucketingSampler(
cuts_train,
max_frames=max_frames,
shuffle=True,
num_buckets=args.num_buckets
)
else:
logging.info('Using SingleCutSampler.')
train_sampler = SingleCutSampler(
cuts_train,
max_frames=max_frames,
shuffle=True,
)
logging.info("About to create train dataloader")
train_dl = torch.utils.data.DataLoader(
train,
sampler=train_sampler,
batch_size=None,
num_workers=4
)
logging.info("About to create dev dataset")
validate = K2SpeechRecognitionDataset(cuts_dev)
valid_sampler = SingleCutSampler(cuts_dev, max_frames=max_frames)
logging.info("About to create dev dataloader")
valid_dl = torch.utils.data.DataLoader(
validate,
sampler=valid_sampler,
batch_size=None,
num_workers=1
)
aishell = AishellAsrDataModule(args)
train_dl = aishell.train_dataloaders()
valid_dl = aishell.valid_dataloaders()

if not torch.cuda.is_available():
logging.error('No GPU detected!')
Expand Down
93 changes: 13 additions & 80 deletions egs/librispeech/asr/simple_v1/ctc_att_transformer_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,69 +9,35 @@
import argparse
import k2
import logging
import math
import numpy as np
import os
import sys
import torch
from datetime import datetime
from pathlib import Path
from torch import nn
from torch.nn.utils import clip_grad_value_
from torch.utils.tensorboard import SummaryWriter
from typing import Dict, Optional, Tuple
from typing import Dict, Optional

from lhotse import CutSet, Fbank, load_manifest
from lhotse.dataset import BucketingSampler, CutConcatenate, CutMix, K2SpeechRecognitionDataset, SingleCutSampler, \
SpecAugment
from lhotse.dataset.cut_transforms.perturb_speed import PerturbSpeed
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse.utils import fix_random_seed
from lhotse.utils import fix_random_seed, nullcontext
from snowfall.common import describe, str2bool
from snowfall.common import load_checkpoint, save_checkpoint
from snowfall.common import save_training_info
from snowfall.common import setup_logger
from snowfall.models import AcousticModel
from snowfall.models.conformer import Conformer
from snowfall.models.transformer import Noam, Transformer
from snowfall.objectives import CTCLoss, encode_supervisions
from snowfall.training.diagnostics import measure_gradient_norms, optim_step_and_measure_param_change
from snowfall.training.ctc_graph import CtcTrainingGraphCompiler
from snowfall.training.mmi_graph import get_phone_symbols


def get_tot_objf_and_num_frames(tot_scores: torch.Tensor,
frames_per_seq: torch.Tensor
) -> Tuple[float, int, int]:
''' Figures out the total score(log-prob) over all successful supervision segments
(i.e. those for which the total score wasn't -infinity), and the corresponding
number of frames of neural net output
Args:
tot_scores: a Torch tensor of shape (num_segments,) containing total scores
from forward-backward
frames_per_seq: a Torch tensor of shape (num_segments,) containing the number of
frames for each segment
Returns:
Returns a tuple of 3 scalar tensors: (tot_score, ok_frames, all_frames)
where ok_frames is the frames for successful (finite) segments, and
all_frames is the frames for all segments (finite or not).
'''
mask = torch.ne(tot_scores, -math.inf)
# finite_indexes is a tensor containing successful segment indexes, e.g.
# [ 0 1 3 4 5 ]
finite_indexes = torch.nonzero(mask).squeeze(1)
if False:
bad_indexes = torch.nonzero(~mask).squeeze(1)
if bad_indexes.shape[0] > 0:
print("Bad indexes: ", bad_indexes, ", bad lengths: ",
frames_per_seq[bad_indexes], " vs. max length ",
torch.max(frames_per_seq), ", avg ",
(torch.sum(frames_per_seq) / frames_per_seq.numel()))
# print("finite_indexes = ", finite_indexes, ", tot_scores = ", tot_scores)
ok_frames = frames_per_seq[finite_indexes].sum()
all_frames = frames_per_seq.sum()
return (tot_scores[finite_indexes].sum(), ok_frames, all_frames)


def get_objf(batch: Dict,
model: AcousticModel,
device: torch.device,
Expand All @@ -84,56 +50,23 @@ def get_objf(batch: Dict,
global_batch_idx_train: Optional[int] = None,
optimizer: Optional[torch.optim.Optimizer] = None):
feature = batch['inputs']
supervisions = batch['supervisions']
supervision_segments = torch.stack(
(supervisions['sequence_idx'],
(((supervisions['start_frame'] - 1) // 2 - 1) // 2),
(((supervisions['num_frames'] - 1) // 2 - 1) // 2)), 1).to(torch.int32)
supervision_segments = torch.clamp(supervision_segments, min=0)
indices = torch.argsort(supervision_segments[:, 2], descending=True)
supervision_segments = supervision_segments[indices]

texts = supervisions['text']
texts = [texts[idx] for idx in indices]
assert feature.ndim == 3
# print(supervision_segments[:, 1] + supervision_segments[:, 2])

feature = feature.to(device)
# at entry, feature is [N, T, C]
feature = feature.permute(0, 2, 1) # now feature is [N, C, T]
if is_training:
supervisions = batch['supervisions']
supervision_segments, texts = encode_supervisions(supervisions)

loss_fn = CTCLoss(graph_compiler)
grad_context = nullcontext if is_training else torch.no_grad

with grad_context():
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
if att_rate != 0.0:
att_loss = model.decoder_forward(encoder_memory, memory_mask, supervisions, graph_compiler)
else:
with torch.no_grad():
nnet_output, encoder_memory, memory_mask = model(feature, supervisions)
if att_rate != 0.0:
att_loss = model.decoder_forward(encoder_memory, memory_mask, supervisions, graph_compiler)

# nnet_output is [N, C, T]
nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C]

decoding_graph = graph_compiler.compile(texts).to(device)

# nnet_output2 = nnet_output.clone()
# blank_bias = -7.0
# nnet_output2[:,:,0] += blank_bias

dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
assert decoding_graph.is_cuda()
assert decoding_graph.device == device
assert nnet_output.device == device

target_graph = k2.intersect_dense(decoding_graph, dense_fsa_vec, 10.0)

tot_scores = target_graph.get_tot_scores(
log_semiring=True,
use_double_scores=True)

(tot_score, tot_frames,
all_frames) = get_tot_objf_and_num_frames(tot_scores,
supervision_segments[:, 2])
# nnet_output is [N, C, T]
nnet_output = nnet_output.permute(0, 2, 1) # now nnet_output is [N, T, C]
tot_score, tot_frames, all_frames = loss_fn(nnet_output, texts, supervision_segments)

if is_training:
def maybe_log_gradients(tag: str):
Expand Down
Loading

0 comments on commit 84f37bd

Please sign in to comment.