Skip to content

Commit

Permalink
Added audio filtering functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mbsantiago committed Sep 9, 2023
1 parent 53fe18f commit ac280c3
Show file tree
Hide file tree
Showing 4 changed files with 261 additions and 1 deletion.
2 changes: 2 additions & 0 deletions src/soundevent/audio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Soundevent functions for handling audio files and arrays."""

from .files import is_audio_file
from .filter import filter
from .io import load_clip, load_recording
from .media_info import MediaInfo, compute_md5_checksum, get_media_info
from .resample import resample
Expand All @@ -14,5 +15,6 @@
"load_clip",
"load_recording",
"is_audio_file",
"resample",
"filter",
]
123 changes: 123 additions & 0 deletions src/soundevent/audio/filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""Funtions for audio filtering."""

from typing import Optional

import numpy as np
import xarray as xr
from scipy import signal

__all__ = [
"filter",
]


def _get_filter(
samplerate: int,
low_freq: Optional[float] = None,
high_freq: Optional[float] = None,
order: int = 5,
) -> np.ndarray:
if low_freq is None and high_freq is None:
raise ValueError(
"At least one of low_freq and high_freq must be specified."
)

if low_freq is None:
# Low pass filter
return signal.butter(
order,
high_freq,
btype="lowpass",
output="sos",
fs=samplerate,
)

if high_freq is None:
# High pass filter
return signal.butter(
order,
low_freq,
btype="highpass",
output="sos",
fs=samplerate,
)

if low_freq > high_freq:
raise ValueError("low_freq must be less than high_freq.")

# Band pass filter
return signal.butter(
order,
[low_freq, high_freq],
btype="bandpass",
output="sos",
fs=samplerate,
)


def filter(
audio: xr.DataArray,
low_freq: Optional[float] = None,
high_freq: Optional[float] = None,
order: int = 5,
) -> xr.DataArray:
"""Filter audio data.
This function assumes that the input audio object is a
:class:`xarray.DataArray` with a "samplerate" attribute and a "time"
dimension.
The filtering is done using a Butterworth filter or the specified order.
The type of filter (lowpass/highpass/bandpass filter) is determined
by the specified cutoff frequencies. If only one cutoff frequency is
specified, a low pass or high pass filter is used. If both cutoff
frequencies are specified, a band pass filter is used.
Parameters
----------
audio : xr.DataArray
The audio data to filter with a "samplerate" attribute and
a "time" dimension.
low_freq : float, optional
The low cutoff frequency in Hz.
high_freq : float, optional
The high cutoff frequency in Hz.
order : int, optional
The order of the filter. By default, 5.
Returns
-------
xr.DataArray
The filtered audio data.
Raises
------
ValueError
If neither low_freq nor high_freq is specified, or if both
are specified and low_freq > high_freq.
"""
if not isinstance(audio, xr.DataArray):
raise ValueError("Audio must be an xarray.DataArray")

if "samplerate" not in audio.attrs:
raise ValueError("Audio must have a 'samplerate' attribute")

if "time" not in audio.dims:
raise ValueError("Audio must have a time dimension")

axis: int = audio.get_axis_num("time") # type: ignore
sos = _get_filter(
audio.attrs["samplerate"],
low_freq,
high_freq,
order,
)

filtered = signal.sosfiltfilt(sos, audio.data, axis=axis)
return xr.DataArray(
data=filtered,
dims=audio.dims,
coords=audio.coords,
attrs=audio.attrs,
)
2 changes: 1 addition & 1 deletion src/soundevent/features/geometric.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
Feature(name='high_freq', value=1000),
Feature(name='bandwidth', value=1000)]
"""
from typing import Callable, Dict, Any, List
from typing import Any, Callable, Dict, List

from soundevent.data import Feature, geometries

Expand Down
135 changes: 135 additions & 0 deletions tests/test_audio/test_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
"""Test suite for filtering functions."""

from unittest import mock

import numpy as np
import pytest
import xarray as xr
from scipy import signal

from soundevent import audio


def test_filter_audio_fails_if_no_samplerate():
"""Test that filter_audio fails if samplerate is missing."""
data = xr.DataArray(np.random.randn(100), dims=["time"])
with pytest.raises(ValueError):
audio.filter(data, 16000)


def test_filter_audio_fails_if_not_an_xarray():
"""Test that filter_audio fails if not an xarray.DataArray."""
data = np.random.randn(100)
with pytest.raises(ValueError):
audio.filter(data, 16000) # type: ignore


def test_filter_audio_fails_if_no_time_axis():
"""Test that filter_audio fails with missing time axis."""
data = xr.DataArray(
np.random.randn(100),
dims=["channel"],
coords={"channel": range(100)},
attrs={"samplerate": 16000},
)
with pytest.raises(ValueError):
audio.filter(data, 16000)


def test_filter_audio_returns_an_xarray():
"""Test that filter_audio returns an xarray.DataArray."""
data = xr.DataArray(
np.random.randn(1000),
dims=["time"],
coords={"time": np.linspace(0, 1, 1000, endpoint=False)},
attrs={"samplerate": 16000},
)
filtered = audio.filter(data, 1000)
assert isinstance(filtered, xr.DataArray)


def test_filter_audio_preserves_attrs():
"""Test that filter_audio preserves attributes."""
data = xr.DataArray(
np.random.randn(1000),
dims=["time"],
coords={"time": np.linspace(0, 1, 1000, endpoint=False)},
attrs={"samplerate": 16000, "other": "value"},
)
filtered = audio.filter(data, 1000)
assert filtered.attrs == data.attrs


def test_filter_audio_fails_if_no_low_or_high_freq_provided():
"""Test filter_audio fails if low_freq and high_freq arent provided."""
data = xr.DataArray(
np.random.randn(1000),
dims=["time"],
coords={"time": np.linspace(0, 1, 1000, endpoint=False)},
attrs={"samplerate": 16000, "other": "value"},
)
with pytest.raises(ValueError):
audio.filter(data)


def test_filter_audio_applies_a_lowpass_filter():
"""Test that filter_audio applies a lowpass filter."""
data = xr.DataArray(
np.random.randn(1000),
dims=["time"],
coords={"time": np.linspace(0, 1, 1000, endpoint=False)},
attrs={"samplerate": 16000},
)

mock_butter = mock.Mock(side_effect=signal.butter)
with mock.patch.object(signal, "butter", mock_butter):
audio.filter(data, high_freq=6000)
mock_butter.assert_called_once_with(
5,
6000,
btype="lowpass",
fs=16000,
output="sos",
)


def test_filter_audio_applies_a_highpass_filter():
"""Test that filter_audio applies a highpass filter."""
data = xr.DataArray(
np.random.randn(1000),
dims=["time"],
coords={"time": np.linspace(0, 1, 1000, endpoint=False)},
attrs={"samplerate": 16000},
)

mock_butter = mock.Mock(side_effect=signal.butter)
with mock.patch.object(signal, "butter", mock_butter):
audio.filter(data, low_freq=6000)
mock_butter.assert_called_once_with(
5,
6000,
btype="highpass",
fs=16000,
output="sos",
)


def test_filter_audio_applies_a_bandpass_filter():
"""Test that filter_audio applies a bandpass filter."""
data = xr.DataArray(
np.random.randn(1000),
dims=["time"],
coords={"time": np.linspace(0, 1, 1000, endpoint=False)},
attrs={"samplerate": 16000},
)

mock_butter = mock.Mock(side_effect=signal.butter)
with mock.patch.object(signal, "butter", mock_butter):
audio.filter(data, low_freq=1000, high_freq=6000)
mock_butter.assert_called_once_with(
5,
[1000, 6000],
btype="bandpass",
fs=16000,
output="sos",
)

0 comments on commit ac280c3

Please sign in to comment.