diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index 45b4c1bec..d92c09f91 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -19,11 +19,7 @@ jitunion, ) from .config import nap_config -from .metadata_class import ( - _MetadataMixin, - add_meta_docstring, - add_or_convert_metadata, -) +from .metadata_class import _MetadataMixin, add_meta_docstring, add_or_convert_metadata from .time_index import TsIndex from .utils import ( _convert_iter_to_str, diff --git a/pynapple/io/__init__.py b/pynapple/io/__init__.py index f4eb2a70b..20b194c7b 100644 --- a/pynapple/io/__init__.py +++ b/pynapple/io/__init__.py @@ -1,10 +1,4 @@ from .folder import Folder from .interface_npz import NPZFile from .interface_nwb import NWBFile -from .misc import ( - append_NWB_LFP, - load_eeg, - load_file, - load_folder, - load_session, -) +from .misc import append_NWB_LFP, load_eeg, load_file, load_folder, load_session diff --git a/pynapple/io/interface_nwb.py b/pynapple/io/interface_nwb.py index 907c75d5f..338192781 100644 --- a/pynapple/io/interface_nwb.py +++ b/pynapple/io/interface_nwb.py @@ -90,6 +90,9 @@ def iterate_over_nwb(nwbfile): elif len(obj.data.shape) == 1: yield obj, {"id": oid, "type": "Tsd"} + elif obj.__class__.__name__ == "EventsTable": # TODO + yield obj, {"id": oid, "type": "TsdFrame"} + def _extract_compatible_data_from_nwbfile(nwbfile): """Extract all the NWB objects that can be converted to a pynapple object. If two objects have the same names, they @@ -214,52 +217,64 @@ def _make_tsd_frame(obj, lazy_loading=True): """ pynwb = importlib.import_module("pynwb") - d = obj.data - if not lazy_loading: - d = d[:] - - if obj.timestamps is not None: - t = obj.timestamps[:] + if not hasattr(obj, "data"): + return _make_tsdframe_from_eventstable(obj) else: - t = obj.starting_time + np.arange(obj.num_samples) / obj.rate + d = obj.data + if not lazy_loading: + d = d[:] - if isinstance(obj, pynwb.behavior.SpatialSeries): - if obj.data.shape[1] == 2: - columns = ["x", "y"] - elif obj.data.shape[1] == 3: - columns = ["x", "y", "z"] + if obj.timestamps is not None: + t = obj.timestamps[:] else: - columns = np.arange(obj.data.shape[1]) + t = obj.starting_time + np.arange(obj.num_samples) / obj.rate - elif isinstance(obj, pynwb.ecephys.ElectricalSeries): - # (channel mapping) - try: - df = obj.electrodes.to_dataframe() - if hasattr(df, "label"): - columns = df["label"].values + if isinstance(obj, pynwb.behavior.SpatialSeries): + if obj.data.shape[1] == 2: + columns = ["x", "y"] + elif obj.data.shape[1] == 3: + columns = ["x", "y", "z"] else: - columns = df.index.values - except Exception: + columns = np.arange(obj.data.shape[1]) + + elif isinstance(obj, pynwb.ecephys.ElectricalSeries): + # (channel mapping) + try: + df = obj.electrodes.to_dataframe() + if hasattr(df, "label"): + columns = df["label"].values + else: + columns = df.index.values + except Exception: + columns = np.arange(obj.data.shape[1]) + + elif isinstance(obj, pynwb.ophys.RoiResponseSeries): + # (cell number) + try: + columns = obj.rois["id"][:] + except Exception: + columns = np.arange(obj.data.shape[1]) + + else: columns = np.arange(obj.data.shape[1]) - elif isinstance(obj, pynwb.ophys.RoiResponseSeries): - # (cell number) - try: - columns = obj.rois["id"][:] - except Exception: + if len(columns) >= d.shape[1]: # Weird sometimes if background ID added + columns = columns[0 : obj.data.shape[1]] + else: columns = np.arange(obj.data.shape[1]) - else: - columns = np.arange(obj.data.shape[1]) + data = nap.TsdFrame(t=t, d=d, columns=columns, load_array=not lazy_loading) - if len(columns) >= d.shape[1]: # Weird sometimes if background ID added - columns = columns[0 : obj.data.shape[1]] - else: - columns = np.arange(obj.data.shape[1]) + return data - data = nap.TsdFrame(t=t, d=d, columns=columns, load_array=not lazy_loading) - return data +def _make_tsdframe_from_eventstable(obj): + if hasattr(obj, "to_dataframe"): + df = obj.to_dataframe().set_index("timestamp") + df = df.select_dtypes(include=np.number) + return nap.TsdFrame(df) + else: + return None def _make_tsgroup(obj, **kwargs): diff --git a/pynapple/process/perievent.py b/pynapple/process/perievent.py index e98f938a7..e8cb7541b 100644 --- a/pynapple/process/perievent.py +++ b/pynapple/process/perievent.py @@ -9,10 +9,7 @@ import numpy as np from .. import core as nap -from ._process_functions import ( - _perievent_continuous, - _perievent_trigger_average, -) +from ._process_functions import _perievent_continuous, _perievent_trigger_average def _validate_perievent_inputs(func): diff --git a/tests/test_nwb.py b/tests/test_nwb.py index b24f2a1bc..fa74fae97 100644 --- a/tests/test_nwb.py +++ b/tests/test_nwb.py @@ -5,12 +5,12 @@ # @Last Modified time: 2023-09-18 10:28:42 """Tests of nwb reading for `pynapple` package.""" - import warnings import numpy as np import pynwb import pytest +import requests from pynwb.testing.mock.file import mock_NWBFile from pynwb.testing.mock.utils import name_generator_registry @@ -663,3 +663,30 @@ def test_path_utility_func(full_path_to_key, expected): out = _get_unique_identifier(full_path_to_key) for k in full_path_to_key: assert expected[k] == out[k] + + +def test_events_tables_load(tmp_path): + url = "https://osf.io/7grz4/download" + temp_file = tmp_path / "test_events_table.nwb" + + response = requests.get(url, allow_redirects=True) + response.raise_for_status() + + with open(temp_file, "wb") as f: + f.write(response.content) + + # Now load with nap + dat = nap.NWBFile(temp_file) + np.testing.assert_array_equal( + dat["ttl_events"]["pulse_value"].d, np.array([55, 1, 2, 3, 31]) + ) + np.testing.assert_array_equal( + dat["ttl_events"].t, + np.array([6820.092244, 6821.208244, 6822.210644, 6822.711364, 6825.934244]), + ) + np.testing.assert_array_equal( + dat["stimulus_presentations"].d, np.array([[0.0, 1.0024], [1.0, 0.99484]]) + ) + np.testing.assert_array_equal( + dat["stimulus_presentations"].t, np.array([6821.208244, 6825.208244]) + )