Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 24, 2025
1 parent 2d73418 commit 0b0d1fc
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
10 changes: 8 additions & 2 deletions pycrostates/cluster/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@

from .._typing import Picks
from ..utils import _gev
from ..utils._checks import _check_n_jobs, _check_random_state, _check_type, _ensure_gfp_function
from ..utils._checks import (
_check_n_jobs,
_check_random_state,
_check_type,
)
from ..utils._docs import copy_doc, fill_doc
from ..utils._logs import logger
from ._base import _BaseCluster
Expand Down Expand Up @@ -206,7 +210,9 @@ def fit(
)
if not converged:
continue
gev = _gev(data, maps, segmentation, ch_type=self.get_channel_types()[0])
gev = _gev(
data, maps, segmentation, ch_type=self.get_channel_types()[0]
)
if best_gev is None or gev > best_gev:
best_gev, best_maps, best_segmentation = (
gev,
Expand Down
8 changes: 6 additions & 2 deletions pycrostates/segmentation/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,9 @@ def compute_parameters(self, norm_gfp: bool = True, return_dist: bool = False):
# create a 1D view of the labels array
labels = labels.reshape(-1)

gfp_function = _ensure_gfp_function(method='auto', ch_type=self._inst.info.get_channel_types()[0])
gfp_function = _ensure_gfp_function(
method="auto", ch_type=self._inst.info.get_channel_types()[0]
)
gfp = gfp_function(data)
if norm_gfp:
labeled = np.argwhere(labels != -1) # ignore unlabeled segments
Expand All @@ -177,7 +179,9 @@ def compute_parameters(self, norm_gfp: bool = True, return_dist: bool = False):
labeled_gfp = gfp[arg_where][:, 0]
# Correlation (i.e explained variance)
dist_corr = _correlation(
labeled_tp, state, ignore_polarity=self._predict_parameters["ignore_polarity"]
labeled_tp,
state,
ignore_polarity=self._predict_parameters["ignore_polarity"],
)
params[f"{state_name}_mean_corr"] = np.mean(dist_corr)
# Global Explained Variance
Expand Down

0 comments on commit 0b0d1fc

Please sign in to comment.