Skip to content

Commit

Permalink
Merge pull request #182 from sensein/178-task-refactor-extract_mel_fi…
Browse files Browse the repository at this point in the history
…lter_bank_from_audios

added extract_mel_filter_bank_from_spectrograms
  • Loading branch information
fabiocat93 authored Nov 14, 2024
2 parents cb2b6f2 + 3798cc1 commit 98b5cbe
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1116,7 +1116,7 @@ def extract_praat_parselmouth_features_from_audios(
jitter: bool = True,
shimmer: bool = True,
plugin: str = "cf",
plugin_args: Optional[Dict[str, Any]] = None,
plugin_args: Optional[Dict[str, Any]] = {},
) -> List[Dict[str, Any]]:
"""Extract features from a list of Audio objects and return a JSON-like dictionary.
Expand Down
71 changes: 52 additions & 19 deletions src/senselab/audio/tasks/features_extraction/torchaudio.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,33 @@ def extract_mel_filter_bank_from_audios(
return mel_filter_banks


def extract_mel_filter_bank_from_spectrograms(
spectrograms: List[Dict[str, torch.Tensor]],
sampling_rate: int,
n_mels: int = 128,
) -> List[Dict[str, torch.Tensor]]:
"""Extract mel filter bank from a list of audio objects.
Args:
spectrograms (List[torch.Tensor]): List of spectrograms.
sampling_rate (int): Sampling rate of the audio.
n_mels (int): Number of mel filter banks. Default is 128.
Returns:
List[Dict[str, torch.Tensor]]: List of Dict objects containing mel filter banks.
"""
mel_filter_banks = []
for spectrogram in spectrograms:
try:
melscale_transform = torchaudio.transforms.MelScale(
sample_rate=sampling_rate, n_mels=n_mels, n_stft=spectrogram["spectrogram"].shape[0]
)
mel_filter_banks.append({"mel_filter_bank": melscale_transform(spectrogram["spectrogram"]).squeeze(0)})
except RuntimeError:
mel_filter_banks.append({"mel_filter_bank": np.nan})
return mel_filter_banks


def extract_pitch_from_audios(
audios: List[Audio], freq_low: int = 80, freq_high: int = 500
) -> List[Dict[str, torch.Tensor]]:
Expand Down Expand Up @@ -214,7 +241,7 @@ def extract_torchaudio_features_from_audios(
win_length: Optional[int] = None,
hop_length: Optional[int] = None,
plugin: str = "cf",
plugin_args: Optional[Dict[str, Any]] = None,
plugin_args: Optional[Dict[str, Any]] = {},
) -> List[Dict[str, Any]]:
"""Extract torchaudio features from a list of audio objects.
Expand All @@ -236,55 +263,61 @@ def extract_torchaudio_features_from_audios(
List[Dict[str, Any]]: The list of feature dictionaries for each audio.
"""
extract_pitch_from_audios_pt = pydra.mark.task(extract_pitch_from_audios)
extract_mel_filter_bank_from_audios_pt = pydra.mark.task(extract_mel_filter_bank_from_audios)
extract_mel_filter_bank_from_spectrograms_pt = pydra.mark.task(extract_mel_filter_bank_from_spectrograms)
extract_mfcc_from_audios_pt = pydra.mark.task(extract_mfcc_from_audios)
extract_mel_spectrogram_from_audios_pt = pydra.mark.task(extract_mel_spectrogram_from_audios)
extract_spectrogram_from_audios_pt = pydra.mark.task(extract_spectrogram_from_audios)

def _extract_sampling_rate(audios: List[Audio]) -> int:
"""Extract the sampling rate from an Audio object."""
return audios[0].sampling_rate

_extract_sampling_rate_pt = pydra.mark.task(_extract_sampling_rate)

formatted_audios = [[audio] for audio in audios]
wf = pydra.Workflow(name="wf", input_spec=["x"])
wf.split("x", x=formatted_audios)
wf.add(_extract_sampling_rate_pt(name="_extract_sampling_rate_pt", audios=wf.lzin.x))
wf.add(
extract_pitch_from_audios_pt(
name="extract_pitch_from_audios_pt", audios=wf.lzin.x, freq_low=freq_low, freq_high=freq_high
)
)
wf.add(
extract_mel_filter_bank_from_audios_pt(
name="extract_mel_filter_bank_from_audios_pt",
extract_spectrogram_from_audios_pt(
name="extract_spectrogram_from_audios_pt",
audios=wf.lzin.x,
n_mels=n_mels,
n_fft=n_fft,
n_nfft=n_fft,
win_length=win_length,
hop_length=hop_length,
)
)
wf.add(
extract_mfcc_from_audios_pt(
name="extract_mfcc_from_audios_pt",
extract_mel_spectrogram_from_audios_pt(
name="extract_mel_spectrogram_from_audios_pt",
audios=wf.lzin.x,
n_mfcc=n_mfcc,
n_fft=n_fft,
n_mels=n_mels,
n_nfft=n_fft,
win_length=win_length,
hop_length=hop_length,
)
)
wf.add(
extract_mel_spectrogram_from_audios_pt(
name="extract_mel_spectrogram_from_audios_pt",
audios=wf.lzin.x,
extract_mel_filter_bank_from_spectrograms_pt(
name="extract_mel_filter_bank_from_spectrograms_pt",
spectrograms=wf.extract_spectrogram_from_audios_pt.lzout.out,
sampling_rate=wf._extract_sampling_rate_pt.lzout.out,
n_mels=n_mels,
n_nfft=n_fft,
win_length=win_length,
hop_length=hop_length,
)
)
wf.add(
extract_spectrogram_from_audios_pt(
name="extract_spectrogram_from_audios_pt",
extract_mfcc_from_audios_pt(
name="extract_mfcc_from_audios_pt",
audios=wf.lzin.x,
n_nfft=n_fft,
n_mfcc=n_mfcc,
n_fft=n_fft,
n_mels=n_mels,
win_length=win_length,
hop_length=hop_length,
)
Expand All @@ -294,7 +327,7 @@ def extract_torchaudio_features_from_audios(
wf.set_output(
[
("pitch_out", wf.extract_pitch_from_audios_pt.lzout.out),
("mel_filter_bank_out", wf.extract_mel_filter_bank_from_audios_pt.lzout.out),
("mel_filter_bank_out", wf.extract_mel_filter_bank_from_spectrograms_pt.lzout.out),
("mfcc_out", wf.extract_mfcc_from_audios_pt.lzout.out),
("mel_spectrogram_out", wf.extract_mel_spectrogram_from_audios_pt.lzout.out),
("spectrogram_out", wf.extract_spectrogram_from_audios_pt.lzout.out),
Expand Down
14 changes: 7 additions & 7 deletions tutorials/audio/features_extraction.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 1,
"metadata": {},
"outputs": [
{
Expand All @@ -35,7 +35,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -51,15 +51,15 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:pydra:Added SpecInfo(name='Output', fields=[('pitch_values_out', typing.Dict[str, float], {'help_string': ' (from extract_pitch_values_pt)'}), ('speech_rate_out', typing.Dict[str, float], {'help_string': ' (from extract_speech_rate_pt)'}), ('pitch_out', typing.Dict[str, float], {'help_string': ' (from extract_pitch_descriptors_pt)'}), ('intensity_out', typing.Dict[str, float], {'help_string': ' (from extract_intensity_descriptors_pt)'}), ('harmonicity_out', typing.Dict[str, float], {'help_string': ' (from extract_harmonicity_descriptors_pt)'}), ('formants_out', typing.Dict[str, float], {'help_string': ' (from measure_formants_pt)'}), ('spectral_moments_out', typing.Dict[str, float], {'help_string': ' (from extract_spectral_moments_pt)'}), ('slope_tilt_out', typing.Dict[str, float], {'help_string': ' (from extract_slope_tilt_pt)'}), ('cpp_out', typing.Dict[str, float], {'help_string': ' (from extract_cpp_descriptors_pt)'}), ('audio_duration', typing.Dict[str, float], {'help_string': ' (from extract_audio_duration_pt)'}), ('jitter_out', typing.Dict[str, float], {'help_string': ' (from extract_jitter_pt)'}), ('shimmer_out', typing.Dict[str, float], {'help_string': ' (from extract_shimmer_pt)'})], bases=(<class 'pydra.engine.specs.BaseSpec'>,)) to wf\n",
"INFO:pydra:Added SpecInfo(name='Output', fields=[('pitch_out', typing.List[typing.Dict[str, torch.Tensor]], {'help_string': ' (from extract_pitch_from_audios_pt)'}), ('mel_filter_bank_out', typing.List[typing.Dict[str, torch.Tensor]], {'help_string': ' (from extract_mel_filter_bank_from_audios_pt)'}), ('mfcc_out', typing.List[typing.Dict[str, torch.Tensor]], {'help_string': ' (from extract_mfcc_from_audios_pt)'}), ('mel_spectrogram_out', typing.List[typing.Dict[str, torch.Tensor]], {'help_string': ' (from extract_mel_spectrogram_from_audios_pt)'}), ('spectrogram_out', typing.List[typing.Dict[str, torch.Tensor]], {'help_string': ' (from extract_spectrogram_from_audios_pt)'})], bases=(<class 'pydra.engine.specs.BaseSpec'>,)) to wf\n",
"INFO:pydra:Added SpecInfo(name='Output', fields=[('pitch_values_out', typing.Dict[str, float], {'help_string': ' (from extract_pitch_values_pt)'}), ('speech_rate_out', typing.Dict[str, float], {'help_string': ' (from extract_speech_rate_pt)'}), ('pitch_out', typing.Dict[str, float], {'help_string': ' (from extract_pitch_descriptors_pt)'}), ('intensity_out', typing.Dict[str, float], {'help_string': ' (from extract_intensity_descriptors_pt)'}), ('harmonicity_out', typing.Dict[str, float], {'help_string': ' (from extract_harmonicity_descriptors_pt)'}), ('formants_out', typing.Dict[str, float], {'help_string': ' (from measure_f1f2_formants_bandwidths_pt)'}), ('spectral_moments_out', typing.Dict[str, float], {'help_string': ' (from extract_spectral_moments_pt)'}), ('slope_tilt_out', typing.Dict[str, float], {'help_string': ' (from extract_slope_tilt_pt)'}), ('cpp_out', typing.Dict[str, float], {'help_string': ' (from extract_cpp_descriptors_pt)'}), ('audio_duration', typing.Dict[str, float], {'help_string': ' (from extract_audio_duration_pt)'}), ('jitter_out', typing.Dict[str, float], {'help_string': ' (from extract_jitter_pt)'}), ('shimmer_out', typing.Dict[str, float], {'help_string': ' (from extract_shimmer_pt)'})], bases=(<class 'pydra.engine.specs.BaseSpec'>,)) to wf\n",
"INFO:pydra:Added SpecInfo(name='Output', fields=[('pitch_out', typing.List[typing.Dict[str, torch.Tensor]], {'help_string': ' (from extract_pitch_from_audios_pt)'}), ('mel_filter_bank_out', typing.List[typing.Dict[str, torch.Tensor]], {'help_string': ' (from extract_mel_filter_bank_from_spectrograms_pt)'}), ('mfcc_out', typing.List[typing.Dict[str, torch.Tensor]], {'help_string': ' (from extract_mfcc_from_audios_pt)'}), ('mel_spectrogram_out', typing.List[typing.Dict[str, torch.Tensor]], {'help_string': ' (from extract_mel_spectrogram_from_audios_pt)'}), ('spectrogram_out', typing.List[typing.Dict[str, torch.Tensor]], {'help_string': ' (from extract_spectrogram_from_audios_pt)'})], bases=(<class 'pydra.engine.specs.BaseSpec'>,)) to wf\n",
"/Users/fabiocat/Library/Caches/pypoetry/virtualenvs/senselab-3xKECMzt-py3.10/lib/python3.10/site-packages/torchaudio/functional/functional.py:584: UserWarning: At least one mel filterbank has all zero values. The value for `n_mels` (128) may be set too high. Or, the value for `n_freqs` (201) may be set too low.\n",
" warnings.warn(\n"
]
Expand Down Expand Up @@ -348,14 +348,14 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:pydra:Added SpecInfo(name='Output', fields=[('pitch_values_out', typing.Dict[str, float], {'help_string': ' (from extract_pitch_values_pt)'}), ('speech_rate_out', typing.Dict[str, float], {'help_string': ' (from extract_speech_rate_pt)'}), ('pitch_out', typing.Dict[str, float], {'help_string': ' (from extract_pitch_descriptors_pt)'}), ('intensity_out', typing.Dict[str, float], {'help_string': ' (from extract_intensity_descriptors_pt)'}), ('harmonicity_out', typing.Dict[str, float], {'help_string': ' (from extract_harmonicity_descriptors_pt)'}), ('formants_out', typing.Dict[str, float], {'help_string': ' (from measure_formants_pt)'}), ('spectral_moments_out', typing.Dict[str, float], {'help_string': ' (from extract_spectral_moments_pt)'}), ('slope_tilt_out', typing.Dict[str, float], {'help_string': ' (from extract_slope_tilt_pt)'}), ('cpp_out', typing.Dict[str, float], {'help_string': ' (from extract_cpp_descriptors_pt)'})], bases=(<class 'pydra.engine.specs.BaseSpec'>,)) to wf\n"
"INFO:pydra:Added SpecInfo(name='Output', fields=[('pitch_values_out', typing.Dict[str, float], {'help_string': ' (from extract_pitch_values_pt)'}), ('speech_rate_out', typing.Dict[str, float], {'help_string': ' (from extract_speech_rate_pt)'}), ('pitch_out', typing.Dict[str, float], {'help_string': ' (from extract_pitch_descriptors_pt)'}), ('intensity_out', typing.Dict[str, float], {'help_string': ' (from extract_intensity_descriptors_pt)'}), ('harmonicity_out', typing.Dict[str, float], {'help_string': ' (from extract_harmonicity_descriptors_pt)'}), ('formants_out', typing.Dict[str, float], {'help_string': ' (from measure_f1f2_formants_bandwidths_pt)'}), ('spectral_moments_out', typing.Dict[str, float], {'help_string': ' (from extract_spectral_moments_pt)'}), ('slope_tilt_out', typing.Dict[str, float], {'help_string': ' (from extract_slope_tilt_pt)'}), ('cpp_out', typing.Dict[str, float], {'help_string': ' (from extract_cpp_descriptors_pt)'})], bases=(<class 'pydra.engine.specs.BaseSpec'>,)) to wf\n"
]
},
{
Expand Down

0 comments on commit 98b5cbe

Please sign in to comment.