Skip to content

Commit

Permalink
Disable gradient computation in evaluation mode.
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Jul 29, 2021
1 parent acc63a9 commit b94d97d
Showing 1 changed file with 18 additions and 16 deletions.
34 changes: 18 additions & 16 deletions egs/librispeech/ASR/conformer_ctc/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,17 @@
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from conformer import Conformer
from transformer import Noam

from lhotse.utils import fix_random_seed
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.nn.utils import clip_grad_value_
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter
from transformer import Noam

from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.checkpoint import load_checkpoint
from icefall.checkpoint import save_checkpoint as save_checkpoint_impl
from icefall.dataset.librispeech import LibriSpeechAsrDataModule
from icefall.dist import cleanup_dist, setup_dist
from icefall.bpe_graph_compiler import BpeCtcTrainingGraphCompiler
from icefall.lexicon import Lexicon
from icefall.utils import (
AttributeDict,
Expand Down Expand Up @@ -194,7 +190,10 @@ def load_checkpoint_if_available(

filename = params.exp_dir / f"epoch-{params.start_epoch-1}.pt"
saved_params = load_checkpoint(
filename, model=model, optimizer=optimizer, scheduler=scheduler,
filename,
model=model,
optimizer=optimizer,
scheduler=scheduler,
)

keys = [
Expand Down Expand Up @@ -312,13 +311,14 @@ def compute_loss(
)

if params.att_rate != 0.0:
att_loss = model.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
with torch.set_grad_enabled(is_training):
att_loss = model.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
else:
loss = ctc_loss
Expand Down Expand Up @@ -431,7 +431,6 @@ def train_one_epoch(

optimizer.zero_grad()
loss.backward()
clip_grad_value_(model.parameters(), 5.0)
optimizer.step()

loss_cpu = loss.detach().cpu().item()
Expand Down Expand Up @@ -575,7 +574,10 @@ def run(rank, world_size, args):
)

save_checkpoint(
params=params, model=model, optimizer=optimizer, rank=rank,
params=params,
model=model,
optimizer=optimizer,
rank=rank,
)

logging.info("Done!")
Expand Down

0 comments on commit b94d97d

Please sign in to comment.