diff --git a/docs/whats_new.rst b/docs/whats_new.rst index 97ac72f..432e68b 100644 --- a/docs/whats_new.rst +++ b/docs/whats_new.rst @@ -27,7 +27,7 @@ Version 0.5.0 (Unreleased) Changelog ~~~~~~~~~ - :meth:`~pyprep.NoisyChannels.find_bad_by_nan_flat` now accepts a ``flat_threshold`` argument, by `Nabil Alibou`_ (:gh:`144`) -- changed _mad function in utils.py to use median_abs_deviation from the sciPy module, by `Ayush Agarwal`_ (:gh:`153`). +- replaced an internal implementation of the MAD algorithm with :func:`scipy.stats.median_abs_deviation`, by `Ayush Agarwal`_ (:gh:`153`) and `Stefan Appelhoff`_ (:gh:`154`) Bug ~~~ diff --git a/pyprep/find_noisy_channels.py b/pyprep/find_noisy_channels.py index 4dd7e8a..88dc7e3 100644 --- a/pyprep/find_noisy_channels.py +++ b/pyprep/find_noisy_channels.py @@ -3,10 +3,11 @@ import numpy as np from mne.utils import check_random_state, logger from scipy import signal +from scipy.stats import median_abs_deviation from pyprep.ransac import find_bad_by_ransac from pyprep.removeTrend import removeTrend -from pyprep.utils import _filter_design, _mad, _mat_iqr, _mat_quantile +from pyprep.utils import _filter_design, _mat_iqr, _mat_quantile class NoisyChannels: @@ -247,7 +248,7 @@ def find_bad_by_nan_flat(self, flat_threshold=1e-15): nan_channels = self.ch_names_original[nan_channel_mask] # Detect channels with flat or extremely weak signals - flat_by_mad = _mad(EEGData, axis=1) < flat_threshold + flat_by_mad = median_abs_deviation(EEGData, axis=1) < flat_threshold flat_by_stdev = np.std(EEGData, axis=1) < flat_threshold flat_channel_mask = flat_by_mad | flat_by_stdev flat_channels = self.ch_names_original[flat_channel_mask] @@ -336,8 +337,8 @@ def find_bad_by_hfnoise(self, HF_zscore_threshold=5.0): # < 50 Hz amplitude for each channel and get robust z-scores of values if self.sample_rate > 100: noisiness = np.divide( - _mad(self.EEGData - self.EEGFiltered, axis=1), - _mad(self.EEGFiltered, axis=1), + median_abs_deviation(self.EEGData - self.EEGFiltered, axis=1), + median_abs_deviation(self.EEGFiltered, axis=1), ) noise_median = np.nanmedian(noisiness) noise_sd = np.median(np.abs(noisiness - noise_median)) * MAD_TO_SD @@ -421,7 +422,7 @@ def find_bad_by_correlation( channel_amplitudes[w, usable] = _mat_iqr(eeg_raw, axis=1) * IQR_TO_SD # Check for any channel dropouts (flat signal) within the window - eeg_amplitude = _mad(eeg_filtered, axis=1) + eeg_amplitude = median_abs_deviation(eeg_filtered, axis=1) dropout[w, usable] = eeg_amplitude == 0 # Exclude any dropout chans from further calculations (avoids div-by-zero) @@ -431,7 +432,7 @@ def find_bad_by_correlation( eeg_amplitude = eeg_amplitude[eeg_amplitude > 0] # Get high-frequency noise ratios for the window - high_freq_amplitude = _mad(eeg_raw - eeg_filtered, axis=1) + high_freq_amplitude = median_abs_deviation(eeg_raw - eeg_filtered, axis=1) noiselevels[w, usable] = high_freq_amplitude / eeg_amplitude # Get inter-channel correlations for the window diff --git a/pyprep/tests/test_utils.py b/pyprep/tests/test_utils.py index 03ba2ee..0cb5aad 100644 --- a/pyprep/tests/test_utils.py +++ b/pyprep/tests/test_utils.py @@ -6,7 +6,6 @@ _correlate_arrays, _eeglab_create_highpass, _get_random_subset, - _mad, _mat_iqr, _mat_quantile, _mat_round, @@ -150,20 +149,3 @@ def test_eeglab_create_highpass(): expected_val = 0.9961 actual_val = vals[len(vals) // 2] assert np.isclose(expected_val, actual_val, atol=0.001) - - -def test_mad(): - """Test the median absolute deviation from the median (MAD) function.""" - # Generate test data - tst = np.array([[1, 2, 3, 4, 8], [80, 10, 20, 30, 40], [100, 200, 800, 300, 400]]) - expected = np.asarray([1, 10, 100]) - - # Compare output to expected results - assert all(np.equal(_mad(tst, axis=1), expected)) - assert all(np.equal(_mad(tst.T, axis=0), expected)) - assert _mad(tst) == 28 # Matches robust.mad from statsmodels - - # Test exception with > 2-D arrays - tst = np.random.rand(3, 3, 3) - with pytest.raises(ValueError): - _mad(tst, axis=0) diff --git a/pyprep/utils.py b/pyprep/utils.py index c32f016..c7ba549 100644 --- a/pyprep/utils.py +++ b/pyprep/utils.py @@ -10,7 +10,6 @@ from psutil import virtual_memory from scipy import linalg from scipy.signal import firwin, lfilter, lfilter_zi -from scipy.stats import median_abs_deviation def _union(list1, list2): @@ -462,36 +461,6 @@ def _correlate_arrays(a, b, matlab_strict=False): return np.diag(np.corrcoef(a, b)[:n_chan, n_chan:]) -def _mad(x, axis=None): - """Calculate median absolute deviations from the median (MAD) for an array. - - Parameters - ---------- - x : np.ndarray - A 1-D or 2-D numeric array to summarize. - axis : {int, tuple of int, None}, optional - Axis along which MADs should be calculated. If ``None``, the MAD will - be calculated for the full input array. Defaults to ``None``. - - Returns - ------- - mad : scalar or np.ndarray - If no axis is specified, returns the MAD for the full input array as a - single numeric value. Otherwise, returns an ``np.ndarray`` containing - the MAD for each index along the specified axis. - - """ - # Ensure array is either 1D or 2D - x = np.asarray(x) - if x.ndim > 2: - e = "Only 1D and 2D arrays are supported (input has {0} dimensions)" - raise ValueError(e.format(x.ndim)) - - # Calculate the median absolute deviation from the median - mad = median_abs_deviation(x, axis=axis) - return mad - - def _filter_design(N_order, amp, freq): """Create FIR low-pass filter for EEG data using frequency sampling method.