Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add epoch selection to EpochsSegmentation #139

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pycrostates/segmentation/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import itertools
from abc import abstractmethod
from copy import deepcopy
from typing import Optional, Union

import numpy as np
Expand Down Expand Up @@ -264,6 +265,10 @@ def compute_expected_transition_matrix(
ignore_repetitions=ignore_repetitions,
)

def copy(self):
"""Return copy of the segmentation instance."""
return deepcopy(self)

@fill_doc
def plot_cluster_centers(
self,
Expand Down
82 changes: 81 additions & 1 deletion pycrostates/segmentation/segmentation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Segmentation module for segmented data."""

from typing import Optional, Union
from __future__ import annotations

from typing import TYPE_CHECKING

from matplotlib.axes import Axes
from mne import BaseEpochs
Expand All @@ -11,6 +13,11 @@
from ..viz import plot_epoch_segmentation, plot_raw_segmentation
from ._base import _BaseSegmentation

if TYPE_CHECKING:
from typing import Optional, Union

from pandas import DataFrame


@fill_doc
class RawSegmentation(_BaseSegmentation):
Expand Down Expand Up @@ -130,6 +137,74 @@ def __init__(self, *args, **kwargs):
f"samples, while the 'labels' has {self._labels.shape[-1]} samples."
)

def __getitem__(self, item):
"""Select epochs in a :class:`~pycrostates.segmentation.EpochsSegmentation`.

Parameters
----------
item : slice, array-like, str, or list
See below for use cases.

Returns
-------
epochs : instance of EpochsSegmentation
Returns a copy of the original instance. See below for use cases.

Notes
-----
:class:`~pycrostates.segmentation.EpochsSegmentation` can be accessed as
``segmentation[...]`` in several ways:

1. **Integer or slice:** ``segmentation[idx]`` will return an
:class:`~pycrostates.segmentation.EpochsSegmentation` object with a subset of
epochs chosen by index (supports single index and Python-style slicing).

2. **String:** ``segmentation['name']`` will return an
:class:`~pycrostates.segmentation.EpochsSegmentation` object comprising only
the epochs labeled ``'name'`` (i.e., epochs created around events with the
label ``'name'``).

If there are no epochs labeled ``'name'`` but there are epochs
labeled with /-separated tags (e.g. ``'name/left'``,
``'name/right'``), then ``segmentation['name']`` will select the epochs
with labels that contain that tag (e.g., ``segmentation['left']`` selects
epochs labeled ``'audio/left'`` and ``'visual/left'``, but not
``'audio_left'``).

If multiple tags are provided *as a single string* (e.g.,
``segmentation['name_1/name_2']``), this selects epochs containing *all*
provided tags. For example, ``segmentation['audio/left']`` selects
``'audio/left'`` and ``'audio/quiet/left'``, but not
``'audio/right'``. Note that tag-based selection is insensitive to
order: tags like ``'audio/left'`` and ``'left/audio'`` will be
treated the same way when selecting via tag.

3. **List of strings:** ``segmentation[['name_1', 'name_2', ... ]]`` will
return an :class:`~pycrostates.segmentation.EpochsSegmentation` object
comprising epochs that match *any* of the provided names (i.e., the list of
names is treated as an inclusive-or condition). If *none* of the provided
names match any epoch labels, a ``KeyError`` will be raised.

If epoch labels are /-separated tags, then providing multiple tags
*as separate list entries* will likewise act as an inclusive-or
filter. For example, ``segmentation[['audio', 'left']]`` would select
``'audio/left'``, ``'audio/right'``, and ``'visual/left'``, but not
``'visual/right'``.

4. **Pandas query:** ``segmentation['pandas query']`` will return an
:class:`~pycrostates.segmentation.EpochsSegmentation` object with a subset of
epochs (and matching metadata) selected by the query called with
``self.metadata.eval``, e.g.::

epochs["col_a > 2 and col_b == 'foo'"]

would return all epochs whose associated ``col_a`` metadata was
greater than two, and whose ``col_b`` metadata was the string 'foo'.
Query-based indexing only works if Pandas is installed and
``self.metadata`` is a :class:`pandas.DataFrame`.
"""
inst = self.copy() # noqa: F841

@fill_doc
def plot(
self,
Expand Down Expand Up @@ -173,3 +248,8 @@ def plot(
def epochs(self) -> BaseEpochs:
"""`~mne.Epochs` instance from which the segmentation was computed."""
return self._inst.copy()

@property
def metadata(self) -> Optional[DataFrame]:
"""Epochs metadata."""
return self._inst.metadata
Comment on lines +252 to +255
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: remove and edit docstring above, we should probably not start exposing all methods from an epoch object here.