Skip to content

Commit

Permalink
Get (start, end) timestamps for CTC models (#876)
Browse files Browse the repository at this point in the history
* parse timestamps and texts for BPE-based models

* parse timestamps (frame indexes) and texts for other cases

* add test functions

* add parse_fsa_timestamps_and_texts function, test in conformer_ctc3/decode.py

* calculate symbol delay for (start, end) timestamps
  • Loading branch information
yaozengwei authored Feb 7, 2023
1 parent 7ae03f6 commit d12e6f0
Show file tree
Hide file tree
Showing 3 changed files with 545 additions and 39 deletions.
63 changes: 41 additions & 22 deletions egs/librispeech/ASR/conformer_ctc3/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@
from icefall.utils import (
AttributeDict,
get_texts,
get_texts_with_timestamp,
parse_hyp_and_timestamp,
parse_fsa_timestamps_and_texts,
setup_logger,
store_transcripts_and_timestamps,
str2bool,
Expand Down Expand Up @@ -396,13 +395,8 @@ def decode_one_batch(
best_path = one_best_decoding(
lattice=lattice, use_double_scores=params.use_double_scores
)
# Note: `best_path.aux_labels` contains token IDs, not word IDs
# since we are using H, not HLG here.
#
# token_ids is a lit-of-list of IDs
res = get_texts_with_timestamp(best_path)
hyps, timestamps = parse_hyp_and_timestamp(
res=res,
timestamps, hyps = parse_fsa_timestamps_and_texts(
best_paths=best_path,
sp=bpe_model,
subsampling_factor=params.subsampling_factor,
frame_shift_ms=params.frame_shift_ms,
Expand Down Expand Up @@ -435,12 +429,11 @@ def decode_one_batch(
lattice=lattice, use_double_scores=params.use_double_scores
)
key = f"no_rescore_hlg_scale_{params.hlg_scale}"
res = get_texts_with_timestamp(best_path)
hyps, timestamps = parse_hyp_and_timestamp(
res=res,
timestamps, hyps = parse_fsa_timestamps_and_texts(
best_paths=best_path,
word_table=word_table,
subsampling_factor=params.subsampling_factor,
frame_shift_ms=params.frame_shift_ms,
word_table=word_table,
)
else:
best_path = nbest_decoding(
Expand Down Expand Up @@ -504,7 +497,18 @@ def decode_dataset(
sos_id: int,
eos_id: int,
G: Optional[k2.Fsa] = None,
) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]:
) -> Dict[
str,
List[
Tuple[
str,
List[str],
List[str],
List[Tuple[float, float]],
List[Tuple[float, float]],
]
],
]:
"""Decode dataset.
Args:
Expand Down Expand Up @@ -555,7 +559,7 @@ def decode_dataset(
time = []
if s.alignment is not None and "word" in s.alignment:
time = [
aliword.start
(aliword.start, aliword.end)
for aliword in s.alignment["word"]
if aliword.symbol != ""
]
Expand Down Expand Up @@ -601,7 +605,15 @@ def save_results(
test_set_name: str,
results_dict: Dict[
str,
List[Tuple[List[str], List[str], List[str], List[float], List[float]]],
List[
Tuple[
List[str],
List[str],
List[str],
List[Tuple[float, float]],
List[Tuple[float, float]],
]
],
],
):
test_set_wers = dict()
Expand All @@ -621,7 +633,11 @@ def save_results(
)
with open(errs_filename, "w") as f:
wer, mean_delay, var_delay = write_error_stats_with_timestamps(
f, f"{test_set_name}-{key}", results, enable_log=True
f,
f"{test_set_name}-{key}",
results,
enable_log=True,
with_end_time=True,
)
test_set_wers[key] = wer
test_set_delays[key] = (mean_delay, var_delay)
Expand All @@ -637,16 +653,17 @@ def save_results(
for key, val in test_set_wers:
print("{}\t{}".format(key, val), file=f)

test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0])
# sort according to the mean start symbol delay
test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0][0])
delays_info = (
params.res_dir
/ f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt"
)
with open(delays_info, "w") as f:
print("settings\tsymbol-delay", file=f)
print("settings\t(start, end) symbol-delay (s) (start, end)", file=f)
for key, val in test_set_delays:
print(
"{}\tmean: {}s, variance: {}".format(key, val[0], val[1]),
"{}\tmean: {}, variance: {}".format(key, val[0], val[1]),
file=f,
)

Expand All @@ -657,10 +674,12 @@ def save_results(
note = ""
logging.info(s)

s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name)
s = "\nFor {}, (start, end) symbol-delay (s) of different settings are:\n".format(
test_set_name
)
note = "\tbest for {}".format(test_set_name)
for key, val in test_set_delays:
s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note)
s += "{}\tmean: {}, variance: {}{}\n".format(key, val[0], val[1], note)
note = ""
logging.info(s)

Expand Down
Loading

0 comments on commit d12e6f0

Please sign in to comment.