From 398ed80d7a5b4cda5d1e82d879267d9ac28cf7a9 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Sat, 31 Jul 2021 15:26:57 +0800 Subject: [PATCH] Minor fixes to support DDP training. --- egs/librispeech/ASR/conformer_ctc/decode.py | 12 ++- egs/librispeech/ASR/conformer_ctc/train.py | 105 +++++++++++++++++--- icefall/decode.py | 3 +- 3 files changed, 99 insertions(+), 21 deletions(-) diff --git a/egs/librispeech/ASR/conformer_ctc/decode.py b/egs/librispeech/ASR/conformer_ctc/decode.py index 625afeda3d..d1cbc14de9 100755 --- a/egs/librispeech/ASR/conformer_ctc/decode.py +++ b/egs/librispeech/ASR/conformer_ctc/decode.py @@ -85,10 +85,10 @@ def get_params() -> AttributeDict: # - whole-lattice-rescoring # - attention-decoder # "method": "whole-lattice-rescoring", - "method": "attention-decoder", + "method": "1best", # num_paths is used when method is "nbest", "nbest-rescoring", # and attention-decoder - "num_paths": 1000, + "num_paths": 100, } ) return params @@ -192,7 +192,7 @@ def decode_one_batch( key = f"no_rescore-{params.num_paths}" hyps = get_texts(best_path) - hyps = [[lexicon.words[i] for i in ids] for ids in hyps] + hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] return {key: hyps} assert params.method in [ @@ -234,7 +234,7 @@ def decode_one_batch( ans = dict() for lm_scale_str, best_path in best_path_dict.items(): hyps = get_texts(best_path) - hyps = [[lexicon.words[i] for i in ids] for ids in hyps] + hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps] ans[lm_scale_str] = hyps return ans @@ -374,6 +374,8 @@ def main(): if not hasattr(HLG, "lm_scores"): HLG.lm_scores = HLG.scores.clone() + # HLG = k2.ctc_topo(4999).to(device) + if params.method in ( "nbest-rescoring", "whole-lattice-rescoring", @@ -383,7 +385,7 @@ def main(): logging.info("Loading G_4_gram.fst.txt") logging.warning("It may take 8 minutes.") with open(params.lm_dir / "G_4_gram.fst.txt") as f: - first_word_disambig_id = lexicon.words["#0"] + first_word_disambig_id = lexicon.word_table["#0"] G = k2.Fsa.from_openfst(f.read(), acceptor=False) # G.aux_labels is not needed in later computations, so diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 23af85f229..40d3cf7fbb 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -130,14 +130,14 @@ def get_params() -> AttributeDict: "weight_decay": 0.0, "subsampling_factor": 4, "start_epoch": 0, - "num_epochs": 10, + "num_epochs": 50, "best_train_loss": float("inf"), "best_valid_loss": float("inf"), "best_train_epoch": -1, "best_valid_epoch": -1, "batch_idx_train": 0, "log_interval": 10, - "valid_interval": 1000, + "valid_interval": 3000, "beam_size": 10, "reduction": "sum", "use_double_scores": True, @@ -312,16 +312,26 @@ def compute_loss( if params.att_rate != 0.0: with torch.set_grad_enabled(is_training): - att_loss = model.decoder_forward( - encoder_memory, - memory_mask, - token_ids=token_ids, - sos_id=graph_compiler.sos_id, - eos_id=graph_compiler.eos_id, - ) + if hasattr(model, "module"): + att_loss = model.module.decoder_forward( + encoder_memory, + memory_mask, + token_ids=token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) + else: + att_loss = model.decoder_forward( + encoder_memory, + memory_mask, + token_ids=token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss else: loss = ctc_loss + att_loss = torch.tensor([0]) # train_frames and valid_frames are used for printing. if is_training: @@ -331,7 +341,7 @@ def compute_loss( assert loss.requires_grad == is_training - return loss + return loss, ctc_loss.detach(), att_loss.detach() def compute_validation_loss( @@ -347,9 +357,11 @@ def compute_validation_loss( model.eval() tot_loss = 0.0 + tot_ctc_loss = 0.0 + tot_att_loss = 0.0 tot_frames = 0.0 for batch_idx, batch in enumerate(valid_dl): - loss = compute_loss( + loss, ctc_loss, att_loss = compute_loss( params=params, model=model, batch=batch, @@ -357,19 +369,32 @@ def compute_validation_loss( is_training=False, ) assert loss.requires_grad is False + assert ctc_loss.requires_grad is False + assert att_loss.requires_grad is False loss_cpu = loss.detach().cpu().item() tot_loss += loss_cpu + + tot_ctc_loss += ctc_loss.detach().cpu().item() + tot_att_loss += att_loss.detach().cpu().item() + tot_frames += params.valid_frames if world_size > 1: - s = torch.tensor([tot_loss, tot_frames], device=loss.device) + s = torch.tensor( + [tot_loss, tot_ctc_loss, tot_att_loss, tot_frames], + device=loss.device, + ) dist.all_reduce(s, op=dist.ReduceOp.SUM) s = s.cpu().tolist() tot_loss = s[0] - tot_frames = s[1] + tot_ctc_loss = s[1] + tot_att_loss = s[2] + tot_frames = s[3] params.valid_loss = tot_loss / tot_frames + params.valid_ctc_loss = tot_ctc_loss / tot_frames + params.valid_att_loss = tot_att_loss / tot_frames if params.valid_loss < params.best_valid_loss: params.best_valid_epoch = params.cur_epoch @@ -413,12 +438,15 @@ def train_one_epoch( model.train() tot_loss = 0.0 # sum of losses over all batches + tot_ctc_loss = 0.0 + tot_att_loss = 0.0 + tot_frames = 0.0 # sum of frames over all batches for batch_idx, batch in enumerate(train_dl): params.batch_idx_train += 1 batch_size = len(batch["supervisions"]["text"]) - loss = compute_loss( + loss, ctc_loss, att_loss = compute_loss( params=params, model=model, batch=batch, @@ -434,19 +462,63 @@ def train_one_epoch( optimizer.step() loss_cpu = loss.detach().cpu().item() + ctc_loss_cpu = ctc_loss.detach().cpu().item() + att_loss_cpu = att_loss.detach().cpu().item() tot_frames += params.train_frames tot_loss += loss_cpu + tot_ctc_loss += ctc_loss_cpu + tot_att_loss += att_loss_cpu + tot_avg_loss = tot_loss / tot_frames + tot_avg_ctc_loss = tot_ctc_loss / tot_frames + tot_avg_att_loss = tot_att_loss / tot_frames if batch_idx % params.log_interval == 0: logging.info( f"Epoch {params.cur_epoch}, batch {batch_idx}, " + f"batch avg ctc loss {ctc_loss_cpu/params.train_frames:.4f}, " + f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, " f"batch avg loss {loss_cpu/params.train_frames:.4f}, " + f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, " + f"total avg att loss: {tot_avg_att_loss:.4f}, " f"total avg loss: {tot_avg_loss:.4f}, " f"batch size: {batch_size}" ) + if tb_writer is not None: + tb_writer.add_scalar( + "train/current_ctc_loss", + ctc_loss_cpu / params.train_frames, + params.batch_idx_train, + ) + tb_writer.add_scalar( + "train/current_att_loss", + att_loss_cpu / params.train_frames, + params.batch_idx_train, + ) + tb_writer.add_scalar( + "train/current_loss", + loss_cpu / params.train_frames, + params.batch_idx_train, + ) + tb_writer.add_scalar( + "train/tot_avg_ctc_loss", + tot_avg_ctc_loss, + params.batch_idx_train, + ) + + tb_writer.add_scalar( + "train/tot_avg_att_loss", + tot_avg_att_loss, + params.batch_idx_train, + ) + tb_writer.add_scalar( + "train/tot_avg_loss", + tot_avg_loss, + params.batch_idx_train, + ) + if batch_idx > 0 and batch_idx % params.valid_interval == 0: compute_validation_loss( params=params, @@ -457,7 +529,10 @@ def train_one_epoch( ) model.train() logging.info( - f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f}," + f"Epoch {params.cur_epoch}, " + f"valid ctc loss {params.valid_ctc_loss:.4f}," + f"valid att loss {params.valid_att_loss:.4f}," + f"valid loss {params.valid_loss:.4f}," f" best valid loss: {params.best_valid_loss:.4f} " f"best valid epoch: {params.best_valid_epoch}" ) diff --git a/icefall/decode.py b/icefall/decode.py index 4801185b81..ed08405fa0 100644 --- a/icefall/decode.py +++ b/icefall/decode.py @@ -659,8 +659,9 @@ def rescore_with_attention_decoder( 0, path_to_seq_map_long ) + # TODO: pass the sos_token_id and eos_token_id via function arguments nll = model.decoder_nll( - expanded_memory, expanded_memory_key_padding_mask, token_ids + expanded_memory, expanded_memory_key_padding_mask, token_ids, 1, 1 ) assert nll.ndim == 2 assert nll.shape[0] == num_word_seqs