Skip to content

Commit

Permalink
Minor Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
edeno committed Oct 14, 2020
1 parent e816f03 commit 0a6c881
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 100 deletions.
2 changes: 1 addition & 1 deletion spectral_connectivity/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# flake8: noqa
from .connectivity import Connectivity
from .wrapper import multitaper_connectivity
from .transforms import Multitaper
from .wrapper import multitaper_connectivity
57 changes: 34 additions & 23 deletions spectral_connectivity/wrapper.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from logging import getLogger

import numpy as np
import xarray as xr
from spectral_connectivity.connectivity import Connectivity
from spectral_connectivity.transforms import Multitaper
import xarray as xr
import numpy as np
from logging import getLogger

logger = getLogger(__name__)


def connectivity_to_xarray(m, method='coherence_magnitude', signal_names=None, squeeze=False, **kwargs):
def connectivity_to_xarray(m, method='coherence_magnitude', signal_names=None,
squeeze=False, **kwargs):
"""
calculate connectivity using `method`. Returns an xarray
with dimensions of ['Time', 'Frequency', 'Source', 'Target']
Expand All @@ -16,10 +18,11 @@ def connectivity_to_xarray(m, method='coherence_magnitude', signal_names=None, s
Parameters
-----------
signal_names : iterable of strings
Sames of time series used to name the 'Source' and 'Target' axes of xarray
Sames of time series used to name the 'Source' and 'Target' axes of
xarray.
squeeze : bool
Whether to only take the first and last source and target time series. Only makes sense for
one pair of signals and symmetrical measures
Whether to only take the first and last source and target time series.
Only makes sense for one pair of signals and symmetrical measures
"""
assert method not in ['power', 'group_delay', 'canonical_coherence'], \
Expand All @@ -29,19 +32,23 @@ def connectivity_to_xarray(m, method='coherence_magnitude', signal_names=None, s
connectivity_mat, labels = getattr(connectivity, method)(**kwargs)
else:
connectivity_mat = getattr(connectivity, method)(**kwargs)
if (m.time_series.shape[-1] == 2) and squeeze: # Only one couple (only makes sense for symmetrical metrics)
logger.warning(f'Squeeze is on, but there are {m.time_series.shape[-1]} pairs!')
# Only one couple (only makes sense for symmetrical metrics)
if (m.time_series.shape[-1] == 2) and squeeze:
logger.warning(
f'Squeeze is on, but there are {m.time_series.shape[-1]} pairs!')
connectivity_mat = connectivity_mat[:, :, 0, -1]
xar = xr.DataArray(connectivity_mat,
coords=[connectivity.time, connectivity.frequencies],
coords=[connectivity.time,
connectivity.frequencies],
dims=['Time', 'Frequency'])

else: # Name the source and target axes
if signal_names is None:
signal_names = np.arange(m.time_series.shape[-1])

xar = xr.DataArray(connectivity_mat,
coords=[connectivity.time, connectivity.frequencies, signal_names, signal_names],
coords=[connectivity.time, connectivity.frequencies,
signal_names, signal_names],
dims=['Time', 'Frequency', 'Source', 'Target'])

xar.name = method
Expand All @@ -50,17 +57,18 @@ def connectivity_to_xarray(m, method='coherence_magnitude', signal_names=None, s
if (attr[0] == '_') or (attr == 'time_series'):
continue
# If we don't add 'mt_', get:
# TypeError: '.dt' accessor only available for DataArray with datetime64 timedelta64 dtype
# TypeError: '.dt' accessor only available for DataArray with
# datetime64 timedelta64 dtype
# or for arrays containing cftime datetime objects.
xar.attrs['mt_' + attr] = getattr(m, attr)

return xar


def multitaper_connectivity(time_series, sampling_frequency, time_window_duration=None,
method='coherence_magnitude', signal_names=None, squeeze=False,
connectivity_kwargs=None,
**kwargs):
def multitaper_connectivity(time_series, sampling_frequency,
time_window_duration=None,
method='coherence_magnitude', signal_names=None,
squeeze=False, connectivity_kwargs=None, **kwargs):
"""
Transform time series to multitaper and
calculate connectivity using `method`. Returns an xarray.DataArray
Expand All @@ -70,10 +78,11 @@ def multitaper_connectivity(time_series, sampling_frequency, time_window_duratio
Parameters
-----------
signal_names : iterable of strings
Sames of time series used to name the 'Source' and 'Target' axes of xarray
Sames of time series used to name the 'Source' and 'Target' axes of
xarray.
squeeze : bool
Whether to only take the first and last source and target time series. Only makes sense for
one pair of signals and symmetrical measures
Whether to only take the first and last source and target time series.
Only makes sense for one pair of signals and symmetrical measures.
Attributes
----------
Expand All @@ -87,10 +96,11 @@ def multitaper_connectivity(time_series, sampling_frequency, time_window_duratio
Duration of sliding window in which to compute the fft. Defaults to
the entire time if not set.
signal_names : iterable of strings
Sames of time series used to name the 'Source' and 'Target' axes of xarray
Sames of time series used to name the 'Source' and 'Target' axes of
xarray.
squeeze : bool
Whether to only take the first and last source and target time series. Only makes sense for
one pair of signals and symmetrical measures
Whether to only take the first and last source and target time series.
Only makes sense for one pair of signals and symmetrical measures.
connectivity_kwargs : dict
Arguments to pass to connectivity calculation
Expand All @@ -102,4 +112,5 @@ def multitaper_connectivity(time_series, sampling_frequency, time_window_duratio
sampling_frequency=sampling_frequency,
time_window_duration=time_window_duration,
**kwargs)
return connectivity_to_xarray(m, method, signal_names, squeeze, **connectivity_kwargs)
return connectivity_to_xarray(m, method, signal_names, squeeze,
**connectivity_kwargs)
22 changes: 8 additions & 14 deletions tests/test_connectivity.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
import numpy as np
from pytest import mark
from unittest.mock import PropertyMock

from spectral_connectivity.connectivity import (Connectivity, _bandpass,
_complex_inner_product,
_conjugate_transpose,
_find_largest_independent_group,
_find_largest_significant_group,
_get_independent_frequencies,
_get_independent_frequency_step,
_inner_combination,
_remove_instantaneous_causality,
_reshape, _set_diagonal_to_zero,
_squared_magnitude, _total_inflow,
_total_outflow)
import numpy as np
from pytest import mark
from spectral_connectivity.connectivity import (
Connectivity, _bandpass, _complex_inner_product, _conjugate_transpose,
_find_largest_independent_group, _find_largest_significant_group,
_get_independent_frequencies, _get_independent_frequency_step,
_inner_combination, _remove_instantaneous_causality, _reshape,
_set_diagonal_to_zero, _squared_magnitude, _total_inflow, _total_outflow)


@mark.parametrize('axis', [(0), (1), (2), (3)])
Expand Down
9 changes: 3 additions & 6 deletions tests/test_minimum_phase_decomposition.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import numpy as np
from scipy.fftpack import fft, ifft
from scipy.signal import freqz_zpk

from spectral_connectivity.minimum_phase_decomposition import (_check_convergence,
_conjugate_transpose,
_get_causal_signal,
_get_intial_conditions,
minimum_phase_decomposition)
from spectral_connectivity.minimum_phase_decomposition import (
_check_convergence, _conjugate_transpose, _get_causal_signal,
_get_intial_conditions, minimum_phase_decomposition)


def test__check_convergence():
Expand Down
9 changes: 4 additions & 5 deletions tests/test_statistics.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import numpy as np
from pytest import mark

from spectral_connectivity.statistics import (Benjamini_Hochberg_procedure,
Bonferroni_correction,
fisher_z_transform,
get_normal_distribution_p_values,
coherence_bias)
Bonferroni_correction,
coherence_bias,
fisher_z_transform,
get_normal_distribution_p_values)


def test_get_normal_distribution_p_values():
Expand Down
14 changes: 7 additions & 7 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import numpy as np
from nitime.algorithms.spectral import dpss_windows as nitime_dpss_windows
from pytest import mark
from scipy.signal import correlate

from nitime.algorithms.spectral import dpss_windows as nitime_dpss_windows
from spectral_connectivity.transforms import (Multitaper, _add_axes,
_auto_correlation, _fix_taper_sign,
_get_low_bias_tapers,
_get_taper_eigenvalues,
_multitaper_fft, _sliding_window,
dpss_windows)
_auto_correlation,
_fix_taper_sign,
_get_low_bias_tapers,
_get_taper_eigenvalues,
_multitaper_fft, _sliding_window,
dpss_windows)


def test__add_axes():
Expand Down
44 changes: 0 additions & 44 deletions tests/test_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,50 +67,6 @@ def test_multitaper_connectivity():
assert not (np.isnan(m.values)).all()


# def test_multitaper_canonical_coherence():
# time_window_duration = .2
# time_halfbandwidth_product = 2
# frequency_of_interest = 200
# sampling_frequency = 1500
# time_extent = (0, 2.400)
# n_trials = 100
# n_signals = 6
# n_time_samples = int(((time_extent[1] - time_extent[0]) * sampling_frequency) + 1)
# time = np.linspace(time_extent[0], time_extent[1], num=n_time_samples, endpoint=True)
#
# signal = np.zeros((n_time_samples, n_trials, n_signals))
# signal[:, :, 0:2] = (
# np.sin(2 * np.pi * time * frequency_of_interest)[:, np.newaxis, np.newaxis] *
# np.ones((1, n_trials, 2)))
#
# expected_time = np.arange(time_extent[0], time_extent[-1], time_window_duration)
#
# if not np.allclose(expected_time[-1] + time_window_duration, time_extent[-1]):
# expected_time = expected_time[:-1]
#
# other_signals = (n_signals + 1) // 2
# n_other_signals = n_signals - other_signals
# phase_offset = np.random.uniform(-np.pi, np.pi, size=(n_time_samples, n_trials, n_other_signals))
# phase_offset[np.where(time > 1.5), :] = np.pi / 2
# signal[:, :, other_signals:] = np.sin(
# (2 * np.pi * time[:, np.newaxis, np.newaxis] * frequency_of_interest) + phase_offset)
# noise = np.random.normal(10, 7, signal.shape)
# group_labels = (['a'] * (n_signals - n_other_signals)) + (['b'] * n_other_signals)
#
# m = multitaper_connectivity(signal + noise,
# sampling_frequency=sampling_frequency,
# time_halfbandwidth_product=time_halfbandwidth_product,
# time_window_duration=time_window_duration,
# time_window_step=0.080,
# method='canonical_coherence',
# connectivity_kwargs={"group_labels": group_labels}
# )
#
# assert np.allclose(m.Time.values, expected_time)
# assert not (m.values == 0).all()
# assert not (np.isnan(m.values)).all()


@mark.parametrize('n_signals', range(2, 5))
def test_multitaper_n_signals(n_signals):
time_window_duration = .1
Expand Down

0 comments on commit 0a6c881

Please sign in to comment.