Skip to content

Commit

Permalink
Fix Audio feature mp3 resampling (#3096)
Browse files Browse the repository at this point in the history
* Test resampling dataset with mp3 Audio feature

* Fix resampling mp3 Audio feature

* Fix resampling mp3 Audio feature

* Fix test
  • Loading branch information
albertvillanova committed Oct 15, 2021
1 parent 324c84e commit baa2baa
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/datasets/features/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@ def _decode_example_with_torchaudio(self, value):
raise ImportError("To support decoding 'mp3' audio files, please install 'sox'.") from err

array, sampling_rate = torchaudio.load(value)
array = array.numpy()
if self.mono:
array = array.mean(axis=0)
if self.sampling_rate and self.sampling_rate != sampling_rate:
if not hasattr(self, "_resampler"):
self._resampler = T.Resample(sampling_rate, self.sampling_rate)
array = self._resampler(array, sampling_rate, self.sampling_rate)
array = self._resampler(array)
sampling_rate = self.sampling_rate
array = array.numpy()
if self.mono:
array = array.mean(axis=0)
return array, sampling_rate

def decode_batch(self, values):
Expand Down
59 changes: 59 additions & 0 deletions tests/features/test_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,34 @@ def test_resampling_at_loading_dataset_with_audio_feature(shared_datadir):
assert column[0]["sampling_rate"] == 16000


@require_sox
@require_sndfile
def test_resampling_at_loading_dataset_with_audio_feature_mp3(shared_datadir):
audio_path = str(shared_datadir / "test_audio_44100.mp3")
data = {"audio": [audio_path]}
features = Features({"audio": Audio(sampling_rate=16000)})
dset = Dataset.from_dict(data, features=features)
item = dset[0]
assert item.keys() == {"audio"}
assert item["audio"].keys() == {"path", "array", "sampling_rate"}
assert item["audio"]["path"] == audio_path
assert item["audio"]["array"].shape == (39707,)
assert item["audio"]["sampling_rate"] == 16000
batch = dset[:1]
assert batch.keys() == {"audio"}
assert len(batch["audio"]) == 1
assert batch["audio"][0].keys() == {"path", "array", "sampling_rate"}
assert batch["audio"][0]["path"] == audio_path
assert batch["audio"][0]["array"].shape == (39707,)
assert batch["audio"][0]["sampling_rate"] == 16000
column = dset["audio"]
assert len(column) == 1
assert column[0].keys() == {"path", "array", "sampling_rate"}
assert column[0]["path"] == audio_path
assert column[0]["array"].shape == (39707,)
assert column[0]["sampling_rate"] == 16000


@require_sndfile
def test_resampling_after_loading_dataset_with_audio_feature(shared_datadir):
audio_path = str(shared_datadir / "test_audio_44100.wav")
Expand Down Expand Up @@ -152,6 +180,37 @@ def test_resampling_after_loading_dataset_with_audio_feature(shared_datadir):
assert column[0]["sampling_rate"] == 16000


@require_sox
@require_sndfile
def test_resampling_after_loading_dataset_with_audio_feature_mp3(shared_datadir):
audio_path = str(shared_datadir / "test_audio_44100.mp3")
data = {"audio": [audio_path]}
features = Features({"audio": Audio()})
dset = Dataset.from_dict(data, features=features)
item = dset[0]
assert item["audio"]["sampling_rate"] == 44100
dset = dset.cast_column("audio", Audio(sampling_rate=16000))
item = dset[0]
assert item.keys() == {"audio"}
assert item["audio"].keys() == {"path", "array", "sampling_rate"}
assert item["audio"]["path"] == audio_path
assert item["audio"]["array"].shape == (39707,)
assert item["audio"]["sampling_rate"] == 16000
batch = dset[:1]
assert batch.keys() == {"audio"}
assert len(batch["audio"]) == 1
assert batch["audio"][0].keys() == {"path", "array", "sampling_rate"}
assert batch["audio"][0]["path"] == audio_path
assert batch["audio"][0]["array"].shape == (39707,)
assert batch["audio"][0]["sampling_rate"] == 16000
column = dset["audio"]
assert len(column) == 1
assert column[0].keys() == {"path", "array", "sampling_rate"}
assert column[0]["path"] == audio_path
assert column[0]["array"].shape == (39707,)
assert column[0]["sampling_rate"] == 16000


@require_sndfile
def test_dataset_with_audio_feature_map_is_not_decoded(shared_datadir):
audio_path = str(shared_datadir / "test_audio_44100.wav")
Expand Down

1 comment on commit baa2baa

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Show benchmarks

PyArrow==3.0.0

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.011055 / 0.011353 (-0.000298) 0.004372 / 0.011008 (-0.006636) 0.035722 / 0.038508 (-0.002786) 0.042526 / 0.023109 (0.019417) 0.328516 / 0.275898 (0.052618) 0.445893 / 0.323480 (0.122413) 0.009560 / 0.007986 (0.001574) 0.003941 / 0.004328 (-0.000387) 0.010929 / 0.004250 (0.006679) 0.045232 / 0.037052 (0.008180) 0.337161 / 0.258489 (0.078672) 0.366355 / 0.293841 (0.072514) 0.026124 / 0.128546 (-0.102422) 0.009442 / 0.075646 (-0.066205) 0.295277 / 0.419271 (-0.123994) 0.053675 / 0.043533 (0.010142) 0.327997 / 0.255139 (0.072858) 0.354892 / 0.283200 (0.071692) 0.095180 / 0.141683 (-0.046503) 2.002112 / 1.452155 (0.549957) 2.024502 / 1.492716 (0.531785)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.238127 / 0.018006 (0.220120) 0.480187 / 0.000490 (0.479697) 0.016498 / 0.000200 (0.016298) 0.000328 / 0.000054 (0.000274)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.043442 / 0.037411 (0.006031) 0.024904 / 0.014526 (0.010378) 0.030772 / 0.176557 (-0.145784) 0.145098 / 0.737135 (-0.592038) 0.032611 / 0.296338 (-0.263727)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.492308 / 0.215209 (0.277099) 4.841280 / 2.077655 (2.763625) 2.147969 / 1.504120 (0.643849) 1.890224 / 1.541195 (0.349029) 1.966583 / 1.468490 (0.498092) 0.434780 / 4.584777 (-4.149996) 5.686560 / 3.745712 (1.940848) 1.009367 / 5.269862 (-4.260495) 0.942816 / 4.565676 (-3.622861) 0.046457 / 0.424275 (-0.377818) 0.005182 / 0.007607 (-0.002425) 0.598980 / 0.226044 (0.372935) 6.096269 / 2.268929 (3.827340) 2.677902 / 55.444624 (-52.766723) 2.219604 / 6.876477 (-4.656873) 2.268213 / 2.142072 (0.126141) 0.560577 / 4.805227 (-4.244651) 0.118682 / 6.500664 (-6.381982) 0.059112 / 0.075469 (-0.016357)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.793416 / 1.841788 (-0.048371) 14.954528 / 8.074308 (6.880220) 30.014784 / 10.191392 (19.823392) 0.851194 / 0.680424 (0.170770) 0.603207 / 0.534201 (0.069006) 0.258264 / 0.579283 (-0.321019) 0.607375 / 0.434364 (0.173011) 0.203045 / 0.540337 (-0.337292) 0.212940 / 1.386936 (-1.173996)
PyArrow==latest
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.010378 / 0.011353 (-0.000975) 0.004218 / 0.011008 (-0.006790) 0.037864 / 0.038508 (-0.000644) 0.040818 / 0.023109 (0.017708) 0.329640 / 0.275898 (0.053742) 0.380600 / 0.323480 (0.057120) 0.009128 / 0.007986 (0.001142) 0.005430 / 0.004328 (0.001102) 0.010644 / 0.004250 (0.006393) 0.048726 / 0.037052 (0.011673) 0.315424 / 0.258489 (0.056935) 0.379363 / 0.293841 (0.085522) 0.028374 / 0.128546 (-0.100173) 0.009216 / 0.075646 (-0.066431) 0.297174 / 0.419271 (-0.122098) 0.054501 / 0.043533 (0.010968) 0.330302 / 0.255139 (0.075163) 0.353423 / 0.283200 (0.070223) 0.099587 / 0.141683 (-0.042096) 1.936321 / 1.452155 (0.484166) 2.107654 / 1.492716 (0.614937)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.312650 / 0.018006 (0.294644) 0.488014 / 0.000490 (0.487525) 0.073518 / 0.000200 (0.073318) 0.000496 / 0.000054 (0.000442)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.039683 / 0.037411 (0.002272) 0.024103 / 0.014526 (0.009577) 0.029410 / 0.176557 (-0.147147) 0.143111 / 0.737135 (-0.594024) 0.031806 / 0.296338 (-0.264532)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.477746 / 0.215209 (0.262537) 4.719240 / 2.077655 (2.641586) 2.034875 / 1.504120 (0.530756) 1.790615 / 1.541195 (0.249421) 1.831245 / 1.468490 (0.362755) 0.432178 / 4.584777 (-4.152599) 5.802877 / 3.745712 (2.057165) 1.055287 / 5.269862 (-4.214574) 1.007373 / 4.565676 (-3.558304) 0.047696 / 0.424275 (-0.376579) 0.005728 / 0.007607 (-0.001879) 0.590052 / 0.226044 (0.364008) 5.885580 / 2.268929 (3.616652) 2.524633 / 55.444624 (-52.919991) 2.134341 / 6.876477 (-4.742136) 2.185671 / 2.142072 (0.043599) 0.543784 / 4.805227 (-4.261443) 0.116720 / 6.500664 (-6.383945) 0.061052 / 0.075469 (-0.014417)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.809815 / 1.841788 (-0.031973) 14.565557 / 8.074308 (6.491249) 29.775263 / 10.191392 (19.583871) 0.958253 / 0.680424 (0.277829) 0.604004 / 0.534201 (0.069803) 0.257290 / 0.579283 (-0.321993) 0.615924 / 0.434364 (0.181560) 0.219893 / 0.540337 (-0.320444) 0.227878 / 1.386936 (-1.159058)

CML watermark

Please sign in to comment.