Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Enhance distance and gpf metrics #161

Draft
wants to merge 71 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
2a09d23
Add dev notebook to test stuff
vferat Feb 16, 2024
8694aaa
Change distance and gfp metrics
vferat Feb 16, 2024
711fca3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 16, 2024
0d0d26c
change correlation and gfp
vferat Feb 19, 2024
2b16b7c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 19, 2024
9a68b74
Update _checks.py
vferat Feb 19, 2024
9755705
Merge branch 'dev-meeg' of https://github.com/vferat/pycrostates into…
vferat Feb 19, 2024
2bc4d81
wip
vferat Feb 19, 2024
a83d6e4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 19, 2024
76a7da8
fix _correlation
vferat Feb 19, 2024
5dbff57
Use pairwise_distances + cosine
vferat Feb 20, 2024
7217470
Merge branch 'dev-meeg' of https://github.com/vferat/pycrostates into…
vferat Feb 20, 2024
ddb4e5f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2024
5f235e0
wip
vferat Feb 20, 2024
017f769
Merge branch 'dev-meeg' of https://github.com/vferat/pycrostates into…
vferat Feb 20, 2024
1ab0af1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2024
7bbcaa8
Update _base.py
vferat Feb 20, 2024
4275359
Merge branch 'dev-meeg' of https://github.com/vferat/pycrostates into…
vferat Feb 20, 2024
22a3a7e
wip
vferat Feb 20, 2024
4203b0c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2024
dff2af8
Update extract_gfp_peaks.py
vferat Feb 20, 2024
f473bdd
Merge branch 'dev-meeg' of https://github.com/vferat/pycrostates into…
vferat Feb 20, 2024
829f87f
wip
vferat Feb 20, 2024
d173c3a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2024
d4627b6
Update _base.py
vferat Feb 20, 2024
f9c64e0
Merge branch 'dev-meeg' of https://github.com/vferat/pycrostates into…
vferat Feb 20, 2024
9387250
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2024
38a0b21
wip
vferat Feb 20, 2024
f9ba5fc
Update _checks.py
vferat Feb 20, 2024
c86ab92
Merge branch 'dev-meeg' of https://github.com/vferat/pycrostates into…
vferat Feb 20, 2024
41be55d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2024
1d0c636
Merge branch 'main' into dev-meeg
vferat Feb 20, 2024
bada884
Merge branch 'dev-meeg' of https://github.com/vferat/pycrostates into…
vferat Feb 20, 2024
a8ec527
fix some tests
vferat Feb 20, 2024
3208985
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2024
17da9ef
Update utils.py
vferat Feb 20, 2024
869894a
Merge branch 'dev-meeg' of https://github.com/vferat/pycrostates into…
vferat Feb 20, 2024
6ef7747
Update utils.py
vferat Feb 20, 2024
53e996a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2024
7a71f67
wip
vferat Feb 20, 2024
e6a16a8
Merge branch 'dev-meeg' of https://github.com/vferat/pycrostates into…
vferat Feb 20, 2024
d94e657
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2024
e6687ab
wip
vferat Feb 21, 2024
5602e4d
Merge branch 'dev-meeg' of https://github.com/vferat/pycrostates into…
vferat Feb 21, 2024
b124339
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 21, 2024
f8b7075
wip
vferat Feb 21, 2024
b918683
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 21, 2024
8c3687a
remove prints
vferat Feb 21, 2024
9d9de27
Fix tests
vferat Feb 21, 2024
34e4fd9
Merge branch 'dev-meeg' of https://github.com/vferat/pycrostates into…
vferat Feb 21, 2024
69c3787
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 21, 2024
4eab0fd
Update _base.py
vferat Feb 21, 2024
617b687
Update _base.py
vferat Feb 21, 2024
b6c63cd
Merge branch 'dev-meeg' of https://github.com/vferat/pycrostates into…
vferat Feb 21, 2024
b4fed8f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 21, 2024
57b1aef
Update _base.py
vferat Apr 5, 2024
09641ea
wip ignore polarity
vferat Apr 15, 2024
9d70a6b
fix arguments
vferat Apr 15, 2024
8bf3cad
Update kmeans.py
vferat Apr 15, 2024
82764f3
Merge branch 'main' into dev-meeg
vferat Apr 15, 2024
0b688f5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 15, 2024
1a8426a
fix imports
vferat Apr 15, 2024
0dc9aa5
Merge branch 'dev-meeg' of https://github.com/vferat/pycrostates into…
vferat Apr 15, 2024
bd98d8d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 15, 2024
02c16cb
save ingore polarity parameter
vferat Apr 15, 2024
ae153f2
Merge branch 'dev-meeg' of https://github.com/vferat/pycrostates into…
vferat Apr 15, 2024
a458316
Squashed commit of the following:
vferat Apr 16, 2024
41aa160
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 16, 2024
af569af
wip save
vferat Apr 16, 2024
723b18a
Merge branch 'dev-meeg' of https://github.com/vferat/pycrostates into…
vferat Apr 16, 2024
d7fef13
Fix kmeans
vferat Apr 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,152 changes: 1,152 additions & 0 deletions dev.ipynb

Large diffs are not rendered by default.

876 changes: 876 additions & 0 deletions dev_main.ipynb

Large diffs are not rendered by default.

78 changes: 48 additions & 30 deletions pycrostates/cluster/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from mne.io.pick import _picks_to_idx

from ..segmentation import EpochsSegmentation, RawSegmentation
from ..utils import _corr_vectors
from ..utils import _correlation
from ..utils._checks import (
_check_picks_uniqueness,
_check_reject_by_annotation,
Expand Down Expand Up @@ -860,22 +860,32 @@ def _predict_raw(

data_ = data[:, onset:end]
segment = _BaseCluster._segment(
data_, cluster_centers_, factor, tol, half_window_size
data_,
cluster_centers_,
self._ignore_polarity,
factor,
tol,
half_window_size,
)
if reject_edges:
segment = _BaseCluster._reject_edge_segments(segment)
segmentation[onset:end] = segment

else:
segmentation = _BaseCluster._segment(
data, cluster_centers_, factor, tol, half_window_size
data,
cluster_centers_,
self._ignore_polarity,
factor,
tol,
half_window_size,
)
if reject_edges:
segmentation = _BaseCluster._reject_edge_segments(segmentation)

if 0 < min_segment_length:
segmentation = _BaseCluster._reject_short_segments(
segmentation, data, min_segment_length
segmentation, data, min_segment_length, self._ignore_polarity
)

# Provide properties to copy the arrays
Expand Down Expand Up @@ -914,12 +924,17 @@ def _predict_epochs(
segments = []
for epoch_data in data:
segment = _BaseCluster._segment(
epoch_data, cluster_centers_, factor, tol, half_window_size
epoch_data,
cluster_centers_,
self._ignore_polarity,
factor,
tol,
half_window_size,
)

if 0 < min_segment_length:
segment = _BaseCluster._reject_short_segments(
segment, epoch_data, min_segment_length
segment, epoch_data, min_segment_length, self._ignore_polarity
)
if reject_edges:
segment = _BaseCluster._reject_edge_segments(segment)
Expand All @@ -940,33 +955,28 @@ def _predict_epochs(
def _segment(
data: ScalarFloatArray,
states: ScalarFloatArray,
ignore_polarity: bool,
factor: int,
tol: Union[int, float],
half_window_size: int,
) -> ScalarIntArray:
"""Create segmentation. Must operate on a copy of states."""
data -= np.mean(data, axis=0)
std = np.std(data, axis=0)
std[std == 0] = 1 # std == 0 -> null map
data /= std

states -= np.mean(states, axis=1)[:, np.newaxis]
states /= np.std(states, axis=1)[:, np.newaxis]

labels = np.argmax(np.abs(np.dot(states, data)), axis=0)
corr = np.zeros((states.shape[0], data.shape[1]))
for k in range(0, states.shape[0]):
corr[k] = _correlation(data, states[k], ignore_polarity=ignore_polarity)
labels = np.argmax(corr, axis=0)

if factor != 0:
labels = _BaseCluster._smooth_segmentation(
data, states, labels, factor, tol, half_window_size
)

return labels

@staticmethod
def _smooth_segmentation(
data: ScalarFloatArray,
states: ScalarFloatArray,
labels: ScalarIntArray,
data: NDArray[float],
states: NDArray[float],
labels: NDArray[int],
factor: int,
tol: Union[int, float],
half_window_size: int,
Expand All @@ -984,6 +994,7 @@ def _smooth_segmentation(
vol. 42, no. 7, pp. 658-665, July 1995,
https://doi.org/10.1109/10.391164.
"""
# TODO: ignore_polarity
Ne, Nt = data.shape
Nu = states.shape[0]
Vvar = np.sum(data * data, axis=0)
Expand Down Expand Up @@ -1021,7 +1032,8 @@ def _reject_short_segments(
segmentation: ScalarIntArray,
data: ScalarFloatArray,
min_segment_length: int,
) -> ScalarIntArray:
ignore_polarity: bool,
) -> NDArray[int]:
"""Reject segments that are too short.

Reject segments that are too short by replacing the labels with the adjacent
Expand All @@ -1048,16 +1060,16 @@ def _reject_short_segments(

while len(new_segment) != 0:
# compute correlation left/right side
left_corr = np.abs(
_corr_vectors(
data[:, left - 1].T,
data[:, left].T,
)
)
right_corr = np.abs(
_corr_vectors(data[:, right].T, data[:, right + 1].T)
)

left_corr = _correlation(
data[:, left - 1],
data[:, left],
ignore_polarity=ignore_polarity,
)[0]
right_corr = _correlation(
data[:, right],
data[:, right + 1],
ignore_polarity=ignore_polarity,
)[0]
if np.abs(right_corr - left_corr) <= 1e-8:
# equal corr, try to do both sides
if len(new_segment) == 1:
Expand Down Expand Up @@ -1217,3 +1229,9 @@ def _check_n_clusters(n_clusters: int) -> int:
f"Provided: '{n_clusters}'."
)
return n_clusters

@staticmethod
def _check_ignore_polarity(ignore_polarity: bool) -> bool:
"""Check that ignore_polarity is a boolean."""
_check_type(ignore_polarity, (bool,), item_name="ignore_polarity")
return ignore_polarity
12 changes: 2 additions & 10 deletions pycrostates/cluster/aahc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from mne.io import BaseRaw

from .._typing import Picks, ScalarFloatArray, ScalarIntArray
from ..utils import _corr_vectors
from ..utils import _gev
from ..utils._checks import _check_type
from ..utils._docs import copy_doc, fill_doc
from ..utils._logs import logger
Expand Down Expand Up @@ -184,12 +184,10 @@ def _aahc(
normalize_input: bool,
) -> tuple[float, ScalarFloatArray, ScalarIntArray]:
"""Run the AAHC algorithm."""
gfp_sum_sq = np.sum(data**2)
maps, segmentation = AAHCluster._compute_maps(
data, n_clusters, ignore_polarity, normalize_input
)
map_corr = _corr_vectors(data, maps[segmentation].T)
gev = np.sum((data * map_corr) ** 2) / gfp_sum_sq
gev = _gev(data, maps, segmentation)
return gev, maps, segmentation

# pylint: disable=too-many-locals
Expand Down Expand Up @@ -275,12 +273,6 @@ def fitted(self, fitted):
if not fitted:
self._GEV_ = None

@staticmethod
def _check_ignore_polarity(ignore_polarity: bool) -> bool:
"""Check that ignore_polarity is a boolean."""
_check_type(ignore_polarity, (bool,), item_name="ignore_polarity")
return ignore_polarity

@staticmethod
def _check_normalize_input(normalize_input: bool) -> bool:
"""Check that normalize_input is a boolean."""
Expand Down
44 changes: 31 additions & 13 deletions pycrostates/cluster/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from mne.parallel import parallel_func
from numpy.random import Generator, RandomState

from ..utils import _corr_vectors
from .._typing import Picks
from ..utils import _gev
from ..utils._checks import _check_n_jobs, _check_random_state, _check_type
from ..utils._docs import copy_doc, fill_doc
from ..utils._logs import logger
Expand Down Expand Up @@ -73,6 +74,9 @@ def __init__(
# fit variables
self._GEV_ = None

# ignore polarity
self._ignore_polarity = True

def _repr_html_(self, caption=None):
from ..html_templates import repr_templates_env

Expand Down Expand Up @@ -116,6 +120,7 @@ def __eq__(self, other: Any) -> bool:
# '_random_state',
# TODO: think about comparison and I/O for random states
"_GEV_",
"_ignore_polarity",
)
for attribute in attributes:
try:
Expand Down Expand Up @@ -191,7 +196,12 @@ def fit(
count_converged = 0
for init in inits:
gev, maps, segmentation, converged = ModKMeans._kmeans(
data, self._n_clusters, self._max_iter, init, self._tol
data,
self._n_clusters,
self._ignore_polarity,
self._max_iter,
init,
self._tol,
)
if not converged:
continue
Expand All @@ -207,7 +217,14 @@ def fit(
ModKMeans._kmeans, n_jobs, total=self._n_init
)
runs = parallel(
p_fun(data, self._n_clusters, self._max_iter, init, self._tol)
p_fun(
data,
self._n_clusters,
self._ignore_polarity,
self._max_iter,
init,
self._tol,
)
for init in inits
)
try:
Expand Down Expand Up @@ -238,7 +255,6 @@ def fit(
self._cluster_centers_ = best_maps
self._labels_ = best_segmentation
self._fitted = True
self._ignore_polarity = True

@copy_doc(_BaseCluster.save)
def save(self, fname: Union[str, Path]):
Expand All @@ -258,6 +274,7 @@ def save(self, fname: Union[str, Path]):
n_init=self._n_init,
max_iter=self._max_iter,
tol=self._tol,
ignore_polarity=self._ignore_polarity,
GEV_=self._GEV_,
)

Expand All @@ -266,25 +283,23 @@ def save(self, fname: Union[str, Path]):
def _kmeans(
data: ScalarFloatArray,
n_clusters: int,
ignore_polarity: bool,
max_iter: int,
random_state: Union[RandomState, Generator],
tol: Union[int, float],
) -> tuple[float, ScalarFloatArray, ScalarIntArray, bool]:
"""Run the k-means algorithm."""
gfp_sum_sq = np.sum(data**2)
maps, converged = ModKMeans._compute_maps(
data, n_clusters, max_iter, random_state, tol
maps, segmentation, converged = ModKMeans._compute_maps(
data, n_clusters, ignore_polarity, max_iter, random_state, tol
)
activation = maps.dot(data)
segmentation = np.argmax(np.abs(activation), axis=0)
map_corr = _corr_vectors(data, maps[segmentation].T)
gev = np.sum((data * map_corr) ** 2) / gfp_sum_sq
gev = _gev(data, maps, segmentation)
return gev, maps, segmentation, converged

@staticmethod
def _compute_maps(
data: ScalarFloatArray,
n_clusters: int,
ignore_polarity: bool,
max_iter: int,
random_state: Union[RandomState, Generator],
tol: Union[int, float],
Expand Down Expand Up @@ -317,7 +332,10 @@ def _compute_maps(
for _ in range(max_iter):
# Assign each sample to the best matching microstate
activation = maps.dot(data)
segmentation = np.argmax(np.abs(activation), axis=0)
if ignore_polarity:
segmentation = np.argmax(np.abs(activation), axis=0)
else:
segmentation = np.argmax(activation, axis=0)

# Recompute the topographic maps of the microstates, based on the
# samples that were assigned to each state.
Expand Down Expand Up @@ -346,7 +364,7 @@ def _compute_maps(
else:
converged = False

return maps, converged
return maps, segmentation, converged

# --------------------------------------------------------------------
@property
Expand Down
31 changes: 28 additions & 3 deletions pycrostates/cluster/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_reject_short_segments():
[3, 3, 3, 3, 3, 3, 3, 6, 4, 5, 2, 2, 2],
]
)
segmentation = _BaseCluster._reject_short_segments(segmentation, data, 3)
segmentation = _BaseCluster._reject_short_segments(segmentation, data, 3, True)
# solo 1 should turn to 2; initial 0 should not change
assert [0, 0, 1, 1, 1, 3, 3, 3, 2, 2, 2, 2, 2] == segmentation

Expand All @@ -53,7 +53,7 @@ def test_reject_short_segments():
[3, 3, 3, 3, 3, 3, 3, 6, 4, 4, 6, 2, 2, 2],
]
)
segmentation = _BaseCluster._reject_short_segments(segmentation, data, 3)
segmentation = _BaseCluster._reject_short_segments(segmentation, data, 3, True)
assert [0, 0, 1, 1, 1, 3, 3, 3, 3, 2, 2, 2, 2, 2] == segmentation

# singleton, same correlation
Expand All @@ -65,7 +65,32 @@ def test_reject_short_segments():
[3, 3, 3, 3, 3, 3, 3, 6, 4, 6, 2, 2, 2],
]
)
segmentation = _BaseCluster._reject_short_segments(segmentation, data, 3)
segmentation = _BaseCluster._reject_short_segments(segmentation, data, 3, True)
assert [0, 0, 1, 1, 1, 3, 3, 3, 3, 2, 2, 2, 2] == segmentation

# ignore polarity
segmentation = [0, 0, 1, 1, 1, 3, 3, 3, 1, 2, 2, 2, 2]
data = np.array(
[
[1, 1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 3, 0.5, -0.5, -0.5, 1, 1, 1],
[3, 3, 3, 3, 3, 3, 3, 4, -4, -4, 2, 2, 2],
]
)
segmentation = _BaseCluster._reject_short_segments(segmentation, data, 3, False)
# solo 1 should turn to 2; initial 0 should not change
assert [0, 0, 1, 1, 1, 3, 3, 3, 2, 2, 2, 2, 2] == segmentation

segmentation = [0, 0, 1, 1, 1, 3, 3, 3, 1, 2, 2, 2, 2]
data = np.array(
[
[1, 1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1],
[2, 2, 2, 2, 2, 2, 3, -0.5, -0.5, 0.5, 1, 1, 1],
[3, 3, 3, 3, 3, 3, 3, -4, -4, 4, 2, 2, 2],
]
)
segmentation = _BaseCluster._reject_short_segments(segmentation, data, 3, False)
# solo 1 should turn to 2; initial 0 should not change
assert [0, 0, 1, 1, 1, 3, 3, 3, 3, 2, 2, 2, 2] == segmentation


Expand Down
7 changes: 4 additions & 3 deletions pycrostates/cluster/tests/test_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,10 +320,11 @@ def test_reorder(caplog):
assert ModK._cluster_names[0] == ModK_._cluster_names[1]

# test ._labels_ reordering
x = ModK_._labels_[:20]
# x: before re-order:
# x = [3, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]
ModK_ = ModK.copy()
ModK_._labels_[:20] = [3, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]
ModK_.reorder_clusters(order=np.array([1, 0, 2, 3]))
# y: expected re-ordered _labels
x = ModK_._labels_[:20]
y = [3, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
assert np.all(x == y)

Expand Down
Loading
Loading