From c800c25c850a0b105e7d255db3992ce770474445 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sun, 11 Apr 2021 15:37:49 +0800 Subject: [PATCH] Add results. --- egs/librispeech/asr/simple_v1/RESULTS.md | 20 +++++ .../simple_v1/mmi_att_transformer_decode.py | 88 ++++++++++++++++++- snowfall/decoding/lm_rescore.py | 18 ++-- 3 files changed, 114 insertions(+), 12 deletions(-) diff --git a/egs/librispeech/asr/simple_v1/RESULTS.md b/egs/librispeech/asr/simple_v1/RESULTS.md index e547aab8..80533ff3 100644 --- a/egs/librispeech/asr/simple_v1/RESULTS.md +++ b/egs/librispeech/asr/simple_v1/RESULTS.md @@ -390,6 +390,26 @@ Average over last 5 epochs, When using 40 filter banks instead of 80 (also twice 2021-03-25 21:52:23,645 INFO [mmi_att_transformer_decode.py:329] [test-clean] %WER 6.93% [3645 / 52576, 529 ins, 308 del, 2808 sub ] 2021-03-25 21:53:10,674 INFO [mmi_att_transformer_decode.py:329] [test-other] %WER 18.53% [9697 / 52343, 1136 ins, 929 del, 7632 sub ] +## 2021-04-11 + +By Fangjun. + +### Average over last 5 epochs + +#### LM rescoring with whole lattice + +$ ./mmi_att_transformer_decode.py --use-lm-rescoring=1 --num-path=-1 --max-duration=10 --output-beam-size=8 +2021-04-11 10:37:58,913 INFO [common.py:356] [test-clean] %WER 5.72% [3008 / 52576, 562 ins, 164 del, 2282 sub ] +2021-04-11 10:46:03,670 INFO [common.py:356] [test-other] %WER 15.71% [8224 / 52343, 1331 ins, 562 del, 6331 sub ] + +#### LM rescoring with n-best list + +$ ./mmi_att_transformer_decode.py --use-lm-rescoring=1 --num-path=100 --max-duration=500 --output-beam-size=20 +2021-04-11 15:17:07,792 INFO [common.py:356] [test-clean] %WER 6.31% [3316 / 52576, 746 ins, 160 del, 2410 sub ] +2021-04-11 15:19:48,583 INFO [common.py:356] [test-other] %WER 16.93% [8863 / 52343, 1649 ins, 514 del, 6700 sub ] + + + ## 2021-03-08 (Han Zhu): Results of diff --git a/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py b/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py index 2d921dc8..010192e4 100755 --- a/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py +++ b/egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py @@ -21,8 +21,10 @@ from snowfall.common import write_error_stats from snowfall.common import load_checkpoint from snowfall.common import setup_logger +from snowfall.common import str2bool from snowfall.data import LibriSpeechAsrDataModule from snowfall.decoding.graph import compile_HLG +from snowfall.decoding.lm_rescore import decode_with_lm_rescoring from snowfall.models import AcousticModel from snowfall.models.transformer import Transformer from snowfall.models.conformer import Conformer @@ -32,7 +34,8 @@ def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel, - device: Union[str, torch.device], HLG: Fsa, symbols: SymbolTable): + device: Union[str, torch.device], HLG: Fsa, symbols: SymbolTable, + num_paths: int, G: k2.Fsa, use_whole_lattice: bool): tot_num_cuts = len(dataloader.dataset.cuts) num_cuts = 0 results = [] # a list of pair (ref_words, hyp_words) @@ -70,8 +73,15 @@ def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel, lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0, 7.0, 30, 10000) - # lattices = k2.intersect_dense(HLG, dense_fsa_vec, 10.0) - best_paths = k2.shortest_path(lattices, use_double_scores=True) + if G is None: + best_paths = k2.shortest_path(lattices, use_double_scores=True) + else: + best_paths = decode_with_lm_rescoring( + lattices, + G, + num_paths=num_paths, + use_whole_lattice=use_whole_lattice) + assert best_paths.shape[0] == len(texts) hyps = get_texts(best_paths, indices) assert len(hyps) == len(texts) @@ -188,6 +198,27 @@ def get_parser(): type=int, default=256, help="Number of units in transformer attention layers.") + parser.add_argument( + '--output-beam-size', + type=int, + default=8, + help='Output beam size. Used in k2.intersect_dense_pruned.'\ + 'Choose a large value (e.g., 20), for 1-best decoding '\ + 'and n-best rescoring. Choose a small value (e.g., 8) for ' \ + 'rescoring with the whole lattice') + parser.add_argument( + '--use-lm-rescoring', + type=str2bool, + default=True, + help='When enabled, it uses LM for rescoring') + parser.add_argument( + '--num-paths', + type=int, + default=-1, + help='Number of paths for rescoring using n-best list.' \ + 'If it is negative, then rescore with the whole lattice.'\ + 'CAUTION: You have to reduce max_duration in case of CUDA OOM' + ) return parser @@ -200,6 +231,13 @@ def main(): epoch = args.epoch avg = args.avg att_rate = args.att_rate + num_paths = args.num_paths + use_lm_rescoring = args.use_lm_rescoring + use_whole_lattice = False + if use_lm_rescoring and num_paths < 1: + # It doesn't make sense to use n-best list for rescoring + # when n is less than 1 + use_whole_lattice = True exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan-sa') setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug') @@ -282,14 +320,53 @@ def main(): d = torch.load(lang_dir / 'HLG.pt') HLG = k2.Fsa.from_dict(d) + if use_lm_rescoring: + if use_whole_lattice: + logging.info('Rescoring with the whole lattice') + else: + logging.info(f'Rescoring with n-best list, n is {num_paths}') + first_word_disambig_id = find_first_disambig_symbol(symbol_table) + if not os.path.exists(lang_dir / 'G_4_gram.pt'): + logging.debug('Loading G_4_gram.fst.txt') + with open(lang_dir / 'G_4_gram.fst.txt') as f: + G = k2.Fsa.from_openfst(f.read(), acceptor=False) + # G.aux_labels is not needed in later computations, so + # remove it here. + del G.aux_labels + # CAUTION(fangjun): The following line is crucial. + # Arcs entering the back-off state have label equal to #0. + # We have to change it to 0 here. + G.labels[G.labels >= first_word_disambig_id] = 0 + G = k2.create_fsa_vec([G]).to(device) + G = k2.arc_sort(G) + torch.save(G.as_dict(), lang_dir / 'G_4_gram.pt') + else: + logging.debug('Loading pre-compiled G_4_gram.pt') + d = torch.load(lang_dir / 'G_4_gram.pt') + G = k2.Fsa.from_dict(d).to(device) + + if use_whole_lattice: + # Add epsilon self-loops to G as we will compose + # it with the whole lattice later + G = k2.add_epsilon_self_loops(G) + G = k2.arc_sort(G) + G = G.to(device) + else: + logging.debug('Decoding without LM rescoring') + G = None + logging.debug("convert HLG to device") HLG = HLG.to(device) HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0) HLG.requires_grad_(False) + if not hasattr(HLG, 'lm_scores'): + HLG.lm_scores = HLG.scores.clone() + # load dataset librispeech = LibriSpeechAsrDataModule(args) test_sets = ['test-clean', 'test-other'] + # test_sets = ['test-other'] for test_set, test_dl in zip(test_sets, librispeech.test_dataloaders()): logging.info(f'* DECODING: {test_set}') @@ -297,7 +374,10 @@ def main(): model=model, device=device, HLG=HLG, - symbols=symbol_table) + symbols=symbol_table, + num_paths=num_paths, + G=G, + use_whole_lattice=use_whole_lattice) recog_path = exp_dir / f'recogs-{test_set}.txt' store_transcripts(path=recog_path, texts=results) diff --git a/snowfall/decoding/lm_rescore.py b/snowfall/decoding/lm_rescore.py index 6ac37fc0..53fa51b5 100644 --- a/snowfall/decoding/lm_rescore.py +++ b/snowfall/decoding/lm_rescore.py @@ -255,9 +255,8 @@ def rescore_with_whole_lattice(lats: k2.Fsa, @torch.no_grad() -def decode_with_lm_rescoring(lats: k2.Fsa, - G: k2.Fsa, - num_paths: Optional[int] = None) -> k2.Fsa: +def decode_with_lm_rescoring(lats: k2.Fsa, G: k2.Fsa, num_paths: int, + use_whole_lattice: bool) -> k2.Fsa: '''Decode using n-best list with LM rescoring. `lats` is a decoding lattice, which has 3 axes. This function first @@ -275,13 +274,16 @@ def decode_with_lm_rescoring(lats: k2.Fsa, An FsaVec representing the language model (LM). Note that it is an FsaVec, but it contains only one Fsa. num_paths: - Optional, If not None, it is the size `n` in `n-best` list. - Otherwise, use the whole lattice for rescoring. + It is the size `n` in `n-best` list. + Used only if use_whole_lattice is False. + use_whole_lattice: + True to use whole lattice for rescoring. False to use n-best list + for rescoring. Returns: An FsaVec representing the best decoding path for each sequence in the lattice. ''' - if num_paths is not None: - return rescore_with_n_best_list(lats, G, num_paths) - else: + if use_whole_lattice: return rescore_with_whole_lattice(lats, G) + else: + return rescore_with_n_best_list(lats, G, num_paths)