From baa2baaf3f3f5e1edf59da938bed219774b4399c Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Fri, 15 Oct 2021 17:38:29 +0200 Subject: [PATCH] Fix Audio feature mp3 resampling (#3096) * Test resampling dataset with mp3 Audio feature * Fix resampling mp3 Audio feature * Fix resampling mp3 Audio feature * Fix test --- src/datasets/features/audio.py | 8 ++--- tests/features/test_audio.py | 59 ++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 4 deletions(-) diff --git a/src/datasets/features/audio.py b/src/datasets/features/audio.py index af05745fbf4..c2261b40344 100644 --- a/src/datasets/features/audio.py +++ b/src/datasets/features/audio.py @@ -64,14 +64,14 @@ def _decode_example_with_torchaudio(self, value): raise ImportError("To support decoding 'mp3' audio files, please install 'sox'.") from err array, sampling_rate = torchaudio.load(value) - array = array.numpy() - if self.mono: - array = array.mean(axis=0) if self.sampling_rate and self.sampling_rate != sampling_rate: if not hasattr(self, "_resampler"): self._resampler = T.Resample(sampling_rate, self.sampling_rate) - array = self._resampler(array, sampling_rate, self.sampling_rate) + array = self._resampler(array) sampling_rate = self.sampling_rate + array = array.numpy() + if self.mono: + array = array.mean(axis=0) return array, sampling_rate def decode_batch(self, values): diff --git a/tests/features/test_audio.py b/tests/features/test_audio.py index df9fea2d4fa..293b7cd50b4 100644 --- a/tests/features/test_audio.py +++ b/tests/features/test_audio.py @@ -122,6 +122,34 @@ def test_resampling_at_loading_dataset_with_audio_feature(shared_datadir): assert column[0]["sampling_rate"] == 16000 +@require_sox +@require_sndfile +def test_resampling_at_loading_dataset_with_audio_feature_mp3(shared_datadir): + audio_path = str(shared_datadir / "test_audio_44100.mp3") + data = {"audio": [audio_path]} + features = Features({"audio": Audio(sampling_rate=16000)}) + dset = Dataset.from_dict(data, features=features) + item = dset[0] + assert item.keys() == {"audio"} + assert item["audio"].keys() == {"path", "array", "sampling_rate"} + assert item["audio"]["path"] == audio_path + assert item["audio"]["array"].shape == (39707,) + assert item["audio"]["sampling_rate"] == 16000 + batch = dset[:1] + assert batch.keys() == {"audio"} + assert len(batch["audio"]) == 1 + assert batch["audio"][0].keys() == {"path", "array", "sampling_rate"} + assert batch["audio"][0]["path"] == audio_path + assert batch["audio"][0]["array"].shape == (39707,) + assert batch["audio"][0]["sampling_rate"] == 16000 + column = dset["audio"] + assert len(column) == 1 + assert column[0].keys() == {"path", "array", "sampling_rate"} + assert column[0]["path"] == audio_path + assert column[0]["array"].shape == (39707,) + assert column[0]["sampling_rate"] == 16000 + + @require_sndfile def test_resampling_after_loading_dataset_with_audio_feature(shared_datadir): audio_path = str(shared_datadir / "test_audio_44100.wav") @@ -152,6 +180,37 @@ def test_resampling_after_loading_dataset_with_audio_feature(shared_datadir): assert column[0]["sampling_rate"] == 16000 +@require_sox +@require_sndfile +def test_resampling_after_loading_dataset_with_audio_feature_mp3(shared_datadir): + audio_path = str(shared_datadir / "test_audio_44100.mp3") + data = {"audio": [audio_path]} + features = Features({"audio": Audio()}) + dset = Dataset.from_dict(data, features=features) + item = dset[0] + assert item["audio"]["sampling_rate"] == 44100 + dset = dset.cast_column("audio", Audio(sampling_rate=16000)) + item = dset[0] + assert item.keys() == {"audio"} + assert item["audio"].keys() == {"path", "array", "sampling_rate"} + assert item["audio"]["path"] == audio_path + assert item["audio"]["array"].shape == (39707,) + assert item["audio"]["sampling_rate"] == 16000 + batch = dset[:1] + assert batch.keys() == {"audio"} + assert len(batch["audio"]) == 1 + assert batch["audio"][0].keys() == {"path", "array", "sampling_rate"} + assert batch["audio"][0]["path"] == audio_path + assert batch["audio"][0]["array"].shape == (39707,) + assert batch["audio"][0]["sampling_rate"] == 16000 + column = dset["audio"] + assert len(column) == 1 + assert column[0].keys() == {"path", "array", "sampling_rate"} + assert column[0]["path"] == audio_path + assert column[0]["array"].shape == (39707,) + assert column[0]["sampling_rate"] == 16000 + + @require_sndfile def test_dataset_with_audio_feature_map_is_not_decoded(shared_datadir): audio_path = str(shared_datadir / "test_audio_44100.wav")