Skip to content

Commit

Permalink
Minor fixes to support DDP training.
Browse files Browse the repository at this point in the history
  • Loading branch information
csukuangfj committed Jul 31, 2021
1 parent b94d97d commit 398ed80
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 21 deletions.
12 changes: 7 additions & 5 deletions egs/librispeech/ASR/conformer_ctc/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ def get_params() -> AttributeDict:
# - whole-lattice-rescoring
# - attention-decoder
# "method": "whole-lattice-rescoring",
"method": "attention-decoder",
"method": "1best",
# num_paths is used when method is "nbest", "nbest-rescoring",
# and attention-decoder
"num_paths": 1000,
"num_paths": 100,
}
)
return params
Expand Down Expand Up @@ -192,7 +192,7 @@ def decode_one_batch(
key = f"no_rescore-{params.num_paths}"

hyps = get_texts(best_path)
hyps = [[lexicon.words[i] for i in ids] for ids in hyps]
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
return {key: hyps}

assert params.method in [
Expand Down Expand Up @@ -234,7 +234,7 @@ def decode_one_batch(
ans = dict()
for lm_scale_str, best_path in best_path_dict.items():
hyps = get_texts(best_path)
hyps = [[lexicon.words[i] for i in ids] for ids in hyps]
hyps = [[lexicon.word_table[i] for i in ids] for ids in hyps]
ans[lm_scale_str] = hyps
return ans

Expand Down Expand Up @@ -374,6 +374,8 @@ def main():
if not hasattr(HLG, "lm_scores"):
HLG.lm_scores = HLG.scores.clone()

# HLG = k2.ctc_topo(4999).to(device)

if params.method in (
"nbest-rescoring",
"whole-lattice-rescoring",
Expand All @@ -383,7 +385,7 @@ def main():
logging.info("Loading G_4_gram.fst.txt")
logging.warning("It may take 8 minutes.")
with open(params.lm_dir / "G_4_gram.fst.txt") as f:
first_word_disambig_id = lexicon.words["#0"]
first_word_disambig_id = lexicon.word_table["#0"]

G = k2.Fsa.from_openfst(f.read(), acceptor=False)
# G.aux_labels is not needed in later computations, so
Expand Down
105 changes: 90 additions & 15 deletions egs/librispeech/ASR/conformer_ctc/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,14 @@ def get_params() -> AttributeDict:
"weight_decay": 0.0,
"subsampling_factor": 4,
"start_epoch": 0,
"num_epochs": 10,
"num_epochs": 50,
"best_train_loss": float("inf"),
"best_valid_loss": float("inf"),
"best_train_epoch": -1,
"best_valid_epoch": -1,
"batch_idx_train": 0,
"log_interval": 10,
"valid_interval": 1000,
"valid_interval": 3000,
"beam_size": 10,
"reduction": "sum",
"use_double_scores": True,
Expand Down Expand Up @@ -312,16 +312,26 @@ def compute_loss(

if params.att_rate != 0.0:
with torch.set_grad_enabled(is_training):
att_loss = model.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
if hasattr(model, "module"):
att_loss = model.module.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
else:
att_loss = model.decoder_forward(
encoder_memory,
memory_mask,
token_ids=token_ids,
sos_id=graph_compiler.sos_id,
eos_id=graph_compiler.eos_id,
)
loss = (1.0 - params.att_rate) * ctc_loss + params.att_rate * att_loss
else:
loss = ctc_loss
att_loss = torch.tensor([0])

# train_frames and valid_frames are used for printing.
if is_training:
Expand All @@ -331,7 +341,7 @@ def compute_loss(

assert loss.requires_grad == is_training

return loss
return loss, ctc_loss.detach(), att_loss.detach()


def compute_validation_loss(
Expand All @@ -347,29 +357,44 @@ def compute_validation_loss(
model.eval()

tot_loss = 0.0
tot_ctc_loss = 0.0
tot_att_loss = 0.0
tot_frames = 0.0
for batch_idx, batch in enumerate(valid_dl):
loss = compute_loss(
loss, ctc_loss, att_loss = compute_loss(
params=params,
model=model,
batch=batch,
graph_compiler=graph_compiler,
is_training=False,
)
assert loss.requires_grad is False
assert ctc_loss.requires_grad is False
assert att_loss.requires_grad is False

loss_cpu = loss.detach().cpu().item()
tot_loss += loss_cpu

tot_ctc_loss += ctc_loss.detach().cpu().item()
tot_att_loss += att_loss.detach().cpu().item()

tot_frames += params.valid_frames

if world_size > 1:
s = torch.tensor([tot_loss, tot_frames], device=loss.device)
s = torch.tensor(
[tot_loss, tot_ctc_loss, tot_att_loss, tot_frames],
device=loss.device,
)
dist.all_reduce(s, op=dist.ReduceOp.SUM)
s = s.cpu().tolist()
tot_loss = s[0]
tot_frames = s[1]
tot_ctc_loss = s[1]
tot_att_loss = s[2]
tot_frames = s[3]

params.valid_loss = tot_loss / tot_frames
params.valid_ctc_loss = tot_ctc_loss / tot_frames
params.valid_att_loss = tot_att_loss / tot_frames

if params.valid_loss < params.best_valid_loss:
params.best_valid_epoch = params.cur_epoch
Expand Down Expand Up @@ -413,12 +438,15 @@ def train_one_epoch(
model.train()

tot_loss = 0.0 # sum of losses over all batches
tot_ctc_loss = 0.0
tot_att_loss = 0.0

tot_frames = 0.0 # sum of frames over all batches
for batch_idx, batch in enumerate(train_dl):
params.batch_idx_train += 1
batch_size = len(batch["supervisions"]["text"])

loss = compute_loss(
loss, ctc_loss, att_loss = compute_loss(
params=params,
model=model,
batch=batch,
Expand All @@ -434,19 +462,63 @@ def train_one_epoch(
optimizer.step()

loss_cpu = loss.detach().cpu().item()
ctc_loss_cpu = ctc_loss.detach().cpu().item()
att_loss_cpu = att_loss.detach().cpu().item()

tot_frames += params.train_frames
tot_loss += loss_cpu
tot_ctc_loss += ctc_loss_cpu
tot_att_loss += att_loss_cpu

tot_avg_loss = tot_loss / tot_frames
tot_avg_ctc_loss = tot_ctc_loss / tot_frames
tot_avg_att_loss = tot_att_loss / tot_frames

if batch_idx % params.log_interval == 0:
logging.info(
f"Epoch {params.cur_epoch}, batch {batch_idx}, "
f"batch avg ctc loss {ctc_loss_cpu/params.train_frames:.4f}, "
f"batch avg att loss {att_loss_cpu/params.train_frames:.4f}, "
f"batch avg loss {loss_cpu/params.train_frames:.4f}, "
f"total avg ctc loss: {tot_avg_ctc_loss:.4f}, "
f"total avg att loss: {tot_avg_att_loss:.4f}, "
f"total avg loss: {tot_avg_loss:.4f}, "
f"batch size: {batch_size}"
)

if tb_writer is not None:
tb_writer.add_scalar(
"train/current_ctc_loss",
ctc_loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/current_att_loss",
att_loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/current_loss",
loss_cpu / params.train_frames,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_ctc_loss",
tot_avg_ctc_loss,
params.batch_idx_train,
)

tb_writer.add_scalar(
"train/tot_avg_att_loss",
tot_avg_att_loss,
params.batch_idx_train,
)
tb_writer.add_scalar(
"train/tot_avg_loss",
tot_avg_loss,
params.batch_idx_train,
)

if batch_idx > 0 and batch_idx % params.valid_interval == 0:
compute_validation_loss(
params=params,
Expand All @@ -457,7 +529,10 @@ def train_one_epoch(
)
model.train()
logging.info(
f"Epoch {params.cur_epoch}, valid loss {params.valid_loss:.4f},"
f"Epoch {params.cur_epoch}, "
f"valid ctc loss {params.valid_ctc_loss:.4f},"
f"valid att loss {params.valid_att_loss:.4f},"
f"valid loss {params.valid_loss:.4f},"
f" best valid loss: {params.best_valid_loss:.4f} "
f"best valid epoch: {params.best_valid_epoch}"
)
Expand Down
3 changes: 2 additions & 1 deletion icefall/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,8 +659,9 @@ def rescore_with_attention_decoder(
0, path_to_seq_map_long
)

# TODO: pass the sos_token_id and eos_token_id via function arguments
nll = model.decoder_nll(
expanded_memory, expanded_memory_key_padding_mask, token_ids
expanded_memory, expanded_memory_key_padding_mask, token_ids, 1, 1
)
assert nll.ndim == 2
assert nll.shape[0] == num_word_seqs
Expand Down

0 comments on commit 398ed80

Please sign in to comment.