|
6 | 6 | # Apache 2.0
|
7 | 7 |
|
8 | 8 | import argparse
|
9 |
| -import k2 |
10 | 9 | import logging
|
11 |
| -import numpy as np |
12 | 10 | import os
|
13 |
| -import torch |
14 |
| -from k2 import Fsa, SymbolTable |
15 |
| -from kaldialign import edit_distance |
16 | 11 | from pathlib import Path
|
17 | 12 | from typing import List
|
18 | 13 | from typing import Union
|
19 | 14 |
|
| 15 | +import k2 |
| 16 | +import numpy as np |
| 17 | +import torch |
| 18 | +from k2 import Fsa, SymbolTable |
| 19 | +from kaldialign import edit_distance |
| 20 | + |
20 | 21 | from snowfall.common import average_checkpoint
|
21 | 22 | from snowfall.common import find_first_disambig_symbol
|
22 | 23 | from snowfall.common import get_texts
|
|
25 | 26 | from snowfall.data import AishellAsrDataModule
|
26 | 27 | from snowfall.decoding.graph import compile_LG
|
27 | 28 | from snowfall.models import AcousticModel
|
28 |
| -from snowfall.models.transformer import Transformer |
29 | 29 | from snowfall.models.conformer import Conformer
|
| 30 | +from snowfall.models.transformer import Transformer |
30 | 31 | from snowfall.training.ctc_graph import build_ctc_topo
|
31 | 32 | from snowfall.training.mmi_graph import create_bigram_phone_lm
|
32 | 33 | from snowfall.training.mmi_graph import get_phone_symbols
|
33 | 34 |
|
34 | 35 |
|
35 | 36 | def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel,
|
36 | 37 | device: Union[str, torch.device], LG: Fsa, symbols: SymbolTable):
|
37 |
| - tot_num_cuts = len(dataloader.dataset.cuts) |
| 38 | + num_batches = None |
| 39 | + try: |
| 40 | + num_batches = len(dataloader) |
| 41 | + except TypeError: |
| 42 | + pass |
38 | 43 | num_cuts = 0
|
39 | 44 | results = [] # a list of pair (ref_words, hyp_words)
|
40 | 45 | for batch_idx, batch in enumerate(dataloader):
|
@@ -83,10 +88,8 @@ def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel,
|
83 | 88 | results.append((ref_words, hyp_words))
|
84 | 89 |
|
85 | 90 | if batch_idx % 10 == 0:
|
86 |
| - logging.info( |
87 |
| - 'batch {}, cuts processed until now is {}/{} ({:.6f}%)'.format( |
88 |
| - batch_idx, num_cuts, tot_num_cuts, |
89 |
| - float(num_cuts) / tot_num_cuts * 100)) |
| 91 | + batch_str = f"{batch_idx}" if num_batches is None else f"{batch_idx}/{num_batches}" |
| 92 | + logging.info(f"batch {batch_str}, number of cuts processed until now is {num_cuts}") |
90 | 93 |
|
91 | 94 | num_cuts += len(texts)
|
92 | 95 |
|
|
0 commit comments