Skip to content
This repository was archived by the owner on Oct 13, 2022. It is now read-only.

Commit 3502531

Browse files
authored
Merge pull request #238 from pzelasko/feature/sampler-updates
Complimentary sampler update following Lhotse's changes
2 parents fd40d8a + d7b7db8 commit 3502531

File tree

10 files changed

+229
-199
lines changed

10 files changed

+229
-199
lines changed

egs/aishell/asr/simple_v1/mmi_att_transformer_decode.py

+14-11
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,18 @@
66
# Apache 2.0
77

88
import argparse
9-
import k2
109
import logging
11-
import numpy as np
1210
import os
13-
import torch
14-
from k2 import Fsa, SymbolTable
15-
from kaldialign import edit_distance
1611
from pathlib import Path
1712
from typing import List
1813
from typing import Union
1914

15+
import k2
16+
import numpy as np
17+
import torch
18+
from k2 import Fsa, SymbolTable
19+
from kaldialign import edit_distance
20+
2021
from snowfall.common import average_checkpoint
2122
from snowfall.common import find_first_disambig_symbol
2223
from snowfall.common import get_texts
@@ -25,16 +26,20 @@
2526
from snowfall.data import AishellAsrDataModule
2627
from snowfall.decoding.graph import compile_LG
2728
from snowfall.models import AcousticModel
28-
from snowfall.models.transformer import Transformer
2929
from snowfall.models.conformer import Conformer
30+
from snowfall.models.transformer import Transformer
3031
from snowfall.training.ctc_graph import build_ctc_topo
3132
from snowfall.training.mmi_graph import create_bigram_phone_lm
3233
from snowfall.training.mmi_graph import get_phone_symbols
3334

3435

3536
def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel,
3637
device: Union[str, torch.device], LG: Fsa, symbols: SymbolTable):
37-
tot_num_cuts = len(dataloader.dataset.cuts)
38+
num_batches = None
39+
try:
40+
num_batches = len(dataloader)
41+
except TypeError:
42+
pass
3843
num_cuts = 0
3944
results = [] # a list of pair (ref_words, hyp_words)
4045
for batch_idx, batch in enumerate(dataloader):
@@ -83,10 +88,8 @@ def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel,
8388
results.append((ref_words, hyp_words))
8489

8590
if batch_idx % 10 == 0:
86-
logging.info(
87-
'batch {}, cuts processed until now is {}/{} ({:.6f}%)'.format(
88-
batch_idx, num_cuts, tot_num_cuts,
89-
float(num_cuts) / tot_num_cuts * 100))
91+
batch_str = f"{batch_idx}" if num_batches is None else f"{batch_idx}/{num_batches}"
92+
logging.info(f"batch {batch_str}, number of cuts processed until now is {num_cuts}")
9093

9194
num_cuts += len(texts)
9295

0 commit comments

Comments
 (0)