Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

Add results for CTC conformer training with SpecAug #143

Merged
merged 2 commits into from
Mar 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 109 additions & 77 deletions egs/librispeech/asr/simple_v1/ctc_att_transformer_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,26 +7,28 @@
import argparse
import k2
import logging
import numpy as np
import os
import torch
from k2 import Fsa, SymbolTable
from kaldialign import edit_distance
from pathlib import Path
from typing import List
from typing import Union

from lhotse import CutSet
from lhotse.dataset import K2SpeechRecognitionDataset
from lhotse.dataset import SingleCutSampler
from snowfall.common import average_checkpoint
from lhotse import CutSet, load_manifest
from lhotse.dataset import K2SpeechRecognitionDataset, SingleCutSampler
from snowfall.common import average_checkpoint, store_transcripts
from snowfall.common import find_first_disambig_symbol
from snowfall.common import get_phone_symbols
from snowfall.common import get_texts
from snowfall.common import load_checkpoint
from snowfall.common import setup_logger
from snowfall.decoding.graph import compile_HLG
from snowfall.models import AcousticModel
from snowfall.models.transformer import Transformer
from snowfall.models.conformer import Conformer
from snowfall.training.ctc_graph import build_ctc_topo
from snowfall.training.mmi_graph import get_phone_symbols


def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel,
Expand Down Expand Up @@ -92,16 +94,22 @@ def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel,

def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
'--model-type',
type=str,
default="conformer",
choices=["transformer", "conformer"],
help="Model type.")
parser.add_argument(
'--epoch',
type=int,
default=20,
default=10,
help="Decoding epoch.")
parser.add_argument(
'--max-frames',
'--max-duration',
type=int,
default=100000,
help="Maximum number of feature frames in a single batch.")
default=1000.0,
help="Maximum pooled recordings duration (seconds) in a single batch.")
parser.add_argument(
'--avg',
type=int,
Expand All @@ -113,64 +121,41 @@ def get_parser():
type=float,
default=0.0,
help="Attention loss rate.")
parser.add_argument(
'--nhead',
type=int,
default=4,
help="Number of attention heads in transformer.")
parser.add_argument(
'--attention-dim',
type=int,
default=256,
help="Number of units in transformer attention layers.")
return parser


def main():
args = get_parser().parse_args()

model_type = args.model_type
epoch = args.epoch
max_frames = args.max_frames
max_duration = args.max_duration
avg = args.avg
att_rate = args.att_rate

exp_dir = Path('exp-transformer-noam-ctc-att-musan')
exp_dir = Path('exp-' + model_type + '-noam-ctc-att-musan-sa')
setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug')

# load L, G, symbol_table
lang_dir = Path('data/lang_nosp')
symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt')
phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt')

phone_ids = get_phone_symbols(phone_symbol_table)
phone_ids_with_blank = [0] + phone_ids
ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank))

if not os.path.exists(lang_dir / 'HLG.pt'):
print("Loading L_disambig.fst.txt")
with open(lang_dir / 'L_disambig.fst.txt') as f:
L = k2.Fsa.from_openfst(f.read(), acceptor=False)
print("Loading G.fst.txt")
with open(lang_dir / 'G.fst.txt') as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
first_phone_disambig_id = find_first_disambig_symbol(phone_symbol_table)
first_word_disambig_id = find_first_disambig_symbol(symbol_table)
HLG = compile_HLG(L=L,
G=G,
H=ctc_topo,
labels_disambig_id_start=first_phone_disambig_id,
aux_labels_disambig_id_start=first_word_disambig_id)
torch.save(HLG.as_dict(), lang_dir / 'HLG.pt')
else:
print("Loading pre-compiled HLG")
d = torch.load(lang_dir / 'HLG.pt')
HLG = k2.Fsa.from_dict(d)

# load dataset
feature_dir = Path('exp/data')
print("About to get test cuts")
cuts_test = CutSet.from_json(feature_dir / 'cuts_test-clean.json.gz')

print("About to create test dataset")
test = K2SpeechRecognitionDataset(cuts_test)
sampler = SingleCutSampler(cuts_test, max_frames=max_frames)
print("About to create test dataloader")
test_dl = torch.utils.data.DataLoader(test, batch_size=None, sampler=sampler, num_workers=1)

# if not torch.cuda.is_available():
# logging.error('No GPU detected!')
# sys.exit(-1)

print("About to load model")
logging.debug("About to load model")
# Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N
# device = torch.device('cuda', 1)
device = torch.device('cuda')
Expand All @@ -180,11 +165,22 @@ def main():
else:
num_decoder_layers = 0

model = Transformer(
num_features=40,
num_classes=len(phone_ids) + 1, # +1 for the blank symbol
subsampling_factor=4,
num_decoder_layers=num_decoder_layers)
if model_type == "transformer":
model = Transformer(
num_features=80,
nhead=args.nhead,
d_model=args.attention_dim,
num_classes=len(phone_ids) + 1, # +1 for the blank symbol
subsampling_factor=4,
num_decoder_layers=num_decoder_layers)
else:
model = Conformer(
num_features=80,
nhead=args.nhead,
d_model=args.attention_dim,
num_classes=len(phone_ids) + 1, # +1 for the blank symbol
subsampling_factor=4,
num_decoder_layers=num_decoder_layers)

if avg == 1:
checkpoint = os.path.join(exp_dir, 'epoch-' + str(epoch - 1) + '.pt')
Expand All @@ -197,34 +193,70 @@ def main():
model.to(device)
model.eval()

print("convert HLG to device")
if not os.path.exists(lang_dir / 'HLG.pt'):
logging.debug("Loading L_disambig.fst.txt")
with open(lang_dir / 'L_disambig.fst.txt') as f:
L = k2.Fsa.from_openfst(f.read(), acceptor=False)
logging.debug("Loading G.fst.txt")
with open(lang_dir / 'G.fst.txt') as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
first_phone_disambig_id = find_first_disambig_symbol(phone_symbol_table)
first_word_disambig_id = find_first_disambig_symbol(symbol_table)
HLG = compile_HLG(L=L,
G=G,
H=ctc_topo,
labels_disambig_id_start=first_phone_disambig_id,
aux_labels_disambig_id_start=first_word_disambig_id)
torch.save(HLG.as_dict(), lang_dir / 'HLG.pt')
else:
logging.debug("Loading pre-compiled HLG")
d = torch.load(lang_dir / 'HLG.pt')
HLG = k2.Fsa.from_dict(d)

logging.debug("convert HLG to device")
HLG = HLG.to(device)
HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0)
HLG.requires_grad_(False)
print("About to decode")
results = decode(dataloader=test_dl,
model=model,
device=device,
HLG=HLG,
symbols=symbol_table)
s = ''
for ref, hyp in results:
s += f'ref={ref}\n'
s += f'hyp={hyp}\n'
logging.info(s)
# compute WER
dists = [edit_distance(r, h) for r, h in results]
errors = {
key: sum(dist[key] for dist in dists)
for key in ['sub', 'ins', 'del', 'total']
}
total_words = sum(len(ref) for ref, _ in results)
# Print Kaldi-like message:
# %WER 8.20 [ 4459 / 54402, 695 ins, 427 del, 3337 sub ]
logging.info(
f'%WER {errors["total"] / total_words:.2%} '
f'[{errors["total"]} / {total_words}, {errors["ins"]} ins, {errors["del"]} del, {errors["sub"]} sub ]'
)

# load dataset
feature_dir = Path('exp/data')
test_sets = ['test-clean', 'test-other']
for test_set in test_sets:
logging.info(f'* DECODING: {test_set}')

logging.debug("About to get test cuts")
cuts_test = load_manifest(feature_dir / f'cuts_{test_set}.json.gz')
logging.debug("About to create test dataset")
from lhotse.dataset.input_strategies import OnTheFlyFeatures
from lhotse import Fbank, FbankConfig
test = K2SpeechRecognitionDataset(cuts_test, input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))))
sampler = SingleCutSampler(cuts_test, max_duration=max_duration)
logging.debug("About to create test dataloader")
test_dl = torch.utils.data.DataLoader(test, batch_size=None, sampler=sampler, num_workers=1)

logging.debug("About to decode")
results = decode(dataloader=test_dl,
model=model,
device=device,
HLG=HLG,
symbols=symbol_table)

recog_path = exp_dir / f'recogs-{test_set}.txt'
store_transcripts(path=recog_path, texts=results)
logging.info(f'The transcripts are stored in {recog_path}')
# compute WER
dists = [edit_distance(r, h) for r, h in results]
errors = {
key: sum(dist[key] for dist in dists)
for key in ['sub', 'ins', 'del', 'total']
}
total_words = sum(len(ref) for ref, _ in results)
# Print Kaldi-like message:
# %WER 8.20 [ 4459 / 54402, 695 ins, 427 del, 3337 sub ]
logging.info(
f'[{test_set}] %WER {errors["total"] / total_words:.2%} '
f'[{errors["total"]} / {total_words}, {errors["ins"]} ins, {errors["del"]} del, {errors["sub"]} sub ]'
)


torch.set_num_threads(1)
Expand Down
Loading