-
Notifications
You must be signed in to change notification settings - Fork 305
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support long audios recognition (#980)
* support long file transcription * rename recipe as long_file_recog * add docs * support multi-gpu decoding * style fix
- Loading branch information
1 parent
f18b539
commit a7e142b
Showing
8 changed files
with
1,681 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
#!/usr/bin/env bash | ||
|
||
# fix segmentation fault reported in https://github.com/k2-fsa/icefall/issues/674 | ||
export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python | ||
|
||
export CUDA_VISIBLE_DEVICES="0,1,2,3" | ||
|
||
set -eou pipefail | ||
|
||
# This script is used to recogize long audios. The process is as follows: | ||
# 1) Split long audios into chunks with overlaps. | ||
# 2) Perform speech recognition on chunks, getting tokens and timestamps. | ||
# 3) Merge the overlapped chunks into utterances acording to the timestamps. | ||
|
||
# Each chunk (except the first and the last) is padded with extra left side and right side. | ||
# The chunk length is: left_side + chunk_size + right_side. | ||
chunk=30.0 | ||
extra=2.0 | ||
|
||
stage=1 | ||
stop_stage=4 | ||
|
||
# We assume that you have downloaded the LibriLight dataset | ||
# with audio files in $corpus_dir and texts in $text_dir | ||
corpus_dir=$PWD/download/libri-light | ||
text_dir=$PWD/download/librilight_text | ||
# Path to save the manifests | ||
output_dir=$PWD/data/librilight | ||
|
||
world_size=4 | ||
|
||
|
||
log() { | ||
# This function is from espnet | ||
local fname=${BASH_SOURCE[1]##*/} | ||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" | ||
} | ||
|
||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then | ||
# We will get librilight_recodings_{subset}.jsonl.gz and librilight_supervisions_{subset}.jsonl.gz | ||
# saved in $output_dir/manifests | ||
log "Stage 1: Prepare LibriLight manifest" | ||
lhotse prepare librilight $corpus_dir $text_dir $output_dir/manifests -j 10 | ||
fi | ||
|
||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then | ||
# Chunk manifests are saved to $output_dir/manifests_chunk/librilight_cuts_{subset}.jsonl.gz | ||
log "Stage 2: Split long audio into chunks" | ||
./long_file_recog/split_into_chunks.py \ | ||
--manifest-in-dir $output_dir/manifests \ | ||
--manifest-out-dir $output_dir/manifests_chunk \ | ||
--chunk $chunk \ | ||
--extra $extra # Extra duration (in seconds) at both sides | ||
fi | ||
|
||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then | ||
# Recognized tokens and timestamps are saved to $output_dir/manifests_chunk_recog/librilight_cuts_{subset}.jsonl.gz | ||
|
||
# This script loads torchscript models, exported by `torch.jit.script()`, | ||
# and uses it to decode waves. | ||
# You can download the jit model from https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11 | ||
|
||
log "Stage 3: Perform speech recognition on splitted chunks" | ||
for subset in small median large; do | ||
./long_file_recog/recognize.py \ | ||
--world-size $world_size \ | ||
--num-workers 8 \ | ||
--subset $subset \ | ||
--manifest-in-dir $output_dir/manifests_chunk \ | ||
--manifest-out-dir $output_dir/manifests_chunk_recog \ | ||
--nn-model-filename long_file_recog/exp/jit_model.pt \ | ||
--bpe-model data/lang_bpe_500/bpe.model \ | ||
--max-duration 2400 \ | ||
--decoding-method greedy_search | ||
--master 12345 | ||
|
||
if [ $world_size -gt 1 ]; then | ||
# Combine manifests from different jobs | ||
lhotse combine $(find $output_dir/manifests_chunk_recog -name librilight_cuts_${subset}_job_*.jsonl.gz | tr "\n" " ") $output_dir/manifests_chunk_recog/librilight_cuts_${subset}.jsonl.gz | ||
fi | ||
done | ||
fi | ||
|
||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then | ||
# Final results are saved in $output_dir/manifests/librilight_cuts_{subset}.jsonl.gz | ||
log "Stage 4: Merge splitted chunks into utterances." | ||
./long_file_recog/merge_chunks.py \ | ||
--manifest-in-dir $output_dir/manifests_chunk_recog \ | ||
--manifest-out-dir $output_dir/manifests \ | ||
--bpe-model data/lang_bpe_500/bpe.model \ | ||
--extra $extra | ||
fi | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
# Copyright 2021 Piotr Żelasko | ||
# Copyright 2022 Xiaomi Corporation (Author: Mingshuang Luo) | ||
# Copyright 2023 Xiaomi Corporation (Author: Zengwei Yao) | ||
# | ||
# See ../../../../LICENSE for clarification regarding multiple authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
|
||
import argparse | ||
import logging | ||
from functools import lru_cache | ||
from pathlib import Path | ||
from typing import Dict, List, Union | ||
|
||
import torch | ||
from lhotse import CutSet, Fbank, FbankConfig, load_manifest_lazy | ||
from lhotse.cut import Cut | ||
from lhotse.dataset import ( # noqa F401 for PrecomputedFeatures | ||
CutConcatenate, | ||
CutMix, | ||
DynamicBucketingSampler, | ||
K2SpeechRecognitionDataset, | ||
PrecomputedFeatures, | ||
SimpleCutSampler, | ||
SpecAugment, | ||
) | ||
from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples | ||
AudioSamples, | ||
BatchIO, | ||
OnTheFlyFeatures, | ||
) | ||
from torch.utils.data import DataLoader | ||
|
||
from icefall.utils import str2bool | ||
|
||
|
||
class SpeechRecognitionDataset(K2SpeechRecognitionDataset): | ||
def __init__( | ||
self, | ||
return_cuts: bool = False, | ||
input_strategy: BatchIO = PrecomputedFeatures(), | ||
): | ||
super().__init__(return_cuts=return_cuts, input_strategy=input_strategy) | ||
|
||
def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[Cut]]]: | ||
""" | ||
Return a new batch, with the batch size automatically determined using the constraints | ||
of max_frames and max_cuts. | ||
""" | ||
self.hdf5_fix.update() | ||
|
||
# Note: don't sort cuts here | ||
# Sort the cuts by duration so that the first one determines the batch time dimensions. | ||
# cuts = cuts.sort_by_duration(ascending=False) | ||
|
||
# Get a tensor with batched feature matrices, shape (B, T, F) | ||
# Collation performs auto-padding, if necessary. | ||
input_tpl = self.input_strategy(cuts) | ||
if len(input_tpl) == 3: | ||
# An input strategy with fault tolerant audio reading mode. | ||
# "cuts" may be a subset of the original "cuts" variable, | ||
# that only has cuts for which we succesfully read the audio. | ||
inputs, _, cuts = input_tpl | ||
else: | ||
inputs, _ = input_tpl | ||
|
||
# Get a dict of tensors that encode the positional information about supervisions | ||
# in the batch of feature matrices. The tensors are named "sequence_idx", | ||
# "start_frame/sample" and "num_frames/samples". | ||
supervision_intervals = self.input_strategy.supervision_intervals(cuts) | ||
|
||
batch = {"inputs": inputs, "supervisions": supervision_intervals} | ||
if self.return_cuts: | ||
batch["supervisions"]["cut"] = [cut for cut in cuts] | ||
|
||
return batch | ||
|
||
|
||
class AsrDataModule: | ||
""" | ||
DataModule for k2 ASR experiments. | ||
It assumes there is always one train and valid dataloader, | ||
but there can be multiple test dataloaders (e.g. LibriSpeech test-clean | ||
and test-other). | ||
It contains all the common data pipeline modules used in ASR | ||
experiments, e.g.: | ||
- dynamic batch size, | ||
- bucketing samplers, | ||
- cut concatenation, | ||
- augmentation, | ||
- on-the-fly feature extraction | ||
This class should be derived for specific corpora used in ASR tasks. | ||
""" | ||
|
||
def __init__(self, args: argparse.Namespace): | ||
self.args = args | ||
|
||
@classmethod | ||
def add_arguments(cls, parser: argparse.ArgumentParser): | ||
group = parser.add_argument_group( | ||
title="ASR data related options", | ||
description="These options are used for the preparation of " | ||
"PyTorch DataLoaders from Lhotse CutSet's -- they control the " | ||
"effective batch sizes, sampling strategies, applied data " | ||
"augmentations, etc.", | ||
) | ||
group.add_argument( | ||
"--manifest-dir", | ||
type=Path, | ||
default=Path("data/manifests_chunk"), | ||
help="Path to directory with train/valid/test cuts.", | ||
) | ||
group.add_argument( | ||
"--max-duration", | ||
type=int, | ||
default=600.0, | ||
help="Maximum pooled recordings duration (seconds) in a " | ||
"single batch. You can reduce it if it causes CUDA OOM.", | ||
) | ||
group.add_argument( | ||
"--bucketing-sampler", | ||
type=str2bool, | ||
default=True, | ||
help="When enabled, the batches will come from buckets of " | ||
"similar duration (saves padding frames).", | ||
) | ||
group.add_argument( | ||
"--return-cuts", | ||
type=str2bool, | ||
default=True, | ||
help="When enabled, each batch will have the " | ||
"field: batch['supervisions']['cut'] with the cuts that " | ||
"were used to construct it.", | ||
) | ||
group.add_argument( | ||
"--num-workers", | ||
type=int, | ||
default=8, | ||
help="The number of training dataloader workers that " | ||
"collect the batches.", | ||
) | ||
|
||
group.add_argument( | ||
"--input-strategy", | ||
type=str, | ||
default="PrecomputedFeatures", | ||
help="AudioSamples or PrecomputedFeatures", | ||
) | ||
|
||
def dataloaders(self, cuts: CutSet) -> DataLoader: | ||
logging.debug("About to create test dataset") | ||
test = SpeechRecognitionDataset( | ||
input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), | ||
return_cuts=self.args.return_cuts, | ||
) | ||
|
||
sampler = SimpleCutSampler( | ||
cuts, | ||
max_duration=self.args.max_duration, | ||
shuffle=False, | ||
drop_last=False, | ||
) | ||
|
||
logging.debug("About to create test dataloader") | ||
test_dl = DataLoader( | ||
test, | ||
batch_size=None, | ||
sampler=sampler, | ||
num_workers=self.args.num_workers, | ||
persistent_workers=False, | ||
) | ||
return test_dl | ||
|
||
@lru_cache() | ||
def load_subset(self, cuts_filename: Path) -> CutSet: | ||
return load_manifest_lazy(cuts_filename) |
Oops, something went wrong.