diff --git a/egs/librispeech/ASR/long_file_recog.sh b/egs/librispeech/ASR/long_file_recog.sh new file mode 100755 index 0000000000..acd1b1253c --- /dev/null +++ b/egs/librispeech/ASR/long_file_recog.sh @@ -0,0 +1,94 @@ +#!/usr/bin/env bash + +# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python + +export CUDA_VISIBLE_DEVICES="0,1,2,3" + +set -eou pipefail + +# This script is used to recogize long audios. The process is as follows: +# 1) Split long audios into chunks with overlaps. +# 2) Perform speech recognition on chunks, getting tokens and timestamps. +# 3) Merge the overlapped chunks into utterances acording to the timestamps. + +# Each chunk (except the first and the last) is padded with extra left side and right side. +# The chunk length is: left_side + chunk_size + right_side. +chunk=30.0 +extra=2.0 + +stage=1 +stop_stage=4 + +# We assume that you have downloaded the LibriLight dataset +# with audio files in $corpus_dir and texts in $text_dir +corpus_dir=$PWD/download/libri-light +text_dir=$PWD/download/librilight_text +# Path to save the manifests +output_dir=$PWD/data/librilight + +world_size=4 + + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then + # We will get librilight_recodings_{subset}.jsonl.gz and librilight_supervisions_{subset}.jsonl.gz + # saved in $output_dir/manifests + log "Stage 1: Prepare LibriLight manifest" + lhotse prepare librilight $corpus_dir $text_dir $output_dir/manifests -j 10 +fi + +if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then + # Chunk manifests are saved to $output_dir/manifests_chunk/librilight_cuts_{subset}.jsonl.gz + log "Stage 2: Split long audio into chunks" + ./long_file_recog/split_into_chunks.py \ + --manifest-in-dir $output_dir/manifests \ + --manifest-out-dir $output_dir/manifests_chunk \ + --chunk $chunk \ + --extra $extra # Extra duration (in seconds) at both sides +fi + +if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then + # Recognized tokens and timestamps are saved to $output_dir/manifests_chunk_recog/librilight_cuts_{subset}.jsonl.gz + + # This script loads torchscript models, exported by `torch.jit.script()`, + # and uses it to decode waves. + # You can download the jit model from https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 + + log "Stage 3: Perform speech recognition on splitted chunks" + for subset in small median large; do + ./long_file_recog/recognize.py \ + --world-size $world_size \ + --num-workers 8 \ + --subset $subset \ + --manifest-in-dir $output_dir/manifests_chunk \ + --manifest-out-dir $output_dir/manifests_chunk_recog \ + --nn-model-filename long_file_recog/exp/jit_model.pt \ + --bpe-model data/lang_bpe_500/bpe.model \ + --max-duration 2400 \ + --decoding-method greedy_search + --master 12345 + + if [ $world_size -gt 1 ]; then + # Combine manifests from different jobs + lhotse combine $(find $output_dir/manifests_chunk_recog -name librilight_cuts_${subset}_job_*.jsonl.gz | tr "\n" " ") $output_dir/manifests_chunk_recog/librilight_cuts_${subset}.jsonl.gz + fi + done +fi + +if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then + # Final results are saved in $output_dir/manifests/librilight_cuts_{subset}.jsonl.gz + log "Stage 4: Merge splitted chunks into utterances." + ./long_file_recog/merge_chunks.py \ + --manifest-in-dir $output_dir/manifests_chunk_recog \ + --manifest-out-dir $output_dir/manifests \ + --bpe-model data/lang_bpe_500/bpe.model \ + --extra $extra +fi + + diff --git a/egs/librispeech/ASR/long_file_recog/asr_datamodule.py b/egs/librispeech/ASR/long_file_recog/asr_datamodule.py new file mode 100644 index 0000000000..eddce7213b --- /dev/null +++ b/egs/librispeech/ASR/long_file_recog/asr_datamodule.py @@ -0,0 +1,189 @@ +# Copyright 2021 Piotr Żelasko +# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) +# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import logging +from functools import lru_cache +from pathlib import Path +from typing import Dict, List, Union + +import torch +from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy +from lhotse.cut import Cut +from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures + CutConcatenate, + CutMix, + DynamicBucketingSampler, + K2SpeechRecognitionDataset, + PrecomputedFeatures, + SimpleCutSampler, + SpecAugment, +) +from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples + AudioSamples, + BatchIO, + OnTheFlyFeatures, +) +from torch.utils.data import DataLoader + +from icefall.utils import str2bool + + +class SpeechRecognitionDataset(K2SpeechRecognitionDataset): + def __init__( + self, + return_cuts: bool = False, + input_strategy: BatchIO = PrecomputedFeatures(), + ): + super().__init__(return_cuts=return_cuts, input_strategy=input_strategy) + + def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[Cut]]]: + """ + Return a new batch, with the batch size automatically determined using the constraints + of max_frames and max_cuts. + """ + self.hdf5_fix.update() + + # Note: don't sort cuts here + # Sort the cuts by duration so that the first one determines the batch time dimensions. + # cuts = cuts.sort_by_duration(ascending=False) + + # Get a tensor with batched feature matrices, shape (B, T, F) + # Collation performs auto-padding, if necessary. + input_tpl = self.input_strategy(cuts) + if len(input_tpl) == 3: + # An input strategy with fault tolerant audio reading mode. + # "cuts" may be a subset of the original "cuts" variable, + # that only has cuts for which we succesfully read the audio. + inputs, _, cuts = input_tpl + else: + inputs, _ = input_tpl + + # Get a dict of tensors that encode the positional information about supervisions + # in the batch of feature matrices. The tensors are named "sequence_idx", + # "start_frame/sample" and "num_frames/samples". + supervision_intervals = self.input_strategy.supervision_intervals(cuts) + + batch = {"inputs": inputs, "supervisions": supervision_intervals} + if self.return_cuts: + batch["supervisions"]["cut"] = [cut for cut in cuts] + + return batch + + +class AsrDataModule: + """ + DataModule for k2 ASR experiments. + It assumes there is always one train and valid dataloader, + but there can be multiple test dataloaders (e.g. LibriSpeech test-clean + and test-other). + + It contains all the common data pipeline modules used in ASR + experiments, e.g.: + - dynamic batch size, + - bucketing samplers, + - cut concatenation, + - augmentation, + - on-the-fly feature extraction + + This class should be derived for specific corpora used in ASR tasks. + """ + + def __init__(self, args: argparse.Namespace): + self.args = args + + @classmethod + def add_arguments(cls, parser: argparse.ArgumentParser): + group = parser.add_argument_group( + title="ASR data related options", + description="These options are used for the preparation of " + "PyTorch DataLoaders from Lhotse CutSet's -- they control the " + "effective batch sizes, sampling strategies, applied data " + "augmentations, etc.", + ) + group.add_argument( + "--manifest-dir", + type=Path, + default=Path("data/manifests_chunk"), + help="Path to directory with train/valid/test cuts.", + ) + group.add_argument( + "--max-duration", + type=int, + default=600.0, + help="Maximum pooled recordings duration (seconds) in a " + "single batch. You can reduce it if it causes CUDA OOM.", + ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) + group.add_argument( + "--return-cuts", + type=str2bool, + default=True, + help="When enabled, each batch will have the " + "field: batch['supervisions']['cut'] with the cuts that " + "were used to construct it.", + ) + group.add_argument( + "--num-workers", + type=int, + default=8, + help="The number of training dataloader workers that " + "collect the batches.", + ) + + group.add_argument( + "--input-strategy", + type=str, + default="PrecomputedFeatures", + help="AudioSamples or PrecomputedFeatures", + ) + + def dataloaders(self, cuts: CutSet) -> DataLoader: + logging.debug("About to create test dataset") + test = SpeechRecognitionDataset( + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), + return_cuts=self.args.return_cuts, + ) + + sampler = SimpleCutSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + drop_last=False, + ) + + logging.debug("About to create test dataloader") + test_dl = DataLoader( + test, + batch_size=None, + sampler=sampler, + num_workers=self.args.num_workers, + persistent_workers=False, + ) + return test_dl + + @lru_cache() + def load_subset(self, cuts_filename: Path) -> CutSet: + return load_manifest_lazy(cuts_filename) diff --git a/egs/librispeech/ASR/long_file_recog/beam_search.py b/egs/librispeech/ASR/long_file_recog/beam_search.py new file mode 100644 index 0000000000..f8c31861c2 --- /dev/null +++ b/egs/librispeech/ASR/long_file_recog/beam_search.py @@ -0,0 +1,613 @@ +# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang +# Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Union + +import k2 +import torch + +from icefall.decode import one_best_decoding +from icefall.utils import DecodingResults, get_texts, get_texts_with_timestamp + + +def fast_beam_search( + model: torch.nn.Module, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + temperature: float = 1.0, +) -> k2.Fsa: + """It limits the maximum number of symbols per frame to 1. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + temperature: + Softmax temperature. + Returns: + Return an FsaVec with axes [utt][state][arc] containing the decoded + lattice. Note: When the input graph is a TrivialGraph, the returned + lattice is actually an acceptor. + """ + assert encoder_out.ndim == 3 + + context_size = model.decoder.context_size + vocab_size = model.decoder.vocab_size + + B, T, C = encoder_out.shape + + config = k2.RnntDecodingConfig( + vocab_size=vocab_size, + decoder_history_len=context_size, + beam=beam, + max_contexts=max_contexts, + max_states=max_states, + ) + individual_streams = [] + for i in range(B): + individual_streams.append(k2.RnntDecodingStream(decoding_graph)) + decoding_streams = k2.RnntDecodingStreams(individual_streams, config) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + for t in range(T): + # shape is a RaggedShape of shape (B, context) + # contexts is a Tensor of shape (shape.NumElements(), context_size) + shape, contexts = decoding_streams.get_contexts() + # `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts = contexts.to(torch.int64) + # decoder_out is of shape (shape.NumElements(), 1, decoder_out_dim) + decoder_out = model.decoder(contexts, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # current_encoder_out is of shape + # (shape.NumElements(), 1, joiner_dim) + # fmt: off + current_encoder_out = torch.index_select( + encoder_out[:, t:t + 1, :], 0, shape.row_ids(1).to(torch.int64) + ) + # fmt: on + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + logits = logits.squeeze(1).squeeze(1) + log_probs = (logits / temperature).log_softmax(dim=-1) + decoding_streams.advance(log_probs) + decoding_streams.terminate_and_flush_to_streams() + lattice = decoding_streams.format_output(encoder_out_lens.tolist()) + + return lattice + + +def fast_beam_search_one_best( + model: torch.nn.Module, + decoding_graph: k2.Fsa, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: float, + max_states: int, + max_contexts: int, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """It limits the maximum number of symbols per frame to 1. + + A lattice is first obtained using fast beam search, and then + the shortest path within the lattice is used as the final output. + + Args: + model: + An instance of `Transducer`. + decoding_graph: + Decoding graph used for decoding, may be a TrivialGraph or a LG. + encoder_out: + A tensor of shape (N, T, C) from the encoder. + encoder_out_lens: + A tensor of shape (N,) containing the number of frames in `encoder_out` + before padding. + beam: + Beam value, similar to the beam used in Kaldi.. + max_states: + Max states per stream per frame. + max_contexts: + Max contexts pre stream per frame. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + lattice = fast_beam_search( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=beam, + max_states=max_states, + max_contexts=max_contexts, + temperature=temperature, + ) + + best_path = one_best_decoding(lattice) + + if not return_timestamps: + return get_texts(best_path) + else: + return get_texts_with_timestamp(best_path) + + +def greedy_search_batch( + model: torch.nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3 + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + device = next(model.parameters()).device + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + hyps = [[-1] * (context_size - 1) + [blank_id] for _ in range(N)] + + # timestamp[n][i] is the frame index after subsampling + # on which hyp[n][i] is decoded + timestamps = [[] for _ in range(N)] + # scores[n][i] is the logits on which hyp[n][i] is decoded + scores = [[] for _ in range(N)] + + decoder_input = torch.tensor( + hyps, + device=device, + dtype=torch.int64, + ) # (N, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out: (N, 1, decoder_out_dim) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + for (t, batch_size) in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape: (batch_size, 1, 1, encoder_out_dim) + offset = end + + decoder_out = decoder_out[:batch_size] + + logits = model.joiner( + current_encoder_out, decoder_out.unsqueeze(1), project_input=False + ) + # logits'shape (batch_size, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (batch_size, vocab_size) + log_probs = logits.log_softmax(dim=-1) + assert log_probs.ndim == 2, log_probs.shape + y = log_probs.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v not in (blank_id, unk_id): + hyps[i].append(v) + timestamps[i].append(t) + scores[i].append(log_probs[i, v].item()) + emitted = True + if emitted: + # update decoder output + decoder_input = [h[-context_size:] for h in hyps[:batch_size]] + decoder_input = torch.tensor( + decoder_input, + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + sorted_ans = [h[context_size:] for h in hyps] + ans = [] + ans_timestamps = [] + ans_scores = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + ans_timestamps.append(timestamps[unsorted_indices[i]]) + ans_scores.append(scores[unsorted_indices[i]]) + + if not return_timestamps: + return ans + else: + return DecodingResults( + hyps=ans, + timestamps=ans_timestamps, + scores=ans_scores, + ) + + +@dataclass +class Hypothesis: + # The predicted tokens so far. + # Newly predicted tokens are appended to `ys`. + ys: List[int] + + # The log prob of ys. + # It contains only one entry. + log_prob: torch.Tensor + + # timestamp[i] is the frame index after subsampling + # on which ys[i] is decoded + timestamp: List[int] = field(default_factory=list) + + @property + def key(self) -> str: + """Return a string representation of self.ys""" + return "_".join(map(str, self.ys)) + + +class HypothesisList(object): + def __init__(self, data: Optional[Dict[str, Hypothesis]] = None) -> None: + """ + Args: + data: + A dict of Hypotheses. Its key is its `value.key`. + """ + if data is None: + self._data = {} + else: + self._data = data + + @property + def data(self) -> Dict[str, Hypothesis]: + return self._data + + def add(self, hyp: Hypothesis) -> None: + """Add a Hypothesis to `self`. + + If `hyp` already exists in `self`, its probability is updated using + `log-sum-exp` with the existed one. + + Args: + hyp: + The hypothesis to be added. + """ + key = hyp.key + if key in self: + old_hyp = self._data[key] # shallow copy + torch.logaddexp(old_hyp.log_prob, hyp.log_prob, out=old_hyp.log_prob) + else: + self._data[key] = hyp + + def get_most_probable(self, length_norm: bool = False) -> Hypothesis: + """Get the most probable hypothesis, i.e., the one with + the largest `log_prob`. + + Args: + length_norm: + If True, the `log_prob` of a hypothesis is normalized by the + number of tokens in it. + Returns: + Return the hypothesis that has the largest `log_prob`. + """ + if length_norm: + return max(self._data.values(), key=lambda hyp: hyp.log_prob / len(hyp.ys)) + else: + return max(self._data.values(), key=lambda hyp: hyp.log_prob) + + def remove(self, hyp: Hypothesis) -> None: + """Remove a given hypothesis. + + Caution: + `self` is modified **in-place**. + + Args: + hyp: + The hypothesis to be removed from `self`. + Note: It must be contained in `self`. Otherwise, + an exception is raised. + """ + key = hyp.key + assert key in self, f"{key} does not exist" + del self._data[key] + + def filter(self, threshold: torch.Tensor) -> "HypothesisList": + """Remove all Hypotheses whose log_prob is less than threshold. + + Caution: + `self` is not modified. Instead, a new HypothesisList is returned. + + Returns: + Return a new HypothesisList containing all hypotheses from `self` + with `log_prob` being greater than the given `threshold`. + """ + ans = HypothesisList() + for _, hyp in self._data.items(): + if hyp.log_prob > threshold: + ans.add(hyp) # shallow copy + return ans + + def topk(self, k: int) -> "HypothesisList": + """Return the top-k hypothesis.""" + hyps = list(self._data.items()) + + hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] + + ans = HypothesisList(dict(hyps)) + return ans + + def __contains__(self, key: str): + return key in self._data + + def __iter__(self): + return iter(self._data.values()) + + def __len__(self) -> int: + return len(self._data) + + def __str__(self) -> str: + s = [] + for key in self: + s.append(key) + return ", ".join(s) + + +def get_hyps_shape(hyps: List[HypothesisList]) -> k2.RaggedShape: + """Return a ragged shape with axes [utt][num_hyps]. + + Args: + hyps: + len(hyps) == batch_size. It contains the current hypothesis for + each utterance in the batch. + Returns: + Return a ragged shape with 2 axes [utt][num_hyps]. Note that + the shape is on CPU. + """ + num_hyps = [len(h) for h in hyps] + + # torch.cumsum() is inclusive sum, so we put a 0 at the beginning + # to get exclusive sum later. + num_hyps.insert(0, 0) + + num_hyps = torch.tensor(num_hyps) + row_splits = torch.cumsum(num_hyps, dim=0, dtype=torch.int32) + ans = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=row_splits[-1].item() + ) + return ans + + +def modified_beam_search( + model: torch.nn.Module, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + beam: int = 4, + temperature: float = 1.0, + return_timestamps: bool = False, +) -> Union[List[List[int]], DecodingResults]: + """Beam search in batch mode with --max-sym-per-frame=1 being hardcoded. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + beam: + Number of active paths during the beam search. + temperature: + Softmax temperature. + return_timestamps: + Whether to return timestamps. + Returns: + If return_timestamps is False, return the decoded result. + Else, return a DecodingResults object containing + decoded result and corresponding timestamps. + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) >= 1, encoder_out.size(0) + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + + blank_id = model.decoder.blank_id + unk_id = getattr(model, "unk_id", blank_id) + context_size = model.decoder.context_size + device = next(model.parameters()).device + + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for (t, batch_size) in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + + hyps_shape = get_hyps_shape(B).to(device) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=hyps_shape.row_ids(1).to(torch.int64), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, + decoder_out, + project_input=False, + ) # (num_hyps, 1, 1, vocab_size) + + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + + log_probs = (logits / temperature).log_softmax(dim=-1) # (num_hyps, vocab_size) + + log_probs.add_(ys_log_probs) + + vocab_size = log_probs.size(-1) + + log_probs = log_probs.reshape(-1) + + row_splits = hyps_shape.row_splits(1) * vocab_size + log_probs_shape = k2.ragged.create_ragged_shape2( + row_splits=row_splits, cached_tot_size=log_probs.numel() + ) + ragged_log_probs = k2.RaggedTensor(shape=log_probs_shape, value=log_probs) + + for i in range(batch_size): + topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam) + + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + topk_hyp_indexes = (topk_indexes // vocab_size).tolist() + topk_token_indexes = (topk_indexes % vocab_size).tolist() + + for k in range(len(topk_hyp_indexes)): + hyp_idx = topk_hyp_indexes[k] + hyp = A[i][hyp_idx] + + new_ys = hyp.ys[:] + new_token = topk_token_indexes[k] + new_timestamp = hyp.timestamp[:] + if new_token not in (blank_id, unk_id): + new_ys.append(new_token) + new_timestamp.append(t) + + new_log_prob = topk_log_probs[k] + new_hyp = Hypothesis( + ys=new_ys, log_prob=new_log_prob, timestamp=new_timestamp + ) + B[i].add(new_hyp) + + B = B + finalized_B + best_hyps = [b.get_most_probable(length_norm=True) for b in B] + + sorted_ans = [h.ys[context_size:] for h in best_hyps] + sorted_timestamps = [h.timestamp for h in best_hyps] + ans = [] + ans_timestamps = [] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + for i in range(N): + ans.append(sorted_ans[unsorted_indices[i]]) + ans_timestamps.append(sorted_timestamps[unsorted_indices[i]]) + + if not return_timestamps: + return ans + else: + return DecodingResults( + hyps=ans, + timestamps=ans_timestamps, + ) diff --git a/egs/librispeech/ASR/long_file_recog/merge_chunks.py b/egs/librispeech/ASR/long_file_recog/merge_chunks.py new file mode 100755 index 0000000000..d38d9c86a0 --- /dev/null +++ b/egs/librispeech/ASR/long_file_recog/merge_chunks.py @@ -0,0 +1,240 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This file merge overlapped chunks into utterances accroding to recording ids. +""" + +import argparse +import logging +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import List + +import sentencepiece as spm +from lhotse import ( + CutSet, + MonoCut, + SupervisionSegment, + SupervisionSet, + load_manifest, + load_manifest_lazy, +) +from lhotse.cut import Cut +from lhotse.serialization import SequentialJsonlWriter + + +def get_parser(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--manifest-in-dir", + type=Path, + default=Path("data/librilight/manifests_chunk_recog"), + help="Path to directory of chunk cuts with recognition results.", + ) + + parser.add_argument( + "--manifest-out-dir", + type=Path, + default=Path("data/manifests"), + help="Path to directory to save full utterance by merging overlapped chunks.", + ) + + parser.add_argument( + "--extra", + type=float, + default=2.0, + help="""Extra duration (in seconds) at both sides.""", + ) + + return parser.parse_args() + + +def merge_chunks( + cuts_chunk: CutSet, + supervisions: SupervisionSet, + cuts_writer: SequentialJsonlWriter, + sp: spm.SentencePieceProcessor, + extra: float, +) -> int: + """Merge chunk-wise cuts accroding to recording ids. + + Args: + cuts_chunk: + The chunk-wise cuts opened in a lazy mode. + supervisions: + The supervision manifest containing text file path, opened in a lazy mode. + cuts_writer: + Writer to save the cuts with recognition results. + sp: + The BPE model. + extra: + Extra duration (in seconds) to drop at both sides of each chunk. + """ + + # Background worker to add alignemnt and save cuts to disk. + def _save_worker(utt_cut: Cut, flush=False): + cuts_writer.write(utt_cut, flush=flush) + + def _merge(cut_list: List[Cut], rec_id: str, utt_idx: int): + """Merge chunks with same recording_id.""" + for cut in cut_list: + assert cut.recording.id == rec_id, (cut.recording.id, rec_id) + + # For each group with a same recording, sort it accroding to the start time + # In fact, we don't need to do this since the cuts have been sorted + # according to the start time + cut_list = sorted(cut_list, key=(lambda cut: cut.start)) + + rec = cut_list[0].recording + alignments = [] + cur_end = 0 + for cut in cut_list: + # Get left and right borders + left = cut.start + extra if cut.start > 0 else 0 + chunk_end = cut.start + cut.duration + right = chunk_end - extra if chunk_end < rec.duration else rec.duration + + # Assert the chunks are continuous + assert left == cur_end, (left, cur_end) + cur_end = right + + assert len(cut.supervisions) == 1, len(cut.supervisions) + for ali in cut.supervisions[0].alignment["symbol"]: + t = ali.start + cut.start + if left <= t < right: + alignments.append(ali.with_offset(cut.start)) + + old_sup = supervisions[rec_id] + # Assuming the supervisions are sorted with the same recoding order as in cuts_chunk + # old_sup = supervisions[utt_idx] + assert old_sup.recording_id == rec_id, (old_sup.recording_id, rec_id) + + new_sup = SupervisionSegment( + id=rec_id, + recording_id=rec_id, + start=0, + duration=rec.duration, + alignment={"symbol": alignments}, + language=old_sup.language, + speaker=old_sup.speaker, + ) + + utt_cut = MonoCut( + id=rec_id, + start=0, + duration=rec.duration, + channel=0, + recording=rec, + supervisions=[new_sup], + ) + # Set a custom attribute to the cut + utt_cut.text_path = old_sup.book + + return utt_cut + + last_rec_id = None + cut_list = [] + utt_idx = 0 + + futures = [] + with ThreadPoolExecutor(max_workers=1) as executor: + + for cut in cuts_chunk: + cur_rec_id = cut.recording.id + if len(cut_list) == 0: + # Case of the first cut + last_rec_id = cur_rec_id + cut_list.append(cut) + elif cur_rec_id == last_rec_id: + cut_list.append(cut) + else: + # Case of a cut belonging to a new recording + utt_cut = _merge(cut_list, last_rec_id, utt_idx) + utt_idx += 1 + + futures.append(executor.submit(_save_worker, utt_cut)) + + last_rec_id = cur_rec_id + cut_list = [cut] + + if utt_idx % 5000 == 0: + logging.info(f"Procesed {utt_idx} utterances.") + + # For the cuts belonging to the last recording + if len(cut_list) != 0: + utt_cut = _merge(cut_list, last_rec_id, utt_idx) + utt_idx += 1 + + futures.append(executor.submit(_save_worker, utt_cut)) + logging.info("Finished") + + for f in futures: + f.result() + + return utt_idx + + +def main(): + args = get_parser() + + sp = spm.SentencePieceProcessor() + sp.load(args.bpe_model) + + # It contains "librilight_recordings_*.jsonl.gz" and "librilight_supervisions_small.jsonl.gz" + manifest_out_dir = args.manifest_out_dir + + subsets = ["small", "median", "large"] + + for subset in subsets: + logging.info(f"Processing {subset} subset") + + manifest_out = manifest_out_dir / f"librilight_cuts_{subset}.jsonl.gz" + if manifest_out.is_file(): + logging.info(f"{manifest_out} already exists - skipping.") + continue + + supervisions = load_manifest( + manifest_out_dir / f"librilight_supervisions_{subset}.jsonl.gz" + ) # We will use the text path from supervisions + + cuts_chunk = load_manifest_lazy( + args.manifest_in_dir / f"librilight_cuts_{subset}.jsonl.gz" + ) + + cuts_writer = CutSet.open_writer(manifest_out, overwrite=True) + num_utt = merge_chunks( + cuts_chunk, supervisions, cuts_writer=cuts_writer, sp=sp, extra=args.extra + ) + cuts_writer.close() + logging.info(f"{num_utt} cuts saved to {manifest_out}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/ASR/long_file_recog/recognize.py b/egs/librispeech/ASR/long_file_recog/recognize.py new file mode 100755 index 0000000000..96c83f8591 --- /dev/null +++ b/egs/librispeech/ASR/long_file_recog/recognize.py @@ -0,0 +1,435 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corporation (Author: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This script loads torchscript models, exported by `torch.jit.script()`, +and uses them to decode waves. +You can use the following command to get the exported models: + +./pruned_transducer_stateless7/export.py \ + --exp-dir ./pruned_transducer_stateless7/exp \ + --bpe-model data/lang_bpe_500/bpe.model \ + --epoch 20 \ + --avg 10 \ + --jit 1 + +You can also download the jit model from +https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 +""" + +import argparse +import torch.multiprocessing as mp +import torch +import torch.nn as nn +import logging +from concurrent.futures import ThreadPoolExecutor +from typing import List, Optional, Tuple + +from pathlib import Path + +import k2 +import sentencepiece as spm +from asr_datamodule import AsrDataModule +from beam_search import ( + fast_beam_search_one_best, + greedy_search_batch, + modified_beam_search, +) +from icefall.utils import AttributeDict, convert_timestamp, setup_logger +from lhotse import CutSet, load_manifest_lazy +from lhotse.cut import Cut +from lhotse.supervision import AlignmentItem +from lhotse.serialization import SequentialJsonlWriter + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + parser.add_argument( + "--subset", + type=str, + default="small", + help="Subset to process. Possible values are 'small', 'medium', 'large'", + ) + + parser.add_argument( + "--manifest-in-dir", + type=Path, + default=Path("data/librilight/manifests_chunk"), + help="Path to directory with chunks cuts.", + ) + + parser.add_argument( + "--manifest-out-dir", + type=Path, + default=Path("data/librilight/manifests_chunk_recog"), + help="Path to directory to save the chunk cuts with recognition results.", + ) + + parser.add_argument( + "--log-dir", + type=Path, + default=Path("long_file_recog/log"), + help="Path to directory to save logs.", + ) + + parser.add_argument( + "--nn-model-filename", + type=str, + required=True, + help="Path to the torchscript model cpu_jit.pt", + ) + + parser.add_argument( + "--bpe-model", + type=str, + default="data/lang_bpe_500/bpe.model", + help="Path to the BPE model", + ) + + parser.add_argument( + "--decoding-method", + type=str, + default="greedy_search", + help="""Possible values are: + - greedy_search + - modified_beam_search + - fast_beam_search + """, + ) + + return parser + + +def get_params() -> AttributeDict: + """Return a dict containing decoding parameters.""" + params = AttributeDict( + { + "subsampling_factor": 4, + "frame_shift_ms": 10, + # Used only when --method is beam_search or modified_beam_search. + "beam_size": 4, + # Used only when --method is beam_search or fast_beam_search. + # A floating point value to calculate the cutoff score during beam + # search (i.e., `cutoff = max-score - beam`), which is the same as the + # `beam` in Kaldi. + "beam": 4, + "max_contexts": 4, + "max_states": 8, + } + ) + return params + + +def decode_one_batch( + params: AttributeDict, + model: nn.Module, + batch: dict, + decoding_graph: Optional[k2.Fsa] = None, +) -> Tuple[List[List[str]], List[List[float]], List[List[float]]]: + """Decode one batch. + + Args: + params: + It's the return value of :func:`get_params`. + paramsmodel: + The neural model. + batch: + It is the return value from iterating + `lhotse.dataset.K2SpeechRecognitionDataset`. See its documentation + for the format of the `batch`. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search. + + Returns: + Return the decoding result, timestamps, and scores. + """ + device = next(model.parameters()).device + feature = batch["inputs"] + assert feature.ndim == 3 + + feature = feature.to(device) + # at entry, feature is (N, T, C) + + supervisions = batch["supervisions"] + feature_lens = supervisions["num_frames"].to(device) + + encoder_out, encoder_out_lens = model.encoder(x=feature, x_lens=feature_lens) + + if params.decoding_method == "fast_beam_search": + res = fast_beam_search_one_best( + model=model, + decoding_graph=decoding_graph, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam, + max_contexts=params.max_contexts, + max_states=params.max_states, + return_timestamps=True, + ) + elif params.decoding_method == "greedy_search": + res = greedy_search_batch( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + return_timestamps=True, + ) + elif params.decoding_method == "modified_beam_search": + res = modified_beam_search( + model=model, + encoder_out=encoder_out, + encoder_out_lens=encoder_out_lens, + beam=params.beam_size, + return_timestamps=True, + ) + else: + raise ValueError(f"Unsupported decoding method: {params.decoding_method}") + + hyps = [] + timestamps = [] + scores = [] + for i in range(feature.shape[0]): + hyps.append(res.hyps[i]) + timestamps.append( + convert_timestamp( + res.timestamps[i], params.subsampling_factor, params.frame_shift_ms + ) + ) + scores.append(res.scores[i]) + + return hyps, timestamps, scores + + +def decode_dataset( + dl: torch.utils.data.DataLoader, + params: AttributeDict, + model: nn.Module, + sp: spm.SentencePieceProcessor, + cuts_writer: SequentialJsonlWriter, + decoding_graph: Optional[k2.Fsa] = None, +) -> None: + """Decode dataset and store the recognition results to manifest. + + Args: + dl: + PyTorch's dataloader containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + sp: + The BPE model. + cuts_writer: + Writer to save the cuts with recognition results. + decoding_graph: + The decoding graph. Can be either a `k2.trivial_graph` or LG, Used + only when --decoding_method is fast_beam_search. + + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains five elements: + - cut_id + - reference transcript + - predicted result + - timestamps of reference transcript + - timestamps of predicted result + """ + # Background worker to add alignemnt and save cuts to disk. + def _save_worker( + cuts: List[Cut], + hyps: List[List[str]], + timestamps: List[List[float]], + scores: List[List[float]], + ): + for cut, symbol_list, time_list, score_list in zip( + cuts, hyps, timestamps, scores + ): + symbol_list = sp.id_to_piece(symbol_list) + ali = [ + AlignmentItem(symbol=symbol, start=start, duration=None, score=score) + for symbol, start, score in zip(symbol_list, time_list, score_list) + ] + assert len(cut.supervisions) == 1, len(cut.supervisions) + cut.supervisions[0].alignment = {"symbol": ali} + cuts_writer.write(cut, flush=True) + + num_cuts = 0 + log_interval = 10 + futures = [] + with ThreadPoolExecutor(max_workers=1) as executor: + # We only want one background worker so that serialization is deterministic. + + for batch_idx, batch in enumerate(dl): + cuts = batch["supervisions"]["cut"] + + hyps, timestamps, scores = decode_one_batch( + params=params, + model=model, + decoding_graph=decoding_graph, + batch=batch, + ) + + futures.append( + executor.submit(_save_worker, cuts, hyps, timestamps, scores) + ) + + num_cuts += len(cuts) + if batch_idx % log_interval == 0: + logging.info(f"cuts processed until now is {num_cuts}") + + for f in futures: + f.result() + + +@torch.no_grad() +def run(rank, world_size, args, in_cuts): + """ + Args: + rank: + It is a value between 0 and `world_size-1`. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + setup_logger(f"{params.log_dir}/log-decode") + logging.info("Decoding started") + + assert params.decoding_method in ( + "greedy_search", + "fast_beam_search", + "modified_beam_search", + ), params.decoding_method + + sp = spm.SentencePieceProcessor() + sp.load(params.bpe_model) + + # is defined in local/train_bpe_model.py + params.blank_id = sp.piece_to_id("") + params.unk_id = sp.piece_to_id("") + params.vocab_size = sp.get_piece_size() + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"device: {device}") + + logging.info("Loading jit model") + model = torch.jit.load(params.nn_model_filename) + model.to(device) + model.eval() + + if params.decoding_method == "fast_beam_search": + decoding_graph = k2.trivial_graph(params.vocab_size - 1, device=device) + else: + decoding_graph = None + + # we will store new cuts with recognition results. + args.return_cuts = True + asr_data_module = AsrDataModule(args) + + if world_size > 1: + in_cuts = in_cuts[rank] + out_cuts_filename = params.manifest_out_dir / ( + f"{params.cuts_filename}_job_{rank}" + params.suffix + ) + else: + out_cuts_filename = params.manifest_out_dir / ( + f"{params.cuts_filename}" + params.suffix + ) + + dl = asr_data_module.dataloaders(in_cuts) + + cuts_writer = CutSet.open_writer(out_cuts_filename, overwrite=True) + decode_dataset( + dl=dl, + params=params, + model=model, + sp=sp, + decoding_graph=decoding_graph, + cuts_writer=cuts_writer, + ) + cuts_writer.close() + logging.info(f"Cuts saved to {out_cuts_filename}") + + logging.info("Done!") + + +def main(): + parser = get_parser() + AsrDataModule.add_arguments(parser) + args = parser.parse_args() + + subset = args.subset + assert subset in ["small", "medium", "large"], subset + + manifest_out_dir = args.manifest_out_dir + manifest_out_dir.mkdir(parents=True, exist_ok=True) + + args.suffix = ".jsonl.gz" + args.cuts_filename = f"librilight_cuts_{args.subset}" + + out_cuts_filename = manifest_out_dir / (args.cuts_filename + args.suffix) + if out_cuts_filename.is_file(): + logging.info(f"{out_cuts_filename} already exists - skipping.") + return + + in_cuts_filename = args.manifest_in_dir / (args.cuts_filename + args.suffix) + in_cuts = load_manifest_lazy(in_cuts_filename) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + chunk_size = (len(in_cuts) + (world_size - 1)) // world_size + # Each manifest is saved at: ``{output_dir}/{prefix}.{split_idx}.jsonl.gz`` + splits = in_cuts.split_lazy( + output_dir=args.manifest_in_dir / "split", + chunk_size=chunk_size, + prefix=args.cuts_filename, + ) + assert len(splits) == world_size, (len(splits), world_size) + mp.spawn(run, args=(world_size, args, splits), nprocs=world_size, join=True) + else: + run(rank=0, world_size=world_size, args=args, in_cuts=in_cuts) + + +torch.set_num_threads(1) +torch.set_num_interop_threads(1) + +if __name__ == "__main__": + main() diff --git a/egs/librispeech/ASR/long_file_recog/split_into_chunks.py b/egs/librispeech/ASR/long_file_recog/split_into_chunks.py new file mode 100755 index 0000000000..4a900831cd --- /dev/null +++ b/egs/librispeech/ASR/long_file_recog/split_into_chunks.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang, Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +""" +This script splits long utterances into chunks with overlaps. +Each chunk (except the first and the last) is padded with extra left side and right side. +The chunk length is: left_side + chunk_size + right_side. +""" + +import argparse +import logging +from pathlib import Path + +from lhotse import CutSet, load_manifest + + +def get_args(): + parser = argparse.ArgumentParser() + + parser.add_argument( + "--manifest-in-dir", + type=Path, + default=Path("data/librilight/manifests"), + help="Path to directory of full utterances.", + ) + + parser.add_argument( + "--manifest-out-dir", + type=Path, + default=Path("data/librilight/manifests_chunk"), + help="Path to directory to save splitted chunks.", + ) + + parser.add_argument( + "--chunk", + type=float, + default=300.0, + help="""Duration (in seconds) of each chunk.""", + ) + + parser.add_argument( + "--extra", + type=float, + default=2.0, + help="""Extra duration (in seconds) at both sides.""", + ) + + return parser.parse_args() + + +def main(): + args = get_args() + logging.info(vars(args)) + + manifest_out_dir = args.manifest_out_dir + manifest_out_dir.mkdir(parents=True, exist_ok=True) + + subsets = ["small", "medium", "large"] + + for subset in subsets: + logging.info(f"Processing {subset} subset") + + manifest_out = manifest_out_dir / f"librilight_cuts_{subset}.jsonl.gz" + if manifest_out.is_file(): + logging.info(f"{manifest_out} already exists - skipping.") + continue + + manifest_in = args.manifest_in_dir / f"librilight_recordings_{subset}.jsonl.gz" + recordings = load_manifest(manifest_in) + + cuts = CutSet.from_manifests(recordings=recordings) + cuts = cuts.cut_into_windows( + duration=args.chunk, hop=args.chunk - args.extra * 2 + ) + cuts = cuts.fill_supervisions() + + cuts.to_file(manifest_out) + logging.info(f"Cuts saved to {manifest_out}") + + +if __name__ == "__main__": + formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + logging.basicConfig(format=formatter, level=logging.INFO) + + main() diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py index 0280193ca7..f5f15808d4 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py @@ -670,6 +670,8 @@ def greedy_search_batch( # timestamp[n][i] is the frame index after subsampling # on which hyp[n][i] is decoded timestamps = [[] for _ in range(N)] + # scores[n][i] is the logits on which hyp[n][i] is decoded + scores = [[] for _ in range(N)] decoder_input = torch.tensor( hyps, @@ -707,6 +709,7 @@ def greedy_search_batch( if v not in (blank_id, unk_id): hyps[i].append(v) timestamps[i].append(t) + scores[i].append(logits[i, v].item()) emitted = True if emitted: # update decoder output @@ -722,10 +725,12 @@ def greedy_search_batch( sorted_ans = [h[context_size:] for h in hyps] ans = [] ans_timestamps = [] + ans_scores = [] unsorted_indices = packed_encoder_out.unsorted_indices.tolist() for i in range(N): ans.append(sorted_ans[unsorted_indices[i]]) ans_timestamps.append(timestamps[unsorted_indices[i]]) + ans_scores.append(scores[unsorted_indices[i]]) if not return_timestamps: return ans @@ -733,6 +738,7 @@ def greedy_search_batch( return DecodingResults( hyps=ans, timestamps=ans_timestamps, + scores=ans_scores, ) diff --git a/icefall/utils.py b/icefall/utils.py index eba95ee111..fb350a73ff 100644 --- a/icefall/utils.py +++ b/icefall/utils.py @@ -272,6 +272,9 @@ class DecodingResults: # for the i-th utterance with fast_beam_search_nbest_LG. hyps: Union[List[List[int]], k2.RaggedTensor] + # scores[i][k] contains the log-prob of tokens[i][k] + scores: Optional[List[List[float]]] = None + def get_texts_with_timestamp( best_paths: k2.Fsa, return_ragged: bool = False @@ -1442,7 +1445,7 @@ def convert_timestamp( frame_shift = frame_shift_ms / 1000.0 time = [] for f in frames: - time.append(f * subsampling_factor * frame_shift) + time.append(round(f * subsampling_factor * frame_shift, ndigits=3)) return time