Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

Commit

Permalink
Add results for CTC conformer training with SpecAug (#143)
Browse files Browse the repository at this point in the history
* Fix feat dim.

* Update ctc transformer training to use SpecAug.
  • Loading branch information
csukuangfj authored Mar 30, 2021
1 parent 2a49d45 commit ed3a16a
Show file tree
Hide file tree
Showing 3 changed files with 353 additions and 148 deletions.
186 changes: 109 additions & 77 deletions egs/librispeech/asr/simple_v1/ctc_att_transformer_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,28 @@
import argparse
import k2
import logging
import numpy as np
import os
import torch
from k2 import Fsa, SymbolTable
from kaldialign import edit_distance
from pathlib import Path
from typing import List
from typing import Union

from lhotse import CutSet
from lhotse.dataset import K2SpeechRecognitionDataset
from lhotse.dataset import SingleCutSampler
from snowfall.common import average_checkpoint
from lhotse import CutSet, load_manifest
from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler
from snowfall.common import average_checkpoint, store_transcripts
from snowfall.common import find_first_disambig_symbol
from snowfall.common import get_phone_symbols
from snowfall.common import get_texts
from snowfall.common import load_checkpoint
from snowfall.common import setup_logger
from snowfall.decoding.graph import compile_HLG
from snowfall.models import AcousticModel
from snowfall.models.transformer import Transformer
from snowfall.models.conformer import Conformer
from snowfall.training.ctc_graph import build_ctc_topo
from snowfall.training.mmi_graph import get_phone_symbols


def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel,
Expand Down Expand Up @@ -92,16 +94,22 @@ def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel,

def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model-type',
type=str,
default="conformer",
choices=["transformer", "conformer"],
help="Model type.")
parser.add_argument(
'--epoch',
type=int,
default=20,
default=10,
help="Decoding epoch.")
parser.add_argument(
'--max-frames',
'--max-duration',
type=int,
default=100000,
help="Maximum number of feature frames in a single batch.")
default=1000.0,
help="Maximum pooled recordings duration (seconds) in a single batch.")
parser.add_argument(
'--avg',
type=int,
Expand All @@ -113,64 +121,41 @@ def get_parser():
type=float,
default=0.0,
help="Attention loss rate.")
parser.add_argument(
'--nhead',
type=int,
default=4,
help="Number of attention heads in transformer.")
parser.add_argument(
'--attention-dim',
type=int,
default=256,
help="Number of units in transformer attention layers.")
return parser


def main():
args = get_parser().parse_args()

model_type = args.model_type
epoch = args.epoch
max_frames = args.max_frames
max_duration = args.max_duration
avg = args.avg
att_rate = args.att_rate

exp_dir = Path('exp-transformer-noam-ctc-att-musan')
exp_dir = Path('exp-' + model_type + '-noam-ctc-att-musan-sa')
setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug')

# load L, G, symbol_table
lang_dir = Path('data/lang_nosp')
symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt')
phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt')

phone_ids = get_phone_symbols(phone_symbol_table)
phone_ids_with_blank = [0] + phone_ids
ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank))

if not os.path.exists(lang_dir / 'HLG.pt'):
print("Loading L_disambig.fst.txt")
with open(lang_dir / 'L_disambig.fst.txt') as f:
L = k2.Fsa.from_openfst(f.read(), acceptor=False)
print("Loading G.fst.txt")
with open(lang_dir / 'G.fst.txt') as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
first_phone_disambig_id = find_first_disambig_symbol(phone_symbol_table)
first_word_disambig_id = find_first_disambig_symbol(symbol_table)
HLG = compile_HLG(L=L,
G=G,
H=ctc_topo,
labels_disambig_id_start=first_phone_disambig_id,
aux_labels_disambig_id_start=first_word_disambig_id)
torch.save(HLG.as_dict(), lang_dir / 'HLG.pt')
else:
print("Loading pre-compiled HLG")
d = torch.load(lang_dir / 'HLG.pt')
HLG = k2.Fsa.from_dict(d)

# load dataset
feature_dir = Path('exp/data')
print("About to get test cuts")
cuts_test = CutSet.from_json(feature_dir / 'cuts_test-clean.json.gz')

print("About to create test dataset")
test = K2SpeechRecognitionDataset(cuts_test)
sampler = SingleCutSampler(cuts_test, max_frames=max_frames)
print("About to create test dataloader")
test_dl = torch.utils.data.DataLoader(test, batch_size=None, sampler=sampler, num_workers=1)

# if not torch.cuda.is_available():
# logging.error('No GPU detected!')
# sys.exit(-1)

print("About to load model")
logging.debug("About to load model")
# Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N
# device = torch.device('cuda', 1)
device = torch.device('cuda')
Expand All @@ -180,11 +165,22 @@ def main():
else:
num_decoder_layers = 0

model = Transformer(
num_features=40,
num_classes=len(phone_ids) + 1, # +1 for the blank symbol
subsampling_factor=4,
num_decoder_layers=num_decoder_layers)
if model_type == "transformer":
model = Transformer(
num_features=80,
nhead=args.nhead,
d_model=args.attention_dim,
num_classes=len(phone_ids) + 1, # +1 for the blank symbol
subsampling_factor=4,
num_decoder_layers=num_decoder_layers)
else:
model = Conformer(
num_features=80,
nhead=args.nhead,
d_model=args.attention_dim,
num_classes=len(phone_ids) + 1, # +1 for the blank symbol
subsampling_factor=4,
num_decoder_layers=num_decoder_layers)

if avg == 1:
checkpoint = os.path.join(exp_dir, 'epoch-' + str(epoch - 1) + '.pt')
Expand All @@ -197,34 +193,70 @@ def main():
model.to(device)
model.eval()

print("convert HLG to device")
if not os.path.exists(lang_dir / 'HLG.pt'):
logging.debug("Loading L_disambig.fst.txt")
with open(lang_dir / 'L_disambig.fst.txt') as f:
L = k2.Fsa.from_openfst(f.read(), acceptor=False)
logging.debug("Loading G.fst.txt")
with open(lang_dir / 'G.fst.txt') as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
first_phone_disambig_id = find_first_disambig_symbol(phone_symbol_table)
first_word_disambig_id = find_first_disambig_symbol(symbol_table)
HLG = compile_HLG(L=L,
G=G,
H=ctc_topo,
labels_disambig_id_start=first_phone_disambig_id,
aux_labels_disambig_id_start=first_word_disambig_id)
torch.save(HLG.as_dict(), lang_dir / 'HLG.pt')
else:
logging.debug("Loading pre-compiled HLG")
d = torch.load(lang_dir / 'HLG.pt')
HLG = k2.Fsa.from_dict(d)

logging.debug("convert HLG to device")
HLG = HLG.to(device)
HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0)
HLG.requires_grad_(False)
print("About to decode")
results = decode(dataloader=test_dl,
model=model,
device=device,
HLG=HLG,
symbols=symbol_table)
s = ''
for ref, hyp in results:
s += f'ref={ref}\n'
s += f'hyp={hyp}\n'
logging.info(s)
# compute WER
dists = [edit_distance(r, h) for r, h in results]
errors = {
key: sum(dist[key] for dist in dists)
for key in ['sub', 'ins', 'del', 'total']
}
total_words = sum(len(ref) for ref, _ in results)
# Print Kaldi-like message:
# %WER 8.20 [ 4459 / 54402, 695 ins, 427 del, 3337 sub ]
logging.info(
f'%WER {errors["total"] / total_words:.2%} '
f'[{errors["total"]} / {total_words}, {errors["ins"]} ins, {errors["del"]} del, {errors["sub"]} sub ]'
)

# load dataset
feature_dir = Path('exp/data')
test_sets = ['test-clean', 'test-other']
for test_set in test_sets:
logging.info(f'* DECODING: {test_set}')

logging.debug("About to get test cuts")
cuts_test = load_manifest(feature_dir / f'cuts_{test_set}.json.gz')
logging.debug("About to create test dataset")
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse import Fbank, FbankConfig
test = K2SpeechRecognitionDataset(cuts_test, input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))))
sampler = SingleCutSampler(cuts_test, max_duration=max_duration)
logging.debug("About to create test dataloader")
test_dl = torch.utils.data.DataLoader(test, batch_size=None, sampler=sampler, num_workers=1)

logging.debug("About to decode")
results = decode(dataloader=test_dl,
model=model,
device=device,
HLG=HLG,
symbols=symbol_table)

recog_path = exp_dir / f'recogs-{test_set}.txt'
store_transcripts(path=recog_path, texts=results)
logging.info(f'The transcripts are stored in {recog_path}')
# compute WER
dists = [edit_distance(r, h) for r, h in results]
errors = {
key: sum(dist[key] for dist in dists)
for key in ['sub', 'ins', 'del', 'total']
}
total_words = sum(len(ref) for ref, _ in results)
# Print Kaldi-like message:
# %WER 8.20 [ 4459 / 54402, 695 ins, 427 del, 3337 sub ]
logging.info(
f'[{test_set}] %WER {errors["total"] / total_words:.2%} '
f'[{errors["total"]} / {total_words}, {errors["ins"]} ins, {errors["del"]} del, {errors["sub"]} sub ]'
)


torch.set_num_threads(1)
Expand Down
Loading

0 comments on commit ed3a16a

Please sign in to comment.