From ac280c3167139943fd68528a20e079a1d4a79da3 Mon Sep 17 00:00:00 2001 From: Santiago Martinez Date: Sat, 9 Sep 2023 18:24:51 +0100 Subject: [PATCH] Added audio filtering functions --- src/soundevent/audio/__init__.py | 2 + src/soundevent/audio/filter.py | 123 ++++++++++++++++++++++++ src/soundevent/features/geometric.py | 2 +- tests/test_audio/test_filter.py | 135 +++++++++++++++++++++++++++ 4 files changed, 261 insertions(+), 1 deletion(-) create mode 100644 src/soundevent/audio/filter.py create mode 100644 tests/test_audio/test_filter.py diff --git a/src/soundevent/audio/__init__.py b/src/soundevent/audio/__init__.py index 278866b..097e6d3 100644 --- a/src/soundevent/audio/__init__.py +++ b/src/soundevent/audio/__init__.py @@ -1,6 +1,7 @@ """Soundevent functions for handling audio files and arrays.""" from .files import is_audio_file +from .filter import filter from .io import load_clip, load_recording from .media_info import MediaInfo, compute_md5_checksum, get_media_info from .resample import resample @@ -14,5 +15,6 @@ "load_clip", "load_recording", "is_audio_file", + "resample", "filter", ] diff --git a/src/soundevent/audio/filter.py b/src/soundevent/audio/filter.py new file mode 100644 index 0000000..5297f18 --- /dev/null +++ b/src/soundevent/audio/filter.py @@ -0,0 +1,123 @@ +"""Funtions for audio filtering.""" + +from typing import Optional + +import numpy as np +import xarray as xr +from scipy import signal + +__all__ = [ + "filter", +] + + +def _get_filter( + samplerate: int, + low_freq: Optional[float] = None, + high_freq: Optional[float] = None, + order: int = 5, +) -> np.ndarray: + if low_freq is None and high_freq is None: + raise ValueError( + "At least one of low_freq and high_freq must be specified." + ) + + if low_freq is None: + # Low pass filter + return signal.butter( + order, + high_freq, + btype="lowpass", + output="sos", + fs=samplerate, + ) + + if high_freq is None: + # High pass filter + return signal.butter( + order, + low_freq, + btype="highpass", + output="sos", + fs=samplerate, + ) + + if low_freq > high_freq: + raise ValueError("low_freq must be less than high_freq.") + + # Band pass filter + return signal.butter( + order, + [low_freq, high_freq], + btype="bandpass", + output="sos", + fs=samplerate, + ) + + +def filter( + audio: xr.DataArray, + low_freq: Optional[float] = None, + high_freq: Optional[float] = None, + order: int = 5, +) -> xr.DataArray: + """Filter audio data. + + This function assumes that the input audio object is a + :class:`xarray.DataArray` with a "samplerate" attribute and a "time" + dimension. + + The filtering is done using a Butterworth filter or the specified order. + The type of filter (lowpass/highpass/bandpass filter) is determined + by the specified cutoff frequencies. If only one cutoff frequency is + specified, a low pass or high pass filter is used. If both cutoff + frequencies are specified, a band pass filter is used. + + Parameters + ---------- + audio : xr.DataArray + The audio data to filter with a "samplerate" attribute and + a "time" dimension. + low_freq : float, optional + The low cutoff frequency in Hz. + high_freq : float, optional + The high cutoff frequency in Hz. + order : int, optional + The order of the filter. By default, 5. + + Returns + ------- + xr.DataArray + The filtered audio data. + + Raises + ------ + ValueError + If neither low_freq nor high_freq is specified, or if both + are specified and low_freq > high_freq. + + """ + if not isinstance(audio, xr.DataArray): + raise ValueError("Audio must be an xarray.DataArray") + + if "samplerate" not in audio.attrs: + raise ValueError("Audio must have a 'samplerate' attribute") + + if "time" not in audio.dims: + raise ValueError("Audio must have a time dimension") + + axis: int = audio.get_axis_num("time") # type: ignore + sos = _get_filter( + audio.attrs["samplerate"], + low_freq, + high_freq, + order, + ) + + filtered = signal.sosfiltfilt(sos, audio.data, axis=axis) + return xr.DataArray( + data=filtered, + dims=audio.dims, + coords=audio.coords, + attrs=audio.attrs, + ) diff --git a/src/soundevent/features/geometric.py b/src/soundevent/features/geometric.py index 5063eb4..107fd9c 100644 --- a/src/soundevent/features/geometric.py +++ b/src/soundevent/features/geometric.py @@ -34,7 +34,7 @@ Feature(name='high_freq', value=1000), Feature(name='bandwidth', value=1000)] """ -from typing import Callable, Dict, Any, List +from typing import Any, Callable, Dict, List from soundevent.data import Feature, geometries diff --git a/tests/test_audio/test_filter.py b/tests/test_audio/test_filter.py new file mode 100644 index 0000000..8563697 --- /dev/null +++ b/tests/test_audio/test_filter.py @@ -0,0 +1,135 @@ +"""Test suite for filtering functions.""" + +from unittest import mock + +import numpy as np +import pytest +import xarray as xr +from scipy import signal + +from soundevent import audio + + +def test_filter_audio_fails_if_no_samplerate(): + """Test that filter_audio fails if samplerate is missing.""" + data = xr.DataArray(np.random.randn(100), dims=["time"]) + with pytest.raises(ValueError): + audio.filter(data, 16000) + + +def test_filter_audio_fails_if_not_an_xarray(): + """Test that filter_audio fails if not an xarray.DataArray.""" + data = np.random.randn(100) + with pytest.raises(ValueError): + audio.filter(data, 16000) # type: ignore + + +def test_filter_audio_fails_if_no_time_axis(): + """Test that filter_audio fails with missing time axis.""" + data = xr.DataArray( + np.random.randn(100), + dims=["channel"], + coords={"channel": range(100)}, + attrs={"samplerate": 16000}, + ) + with pytest.raises(ValueError): + audio.filter(data, 16000) + + +def test_filter_audio_returns_an_xarray(): + """Test that filter_audio returns an xarray.DataArray.""" + data = xr.DataArray( + np.random.randn(1000), + dims=["time"], + coords={"time": np.linspace(0, 1, 1000, endpoint=False)}, + attrs={"samplerate": 16000}, + ) + filtered = audio.filter(data, 1000) + assert isinstance(filtered, xr.DataArray) + + +def test_filter_audio_preserves_attrs(): + """Test that filter_audio preserves attributes.""" + data = xr.DataArray( + np.random.randn(1000), + dims=["time"], + coords={"time": np.linspace(0, 1, 1000, endpoint=False)}, + attrs={"samplerate": 16000, "other": "value"}, + ) + filtered = audio.filter(data, 1000) + assert filtered.attrs == data.attrs + + +def test_filter_audio_fails_if_no_low_or_high_freq_provided(): + """Test filter_audio fails if low_freq and high_freq arent provided.""" + data = xr.DataArray( + np.random.randn(1000), + dims=["time"], + coords={"time": np.linspace(0, 1, 1000, endpoint=False)}, + attrs={"samplerate": 16000, "other": "value"}, + ) + with pytest.raises(ValueError): + audio.filter(data) + + +def test_filter_audio_applies_a_lowpass_filter(): + """Test that filter_audio applies a lowpass filter.""" + data = xr.DataArray( + np.random.randn(1000), + dims=["time"], + coords={"time": np.linspace(0, 1, 1000, endpoint=False)}, + attrs={"samplerate": 16000}, + ) + + mock_butter = mock.Mock(side_effect=signal.butter) + with mock.patch.object(signal, "butter", mock_butter): + audio.filter(data, high_freq=6000) + mock_butter.assert_called_once_with( + 5, + 6000, + btype="lowpass", + fs=16000, + output="sos", + ) + + +def test_filter_audio_applies_a_highpass_filter(): + """Test that filter_audio applies a highpass filter.""" + data = xr.DataArray( + np.random.randn(1000), + dims=["time"], + coords={"time": np.linspace(0, 1, 1000, endpoint=False)}, + attrs={"samplerate": 16000}, + ) + + mock_butter = mock.Mock(side_effect=signal.butter) + with mock.patch.object(signal, "butter", mock_butter): + audio.filter(data, low_freq=6000) + mock_butter.assert_called_once_with( + 5, + 6000, + btype="highpass", + fs=16000, + output="sos", + ) + + +def test_filter_audio_applies_a_bandpass_filter(): + """Test that filter_audio applies a bandpass filter.""" + data = xr.DataArray( + np.random.randn(1000), + dims=["time"], + coords={"time": np.linspace(0, 1, 1000, endpoint=False)}, + attrs={"samplerate": 16000}, + ) + + mock_butter = mock.Mock(side_effect=signal.butter) + with mock.patch.object(signal, "butter", mock_butter): + audio.filter(data, low_freq=1000, high_freq=6000) + mock_butter.assert_called_once_with( + 5, + [1000, 6000], + btype="bandpass", + fs=16000, + output="sos", + )