Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions pynapple/core/interval_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,7 @@
jitunion,
)
from .config import nap_config
from .metadata_class import (
_MetadataMixin,
add_meta_docstring,
add_or_convert_metadata,
)
from .metadata_class import _MetadataMixin, add_meta_docstring, add_or_convert_metadata
from .time_index import TsIndex
from .utils import (
_convert_iter_to_str,
Expand Down
8 changes: 1 addition & 7 deletions pynapple/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@
from .folder import Folder
from .interface_npz import NPZFile
from .interface_nwb import NWBFile
from .misc import (
append_NWB_LFP,
load_eeg,
load_file,
load_folder,
load_session,
)
from .misc import append_NWB_LFP, load_eeg, load_file, load_folder, load_session
83 changes: 49 additions & 34 deletions pynapple/io/interface_nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def iterate_over_nwb(nwbfile):
elif len(obj.data.shape) == 1:
yield obj, {"id": oid, "type": "Tsd"}

elif obj.__class__.__name__ == "EventsTable": # TODO
yield obj, {"id": oid, "type": "TsdFrame"}


def _extract_compatible_data_from_nwbfile(nwbfile):
"""Extract all the NWB objects that can be converted to a pynapple object. If two objects have the same names, they
Expand Down Expand Up @@ -214,52 +217,64 @@ def _make_tsd_frame(obj, lazy_loading=True):
"""
pynwb = importlib.import_module("pynwb")

d = obj.data
if not lazy_loading:
d = d[:]

if obj.timestamps is not None:
t = obj.timestamps[:]
if not hasattr(obj, "data"):
return _make_tsdframe_from_eventstable(obj)
else:
t = obj.starting_time + np.arange(obj.num_samples) / obj.rate
d = obj.data
if not lazy_loading:
d = d[:]

if isinstance(obj, pynwb.behavior.SpatialSeries):
if obj.data.shape[1] == 2:
columns = ["x", "y"]
elif obj.data.shape[1] == 3:
columns = ["x", "y", "z"]
if obj.timestamps is not None:
t = obj.timestamps[:]
else:
columns = np.arange(obj.data.shape[1])
t = obj.starting_time + np.arange(obj.num_samples) / obj.rate

elif isinstance(obj, pynwb.ecephys.ElectricalSeries):
# (channel mapping)
try:
df = obj.electrodes.to_dataframe()
if hasattr(df, "label"):
columns = df["label"].values
if isinstance(obj, pynwb.behavior.SpatialSeries):
if obj.data.shape[1] == 2:
columns = ["x", "y"]
elif obj.data.shape[1] == 3:
columns = ["x", "y", "z"]
else:
columns = df.index.values
except Exception:
columns = np.arange(obj.data.shape[1])

elif isinstance(obj, pynwb.ecephys.ElectricalSeries):
# (channel mapping)
try:
df = obj.electrodes.to_dataframe()
if hasattr(df, "label"):
columns = df["label"].values
else:
columns = df.index.values
except Exception:
columns = np.arange(obj.data.shape[1])

elif isinstance(obj, pynwb.ophys.RoiResponseSeries):
# (cell number)
try:
columns = obj.rois["id"][:]
except Exception:
columns = np.arange(obj.data.shape[1])

else:
columns = np.arange(obj.data.shape[1])

elif isinstance(obj, pynwb.ophys.RoiResponseSeries):
# (cell number)
try:
columns = obj.rois["id"][:]
except Exception:
if len(columns) >= d.shape[1]: # Weird sometimes if background ID added
columns = columns[0 : obj.data.shape[1]]
else:
columns = np.arange(obj.data.shape[1])

else:
columns = np.arange(obj.data.shape[1])
data = nap.TsdFrame(t=t, d=d, columns=columns, load_array=not lazy_loading)

if len(columns) >= d.shape[1]: # Weird sometimes if background ID added
columns = columns[0 : obj.data.shape[1]]
else:
columns = np.arange(obj.data.shape[1])
return data

data = nap.TsdFrame(t=t, d=d, columns=columns, load_array=not lazy_loading)

return data
def _make_tsdframe_from_eventstable(obj):
if hasattr(obj, "to_dataframe"):
df = obj.to_dataframe().set_index("timestamp")
df = df.select_dtypes(include=np.number)
return nap.TsdFrame(df)
else:
return None


def _make_tsgroup(obj, **kwargs):
Expand Down
5 changes: 1 addition & 4 deletions pynapple/process/perievent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@
import numpy as np

from .. import core as nap
from ._process_functions import (
_perievent_continuous,
_perievent_trigger_average,
)
from ._process_functions import _perievent_continuous, _perievent_trigger_average


def _validate_perievent_inputs(func):
Expand Down
29 changes: 28 additions & 1 deletion tests/test_nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
# @Last Modified time: 2023-09-18 10:28:42

"""Tests of nwb reading for `pynapple` package."""

import warnings

import numpy as np
import pynwb
import pytest
import requests
from pynwb.testing.mock.file import mock_NWBFile
from pynwb.testing.mock.utils import name_generator_registry

Expand Down Expand Up @@ -663,3 +663,30 @@ def test_path_utility_func(full_path_to_key, expected):
out = _get_unique_identifier(full_path_to_key)
for k in full_path_to_key:
assert expected[k] == out[k]


def test_events_tables_load(tmp_path):
url = "https://osf.io/7grz4/download"
temp_file = tmp_path / "test_events_table.nwb"

response = requests.get(url, allow_redirects=True)
response.raise_for_status()

with open(temp_file, "wb") as f:
f.write(response.content)

# Now load with nap
dat = nap.NWBFile(temp_file)
np.testing.assert_array_equal(
dat["ttl_events"]["pulse_value"].d, np.array([55, 1, 2, 3, 31])
)
np.testing.assert_array_equal(
dat["ttl_events"].t,
np.array([6820.092244, 6821.208244, 6822.210644, 6822.711364, 6825.934244]),
)
np.testing.assert_array_equal(
dat["stimulus_presentations"].d, np.array([[0.0, 1.0024], [1.0, 0.99484]])
)
np.testing.assert_array_equal(
dat["stimulus_presentations"].t, np.array([6821.208244, 6825.208244])
)
Loading