Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Version 0.5.0 (Unreleased)
Changelog
~~~~~~~~~
- :meth:`~pyprep.NoisyChannels.find_bad_by_nan_flat` now accepts a ``flat_threshold`` argument, by `Nabil Alibou`_ (:gh:`144`)
- changed _mad function in utils.py to use median_abs_deviation from the sciPy module, by `Ayush Agarwal`_ (:gh:`153`).
- replaced an internal implementation of the MAD algorithm with :func:`scipy.stats.median_abs_deviation`, by `Ayush Agarwal`_ (:gh:`153`) and `Stefan Appelhoff`_ (:gh:`154`)

Bug
~~~
Expand Down
13 changes: 7 additions & 6 deletions pyprep/find_noisy_channels.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import numpy as np
from mne.utils import check_random_state, logger
from scipy import signal
from scipy.stats import median_abs_deviation

from pyprep.ransac import find_bad_by_ransac
from pyprep.removeTrend import removeTrend
from pyprep.utils import _filter_design, _mad, _mat_iqr, _mat_quantile
from pyprep.utils import _filter_design, _mat_iqr, _mat_quantile


class NoisyChannels:
Expand Down Expand Up @@ -247,7 +248,7 @@ def find_bad_by_nan_flat(self, flat_threshold=1e-15):
nan_channels = self.ch_names_original[nan_channel_mask]

# Detect channels with flat or extremely weak signals
flat_by_mad = _mad(EEGData, axis=1) < flat_threshold
flat_by_mad = median_abs_deviation(EEGData, axis=1) < flat_threshold
flat_by_stdev = np.std(EEGData, axis=1) < flat_threshold
flat_channel_mask = flat_by_mad | flat_by_stdev
flat_channels = self.ch_names_original[flat_channel_mask]
Expand Down Expand Up @@ -336,8 +337,8 @@ def find_bad_by_hfnoise(self, HF_zscore_threshold=5.0):
# < 50 Hz amplitude for each channel and get robust z-scores of values
if self.sample_rate > 100:
noisiness = np.divide(
_mad(self.EEGData - self.EEGFiltered, axis=1),
_mad(self.EEGFiltered, axis=1),
median_abs_deviation(self.EEGData - self.EEGFiltered, axis=1),
median_abs_deviation(self.EEGFiltered, axis=1),
)
noise_median = np.nanmedian(noisiness)
noise_sd = np.median(np.abs(noisiness - noise_median)) * MAD_TO_SD
Expand Down Expand Up @@ -421,7 +422,7 @@ def find_bad_by_correlation(
channel_amplitudes[w, usable] = _mat_iqr(eeg_raw, axis=1) * IQR_TO_SD

# Check for any channel dropouts (flat signal) within the window
eeg_amplitude = _mad(eeg_filtered, axis=1)
eeg_amplitude = median_abs_deviation(eeg_filtered, axis=1)
dropout[w, usable] = eeg_amplitude == 0

# Exclude any dropout chans from further calculations (avoids div-by-zero)
Expand All @@ -431,7 +432,7 @@ def find_bad_by_correlation(
eeg_amplitude = eeg_amplitude[eeg_amplitude > 0]

# Get high-frequency noise ratios for the window
high_freq_amplitude = _mad(eeg_raw - eeg_filtered, axis=1)
high_freq_amplitude = median_abs_deviation(eeg_raw - eeg_filtered, axis=1)
noiselevels[w, usable] = high_freq_amplitude / eeg_amplitude

# Get inter-channel correlations for the window
Expand Down
18 changes: 0 additions & 18 deletions pyprep/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
_correlate_arrays,
_eeglab_create_highpass,
_get_random_subset,
_mad,
_mat_iqr,
_mat_quantile,
_mat_round,
Expand Down Expand Up @@ -150,20 +149,3 @@ def test_eeglab_create_highpass():
expected_val = 0.9961
actual_val = vals[len(vals) // 2]
assert np.isclose(expected_val, actual_val, atol=0.001)


def test_mad():
"""Test the median absolute deviation from the median (MAD) function."""
# Generate test data
tst = np.array([[1, 2, 3, 4, 8], [80, 10, 20, 30, 40], [100, 200, 800, 300, 400]])
expected = np.asarray([1, 10, 100])

# Compare output to expected results
assert all(np.equal(_mad(tst, axis=1), expected))
assert all(np.equal(_mad(tst.T, axis=0), expected))
assert _mad(tst) == 28 # Matches robust.mad from statsmodels

# Test exception with > 2-D arrays
tst = np.random.rand(3, 3, 3)
with pytest.raises(ValueError):
_mad(tst, axis=0)
31 changes: 0 additions & 31 deletions pyprep/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from psutil import virtual_memory
from scipy import linalg
from scipy.signal import firwin, lfilter, lfilter_zi
from scipy.stats import median_abs_deviation


def _union(list1, list2):
Expand Down Expand Up @@ -462,36 +461,6 @@ def _correlate_arrays(a, b, matlab_strict=False):
return np.diag(np.corrcoef(a, b)[:n_chan, n_chan:])


def _mad(x, axis=None):
"""Calculate median absolute deviations from the median (MAD) for an array.

Parameters
----------
x : np.ndarray
A 1-D or 2-D numeric array to summarize.
axis : {int, tuple of int, None}, optional
Axis along which MADs should be calculated. If ``None``, the MAD will
be calculated for the full input array. Defaults to ``None``.

Returns
-------
mad : scalar or np.ndarray
If no axis is specified, returns the MAD for the full input array as a
single numeric value. Otherwise, returns an ``np.ndarray`` containing
the MAD for each index along the specified axis.

"""
# Ensure array is either 1D or 2D
x = np.asarray(x)
if x.ndim > 2:
e = "Only 1D and 2D arrays are supported (input has {0} dimensions)"
raise ValueError(e.format(x.ndim))

# Calculate the median absolute deviation from the median
mad = median_abs_deviation(x, axis=axis)
return mad


def _filter_design(N_order, amp, freq):
"""Create FIR low-pass filter for EEG data using frequency sampling method.

Expand Down
Loading