diff --git a/egs/librispeech/ASR/conformer_ctc/train.py b/egs/librispeech/ASR/conformer_ctc/train.py index 118e7e7176..23af85f229 100755 --- a/egs/librispeech/ASR/conformer_ctc/train.py +++ b/egs/librispeech/ASR/conformer_ctc/train.py @@ -13,21 +13,17 @@ import torch.distributed as dist import torch.multiprocessing as mp import torch.nn as nn -import torch.optim as optim from conformer import Conformer -from transformer import Noam - from lhotse.utils import fix_random_seed from torch.nn.parallel import DistributedDataParallel as DDP -from torch.nn.utils import clip_grad_value_ -from torch.optim.lr_scheduler import StepLR from torch.utils.tensorboard import SummaryWriter +from transformer import Noam +from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.checkpoint import load_checkpoint from icefall.checkpoint import save_checkpoint as save_checkpoint_impl from icefall.dataset.librispeech import LibriSpeechAsrDataModule from icefall.dist import cleanup_dist, setup_dist -from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler from icefall.lexicon import Lexicon from icefall.utils import ( AttributeDict, @@ -194,7 +190,10 @@ def load_checkpoint_if_available( filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt" saved_params = load_checkpoint( - filename, model=model, optimizer=optimizer, scheduler=scheduler, + filename, + model=model, + optimizer=optimizer, + scheduler=scheduler, ) keys = [ @@ -312,13 +311,14 @@ def compute_loss( ) if params.att_rate != 0.0: - att_loss = model.decoder_forward( - encoder_memory, - memory_mask, - token_ids=token_ids, - sos_id=graph_compiler.sos_id, - eos_id=graph_compiler.eos_id, - ) + with torch.set_grad_enabled(is_training): + att_loss = model.decoder_forward( + encoder_memory, + memory_mask, + token_ids=token_ids, + sos_id=graph_compiler.sos_id, + eos_id=graph_compiler.eos_id, + ) loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss else: loss = ctc_loss @@ -431,7 +431,6 @@ def train_one_epoch( optimizer.zero_grad() loss.backward() - clip_grad_value_(model.parameters(), 5.0) optimizer.step() loss_cpu = loss.detach().cpu().item() @@ -575,7 +574,10 @@ def run(rank, world_size, args): ) save_checkpoint( - params=params, model=model, optimizer=optimizer, rank=rank, + params=params, + model=model, + optimizer=optimizer, + rank=rank, ) logging.info("Done!")