From 3fc06cc2b9120a79a3e061bf35cef8d7220a42f3 Mon Sep 17 00:00:00 2001 From: Xiaoyu Yang <45973641+marcoyang1998@users.noreply.github.com> Date: Thu, 22 Aug 2024 15:27:25 +0800 Subject: [PATCH] Support AudioSet training with weighted sampler (#1727) --- egs/audioset/AT/RESULTS.md | 36 +++++-- egs/audioset/AT/local/compute_weight.py | 73 ++++++++++++++ egs/audioset/AT/prepare.sh | 13 ++- egs/audioset/AT/zipformer/at_datamodule.py | 107 ++++++++++++++++----- egs/audioset/AT/zipformer/train.py | 11 ++- 5 files changed, 207 insertions(+), 33 deletions(-) create mode 100644 egs/audioset/AT/local/compute_weight.py diff --git a/egs/audioset/AT/RESULTS.md b/egs/audioset/AT/RESULTS.md index 0128b70184..36613db031 100644 --- a/egs/audioset/AT/RESULTS.md +++ b/egs/audioset/AT/RESULTS.md @@ -35,16 +35,40 @@ python zipformer/train.py \ --master-port 13455 ``` +We recommend that you train the model with weighted sampler, as the model converges +faster with better performance: + +| Model | mAP | +| ------ | ------- | +| Zipformer-AT, train with weighted sampler | 46.6 | + The evaluation command is: ```bash -python zipformer/evaluate.py \ - --epoch 32 \ - --avg 8 \ - --exp-dir zipformer/exp_at_as_full \ - --max-duration 500 +export CUDA_VISIBLE_DEVICES="4,5,6,7" +subset=full +weighted_sampler=1 +bucket_sampler=0 +lr_epochs=15 + +python zipformer/train.py \ + --world-size 4 \ + --audioset-subset $subset \ + --num-epochs 120 \ + --start-epoch 1 \ + --use-fp16 1 \ + --num-events 527 \ + --lr-epochs $lr_epochs \ + --exp-dir zipformer/exp_AS_${subset}_weighted_sampler${weighted_sampler} \ + --weighted-sampler $weighted_sampler \ + --bucketing-sampler $bucket_sampler \ + --max-duration 1000 \ + --enable-musan True \ + --master-port 13452 ``` +The command for evaluation is the same. The pre-trained model can be downloaded from https://huggingface.co/marcoyang/icefall-audio-tagging-audioset-zipformer-M-weighted-sampler + #### small-scaled model, number of model parameters: 22125218, i.e., 22.13 M @@ -92,4 +116,4 @@ python zipformer/evaluate.py \ --encoder-unmasked-dim 192,192,192,192,192,192 \ --exp-dir zipformer/exp_small_at_as_full \ --max-duration 500 -``` \ No newline at end of file +``` diff --git a/egs/audioset/AT/local/compute_weight.py b/egs/audioset/AT/local/compute_weight.py new file mode 100644 index 0000000000..a0deddc0c9 --- /dev/null +++ b/egs/audioset/AT/local/compute_weight.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +# Copyright 2023 Xiaomi Corp. (authors: Xiaoyu Yang) +# +# See ../../../../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file generates the manifest and computes the fbank features for AudioSet +dataset. The generated manifests and features are stored in data/fbank. +""" + +import argparse + +import lhotse +from lhotse import load_manifest + + +def get_parser(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + + parser.add_argument( + "--input-manifest", type=str, default="data/fbank/cuts_audioset_full.jsonl.gz" + ) + + parser.add_argument( + "--output", + type=str, + required=True, + ) + return parser + + +def main(): + # Reference: https://github.com/YuanGongND/ast/blob/master/egs/audioset/gen_weight_file.py + parser = get_parser() + args = parser.parse_args() + + cuts = load_manifest(args.input_manifest) + + print(f"A total of {len(cuts)} cuts.") + + label_count = [0] * 527 # a total of 527 classes + for c in cuts: + audio_event = c.supervisions[0].audio_event + labels = list(map(int, audio_event.split(";"))) + for label in labels: + label_count[label] += 1 + + with open(args.output, "w") as f: + for c in cuts: + audio_event = c.supervisions[0].audio_event + labels = list(map(int, audio_event.split(";"))) + weight = 0 + for label in labels: + weight += 1000 / (label_count[label] + 0.01) + f.write(f"{c.id} {weight}\n") + + +if __name__ == "__main__": + main() diff --git a/egs/audioset/AT/prepare.sh b/egs/audioset/AT/prepare.sh index f7f73a008c..8beaf2d86a 100755 --- a/egs/audioset/AT/prepare.sh +++ b/egs/audioset/AT/prepare.sh @@ -10,6 +10,7 @@ stage=-1 stop_stage=4 dl_dir=$PWD/download +fbank_dir=data/fbank # we assume that you have your downloaded the AudioSet and placed # it under $dl_dir/audioset, the folder structure should look like @@ -49,7 +50,6 @@ fi if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then log "Stage 0: Construct the audioset manifest and compute the fbank features for balanced set" - fbank_dir=data/fbank if [! -e $fbank_dir/.balanced.done]; then python local/generate_audioset_manifest.py \ --dataset-dir $dl_dir/audioset \ @@ -102,3 +102,14 @@ if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then touch data/fbank/.musan.done fi fi + +# The following stages are required to do weighted-sampling training +if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then + log "Stage 5: Prepare for weighted-sampling training" + if [ ! -e $fbank_dir/cuts_audioset_full.jsonl.gz ]; then + lhotse combine $fbank_dir/cuts_audioset_balanced.jsonl.gz $fbank_dir/cuts_audioset_unbalanced.jsonl.gz $fbank_dir/cuts_audioset_full.jsonl.gz + fi + python ./local/compute_weight.py \ + --input-manifest $fbank_dir/cuts_audioset_full.jsonl.gz \ + --output $fbank_dir/sampling_weights_full.txt +fi diff --git a/egs/audioset/AT/zipformer/at_datamodule.py b/egs/audioset/AT/zipformer/at_datamodule.py index ac8671fa61..b7df015390 100644 --- a/egs/audioset/AT/zipformer/at_datamodule.py +++ b/egs/audioset/AT/zipformer/at_datamodule.py @@ -31,6 +31,7 @@ PrecomputedFeatures, SimpleCutSampler, SpecAugment, + WeightedSimpleCutSampler, ) from lhotse.dataset.input_strategies import ( # noqa F401 For AudioSamples AudioSamples, @@ -99,6 +100,20 @@ def add_arguments(cls, parser: argparse.ArgumentParser): help="Maximum pooled recordings duration (seconds) in a " "single batch. You can reduce it if it causes CUDA OOM.", ) + group.add_argument( + "--weighted-sampler", + type=str2bool, + default=False, + help="When enabled, samples are drawn from by their weights. " + "It cannot be used together with bucketing sampler", + ) + group.add_argument( + "--num-samples", + type=int, + default=200000, + help="The number of samples to be drawn in each epoch. Only be used" + "for weighed sampler", + ) group.add_argument( "--bucketing-sampler", type=str2bool, @@ -295,6 +310,9 @@ def train_dataloaders( ) if self.args.bucketing_sampler: + assert ( + not self.args.weighted_sampler + ), "weighted sampling is not supported in bucket sampler" logging.info("Using DynamicBucketingSampler.") train_sampler = DynamicBucketingSampler( cuts_train, @@ -304,13 +322,26 @@ def train_dataloaders( drop_last=self.args.drop_last, ) else: - logging.info("Using SimpleCutSampler.") - train_sampler = SimpleCutSampler( - cuts_train, - max_duration=self.args.max_duration, - shuffle=self.args.shuffle, - drop_last=self.args.drop_last, - ) + if self.args.weighted_sampler: + # assert self.args.audioset_subset == "full", "Only use weighted sampling for full audioset" + logging.info("Using weighted SimpleCutSampler") + weights = self.audioset_sampling_weights() + train_sampler = WeightedSimpleCutSampler( + cuts_train, + weights, + num_samples=self.args.num_samples, + max_duration=self.args.max_duration, + shuffle=False, # do not support shuffle + drop_last=self.args.drop_last, + ) + else: + logging.info("Using SimpleCutSampler.") + train_sampler = SimpleCutSampler( + cuts_train, + max_duration=self.args.max_duration, + shuffle=self.args.shuffle, + drop_last=self.args.drop_last, + ) logging.info("About to create train dataloader") if sampler_state_dict is not None: @@ -373,11 +404,9 @@ def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: def test_dataloaders(self, cuts: CutSet) -> DataLoader: logging.debug("About to create test dataset") test = AudioTaggingDataset( - input_strategy=( - OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) - if self.args.on_the_fly_feats - else eval(self.args.input_strategy)() - ), + input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) + if self.args.on_the_fly_feats + else eval(self.args.input_strategy)(), return_cuts=self.args.return_cuts, ) sampler = DynamicBucketingSampler( @@ -397,21 +426,30 @@ def test_dataloaders(self, cuts: CutSet) -> DataLoader: @lru_cache() def audioset_train_cuts(self) -> CutSet: logging.info("About to get the audioset training cuts.") - balanced_cuts = load_manifest_lazy( - self.args.manifest_dir / "cuts_audioset_balanced.jsonl.gz" - ) - if self.args.audioset_subset == "full": - unbalanced_cuts = load_manifest_lazy( - self.args.manifest_dir / "cuts_audioset_unbalanced.jsonl.gz" - ) - cuts = CutSet.mux( - balanced_cuts, - unbalanced_cuts, - weights=[20000, 2000000], - stop_early=True, + if not self.args.weighted_sampler: + balanced_cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_audioset_balanced.jsonl.gz" ) + if self.args.audioset_subset == "full": + unbalanced_cuts = load_manifest_lazy( + self.args.manifest_dir / "cuts_audioset_unbalanced.jsonl.gz" + ) + cuts = CutSet.mux( + balanced_cuts, + unbalanced_cuts, + weights=[20000, 2000000], + stop_early=True, + ) + else: + cuts = balanced_cuts else: - cuts = balanced_cuts + # assert self.args.audioset_subset == "full", "Only do weighted sampling for full AudioSet" + cuts = load_manifest( + self.args.manifest_dir + / f"cuts_audioset_{self.args.audioset_subset}.jsonl.gz" + ) + logging.info(f"Get {len(cuts)} cuts in total.") + return cuts @lru_cache() @@ -420,3 +458,22 @@ def audioset_eval_cuts(self) -> CutSet: return load_manifest_lazy( self.args.manifest_dir / "cuts_audioset_eval.jsonl.gz" ) + + @lru_cache() + def audioset_sampling_weights(self): + logging.info( + f"About to get the sampling weight for {self.args.audioset_subset} in AudioSet" + ) + weights = [] + with open( + self.args.manifest_dir / f"sample_weights_{self.args.audioset_subset}.txt", + "r", + ) as f: + while True: + line = f.readline() + if not line: + break + weight = float(line.split()[1]) + weights.append(weight) + logging.info(f"Get the sampling weight for {len(weights)} cuts") + return weights diff --git a/egs/audioset/AT/zipformer/train.py b/egs/audioset/AT/zipformer/train.py index 2d193030a8..67c7033642 100644 --- a/egs/audioset/AT/zipformer/train.py +++ b/egs/audioset/AT/zipformer/train.py @@ -789,12 +789,14 @@ def save_bad_model(suffix: str = ""): rank=0, ) + num_samples = 0 for batch_idx, batch in enumerate(train_dl): if batch_idx % 10 == 0: set_batch_count(model, get_adjusted_batch_count(params)) params.batch_idx_train += 1 batch_size = batch["inputs"].size(0) + num_samples += batch_size try: with torch.cuda.amp.autocast(enabled=params.use_fp16): @@ -919,6 +921,12 @@ def save_bad_model(suffix: str = ""): tb_writer, "train/valid_", params.batch_idx_train ) + if num_samples > params.num_samples: + logging.info( + f"Number of training samples exceeds {params.num_samples} in this epoch, move on to next epoch" + ) + break + loss_value = tot_loss["loss"] / tot_loss["frames"] params.train_loss = loss_value if params.train_loss < params.best_train_loss: @@ -1032,7 +1040,8 @@ def remove_short_and_long_utt(c: Cut): return True - train_cuts = train_cuts.filter(remove_short_and_long_utt) + if not params.weighted_sampler: + train_cuts = train_cuts.filter(remove_short_and_long_utt) if params.start_batch > 0 and checkpoints and "sampler" in checkpoints: # We only load the sampler's state dict when it loads a checkpoint