From 177c7347f56bcb171cbd43346e8890c8356e09a1 Mon Sep 17 00:00:00 2001 From: pkufool Date: Tue, 11 Jun 2024 17:41:38 +0800 Subject: [PATCH] Use streaming asr to transcript the audio --- examples/libriheavy/matching.py | 4 +- examples/libriheavy/tools/asr_datamodule.py | 83 +-- examples/libriheavy/tools/beam_search.py | 317 +++++++- examples/libriheavy/tools/cut_into_segment.py | 105 +++ examples/libriheavy/tools/decode_stream.py | 120 +++ examples/libriheavy/tools/recognize.py | 5 +- .../libriheavy/tools/streaming_recognize.py | 703 ++++++++++++++++++ textsearch/python/textsearch/match.py | 4 +- 8 files changed, 1242 insertions(+), 99 deletions(-) create mode 100644 examples/libriheavy/tools/cut_into_segment.py create mode 100644 examples/libriheavy/tools/decode_stream.py create mode 100755 examples/libriheavy/tools/streaming_recognize.py diff --git a/examples/libriheavy/matching.py b/examples/libriheavy/matching.py index b470ad7..b76aade 100755 --- a/examples/libriheavy/matching.py +++ b/examples/libriheavy/matching.py @@ -94,8 +94,8 @@ def get_params() -> AttributeDict: # you can find the docs in textsearch/match.py#split_aligned_queries "preceding_context_length": 1000, "timestamp_position": "current", - "duration_add_on_left": 0.0, - "duration_add_on_right": 0.5, + "duration_add_on_left": -0.4, + "duration_add_on_right": -0.8, "silence_length_to_break": 0.45, "overlap_ratio": 0.25, "min_duration": 2, diff --git a/examples/libriheavy/tools/asr_datamodule.py b/examples/libriheavy/tools/asr_datamodule.py index 3f35147..842e156 100644 --- a/examples/libriheavy/tools/asr_datamodule.py +++ b/examples/libriheavy/tools/asr_datamodule.py @@ -27,6 +27,7 @@ from lhotse.cut import Cut from lhotse.dataset import ( K2SpeechRecognitionDataset, + DynamicBucketingSampler, SimpleCutSampler, ) from lhotse.dataset.input_strategies import ( @@ -38,53 +39,6 @@ from textsearch.utils import str2bool -class SpeechRecognitionDataset(K2SpeechRecognitionDataset): - def __init__( - self, - return_cuts: bool = False, - input_strategy: BatchIO = OnTheFlyFeatures(Fbank()), - ): - super().__init__(return_cuts=return_cuts, input_strategy=input_strategy) - - def __getitem__( - self, cuts: CutSet - ) -> Dict[str, Union[torch.Tensor, List[Cut]]]: - """ - Return a new batch, with the batch size automatically determined using the constraints - of max_frames and max_cuts. - """ - self.hdf5_fix.update() - - # Note: don't sort cuts here - # Sort the cuts by duration so that the first one determines the batch time dimensions. - # cuts = cuts.sort_by_duration(ascending=False) - - # Resample cuts since the ASR model works at 16kHz - cuts = cuts.resample(16000) - - # Get a tensor with batched feature matrices, shape (B, T, F) - # Collation performs auto-padding, if necessary. - input_tpl = self.input_strategy(cuts) - if len(input_tpl) == 3: - # An input strategy with fault tolerant audio reading mode. - # "cuts" may be a subset of the original "cuts" variable, - # that only has cuts for which we succesfully read the audio. - inputs, _, cuts = input_tpl - else: - inputs, _ = input_tpl - - # Get a dict of tensors that encode the positional information about supervisions - # in the batch of feature matrices. The tensors are named "sequence_idx", - # "start_frame/sample" and "num_frames/samples". - supervision_intervals = self.input_strategy.supervision_intervals(cuts) - - batch = {"inputs": inputs, "supervisions": supervision_intervals} - if self.return_cuts: - batch["supervisions"]["cut"] = [cut for cut in cuts] - - return batch - - class AsrDataModule: """ DataModule for k2 ASR experiments. @@ -117,6 +71,13 @@ def add_arguments(cls, parser: argparse.ArgumentParser): "field: batch['supervisions']['cut'] with the cuts that " "were used to construct it.", ) + group.add_argument( + "--bucketing-sampler", + type=str2bool, + default=True, + help="When enabled, the batches will come from buckets of " + "similar duration (saves padding frames).", + ) group.add_argument( "--num-mel-bins", type=int, @@ -130,22 +91,36 @@ def add_arguments(cls, parser: argparse.ArgumentParser): help="The number of training dataloader workers that " "collect the batches.", ) + group.add_argument( + "--batch-size", + type=int, + default=10, + help="The number of utterances in a batch", + ) def dataloaders(self, cuts: CutSet) -> DataLoader: logging.debug("About to create test dataset") - dataset = SpeechRecognitionDataset( + dataset = K2SpeechRecognitionDataset( input_strategy=OnTheFlyFeatures( Fbank(FbankConfig(num_mel_bins=self.args.num_mel_bins)) ), return_cuts=self.args.return_cuts, ) - sampler = SimpleCutSampler( - cuts, - max_duration=self.args.max_duration, - shuffle=False, - drop_last=False, - ) + if self.args.bucketing_sampler: + logging.info("Using DynamicBucketingSampler.") + sampler = DynamicBucketingSampler( + cuts, + max_duration=self.args.max_duration, + shuffle=False, + drop_last=False, + ) + else: + logging.info("Using SimpleCutSampler.") + sampler = SimpleCutSampler( + cuts, + max_cuts=self.args.batch_size, + ) logging.debug("About to create test dataloader") dl = DataLoader( diff --git a/examples/libriheavy/tools/beam_search.py b/examples/libriheavy/tools/beam_search.py index 526ad84..a05a890 100644 --- a/examples/libriheavy/tools/beam_search.py +++ b/examples/libriheavy/tools/beam_search.py @@ -18,9 +18,11 @@ import warnings from dataclasses import dataclass, field from typing import Dict, List, Optional, Union +import sentencepiece as spm import torch from utils import row_splits_to_row_ids +from decode_stream import DecodeStream @dataclass @@ -157,14 +159,14 @@ class Hypothesis: # Newly predicted tokens are appended to `ys`. ys: List[int] - # The log_prob of each token in ys[context_size:] - # It is derived from the nnet_output. - scores: List[float] - # The log prob of ys. # It contains only one entry. log_prob: torch.Tensor + # The log_prob of each token in ys[context_size:] + # It is derived from the nnet_output. + scores: List[float] = field(default_factory=list) + # timestamp[i] is the frame index after subsampling # on which ys[i] is decoded timestamp: List[int] = field(default_factory=list) @@ -228,43 +230,16 @@ def get_most_probable(self, length_norm: bool = False) -> Hypothesis: else: return max(self._data.values(), key=lambda hyp: hyp.log_prob) - def remove(self, hyp: Hypothesis) -> None: - """Remove a given hypothesis. - - Caution: - `self` is modified **in-place**. - - Args: - hyp: - The hypothesis to be removed from `self`. - Note: It must be contained in `self`. Otherwise, - an exception is raised. - """ - key = hyp.key - assert key in self, f"{key} does not exist" - del self._data[key] - - def filter(self, threshold: torch.Tensor) -> "HypothesisList": - """Remove all Hypotheses whose log_prob is less than threshold. - - Caution: - `self` is not modified. Instead, a new HypothesisList is returned. - - Returns: - Return a new HypothesisList containing all hypotheses from `self` - with `log_prob` being greater than the given `threshold`. - """ - ans = HypothesisList() - for _, hyp in self._data.items(): - if hyp.log_prob > threshold: - ans.add(hyp) # shallow copy - return ans - - def topk(self, k: int) -> "HypothesisList": + def topk(self, k: int, length_norm: bool = False) -> "HypothesisList": """Return the top-k hypothesis.""" hyps = list(self._data.items()) - hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] + if length_norm: + hyps = sorted( + hyps, key=lambda h: h[1].log_prob / len(h[1].ys), reverse=True + )[:k] + else: + hyps = sorted(hyps, key=lambda h: h[1].log_prob, reverse=True)[:k] ans = HypothesisList(dict(hyps)) return ans @@ -490,3 +465,269 @@ def modified_beam_search( timestamps=ans_timestamps, scores=ans_scores, ) + + +# The force alignment problem can be formulated as finding +# a path in a rectangular lattice, where the path starts +# from the lower left corner and ends at the upper right +# corner. The horizontal axis of the lattice is `t` (representing +# acoustic frame indexes) and the vertical axis is `u` (representing +# BPE tokens of the transcript). +# +# The notations `t` and `u` are from the paper +# https://arxiv.org/pdf/1211.3711.pdf +# +# Beam search is used to find the path with the highest log probabilities. +# +# It assumes the maximum number of symbols that can be +# emitted per frame is 1. + + +def batch_force_alignment( + model: torch.nn.Module, + sp: spm.SentencePieceProcessor, + encoder_out: torch.Tensor, + encoder_out_lens: torch.Tensor, + ys_list: List[List[int]], + beam_size: int = 4, +) -> List[int]: + """Compute the force alignment of a batch of utterances given their transcripts + in BPE tokens and the corresponding acoustic output from the encoder. + + Caution: + This function is modified from `modified_beam_search` in beam_search.py. + We assume that the maximum number of sybmols per frame is 1. + + Args: + model: + The transducer model. + encoder_out: + A tensor of shape (N, T, C). + encoder_out_lens: + A 1-D tensor of shape (N,), containing number of valid frames in + encoder_out before padding. + ys_list: + A list of BPE token IDs list. We require that for each utterance i, + len(ys_list[i]) <= encoder_out_lens[i]. + beam_size: + Size of the beam used in beam search. + + Returns: + Return a list of frame indexes list for each utterance i, + where len(ans[i]) == len(ys_list[i]). + """ + assert encoder_out.ndim == 3, encoder_out.shape + assert encoder_out.size(0) == len(ys_list), ( + encoder_out.size(0), + len(ys_list), + ) + assert encoder_out.size(0) > 0, encoder_out.size(0) + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = next(model.parameters()).device + + packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence( + input=encoder_out, + lengths=encoder_out_lens.cpu(), + batch_first=True, + enforce_sorted=False, + ) + batch_size_list = packed_encoder_out.batch_sizes.tolist() + N = encoder_out.size(0) + assert torch.all(encoder_out_lens > 0), encoder_out_lens + assert N == batch_size_list[0], (N, batch_size_list) + + sorted_indices = packed_encoder_out.sorted_indices.tolist() + encoder_out_lens = encoder_out_lens.tolist() + ys_lens = [len(ys) for ys in ys_list] + sorted_encoder_out_lens = [encoder_out_lens[i] for i in sorted_indices] + sorted_ys_lens = [ys_lens[i] for i in sorted_indices] + sorted_ys_list = [ys_list[i] for i in sorted_indices] + + B = [HypothesisList() for _ in range(N)] + for i in range(N): + B[i].add( + Hypothesis( + ys=[blank_id] * context_size, + log_prob=torch.zeros(1, dtype=torch.float32, device=device), + timestamp=[], + scores=[], + ) + ) + + encoder_out = model.joiner.encoder_proj(packed_encoder_out.data) + + offset = 0 + finalized_B = [] + for t, batch_size in enumerate(batch_size_list): + start = offset + end = offset + batch_size + current_encoder_out = encoder_out.data[start:end] + current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1) + # current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim) + offset = end + + finalized_B = B[batch_size:] + finalized_B + B = B[:batch_size] + sorted_encoder_out_lens = sorted_encoder_out_lens[:batch_size] + sorted_ys_lens = sorted_ys_lens[:batch_size] + + # on cpu + hyps_row_splits = get_hyps_row_splits(B) + + A = [list(b) for b in B] + B = [HypothesisList() for _ in range(batch_size)] + + ys_log_probs = torch.cat( + [hyp.log_prob.reshape(1, 1) for hyps in A for hyp in hyps] + ) # (num_hyps, 1) + + decoder_input = torch.tensor( + [hyp.ys[-context_size:] for hyps in A for hyp in hyps], + device=device, + dtype=torch.int64, + ) # (num_hyps, context_size) + + decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1) + decoder_out = model.joiner.decoder_proj(decoder_out) + # decoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + # Note: For torch 1.7.1 and below, it requires a torch.int64 tensor + # as index, so we use `to(torch.int64)` below. + current_encoder_out = torch.index_select( + current_encoder_out, + dim=0, + index=row_splits_to_row_ids(hyps_row_splits).to(device), + ) # (num_hyps, 1, 1, encoder_out_dim) + + logits = model.joiner( + current_encoder_out, decoder_out, project_input=False + ) # (num_hyps, 1, 1, vocab_size) + logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size) + log_probs = logits.log_softmax(dim=-1) # (num_hyps, vocab_size) + log_probs.add_(ys_log_probs.to(device)) + vocab_size = log_probs.size(-1) + log_probs = log_probs.reshape(-1) + row_splits = hyps_row_splits * vocab_size + + for i in range(batch_size): + current_log_probs = log_probs[ + row_splits[i] : row_splits[i + 1] + ].cpu() + + for h, hyp in enumerate(A[i]): + pos_u = len(hyp.timestamp) + idx_offset = h * vocab_size + if (sorted_encoder_out_lens[i] - 1 - t) >= ( + sorted_ys_lens[i] - pos_u + ): + # emit blank token + new_hyp = Hypothesis( + log_prob=current_log_probs[idx_offset + blank_id], + ys=hyp.ys[:], + timestamp=hyp.timestamp[:], + scores=hyp.scores[:], + ) + B[i].add(new_hyp) + if pos_u < sorted_ys_lens[i]: + # emit non-blank token + new_token = sorted_ys_list[i][pos_u] + log_prob = current_log_probs[idx_offset + new_token] + new_hyp = Hypothesis( + log_prob=log_prob, + ys=hyp.ys + [new_token], + timestamp=hyp.timestamp + [t], + scores=hyp.scores + [float(log_prob.exp())], + ) + B[i].add(new_hyp) + + if len(B[i]) > beam_size: + B[i] = B[i].topk(beam_size, length_norm=True) + + B = B + finalized_B + sorted_hyps = [b.get_most_probable() for b in B] + unsorted_indices = packed_encoder_out.unsorted_indices.tolist() + hyps = [sorted_hyps[i] for i in unsorted_indices] + ans = [] + for i, hyp in enumerate(hyps): + assert hyp.ys[context_size:] == ys_list[i], ( + hyp.ys[context_size:], + ys_list[i], + ) + sym_list = [sp.id_to_piece(j) for j in ys_list[i]] + ans.append(list(zip(sym_list, hyp.scores))) + + return ans + + +def streaming_greedy_search( + model: torch.nn.Module, + encoder_out: torch.Tensor, + streams: List[DecodeStream], + blank_penalty: float = 0.0, +) -> None: + """Greedy search in batch mode. It hardcodes --max-sym-per-frame=1. + + Args: + model: + The transducer model. + encoder_out: + Output from the encoder. Its shape is (N, T, C), where N >= 1. + streams: + A list of Stream objects. + """ + assert len(streams) == encoder_out.size(0) + assert encoder_out.ndim == 3 + + blank_id = model.decoder.blank_id + context_size = model.decoder.context_size + device = model.device + T = encoder_out.size(1) + + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + # decoder_out is of shape (N, 1, decoder_out_dim) + decoder_out = model.decoder(decoder_input, need_pad=False) + decoder_out = model.joiner.decoder_proj(decoder_out) + + for t in range(T): + # current_encoder_out's shape: (batch_size, 1, encoder_out_dim) + current_encoder_out = encoder_out[:, t : t + 1, :] # noqa + + logits = model.joiner( + current_encoder_out.unsqueeze(2), + decoder_out.unsqueeze(1), + project_input=False, + ) + # logits'shape (batch_size, vocab_size) + logits = logits.squeeze(1).squeeze(1) + + if blank_penalty != 0.0: + logits[:, 0] -= blank_penalty + + assert logits.ndim == 2, logits.shape + log_probs = logits.log_softmax(dim=-1) + y = logits.argmax(dim=1).tolist() + emitted = False + for i, v in enumerate(y): + if v != blank_id: + streams[i].hyp.append(v) + streams[i].timestamps.append(t + streams[i].done_frames) + streams[i].scores.append(log_probs[i, v].item()) + emitted = True + if emitted: + # update decoder output + decoder_input = torch.tensor( + [stream.hyp[-context_size:] for stream in streams], + device=device, + dtype=torch.int64, + ) + decoder_out = model.decoder( + decoder_input, + need_pad=False, + ) + decoder_out = model.joiner.decoder_proj(decoder_out) diff --git a/examples/libriheavy/tools/cut_into_segment.py b/examples/libriheavy/tools/cut_into_segment.py new file mode 100644 index 0000000..09c2287 --- /dev/null +++ b/examples/libriheavy/tools/cut_into_segment.py @@ -0,0 +1,105 @@ +#!/usr/bin/env python3 +# Copyright 2024 Xiaomi Corporation (Author: Wei Kang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import gzip +import json +import logging +from pathlib import Path +import soundfile as sf + + +def get_args(): + parser = argparse.ArgumentParser( + """ + Cut the long wav into small segments according to the supervisions in cuts, and + also generate the corresponding manifests. + """ + ) + parser.add_argument( + "--manifest", + type=Path, + help="""The input file in lhotse manifest format, MUST be + a jsonl.gz file. + """, + ) + + parser.add_argument( + "--output-dir", + type=Path, + help="""The directory that wavs and manifest will been written to.""", + ) + + parser.add_argument( + "--num-segments", + type=float, + default=-1, + help="The number of segments need to be processed, for debugging purpose, -1 means all.", + ) + + return parser.parse_args() + + +def cut_into_segments(ifile: Path, output_dir: Path, num_segments: int = -1): + index = 0 + with gzip.open(ifile, "r") as f, open( + output_dir / "manifests.txt", "w" + ) as fm: + prev_audio = "" + for line in f: + if num_segments != -1 and index == num_segments: + break + index += 1 + cut = json.loads(line) + id = cut["id"].replace("/", "_") + duration = cut["duration"] + start = cut["start"] + end = start + duration + audio = cut["recording"]["sources"][0]["source"] + text = cut["supervisions"][0]["custom"]["texts"][0] + if audio != prev_audio: + samples, sample_rate = sf.read(audio) + prev_audio = audio + current_samples = samples[ + int(start * sample_rate) : int((end + 0.5) * sample_rate) + ] + sf.write(output_dir / f"{id}.wav", current_samples, sample_rate) + fm.write(f"{id}\t{id}.wav\t{text}\n") + if index % 200 == 0: + logging.info("Processed {index} segments.") + + +def main(): + args = get_args() + ifile = args.manifest + assert ifile.is_file(), f"File not exists : {ifile}" + assert str(ifile).endswith( + "jsonl.gz" + ), f"Expect a jsonl gz file, given : {ifile}" + + output_dir = args.output_dir + output_dir.mkdir(parents=True, exist_ok=True) + cut_into_segments(ifile, output_dir, args.num_segments) + + +if __name__ == "__main__": + formatter = ( + "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" + ) + logging.basicConfig(format=formatter, level=logging.INFO) + main() diff --git a/examples/libriheavy/tools/decode_stream.py b/examples/libriheavy/tools/decode_stream.py new file mode 100644 index 0000000..9aad3e3 --- /dev/null +++ b/examples/libriheavy/tools/decode_stream.py @@ -0,0 +1,120 @@ +# Copyright 2022-2024 Xiaomi Corp. (authors: Wei Kang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Any, List, Optional, Tuple + +import torch + +from textsearch.utils import AttributeDict + + +class DecodeStream(object): + def __init__( + self, + params: AttributeDict, + cut: Any, + initial_states: List[torch.Tensor], + device: torch.device = torch.device("cpu"), + ) -> None: + """ + Args: + initial_states: + Initial decode states of the model, e.g. the return value of + `get_init_state` in conformer.py + device: + The device to run this stream. + """ + + self.params = params + self.cut = cut + self.LOG_EPS = math.log(1e-10) + + self.states = initial_states + + # It contains a 2-D tensors representing the feature frames. + self.features: torch.Tensor = None + + self.num_frames: int = 0 + # how many frames have been processed. (before subsampling). + # we only modify this value in `func:get_feature_frames`. + self.num_processed_frames: int = 0 + + self._done: bool = False + + # The transcript of current utterance. + self.ground_truth: str = "" + + # how many frames have been processed, at encoder output + self.done_frames: int = 0 + + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + self.pad_length = 7 + 2 * 3 + + self.hyp = [-1] * (params.context_size - 1) + [params.blank_id] + + self.timestamps = [] + + self.scores = [] + + @property + def done(self) -> bool: + """Return True if all the features are processed.""" + return self._done + + @property + def id(self) -> str: + return self.cut.id + + def set_features( + self, + features: torch.Tensor, + tail_pad_len: int = 0, + ) -> None: + """Set features tensor of current utterance.""" + assert features.dim() == 2, features.dim() + self.features = torch.nn.functional.pad( + features, + (0, 0, 0, self.pad_length + tail_pad_len), + mode="constant", + value=self.LOG_EPS, + ) + self.num_frames = self.features.size(0) + + def get_feature_frames(self, chunk_size: int) -> Tuple[torch.Tensor, int]: + """Consume chunk_size frames of features""" + chunk_length = chunk_size + self.pad_length + + ret_length = min( + self.num_frames - self.num_processed_frames, chunk_length + ) + + ret_features = self.features[ + self.num_processed_frames : self.num_processed_frames + + ret_length # noqa + ] + + self.num_processed_frames += chunk_size + if self.num_processed_frames >= self.num_frames: + self._done = True + + return ret_features, ret_length + + def decoding_result(self) -> List[int]: + """Obtain current decoding result.""" + return self.hyp[self.params.context_size :] # noqa diff --git a/examples/libriheavy/tools/recognize.py b/examples/libriheavy/tools/recognize.py index b845f7d..bfdf279 100755 --- a/examples/libriheavy/tools/recognize.py +++ b/examples/libriheavy/tools/recognize.py @@ -427,8 +427,7 @@ def main(): run(rank=0, world_size=world_size, args=args, in_cuts=in_cuts) -torch.set_num_threads(1) -torch.set_num_interop_threads(1) - if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) main() diff --git a/examples/libriheavy/tools/streaming_recognize.py b/examples/libriheavy/tools/streaming_recognize.py new file mode 100755 index 0000000..6da663e --- /dev/null +++ b/examples/libriheavy/tools/streaming_recognize.py @@ -0,0 +1,703 @@ +#!/usr/bin/env python3 +# Copyright 2022-2023 Xiaomi Corporation (Authors: Wei Kang, +# Fangjun Kuang, +# Zengwei Yao) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import argparse +import torch.multiprocessing as mp +import logging +import math +from pathlib import Path +from typing import Dict, List, Optional, Tuple +from dataclasses import dataclass + +import numpy as np +import torch +from decode_stream import DecodeStream +from lhotse import CutSet, Fbank, FbankConfig, combine, load_manifest_lazy +from beam_search import streaming_greedy_search +from torch import Tensor, nn +from torch.nn.utils.rnn import pad_sequence +from concurrent.futures import ThreadPoolExecutor +from utils import SymbolTable, convert_timestamp +from lhotse.cut import Cut +from lhotse.supervision import AlignmentItem +from lhotse.serialization import SequentialJsonlWriter + +from textsearch.utils import ( + AttributeDict, + setup_logger, + str2bool, +) + +LOG_EPS = math.log(1e-10) + + +@dataclass +class DecodingResult: + timestamps: List[int] + + hyp: List[int] + + scores: List[float] + + +def get_params() -> AttributeDict: + """Return a dict containing decoding parameters.""" + params = AttributeDict( + { + "subsampling_factor": 4, + "frame_shift_ms": 10, + "beam_size": 4, + } + ) + return params + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--world-size", + type=int, + default=1, + help="Number of GPUs for DDP training.", + ) + + parser.add_argument( + "--master-port", + type=int, + default=12354, + help="Master port to use for DDP training.", + ) + + parser.add_argument( + "--checkpoint", + type=str, + required=True, + help="Path to the checkpoint. " + "The checkpoint is assumed to be a exported jit.script model.", + ) + + parser.add_argument( + "--chunk-size", + type=int, + required=True, + help="The decoding chunk size", + ) + + parser.add_argument( + "--left-context-frames", + type=int, + required=True, + help="The decoding left context frames", + ) + + parser.add_argument( + "--manifests-in", + type=str, + required=True, + help="The path to the input manifests.", + ) + + parser.add_argument( + "--manifests-out", + type=str, + required=True, + help="The path to the output manifests.", + ) + + parser.add_argument( + "--num-decode-streams", + type=int, + default=2000, + help="The number of streams that can be decoded parallel.", + ) + + parser.add_argument( + "--log-dir", + type=Path, + default=Path("logs"), + help="Path to directory to save logs.", + ) + + parser.add_argument( + "--tokens", + type=str, + default="path/to/tokens.txt", + help="Path to the tokens.txt", + ) + return parser + + +def get_init_states( + model: nn.Module, + batch_size: int = 1, + device: torch.device = torch.device("cpu"), +) -> List[torch.Tensor]: + """ + Returns a list of cached tensors of all encoder layers. For layer-i, states[i*6:(i+1)*6] + is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, cached_conv1, cached_conv2). + states[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + """ + states = model.encoder.get_init_states(batch_size, device) + + return states + + +def stack_states(state_list: List[List[torch.Tensor]]) -> List[torch.Tensor]: + """Stack list of zipformer states that correspond to separate utterances + into a single emformer state, so that it can be used as an input for + zipformer when those utterances are formed into a batch. + + Args: + state_list: + Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. For element-n, + state_list[n] is a list of cached tensors of all encoder layers. For layer-i, + state_list[n][i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, + cached_val2, cached_conv1, cached_conv2). + state_list[n][-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + state_list[n][-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Note: + It is the inverse of :func:`unstack_states`. + """ + batch_size = len(state_list) + assert (len(state_list[0]) - 2) % 6 == 0, ( + len(state_list[0]), + len(state_list), + ) + tot_num_layers = (len(state_list[0]) - 2) // 6 + + batch_states = [] + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key = torch.cat( + [state_list[i][layer_offset] for i in range(batch_size)], dim=1 + ) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn = torch.cat( + [state_list[i][layer_offset + 1] for i in range(batch_size)], dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1 = torch.cat( + [state_list[i][layer_offset + 2] for i in range(batch_size)], dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2 = torch.cat( + [state_list[i][layer_offset + 3] for i in range(batch_size)], dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1 = torch.cat( + [state_list[i][layer_offset + 4] for i in range(batch_size)], dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2 = torch.cat( + [state_list[i][layer_offset + 5] for i in range(batch_size)], dim=0 + ) + batch_states += [ + cached_key, + cached_nonlin_attn, + cached_val1, + cached_val2, + cached_conv1, + cached_conv2, + ] + + cached_embed_left_pad = torch.cat( + [state_list[i][-2] for i in range(batch_size)], dim=0 + ) + batch_states.append(cached_embed_left_pad) + + processed_lens = torch.cat( + [state_list[i][-1] for i in range(batch_size)], dim=0 + ) + batch_states.append(processed_lens) + + return batch_states + + +def unstack_states(batch_states: List[Tensor]) -> List[List[Tensor]]: + """Unstack the zipformer state corresponding to a batch of utterances + into a list of states, where the i-th entry is the state from the i-th + utterance in the batch. + + Note: + It is the inverse of :func:`stack_states`. + + Args: + batch_states: A list of cached tensors of all encoder layers. For layer-i, + states[i*6:(i+1)*6] is (cached_key, cached_nonlin_attn, cached_val1, cached_val2, + cached_conv1, cached_conv2). + state_list[-2] is the cached left padding for ConvNeXt module, + of shape (batch_size, num_channels, left_pad, num_freqs) + states[-1] is processed_lens of shape (batch,), which records the number + of processed frames (at 50hz frame rate, after encoder_embed) for each sample in batch. + + Returns: + state_list: A list of list. Each element in state_list corresponding to the internal state + of the zipformer model for a single utterance. + """ + assert (len(batch_states) - 2) % 6 == 0, len(batch_states) + tot_num_layers = (len(batch_states) - 2) // 6 + + processed_lens = batch_states[-1] + batch_size = processed_lens.shape[0] + + state_list = [[] for _ in range(batch_size)] + + for layer in range(tot_num_layers): + layer_offset = layer * 6 + # cached_key: (left_context_len, batch_size, key_dim) + cached_key_list = batch_states[layer_offset].chunk( + chunks=batch_size, dim=1 + ) + # cached_nonlin_attn: (num_heads, batch_size, left_context_len, head_dim) + cached_nonlin_attn_list = batch_states[layer_offset + 1].chunk( + chunks=batch_size, dim=1 + ) + # cached_val1: (left_context_len, batch_size, value_dim) + cached_val1_list = batch_states[layer_offset + 2].chunk( + chunks=batch_size, dim=1 + ) + # cached_val2: (left_context_len, batch_size, value_dim) + cached_val2_list = batch_states[layer_offset + 3].chunk( + chunks=batch_size, dim=1 + ) + # cached_conv1: (#batch, channels, left_pad) + cached_conv1_list = batch_states[layer_offset + 4].chunk( + chunks=batch_size, dim=0 + ) + # cached_conv2: (#batch, channels, left_pad) + cached_conv2_list = batch_states[layer_offset + 5].chunk( + chunks=batch_size, dim=0 + ) + for i in range(batch_size): + state_list[i] += [ + cached_key_list[i], + cached_nonlin_attn_list[i], + cached_val1_list[i], + cached_val2_list[i], + cached_conv1_list[i], + cached_conv2_list[i], + ] + + cached_embed_left_pad_list = batch_states[-2].chunk( + chunks=batch_size, dim=0 + ) + for i in range(batch_size): + state_list[i].append(cached_embed_left_pad_list[i]) + + processed_lens_list = batch_states[-1].chunk(chunks=batch_size, dim=0) + for i in range(batch_size): + state_list[i].append(processed_lens_list[i]) + + return state_list + + +def streaming_forward( + features: Tensor, + feature_lens: Tensor, + model: nn.Module, + states: List[Tensor], + chunk_size: int, + left_context_len: int, +) -> Tuple[Tensor, Tensor, List[Tensor]]: + """ + Returns encoder outputs, output lengths, and updated states. + """ + cached_embed_left_pad = states[-2] + ( + x, + x_lens, + new_cached_embed_left_pad, + ) = model.encoder_embed.streaming_forward( + x=features, + x_lens=feature_lens, + cached_left_pad=cached_embed_left_pad, + ) + assert x.size(1) == chunk_size, (x.size(1), chunk_size) + + src_key_padding_mask = make_pad_mask(x_lens) + + # processed_mask is used to mask out initial states + processed_mask = torch.arange(left_context_len, device=x.device).expand( + x.size(0), left_context_len + ) + processed_lens = states[-1] # (batch,) + # (batch, left_context_size) + processed_mask = (processed_lens.unsqueeze(1) <= processed_mask).flip(1) + # Update processed lengths + new_processed_lens = processed_lens + x_lens + + # (batch, left_context_size + chunk_size) + src_key_padding_mask = torch.cat( + [processed_mask, src_key_padding_mask], dim=1 + ) + + x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C) + encoder_states = states[:-2] + ( + encoder_out, + encoder_out_lens, + new_encoder_states, + ) = model.encoder.streaming_forward( + x=x, + x_lens=x_lens, + states=encoder_states, + src_key_padding_mask=src_key_padding_mask, + ) + encoder_out = encoder_out.permute(1, 0, 2) # (T, N, C) ->(N, T, C) + + new_states = new_encoder_states + [ + new_cached_embed_left_pad, + new_processed_lens, + ] + return encoder_out, encoder_out_lens, new_states + + +def decode_one_chunk( + params: AttributeDict, + model: nn.Module, + decode_streams: List[DecodeStream], +) -> List[int]: + """Decode one chunk frames of features for each decode_streams and + return the indexes of finished streams in a List. + + Args: + params: + It's the return value of :func:`get_params`. + model: + The neural model. + decode_streams: + A List of DecodeStream, each belonging to a utterance. + Returns: + Return a List containing which DecodeStreams are finished. + """ + device = params.device + chunk_size = int(params.chunk_size) + left_context_len = int(params.left_context_frames) + + features = [] + feature_lens = [] + states = [] + processed_lens = [] # Used in fast-beam-search + + for stream in decode_streams: + feat, feat_len = stream.get_feature_frames(chunk_size * 2) + features.append(feat) + feature_lens.append(feat_len) + states.append(stream.states) + processed_lens.append(stream.done_frames) + + feature_lens = torch.tensor(feature_lens, device=device) + features = pad_sequence(features, batch_first=True, padding_value=LOG_EPS) + + # Make sure the length after encoder_embed is at least 1. + # The encoder_embed subsample features (T - 7) // 2 + # The ConvNeXt module needs (7 - 1) // 2 = 3 frames of right padding after subsampling + tail_length = chunk_size * 2 + 7 + 2 * 3 + if features.size(1) < tail_length: + pad_length = tail_length - features.size(1) + feature_lens += pad_length + features = torch.nn.functional.pad( + features, + (0, 0, 0, pad_length), + mode="constant", + value=LOG_EPS, + ) + + states = stack_states(states) + + encoder_out, encoder_out_lens, new_states = model.encoder.forward( + features=features, feature_lengths=feature_lens, states=states + ) + + encoder_out = model.joiner.encoder_proj(encoder_out) + + streaming_greedy_search( + model=model, encoder_out=encoder_out, streams=decode_streams + ) + + states = unstack_states(new_states) + + finished_streams = [] + for i in range(len(decode_streams)): + decode_streams[i].states = states[i] + decode_streams[i].done_frames += int(encoder_out_lens[i]) + if decode_streams[i].done: + finished_streams.append(i) + + return finished_streams + + +def decode_dataset( + cuts: CutSet, + params: AttributeDict, + model: nn.Module, + token_table: SymbolTable, + cuts_writer: SequentialJsonlWriter, +): + """Decode dataset. + + Args: + cuts: + Lhotse Cutset containing the dataset to decode. + params: + It is returned by :func:`get_params`. + model: + The neural model. + Returns: + Return a dict, whose key may be "greedy_search" if greedy search + is used, or it may be "beam_7" if beam size of 7 is used. + Its value is a list of tuples. Each tuple contains two elements: + The first is the reference transcript, and the second is the + predicted result. + """ + # Background worker to add alignemnt and save cuts to disk. + def _save_worker( + decode_streams: List[DecodeStream], + ): + for stream in decode_streams: + cut = stream.cut + symbol_list = [token_table[x] for x in stream.decoding_result()] + timestamps = convert_timestamp( + stream.timestamps, + params.subsampling_factor, + params.frame_shift_ms, + ) + + ali = [ + AlignmentItem( + symbol=symbol, start=start, duration=None, score=score + ) + for symbol, start, score in zip( + symbol_list, timestamps, stream.scores + ) + ] + assert len(cut.supervisions) == 1, len(cut.supervisions) + cut.supervisions[0].alignment = {"symbol": ali} + cuts_writer.write(cut, flush=True) + + device = params.device + opts = FbankConfig() + opts.sampling_rate = 16000 + opts.num_filters = 80 + + log_interval = 50 + + # Contain decode streams currently running. + decode_streams = [] + futures = [] + + with ThreadPoolExecutor(max_workers=1) as executor: + for num, cut in enumerate(cuts): + # each utterance has a DecodeStream. + initial_states = states = model.encoder.get_init_states(1, device) + decode_stream = DecodeStream( + params=params, + cut=cut, + initial_states=initial_states, + device=device, + ) + + audio: np.ndarray = cut.load_audio() + # audio.shape: (1, num_samples) + assert len(audio.shape) == 2 + assert audio.shape[0] == 1, "Should be single channel" + assert audio.dtype == np.float32, audio.dtype + + # The trained model is using normalized samples + # - this is to avoid sending [-32k,+32k] signal in... + # - some lhotse AudioTransform classes can make the signal + # be out of range [-1, 1], hence the tolerance 10 + assert ( + np.abs(audio).max() <= 10 + ), "Should be normalized to [-1, 1], 10 for tolerance..." + + samples = torch.from_numpy(audio).squeeze(0) + + fbank = Fbank(opts) + feature = fbank.extract(samples, sampling_rate=16000).to(device) + decode_stream.set_features(feature, tail_pad_len=30) + + decode_streams.append(decode_stream) + + while len(decode_streams) >= params.num_decode_streams: + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + done_streams = [] + for i in sorted(finished_streams, reverse=True): + done_streams.append(decode_streams[i]) + del decode_streams[i] + if done_streams: + futures.append(executor.submit(_save_worker, done_streams)) + + if num % log_interval == 0: + logging.info(f"Cuts processed until now is {num}.") + + # decode final chunks of last sequences + while len(decode_streams): + finished_streams = decode_one_chunk( + params=params, model=model, decode_streams=decode_streams + ) + done_streams = [] + for i in sorted(finished_streams, reverse=True): + done_streams.append(decode_streams[i]) + del decode_streams[i] + if done_streams: + futures.append(executor.submit(_save_worker, done_streams)) + + for f in futures: + f.result() + + +@torch.no_grad() +def run(rank, world_size, args, in_cuts): + """ + Args: + rank: + It is a value between 0 and `world_size-1`. + world_size: + Number of GPUs for DDP training. + args: + The return value of get_parser().parse_args() + """ + params = get_params() + params.update(vars(args)) + + setup_logger( + f"{params.log_dir}/log-decode", + dist=(rank, world_size) if world_size > 1 else None, + ) + logging.info("Decoding started") + + token_table = SymbolTable.from_file(params.tokens) + params.blank_id = token_table[""] + + logging.info(f"{params}") + + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda", rank) + logging.info(f"device: {device}") + params.device = device + + logging.info("Loading jit model") + model = torch.jit.load(params.checkpoint) + model.to(device) + model.eval() + model.device = device + + params.context_size = model.decoder.context_size + + # we will store new cuts with recognition results. + args.return_cuts = True + + if world_size > 1: + in_cuts = in_cuts[rank] + out_cuts_filename = params.manifests_out_dir / ( + f"split/{params.cuts_filename}_{rank}" + params.suffix + ) + else: + out_cuts_filename = params.manifests_out_dir / ( + f"{params.cuts_filename}" + params.suffix + ) + + cuts_writer = CutSet.open_writer(out_cuts_filename, overwrite=True) + decode_dataset( + cuts=in_cuts, + params=params, + model=model, + token_table=token_table, + cuts_writer=cuts_writer, + ) + cuts_writer.close() + logging.info(f"Cuts saved to {out_cuts_filename}") + + logging.info("Done!") + + +def main(): + parser = get_parser() + args = parser.parse_args() + + args.manifests_in = Path(args.manifests_in) + args.manifests_out = Path(args.manifests_out) + + if args.manifests_in == args.manifests_out: + print( + f"Input manifest and output manifest share the same path : " + f"{args.manifests_in}, the filenames should be different." + ) + + args.manifests_out_dir = args.manifests_out.parents[0] + args.manifests_out_dir.mkdir(parents=True, exist_ok=True) + + assert args.manifests_in.is_file(), args.manifests_in + + args.suffix = ".jsonl.gz" + args.cuts_filename = str(args.manifests_out.name).replace(args.suffix, "") + + if args.manifests_out.is_file(): + print(f"{args.manifests_out} already exists - skipping.") + return + + in_cuts = load_manifest_lazy(args.manifests_in) + + world_size = args.world_size + assert world_size >= 1 + if world_size > 1: + chunk_size = (len(in_cuts) + (world_size - 1)) // world_size + # Each manifest is saved at: ``{output_dir}/{prefix}.{split_idx}.jsonl.gz`` + splits = in_cuts.split_lazy( + output_dir=args.manifests_out_dir / "split", + chunk_size=chunk_size, + prefix=args.cuts_filename, + ) + assert len(splits) == world_size, (len(splits), world_size) + mp.spawn( + run, args=(world_size, args, splits), nprocs=world_size, join=True + ) + out_filenames = [] + for i in range(world_size): + out_filenames.append( + args.manifests_out_dir + / f"split/{args.cuts_filename}_{i}{args.suffix}" + ) + cuts = combine(*[load_manifest_lazy(x) for x in out_filenames]) + cuts.to_file(args.manifests_out) + print(f"Cuts saved to {args.manifests_out}") + else: + run(rank=0, world_size=world_size, args=args, in_cuts=in_cuts) + + +if __name__ == "__main__": + torch.set_num_threads(1) + torch.set_num_interop_threads(1) + main() diff --git a/textsearch/python/textsearch/match.py b/textsearch/python/textsearch/match.py index bc296e6..05ac141 100644 --- a/textsearch/python/textsearch/match.py +++ b/textsearch/python/textsearch/match.py @@ -1209,8 +1209,8 @@ def _split_into_segments( { "begin_byte": begin_pos, "end_byte": end_pos, - "start_time": start_time, - "duration": math.floor(1000 * (end_time - start_time)) / 1000, + "start_time": start_time + duration_add_on_left, + "duration": math.floor(1000 * (end_time + duration_add_on_right - start_time)) / 1000, "hyp": hyp, "ref": ref, "pre_ref": preceding_ref,