Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

Commit

Permalink
Add results.
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Apr 11, 2021
1 parent 84f37bd commit c800c25
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 12 deletions.
20 changes: 20 additions & 0 deletions egs/librispeech/asr/simple_v1/RESULTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,26 @@ Average over last 5 epochs, When using 40 filter banks instead of 80 (also twice
2021-03-25 21:52:23,645 INFO [mmi_att_transformer_decode.py:329] [test-clean] %WER 6.93% [3645 / 52576, 529 ins, 308 del, 2808 sub ]
2021-03-25 21:53:10,674 INFO [mmi_att_transformer_decode.py:329] [test-other] %WER 18.53% [9697 / 52343, 1136 ins, 929 del, 7632 sub ]

## 2021-04-11

By Fangjun.

### Average over last 5 epochs

#### LM rescoring with whole lattice

$ ./mmi_att_transformer_decode.py --use-lm-rescoring=1 --num-path=-1 --max-duration=10 --output-beam-size=8
2021-04-11 10:37:58,913 INFO [common.py:356] [test-clean] %WER 5.72% [3008 / 52576, 562 ins, 164 del, 2282 sub ]
2021-04-11 10:46:03,670 INFO [common.py:356] [test-other] %WER 15.71% [8224 / 52343, 1331 ins, 562 del, 6331 sub ]

#### LM rescoring with n-best list

$ ./mmi_att_transformer_decode.py --use-lm-rescoring=1 --num-path=100 --max-duration=500 --output-beam-size=20
2021-04-11 15:17:07,792 INFO [common.py:356] [test-clean] %WER 6.31% [3316 / 52576, 746 ins, 160 del, 2410 sub ]
2021-04-11 15:19:48,583 INFO [common.py:356] [test-other] %WER 16.93% [8863 / 52343, 1649 ins, 514 del, 6700 sub ]



## 2021-03-08

(Han Zhu): Results of <https://github.com/k2-fsa/snowfall/pull/119>
Expand Down
88 changes: 84 additions & 4 deletions egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,10 @@
from snowfall.common import write_error_stats
from snowfall.common import load_checkpoint
from snowfall.common import setup_logger
from snowfall.common import str2bool
from snowfall.data import LibriSpeechAsrDataModule
from snowfall.decoding.graph import compile_HLG
from snowfall.decoding.lm_rescore import decode_with_lm_rescoring
from snowfall.models import AcousticModel
from snowfall.models.transformer import Transformer
from snowfall.models.conformer import Conformer
Expand All @@ -32,7 +34,8 @@


def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel,
device: Union[str, torch.device], HLG: Fsa, symbols: SymbolTable):
device: Union[str, torch.device], HLG: Fsa, symbols: SymbolTable,
num_paths: int, G: k2.Fsa, use_whole_lattice: bool):
tot_num_cuts = len(dataloader.dataset.cuts)
num_cuts = 0
results = [] # a list of pair (ref_words, hyp_words)
Expand Down Expand Up @@ -70,8 +73,15 @@ def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel,
lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0, 7.0, 30,
10000)

# lattices = k2.intersect_dense(HLG, dense_fsa_vec, 10.0)
best_paths = k2.shortest_path(lattices, use_double_scores=True)
if G is None:
best_paths = k2.shortest_path(lattices, use_double_scores=True)
else:
best_paths = decode_with_lm_rescoring(
lattices,
G,
num_paths=num_paths,
use_whole_lattice=use_whole_lattice)

assert best_paths.shape[0] == len(texts)
hyps = get_texts(best_paths, indices)
assert len(hyps) == len(texts)
Expand Down Expand Up @@ -188,6 +198,27 @@ def get_parser():
type=int,
default=256,
help="Number of units in transformer attention layers.")
parser.add_argument(
'--output-beam-size',
type=int,
default=8,
help='Output beam size. Used in k2.intersect_dense_pruned.'\
'Choose a large value (e.g., 20), for 1-best decoding '\
'and n-best rescoring. Choose a small value (e.g., 8) for ' \
'rescoring with the whole lattice')
parser.add_argument(
'--use-lm-rescoring',
type=str2bool,
default=True,
help='When enabled, it uses LM for rescoring')
parser.add_argument(
'--num-paths',
type=int,
default=-1,
help='Number of paths for rescoring using n-best list.' \
'If it is negative, then rescore with the whole lattice.'\
'CAUTION: You have to reduce max_duration in case of CUDA OOM'
)
return parser


Expand All @@ -200,6 +231,13 @@ def main():
epoch = args.epoch
avg = args.avg
att_rate = args.att_rate
num_paths = args.num_paths
use_lm_rescoring = args.use_lm_rescoring
use_whole_lattice = False
if use_lm_rescoring and num_paths < 1:
# It doesn't make sense to use n-best list for rescoring
# when n is less than 1
use_whole_lattice = True

exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan-sa')
setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug')
Expand Down Expand Up @@ -282,22 +320,64 @@ def main():
d = torch.load(lang_dir / 'HLG.pt')
HLG = k2.Fsa.from_dict(d)

if use_lm_rescoring:
if use_whole_lattice:
logging.info('Rescoring with the whole lattice')
else:
logging.info(f'Rescoring with n-best list, n is {num_paths}')
first_word_disambig_id = find_first_disambig_symbol(symbol_table)
if not os.path.exists(lang_dir / 'G_4_gram.pt'):
logging.debug('Loading G_4_gram.fst.txt')
with open(lang_dir / 'G_4_gram.fst.txt') as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so
# remove it here.
del G.aux_labels
# CAUTION(fangjun): The following line is crucial.
# Arcs entering the back-off state have label equal to #0.
# We have to change it to 0 here.
G.labels[G.labels >= first_word_disambig_id] = 0
G = k2.create_fsa_vec([G]).to(device)
G = k2.arc_sort(G)
torch.save(G.as_dict(), lang_dir / 'G_4_gram.pt')
else:
logging.debug('Loading pre-compiled G_4_gram.pt')
d = torch.load(lang_dir / 'G_4_gram.pt')
G = k2.Fsa.from_dict(d).to(device)

if use_whole_lattice:
# Add epsilon self-loops to G as we will compose
# it with the whole lattice later
G = k2.add_epsilon_self_loops(G)
G = k2.arc_sort(G)
G = G.to(device)
else:
logging.debug('Decoding without LM rescoring')
G = None

logging.debug("convert HLG to device")
HLG = HLG.to(device)
HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0)
HLG.requires_grad_(False)

if not hasattr(HLG, 'lm_scores'):
HLG.lm_scores = HLG.scores.clone()

# load dataset
librispeech = LibriSpeechAsrDataModule(args)
test_sets = ['test-clean', 'test-other']
# test_sets = ['test-other']
for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()):
logging.info(f'* DECODING: {test_set}')

results = decode(dataloader=test_dl,
model=model,
device=device,
HLG=HLG,
symbols=symbol_table)
symbols=symbol_table,
num_paths=num_paths,
G=G,
use_whole_lattice=use_whole_lattice)

recog_path = exp_dir / f'recogs-{test_set}.txt'
store_transcripts(path=recog_path, texts=results)
Expand Down
18 changes: 10 additions & 8 deletions snowfall/decoding/lm_rescore.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,8 @@ def rescore_with_whole_lattice(lats: k2.Fsa,


@torch.no_grad()
def decode_with_lm_rescoring(lats: k2.Fsa,
G: k2.Fsa,
num_paths: Optional[int] = None) -> k2.Fsa:
def decode_with_lm_rescoring(lats: k2.Fsa, G: k2.Fsa, num_paths: int,
use_whole_lattice: bool) -> k2.Fsa:
'''Decode using n-best list with LM rescoring.
`lats` is a decoding lattice, which has 3 axes. This function first
Expand All @@ -275,13 +274,16 @@ def decode_with_lm_rescoring(lats: k2.Fsa,
An FsaVec representing the language model (LM). Note that it
is an FsaVec, but it contains only one Fsa.
num_paths:
Optional, If not None, it is the size `n` in `n-best` list.
Otherwise, use the whole lattice for rescoring.
It is the size `n` in `n-best` list.
Used only if use_whole_lattice is False.
use_whole_lattice:
True to use whole lattice for rescoring. False to use n-best list
for rescoring.
Returns:
An FsaVec representing the best decoding path for each sequence
in the lattice.
'''
if num_paths is not None:
return rescore_with_n_best_list(lats, G, num_paths)
else:
if use_whole_lattice:
return rescore_with_whole_lattice(lats, G)
else:
return rescore_with_n_best_list(lats, G, num_paths)

0 comments on commit c800c25

Please sign in to comment.