diff --git a/egs/librispeech/ASR/conformer_ctc3/decode.py b/egs/librispeech/ASR/conformer_ctc3/decode.py index 39186e5468..3b24ad5971 100755 --- a/egs/librispeech/ASR/conformer_ctc3/decode.py +++ b/egs/librispeech/ASR/conformer_ctc3/decode.py @@ -96,8 +96,7 @@ from icefall.utils import ( AttributeDict, get_texts, - get_texts_with_timestamp, - parse_hyp_and_timestamp, + parse_fsa_timestamps_and_texts, setup_logger, store_transcripts_and_timestamps, str2bool, @@ -396,13 +395,8 @@ def decode_one_batch( best_path = one_best_decoding( lattice=lattice, use_double_scores=params.use_double_scores ) - # Note: `best_path.aux_labels` contains token IDs, not word IDs - # since we are using H, not HLG here. - # - # token_ids is a lit-of-list of IDs - res = get_texts_with_timestamp(best_path) - hyps, timestamps = parse_hyp_and_timestamp( - res=res, + timestamps, hyps = parse_fsa_timestamps_and_texts( + best_paths=best_path, sp=bpe_model, subsampling_factor=params.subsampling_factor, frame_shift_ms=params.frame_shift_ms, @@ -435,12 +429,11 @@ def decode_one_batch( lattice=lattice, use_double_scores=params.use_double_scores ) key = f"no_rescore_hlg_scale_{params.hlg_scale}" - res = get_texts_with_timestamp(best_path) - hyps, timestamps = parse_hyp_and_timestamp( - res=res, + timestamps, hyps = parse_fsa_timestamps_and_texts( + best_paths=best_path, + word_table=word_table, subsampling_factor=params.subsampling_factor, frame_shift_ms=params.frame_shift_ms, - word_table=word_table, ) else: best_path = nbest_decoding( @@ -504,7 +497,18 @@ def decode_dataset( sos_id: int, eos_id: int, G: Optional[k2.Fsa] = None, -) -> Dict[str, List[Tuple[str, List[str], List[str], List[float], List[float]]]]: +) -> Dict[ + str, + List[ + Tuple[ + str, + List[str], + List[str], + List[Tuple[float, float]], + List[Tuple[float, float]], + ] + ], +]: """Decode dataset. Args: @@ -555,7 +559,7 @@ def decode_dataset( time = [] if s.alignment is not None and "word" in s.alignment: time = [ - aliword.start + (aliword.start, aliword.end) for aliword in s.alignment["word"] if aliword.symbol != "" ] @@ -601,7 +605,15 @@ def save_results( test_set_name: str, results_dict: Dict[ str, - List[Tuple[List[str], List[str], List[str], List[float], List[float]]], + List[ + Tuple[ + List[str], + List[str], + List[str], + List[Tuple[float, float]], + List[Tuple[float, float]], + ] + ], ], ): test_set_wers = dict() @@ -621,7 +633,11 @@ def save_results( ) with open(errs_filename, "w") as f: wer, mean_delay, var_delay = write_error_stats_with_timestamps( - f, f"{test_set_name}-{key}", results, enable_log=True + f, + f"{test_set_name}-{key}", + results, + enable_log=True, + with_end_time=True, ) test_set_wers[key] = wer test_set_delays[key] = (mean_delay, var_delay) @@ -637,16 +653,17 @@ def save_results( for key, val in test_set_wers: print("{}\t{}".format(key, val), file=f) - test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0]) + # sort according to the mean start symbol delay + test_set_delays = sorted(test_set_delays.items(), key=lambda x: x[1][0][0]) delays_info = ( params.res_dir / f"symbol-delay-summary-{test_set_name}-{key}-{params.suffix}.txt" ) with open(delays_info, "w") as f: - print("settings\tsymbol-delay", file=f) + print("settings\t(start, end) symbol-delay (s) (start, end)", file=f) for key, val in test_set_delays: print( - "{}\tmean: {}s, variance: {}".format(key, val[0], val[1]), + "{}\tmean: {}, variance: {}".format(key, val[0], val[1]), file=f, ) @@ -657,10 +674,12 @@ def save_results( note = "" logging.info(s) - s = "\nFor {}, symbol-delay of different settings are:\n".format(test_set_name) + s = "\nFor {}, (start, end) symbol-delay (s) of different settings are:\n".format( + test_set_name + ) note = "\tbest for {}".format(test_set_name) for key, val in test_set_delays: - s += "{}\tmean: {}s, variance: {}{}\n".format(key, val[0], val[1], note) + s += "{}\tmean: {}, variance: {}{}\n".format(key, val[0], val[1], note) note = "" logging.info(s) diff --git a/icefall/utils.py b/icefall/utils.py index ba0b7fe43d..2358ed02f6 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -1,5 +1,6 @@ -# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang -# Mingshuang Luo) +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang, +# Mingshuang Luo, +# Zengwei Yao) # # See ../../LICENSE for clarification regarding multiple authors # @@ -453,11 +454,32 @@ def store_transcripts_and_timestamps( for cut_id, ref, hyp, time_ref, time_hyp in texts: print(f"{cut_id}:\tref={ref}", file=f) print(f"{cut_id}:\thyp={hyp}", file=f) + if len(time_ref) > 0: - s = "[" + ", ".join(["%0.3f" % i for i in time_ref]) + "]" + if isinstance(time_ref[0], tuple): + # each element is pair + s = ( + "[" + + ", ".join(["(%0.3f, %.03f)" % (i, j) for (i, j) in time_ref]) + + "]" + ) + else: + # each element is a float number + s = "[" + ", ".join(["%0.3f" % i for i in time_ref]) + "]" print(f"{cut_id}:\ttimestamp_ref={s}", file=f) - s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]" - print(f"{cut_id}:\ttimestamp_hyp={s}", file=f) + + if len(time_hyp) > 0: + if isinstance(time_hyp[0], tuple): + # each element is pair + s = ( + "[" + + ", ".join(["(%0.3f, %.03f)" % (i, j) for (i, j) in time_hyp]) + + "]" + ) + else: + # each element is a float number + s = "[" + ", ".join(["%0.3f" % i for i in time_hyp]) + "]" + print(f"{cut_id}:\ttimestamp_hyp={s}", file=f) def write_error_stats( @@ -624,9 +646,18 @@ def write_error_stats( def write_error_stats_with_timestamps( f: TextIO, test_set_name: str, - results: List[Tuple[str, List[str], List[str], List[float], List[float]]], + results: List[ + Tuple[ + str, + List[str], + List[str], + List[Union[float, Tuple[float, float]]], + List[Union[float, Tuple[float, float]]], + ] + ], enable_log: bool = True, -) -> Tuple[float, float, float]: + with_end_time: bool = False, +) -> Tuple[float, Union[float, Tuple[float, float]], Union[float, Tuple[float, float]]]: """Write statistics based on predicted results and reference transcripts as well as their timestamps. @@ -659,6 +690,8 @@ def write_error_stats_with_timestamps( enable_log: If True, also print detailed WER to the console. Otherwise, it is written only to the given file. + with_end_time: + Whether use end timestamps. Returns: Return total word error rate and mean delay. @@ -704,7 +737,15 @@ def write_error_stats_with_timestamps( words[ref_word][0] += 1 num_corr += 1 if has_time: - all_delay.append(time_hyp[p_hyp] - time_ref[p_ref]) + if with_end_time: + all_delay.append( + ( + time_hyp[p_hyp][0] - time_ref[p_ref][0], + time_hyp[p_hyp][1] - time_ref[p_ref][1], + ) + ) + else: + all_delay.append(time_hyp[p_hyp] - time_ref[p_ref]) p_hyp += 1 p_ref += 1 if has_time: @@ -716,16 +757,39 @@ def write_error_stats_with_timestamps( ins_errs = sum(ins.values()) del_errs = sum(dels.values()) tot_errs = sub_errs + ins_errs + del_errs - tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len) + tot_err_rate = float("%.2f" % (100.0 * tot_errs / ref_len)) - mean_delay = "inf" - var_delay = "inf" + if with_end_time: + mean_delay = (float("inf"), float("inf")) + var_delay = (float("inf"), float("inf")) + else: + mean_delay = float("inf") + var_delay = float("inf") num_delay = len(all_delay) if num_delay > 0: - mean_delay = sum(all_delay) / num_delay - var_delay = sum([(i - mean_delay) ** 2 for i in all_delay]) / num_delay - mean_delay = "%.3f" % mean_delay - var_delay = "%.3f" % var_delay + if with_end_time: + all_delay_start = [i[0] for i in all_delay] + mean_delay_start = sum(all_delay_start) / num_delay + var_delay_start = ( + sum([(i - mean_delay_start) ** 2 for i in all_delay_start]) / num_delay + ) + + all_delay_end = [i[1] for i in all_delay] + mean_delay_end = sum(all_delay_end) / num_delay + var_delay_end = ( + sum([(i - mean_delay_end) ** 2 for i in all_delay_end]) / num_delay + ) + + mean_delay = ( + float("%.3f" % mean_delay_start), + float("%.3f" % mean_delay_end), + ) + var_delay = (float("%.3f" % var_delay_start), float("%.3f" % var_delay_end)) + else: + mean_delay = sum(all_delay) / num_delay + var_delay = sum([(i - mean_delay) ** 2 for i in all_delay]) / num_delay + mean_delay = float("%.3f" % mean_delay) + var_delay = float("%.3f" % var_delay) if enable_log: logging.info( @@ -734,7 +798,8 @@ def write_error_stats_with_timestamps( f"{del_errs} del, {sub_errs} sub ]" ) logging.info( - f"[{test_set_name}] %symbol-delay mean: {mean_delay}s, variance: {var_delay} " # noqa + f"[{test_set_name}] %symbol-delay mean (s): " + f"{mean_delay}, variance: {var_delay} " # noqa f"computed on {num_delay} correct words" ) @@ -817,7 +882,8 @@ def write_error_stats_with_timestamps( hyp_count = corr + hyp_sub + ins print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f) - return float(tot_err_rate), float(mean_delay), float(var_delay) + + return tot_err_rate, mean_delay, var_delay class MetricsTracker(collections.defaultdict): @@ -1431,3 +1497,270 @@ def filter_uneven_sized_batch(batch: dict, allowed_max_frames: int): batch["supervisions"][k] = v[:keep_num_utt] return batch + + +def parse_bpe_start_end_pairs( + tokens: List[str], is_first_token: List[bool] +) -> List[Tuple[int, int]]: + """Parse pairs of start and end frame indexes for each word. + + Args: + tokens: + List of BPE tokens. + is_first_token: + List of bool values, which indicates whether it is the first token, + i.e., not repeat or blank. + + Returns: + List of (start-frame-index, end-frame-index) pairs for each word. + """ + assert len(tokens) == len(is_first_token), (len(tokens), len(is_first_token)) + + start_token = b"\xe2\x96\x81".decode() # '_' + blank_token = "" + + non_blank_idx = [i for i in range(len(tokens)) if tokens[i] != blank_token] + num_non_blank = len(non_blank_idx) + + pairs = [] + start = -1 + end = -1 + for j in range(num_non_blank): + # The index in all frames + i = non_blank_idx[j] + + found_start = False + if is_first_token[i] and (j == 0 or tokens[i].startswith(start_token)): + found_start = True + if tokens[i] == start_token: + if j == num_non_blank - 1: + # It is the last non-blank token + found_start = False + elif is_first_token[non_blank_idx[j + 1]] and tokens[ + non_blank_idx[j + 1] + ].startswith(start_token): + # The next not-blank token is a first-token and also starts with start_token + found_start = False + if found_start: + start = i + + if start != -1: + found_end = False + if j == num_non_blank - 1: + # It is the last non-blank token + found_end = True + elif is_first_token[non_blank_idx[j + 1]] and tokens[ + non_blank_idx[j + 1] + ].startswith(start_token): + # The next not-blank token is a first-token and also starts with start_token + found_end = True + if found_end: + end = i + + if start != -1 and end != -1: + if not all([tokens[t] == start_token for t in range(start, end + 1)]): + # except the case of all start_token + pairs.append((start, end)) + # Reset start and end + start = -1 + end = -1 + + return pairs + + +def parse_bpe_timestamps_and_texts( + best_paths: k2.Fsa, sp: spm.SentencePieceProcessor +) -> Tuple[List[Tuple[int, int]], List[List[str]]]: + """Parse timestamps (frame indexes) and texts. + + Args: + best_paths: + A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. + containing multiple FSAs, which is expected to be the result + of k2.shortest_path (otherwise the returned values won't + be meaningful). Its attribtutes `labels` and `aux_labels` + are both BPE tokens. + sp: + The BPE model. + + Returns: + utt_index_pairs: + A list of pair list. utt_index_pairs[i] is a list of + (start-frame-index, end-frame-index) pairs for each word in + utterance-i. + utt_words: + A list of str list. utt_words[i] is a word list of utterence-i. + """ + shape = best_paths.arcs.shape().remove_axis(1) + + # labels: [utt][arcs] + labels = k2.RaggedTensor(shape, best_paths.labels.contiguous()) + # remove -1's. + labels = labels.remove_values_eq(-1) + labels = labels.tolist() + + # aux_labels: [utt][arcs] + aux_labels = k2.RaggedTensor(shape, best_paths.aux_labels.contiguous()) + + # remove -1's. + all_aux_labels = aux_labels.remove_values_eq(-1) + # len(all_aux_labels[i]) is equal to the number of frames + all_aux_labels = all_aux_labels.tolist() + + # remove 0's and -1's. + out_aux_labels = aux_labels.remove_values_leq(0) + # len(out_aux_labels[i]) is equal to the number of output BPE tokens + out_aux_labels = out_aux_labels.tolist() + + utt_index_pairs = [] + utt_words = [] + for i in range(len(labels)): + tokens = sp.id_to_piece(labels[i]) + words = sp.decode(out_aux_labels[i]).split() + + # Indicates whether it is the first token, i.e., not-repeat and not-blank. + is_first_token = [a != 0 for a in all_aux_labels[i]] + index_pairs = parse_bpe_start_end_pairs(tokens, is_first_token) + assert len(index_pairs) == len(words), (len(index_pairs), len(words), tokens) + utt_index_pairs.append(index_pairs) + utt_words.append(words) + + return utt_index_pairs, utt_words + + +def parse_timestamps_and_texts( + best_paths: k2.Fsa, word_table: k2.SymbolTable +) -> Tuple[List[Tuple[int, int]], List[List[str]]]: + """Parse timestamps (frame indexes) and texts. + + Args: + best_paths: + A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. + containing multiple FSAs, which is expected to be the result + of k2.shortest_path (otherwise the returned values won't + be meaningful). Attribtute `labels` is the prediction unit, + e.g., phone or BPE tokens. Attribute `aux_labels` is the word index. + word_table: + The word symbol table. + + Returns: + utt_index_pairs: + A list of pair list. utt_index_pairs[i] is a list of + (start-frame-index, end-frame-index) pairs for each word in + utterance-i. + utt_words: + A list of str list. utt_words[i] is a word list of utterence-i. + """ + # [utt][words] + word_ids = get_texts(best_paths) + + shape = best_paths.arcs.shape().remove_axis(1) + + # labels: [utt][arcs] + labels = k2.RaggedTensor(shape, best_paths.labels.contiguous()) + # remove -1's. + labels = labels.remove_values_eq(-1) + labels = labels.tolist() + + # aux_labels: [utt][arcs] + aux_shape = shape.compose(best_paths.aux_labels.shape) + aux_labels = k2.RaggedTensor(aux_shape, best_paths.aux_labels.values.contiguous()) + aux_labels = aux_labels.tolist() + + utt_index_pairs = [] + utt_words = [] + for i, (label, aux_label) in enumerate(zip(labels, aux_labels)): + num_arcs = len(label) + # The last arc of aux_label is the arc entering the final state + assert num_arcs == len(aux_label) - 1, (num_arcs, len(aux_label)) + + index_pairs = [] + start = -1 + end = -1 + for arc in range(num_arcs): + # len(aux_label[arc]) is 0 or 1 + if label[arc] != 0 and len(aux_label[arc]) != 0: + if start != -1 and end != -1: + index_pairs.append((start, end)) + start = arc + if label[arc] != 0: + end = arc + if start != -1 and end != -1: + index_pairs.append((start, end)) + + words = [word_table[w] for w in word_ids[i]] + assert len(index_pairs) == len(words), (len(index_pairs), len(words)) + + utt_index_pairs.append(index_pairs) + utt_words.append(words) + + return utt_index_pairs, utt_words + + +def parse_fsa_timestamps_and_texts( + best_paths: k2.Fsa, + sp: Optional[spm.SentencePieceProcessor] = None, + word_table: Optional[k2.SymbolTable] = None, + subsampling_factor: int = 4, + frame_shift_ms: float = 10, +) -> Tuple[List[Tuple[float, float]], List[List[str]]]: + """Parse timestamps (in seconds) and texts for given decoded fsa paths. + Currently it supports two cases: + (1) ctc-decoding, the attribtutes `labels` and `aux_labels` + are both BPE tokens. In this case, sp should be provided. + (2) HLG-based 1best, the attribtute `labels` is the prediction unit, + e.g., phone or BPE tokens; attribute `aux_labels` is the word index. + In this case, word_table should be provided. + + Args: + best_paths: + A k2.Fsa with best_paths.arcs.num_axes() == 3, i.e. + containing multiple FSAs, which is expected to be the result + of k2.shortest_path (otherwise the returned values won't + be meaningful). + sp: + The BPE model. + word_table: + The word symbol table. + subsampling_factor: + The subsampling factor of the model. + frame_shift_ms: + Frame shift in milliseconds between two contiguous frames. + + Returns: + utt_time_pairs: + A list of pair list. utt_time_pairs[i] is a list of + (start-time, end-time) pairs for each word in + utterance-i. + utt_words: + A list of str list. utt_words[i] is a word list of utterence-i. + """ + if sp is not None: + assert word_table is None, "word_table is not needed if sp is provided." + utt_index_pairs, utt_words = parse_bpe_timestamps_and_texts( + best_paths=best_paths, sp=sp + ) + elif word_table is not None: + assert sp is None, "sp is not needed if word_table is provided." + utt_index_pairs, utt_words = parse_timestamps_and_texts( + best_paths=best_paths, word_table=word_table + ) + else: + raise ValueError("Either sp or word_table should be provided.") + + utt_time_pairs = [] + for utt in utt_index_pairs: + start = convert_timestamp( + frames=[i[0] for i in utt], + subsampling_factor=subsampling_factor, + frame_shift_ms=frame_shift_ms, + ) + end = convert_timestamp( + # The duration in frames is (end_frame_index - start_frame_index + 1) + frames=[i[1] + 1 for i in utt], + subsampling_factor=subsampling_factor, + frame_shift_ms=frame_shift_ms, + ) + utt_time_pairs.append(list(zip(start, end))) + + return utt_time_pairs, utt_words diff --git a/test/test_parse_timestamp.py b/test/test_parse_timestamp.py new file mode 100755 index 0000000000..92bfb49c65 --- /dev/null +++ b/test/test_parse_timestamp.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Zengwei Yao) +# +# See ../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from pathlib import Path + +import k2 +import sentencepiece as spm +import torch + +from icefall.lexicon import Lexicon +from icefall.utils import parse_bpe_timestamps_and_texts, parse_timestamps_and_texts + +ICEFALL_DIR = Path(__file__).resolve().parent.parent + + +def test_parse_bpe_timestamps_and_texts(): + lang_dir = ICEFALL_DIR / "egs/librispeech/ASR/data/lang_bpe_500" + if not lang_dir.is_dir(): + print(f"{lang_dir} does not exist.") + return + + sp = spm.SentencePieceProcessor() + sp.load(str(lang_dir / "bpe.model")) + + text_1 = "HELLO WORLD" + token_ids_1 = sp.encode(text_1, out_type=int) + # out_type=str: ['_HE', 'LL', 'O', '_WORLD'] + # out_type=int: [22, 58, 24, 425] + + # [22, 22, 58, 24, 0, 0, 425, 425, 425, 0, 0] + labels_1 = ( + token_ids_1[0:1] * 2 + + token_ids_1[1:3] + + [0] * 2 + + token_ids_1[3:4] * 3 + + [0] * 2 + ) + # [22, 0, 58, 24, 0, 0, 425, 0, 0, 0, 0, -1] + aux_labels_1 = ( + token_ids_1[0:1] + + [0] + + token_ids_1[1:3] + + [0] * 2 + + token_ids_1[3:4] + + [0] * 4 + + [-1] + ) + fsa_1 = k2.linear_fsa(labels_1) + fsa_1.aux_labels = torch.tensor(aux_labels_1).to(torch.int32) + + text_2 = "SAY GOODBYE" + token_ids_2 = sp.encode(text_2, out_type=int) + # out_type=str: ['_SAY', '_GOOD', 'B', 'Y', 'E'] + # out_type=int: [289, 286, 41, 16, 11] + + # [289, 0, 0, 286, 286, 41, 16, 11, 0, 0] + labels_2 = ( + token_ids_2[0:1] + [0] * 2 + token_ids_2[1:2] * 2 + token_ids_2[2:5] + [0] * 2 + ) + # [289, 0, 0, 286, 0, 41, 16, 11, 0, 0, -1] + aux_labels_2 = ( + token_ids_2[0:1] + + [0] * 2 + + token_ids_2[1:2] + + [0] + + token_ids_2[2:5] + + [0] * 2 + + [-1] + ) + fsa_2 = k2.linear_fsa(labels_2) + fsa_2.aux_labels = torch.tensor(aux_labels_2).to(torch.int32) + + fsa_vec = k2.create_fsa_vec([fsa_1, fsa_2]) + + utt_index_pairs, utt_words = parse_bpe_timestamps_and_texts(fsa_vec, sp) + assert utt_index_pairs[0] == [(0, 3), (6, 8)], utt_index_pairs[0] + assert utt_words[0] == ["HELLO", "WORLD"], utt_words[0] + assert utt_index_pairs[1] == [(0, 0), (3, 7)], utt_index_pairs[1] + assert utt_words[1] == ["SAY", "GOODBYE"], utt_words[1] + + +def test_parse_timestamps_and_texts(): + lang_dir = ICEFALL_DIR / "egs/librispeech/ASR/data/lang_bpe_500" + if not lang_dir.is_dir(): + print(f"{lang_dir} does not exist.") + return + + lexicon = Lexicon(lang_dir) + + sp = spm.SentencePieceProcessor() + sp.load(str(lang_dir / "bpe.model")) + word_table = lexicon.word_table + + text_1 = "HELLO WORLD" + token_ids_1 = sp.encode(text_1, out_type=int) + # out_type=str: ['_HE', 'LL', 'O', '_WORLD'] + # out_type=int: [22, 58, 24, 425] + word_ids_1 = [word_table[s] for s in text_1.split()] # [79677, 196937] + # [22, 22, 58, 24, 0, 0, 425, 425, 425, 0, 0] + labels_1 = ( + token_ids_1[0:1] * 2 + + token_ids_1[1:3] + + [0] * 2 + + token_ids_1[3:4] * 3 + + [0] * 2 + ) + # [[79677], [], [], [], [], [], [196937], [], [], [], [], []] + aux_labels_1 = [word_ids_1[0:1]] + [[]] * 5 + [word_ids_1[1:2]] + [[]] * 5 + + fsa_1 = k2.linear_fsa(labels_1) + fsa_1.aux_labels = k2.RaggedTensor(aux_labels_1) + + text_2 = "SAY GOODBYE" + token_ids_2 = sp.encode(text_2, out_type=int) + # out_type=str: ['_SAY', '_GOOD', 'B', 'Y', 'E'] + # out_type=int: [289, 286, 41, 16, 11] + word_ids_2 = [word_table[s] for s in text_2.split()] # [154967, 72079] + # [289, 0, 0, 286, 286, 41, 16, 11, 0, 0] + labels_2 = ( + token_ids_2[0:1] + [0] * 2 + token_ids_2[1:2] * 2 + token_ids_2[2:5] + [0] * 2 + ) + # [[154967], [], [], [72079], [], [], [], [], [], [], []] + aux_labels_2 = [word_ids_2[0:1]] + [[]] * 2 + [word_ids_2[1:2]] + [[]] * 7 + + fsa_2 = k2.linear_fsa(labels_2) + fsa_2.aux_labels = k2.RaggedTensor(aux_labels_2) + + fsa_vec = k2.create_fsa_vec([fsa_1, fsa_2]) + + utt_index_pairs, utt_words = parse_timestamps_and_texts(fsa_vec, word_table) + assert utt_index_pairs[0] == [(0, 3), (6, 8)], utt_index_pairs[0] + assert utt_words[0] == ["HELLO", "WORLD"], utt_words[0] + assert utt_index_pairs[1] == [(0, 0), (3, 7)], utt_index_pairs[1] + assert utt_words[1] == ["SAY", "GOODBYE"], utt_words[1] + + +if __name__ == "__main__": + test_parse_bpe_timestamps_and_texts() + test_parse_timestamps_and_texts()