From 5db013f909f06322941679ee0b80842ba244c5c6 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 7 Mar 2024 16:16:58 +0000 Subject: [PATCH 01/47] fix-62-63 --- src/expipe_plugin_cinpla/scripts/utils.py | 19 +++++++++---------- src/expipe_plugin_cinpla/widgets/register.py | 20 ++++++++++---------- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/src/expipe_plugin_cinpla/scripts/utils.py b/src/expipe_plugin_cinpla/scripts/utils.py index cc75495..c6572e7 100644 --- a/src/expipe_plugin_cinpla/scripts/utils.py +++ b/src/expipe_plugin_cinpla/scripts/utils.py @@ -176,19 +176,18 @@ def _make_data_path(action, overwrite, suffix=".nwb"): def _get_data_path(action): - if "main" not in action.data: - return try: + if "main" not in action.data: + return data_path = action.data_path("main") + if not data_path.is_dir(): + action_path = action._backend.path + project_path = action_path.parent.parent + # data_path = action.data['main'] + data_path = project_path / str(Path(PureWindowsPath(action.data["main"]))) + return data_path except: - data_path = Path("None") - pass - if not data_path.is_dir(): - action_path = action._backend.path - project_path = action_path.parent.parent - # data_path = action.data['main'] - data_path = project_path / str(Path(PureWindowsPath(action.data["main"]))) - return data_path + return def register_templates(action, templates, overwrite=False): diff --git a/src/expipe_plugin_cinpla/widgets/register.py b/src/expipe_plugin_cinpla/widgets/register.py index be0e8b1..4caa6e9 100644 --- a/src/expipe_plugin_cinpla/widgets/register.py +++ b/src/expipe_plugin_cinpla/widgets/register.py @@ -140,7 +140,7 @@ def register_adjustment_view(project): adjustment = MultiInput(["*Key", "*Probe", "*Adjustment", "*Unit"], "Add adjustment") depth = MultiInput(["Key", "Probe", "Depth", "Unit"], "Add depth") depth_from_surgery = ipywidgets.Checkbox(description="Get depth from surgery", value=True) - register = ipywidgets.Button(description="Register") + register_button = ipywidgets.Button(description="Register") fields = ipywidgets.VBox([user, date, adjustment, register]) main_box = ipywidgets.VBox([depth_from_surgery, ipywidgets.HBox([fields, entity_id])]) @@ -171,7 +171,7 @@ def on_register(change): yes=True, ) - register.on_click(on_register) + register_button.on_click(on_register) return main_box @@ -196,7 +196,7 @@ def register_annotate_view(project): message = ipywidgets.Text(placeholder="Message") tag = ipywidgets.Text(placeholder="Tags (; to separate)") templates = SearchSelectMultiple(project.templates, description="Templates") - register = ipywidgets.Button(description="Register") + register_button = ipywidgets.Button(description="Register") fields = ipywidgets.VBox([user, date, location, message, action_type, tag, depth, entity_id, register]) main_box = ipywidgets.VBox([ipywidgets.HBox([fields, action_id, templates])]) @@ -221,7 +221,7 @@ def on_register(change): correct_depth_answer=True, ) - register.on_click(on_register) + register_button.on_click(on_register) return main_box @@ -248,7 +248,7 @@ def register_entity_view(project): templates = SearchSelectMultiple(project.templates, description="Templates") overwrite = ipywidgets.Checkbox(description="Overwrite", value=False) - register = ipywidgets.Button(description="Register") + register_button = ipywidgets.Button(description="Register") fields = ipywidgets.VBox([entity_id, user, species, sex, location, birthday, message, tag, register]) main_box = ipywidgets.VBox([overwrite, ipywidgets.HBox([fields, templates])]) @@ -273,7 +273,7 @@ def on_register(change): templates=templates.value, ) - register.on_click(on_register) + register_button.on_click(on_register) return view @@ -307,7 +307,7 @@ def register_surgery_view(project): angle = MultiInput(["*Key", "*Probe", "*Angle", "*Unit"], "Add angle") templates = SearchSelectMultiple(project.templates, description="Templates") overwrite = ipywidgets.Checkbox(description="Overwrite", value=False) - register = ipywidgets.Button(description="Register") + register_button = ipywidgets.Button(description="Register") fields = ipywidgets.VBox([user, location, date, weight, position, angle, message, procedure, tag, register]) main_box = ipywidgets.VBox([overwrite, ipywidgets.HBox([fields, ipywidgets.VBox([entity_id, templates])])]) @@ -336,7 +336,7 @@ def on_register(change): tags=tags, ) - register.on_click(on_register) + register_button.on_click(on_register) return view @@ -367,7 +367,7 @@ def register_perfuse_view(project): templates = SearchSelectMultiple(project.templates, description="Templates") overwrite = ipywidgets.Checkbox(description="Overwrite", value=False) - register = ipywidgets.Button(description="Register") + register_button = ipywidgets.Button(description="Register") fields = ipywidgets.VBox([user, location, date, weight, message, register]) main_box = ipywidgets.VBox([overwrite, ipywidgets.HBox([fields, entity_id, templates])]) view = BaseViewWithLog(main_box=main_box, project=project) @@ -389,5 +389,5 @@ def on_register(change): message=none_if_empty(message.value), ) - register.on_click(on_register) + register_button.on_click(on_register) return view From 10b17df9930d36b924f94190f3259f4bc19c8d58 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 8 Mar 2024 12:10:38 +0100 Subject: [PATCH 02/47] Fix register_button --- src/expipe_plugin_cinpla/widgets/register.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/expipe_plugin_cinpla/widgets/register.py b/src/expipe_plugin_cinpla/widgets/register.py index 4caa6e9..c536b5a 100644 --- a/src/expipe_plugin_cinpla/widgets/register.py +++ b/src/expipe_plugin_cinpla/widgets/register.py @@ -142,7 +142,7 @@ def register_adjustment_view(project): depth_from_surgery = ipywidgets.Checkbox(description="Get depth from surgery", value=True) register_button = ipywidgets.Button(description="Register") - fields = ipywidgets.VBox([user, date, adjustment, register]) + fields = ipywidgets.VBox([user, date, adjustment, register_button]) main_box = ipywidgets.VBox([depth_from_surgery, ipywidgets.HBox([fields, entity_id])]) def on_manual_depth(change): @@ -198,7 +198,7 @@ def register_annotate_view(project): templates = SearchSelectMultiple(project.templates, description="Templates") register_button = ipywidgets.Button(description="Register") - fields = ipywidgets.VBox([user, date, location, message, action_type, tag, depth, entity_id, register]) + fields = ipywidgets.VBox([user, date, location, message, action_type, tag, depth, entity_id, register_button]) main_box = ipywidgets.VBox([ipywidgets.HBox([fields, action_id, templates])]) def on_register(change): @@ -249,7 +249,7 @@ def register_entity_view(project): overwrite = ipywidgets.Checkbox(description="Overwrite", value=False) register_button = ipywidgets.Button(description="Register") - fields = ipywidgets.VBox([entity_id, user, species, sex, location, birthday, message, tag, register]) + fields = ipywidgets.VBox([entity_id, user, species, sex, location, birthday, message, tag, register_button]) main_box = ipywidgets.VBox([overwrite, ipywidgets.HBox([fields, templates])]) view = BaseViewWithLog(main_box=main_box, project=project) @@ -309,7 +309,7 @@ def register_surgery_view(project): overwrite = ipywidgets.Checkbox(description="Overwrite", value=False) register_button = ipywidgets.Button(description="Register") - fields = ipywidgets.VBox([user, location, date, weight, position, angle, message, procedure, tag, register]) + fields = ipywidgets.VBox([user, location, date, weight, position, angle, message, procedure, tag, register_button]) main_box = ipywidgets.VBox([overwrite, ipywidgets.HBox([fields, ipywidgets.VBox([entity_id, templates])])]) view = BaseViewWithLog(main_box=main_box, project=project) @@ -368,7 +368,7 @@ def register_perfuse_view(project): overwrite = ipywidgets.Checkbox(description="Overwrite", value=False) register_button = ipywidgets.Button(description="Register") - fields = ipywidgets.VBox([user, location, date, weight, message, register]) + fields = ipywidgets.VBox([user, location, date, weight, message, register_button]) main_box = ipywidgets.VBox([overwrite, ipywidgets.HBox([fields, entity_id, templates])]) view = BaseViewWithLog(main_box=main_box, project=project) From ce0275412222c14eb739a3a1c020a889617556a2 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 8 Mar 2024 16:35:05 +0100 Subject: [PATCH 03/47] data_processing.py -> data_loader.py --- src/expipe_plugin_cinpla/data_loader.py | 337 ++++++++++++++++++++++++ 1 file changed, 337 insertions(+) create mode 100644 src/expipe_plugin_cinpla/data_loader.py diff --git a/src/expipe_plugin_cinpla/data_loader.py b/src/expipe_plugin_cinpla/data_loader.py new file mode 100644 index 0000000..72a11f5 --- /dev/null +++ b/src/expipe_plugin_cinpla/data_loader.py @@ -0,0 +1,337 @@ +"""Utils for loading data from NWB files""" +import numpy as np + +import quantities as pq +import neo +import spikeinterface as si +import spikeinterface.extractors as se + +from pynwb import NWBHDF5IO +from .utils import _get_data_path + + +def get_data_path(action): + """Returns the path to the main.nwb file""" + return str(_get_data_path(action)) + + +def get_sample_rate(data_path): + """ + Return the sampling rate of the recording + + Parameters + ---------- + data_path: Path + The action data path + + Returns + ------- + sr: pq.Quantity + The sampling rate of the recording + """ + recording = se.read_nwb_recording(str(data_path), electrical_series_path="acquisition/ElectricalSeries") + sr = recording.get_sampling_frequency() * pq.Hz + return sr + + +def get_duration(data_path): + """ + Return the duration of the recording in s + + Parameters + ---------- + data_path: Path + The action data path + + Returns + ------- + duration: pq.Quantity + The duration of the recording + """ + recording = se.read_nwb_recording(str(data_path), electrical_series_path="acquisition/ElectricalSeries") + duration = recording.get_total_duration() * pq.s + return duration + + +def view_active_channels(action, sorter): + """ + Returns the active channels for a given action and sorter + + Parameters + ---------- + action: Action + The action + sorter: str + The sorter name + + Returns + ------- + active_channels: list + The active channels + """ + path = _get_data_path(action) + sorter_path = path.parent / "spikeinterface" / sorter + if not sorter_path.is_dir(): + raise ValueError(f"Action {action.id} has not been sorted with {sorter}") + waveforms_folder = sorter_path / "waveforms" + we = si.load_waveforms(waveforms_folder, with_recording=False) + return we.channel_ids + + +def load_leds(data_path): + """ + Returns the positions of the LEDs (red + green) + + Parameters + ---------- + data_path: Path + The action data path + + Returns + ------- + x1, y1, t1, x2, y2, t2, stop_time: tuple + The x and y positions of the red and green LEDs, the timestamps and the stop time + """ + io = NWBHDF5IO(str(data_path), "r") + nwbfile = io.read() + + behavior = nwbfile.processing["behavior"] + + # tracking data + open_field_position = behavior["Open Field Position"] + red_spatial_series = open_field_position["LedRed"] + green_spatial_series = open_field_position["LedGreen"] + red_data = red_spatial_series.data + green_data = green_spatial_series.data + x1, y1 = red_data[:, 0], red_data[:, 1] + x2, y2 = green_data[:, 0], green_data[:, 1] + t1 = red_spatial_series.timestamps + t2 = green_spatial_series.timestamps + stop_time = np.max([t1[-1], t2[-1]]) + + return x1, y1, t1, x2, y2, t2, stop_time + + +def load_lfp(data_path, channel_group=None, lim=None): + """ + Returns the LFP signal + + Parameters + ---------- + data_path: Path + The action data path + channel_group: str, optional + The channel group to load. If None, all channel groups are loaded + lim: list, optional + The time limits to load the LFP signal. If None, the entire signal is loaded + + Returns + ------- + LFP: neo.AnalogSignal + The LFP signal + """ + recording_lfp = se.read_nwb_recording( + str(data_path), electrical_series_path="processing/ecephys/LFP/ElectricalSeriesLFP" + ) + # LFP + units = pq.uV + sampling_rate = recording_lfp.sampling_frequency * pq.Hz + + if channel_group is not None: + available_channel_groups = np.unique(recording_lfp.get_channel_groups()) + assert ( + channel_group in recording_lfp.get_channel_groups() + ), f"Channel group {channel_group} not found in available channel groups: {available_channel_groups}" + # this returns a sub-extractor with the requested channel group + recording_lfp_group = recording_lfp.split_by("group")[channel_group] + (electrode_idx,) = np.nonzero(np.isin(recording_lfp.channel_ids, recording_lfp_group.channel_ids)) + else: + recording_lfp_group = recording_lfp + electrode_idx = np.arange(recording_lfp.get_num_channels()) + + if lim is None: + lfp_traces = recording_lfp_group.get_traces(return_scaled=True) + t_start = recording_lfp.get_times()[0] * pq.s + t_stop = recording_lfp.get_times()[-1] * pq.s + else: + assert len(lim) == 2, "lim must be a list of two elements with t_start and t_stop" + times_all = recording_lfp_group.get_times() + start_frame, end_frame = np.searchsorted(times_all, lim) + times = times_all[start_frame:end_frame] + t_start = times[0] * pq.s + t_stop = times[-1] * pq.s + lfp_traces = recording_lfp_group.get_traces(start_frame=start_frame, end_frame=end_frame, return_scaled=True) + + LFP = neo.AnalogSignal( + lfp_traces, + units=units, + t_start=t_start, + t_stop=t_stop, + sampling_rate=sampling_rate, + **{"electrode_idx": electrode_idx}, + ) + LFP = LFP.rescale("mV") + return LFP + + +def load_epochs(data_path, label_column=None): + """ + Returns the trials as NEO epochs + + Parameters + ---------- + data_path: Path + The action data path + label_column: str, optional + The column name to use as labels + + Returns + ------- + epochs: neo.Epoch + The trials as NEO epochs + """ + with NWBHDF5IO(str(data_path), "r") as io: + nwbfile = io.read() + trials = nwbfile.trials.to_dataframe() + + start_times = trials["start_time"].values * pq.s + stop_times = trials["stop_time"].values * pq.s + durations = stop_times - start_times + + if label_column is not None and label_column in trials.columns: + labels = trials[label_column].values + else: + labels = None + + epochs = neo.Epoch( + times=start_times, + durations=durations, + labels=labels, + ) + return epochs + + +def get_channel_groups(data_path): + """ + Returns channel groups of session + + Parameters + ---------- + data_path: Path + The action data path + + Returns + ------- + channel groups: list + The channel groups + """ + recording = se.read_nwb_recording(str(data_path), electrical_series_path="acquisition/ElectricalSeries") + channel_groups = list(np.unique(recording.get_channel_groups())) + return channel_groups + + +def load_spiketrains(data_path, channel_group=None, lim=None): + """ + Returns the spike trains as a list of NEO spike trains + + Parameters + ---------- + data_path: str / Path + The action data path + channel_group: str, optional + The channel group to load. If None, all channel groups are loaded + lim: list, optional + The time limits to load the spike trains. If None, the entire spike train is loaded + + Returns + ------- + spiketrains: list of NEO spike trains + The spike trains + """ + recording = se.read_nwb_recording(str(data_path), electrical_series_path="acquisition/ElectricalSeries") + sorting = se.read_nwb_sorting(str(data_path), electrical_series_path="acquisition/ElectricalSeries") + + if channel_group is None: + unit_ids = sorting.unit_ids + else: + assert "group" in sorting.get_property_keys(), "group property not found in sorting" + groups = sorting.get_property("group") + unit_ids = [ + unit_id for unit_index, unit_id in enumerate(sorting.unit_ids) if groups[unit_index] == channel_group + ] + sptr = [] + # build neo pbjects + for unit in unit_ids: + times = sorting.get_unit_spike_train(unit, return_times=True) * pq.s + if lim is None: + times = recording.get_times() * pq.s + t_start = times[0] + t_stop = times[-1] + else: + t_start = pq.Quantity(lim[0], "s") + t_stop = pq.Quantity(lim[1], "s") + mask = (times >= t_start) & (times <= t_stop) + times = times[mask] + + st = neo.SpikeTrain( + times=times, t_start=t_start, t_stop=t_stop, sampling_rate=sorting.sampling_frequency * pq.Hz + ) + for p in sorting.get_property_keys(): + st.annotations.update({p: sorting.get_unit_property(unit, p)}) + sptr.append(st) + + return sptr + + +def load_unit_annotations(data_path, channel_group=None): + """ + Returns the annotations of the units + + Parameters + ---------- + data_path: str/Path + The action data path + channel_group: str, optional + The channel group to load. If None, all channel groups are loaded + + Returns + ------- + annotations: list of dicts + The annotations of the units + """ + sorting = se.read_nwb_sorting(str(data_path), electrical_series_path="acquisition/ElectricalSeries") + + units = [] + + if channel_group is None: + unit_ids = sorting.unit_ids + else: + assert "group" in sorting.get_property_keys(), "group property not found in sorting" + groups = sorting.get_property("group") + unit_ids = [ + unit_id for unit_index, unit_id in enumerate(sorting.unit_ids) if groups[unit_index] == channel_group + ] + + for unit in unit_ids: + annotations = {} + for p in sorting.get_property_keys(): + annotations.update({p: sorting.get_unit_property(unit, p)}) + units.append(annotations) + return units + + +# These functions are not relevant anymore +# def get_unit_id(unit): +# try: +# uid = int(unit.annotations['name'].split('#')[-1]) +# except AttributeError: +# uid = int(unit['name'].split('#')[-1]) +# return uid + +# def sort_by_cluster_id(spike_trains): +# if len(spike_trains) == 0: +# return spike_trains +# if "name" not in spike_trains[0].annotations: +# print("Unable to get cluster_id, save with phy to create") +# sorted_sptrs = sorted(spike_trains, key=lambda x: int(x.annotations["name"].lower().replace("unit #", ""))) +# return sorted_sptrs From 4cee52227901671e66834088a5ddcdff24b8d26a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 8 Mar 2024 16:40:21 +0100 Subject: [PATCH 04/47] fix data loader import --- src/expipe_plugin_cinpla/data_loader.py | 2 +- .../scripts/data_processing.py | 337 ------------------ 2 files changed, 1 insertion(+), 338 deletions(-) delete mode 100644 src/expipe_plugin_cinpla/scripts/data_processing.py diff --git a/src/expipe_plugin_cinpla/data_loader.py b/src/expipe_plugin_cinpla/data_loader.py index 72a11f5..309e4ca 100644 --- a/src/expipe_plugin_cinpla/data_loader.py +++ b/src/expipe_plugin_cinpla/data_loader.py @@ -7,7 +7,7 @@ import spikeinterface.extractors as se from pynwb import NWBHDF5IO -from .utils import _get_data_path +from .scripts.utils import _get_data_path def get_data_path(action): diff --git a/src/expipe_plugin_cinpla/scripts/data_processing.py b/src/expipe_plugin_cinpla/scripts/data_processing.py deleted file mode 100644 index 72a11f5..0000000 --- a/src/expipe_plugin_cinpla/scripts/data_processing.py +++ /dev/null @@ -1,337 +0,0 @@ -"""Utils for loading data from NWB files""" -import numpy as np - -import quantities as pq -import neo -import spikeinterface as si -import spikeinterface.extractors as se - -from pynwb import NWBHDF5IO -from .utils import _get_data_path - - -def get_data_path(action): - """Returns the path to the main.nwb file""" - return str(_get_data_path(action)) - - -def get_sample_rate(data_path): - """ - Return the sampling rate of the recording - - Parameters - ---------- - data_path: Path - The action data path - - Returns - ------- - sr: pq.Quantity - The sampling rate of the recording - """ - recording = se.read_nwb_recording(str(data_path), electrical_series_path="acquisition/ElectricalSeries") - sr = recording.get_sampling_frequency() * pq.Hz - return sr - - -def get_duration(data_path): - """ - Return the duration of the recording in s - - Parameters - ---------- - data_path: Path - The action data path - - Returns - ------- - duration: pq.Quantity - The duration of the recording - """ - recording = se.read_nwb_recording(str(data_path), electrical_series_path="acquisition/ElectricalSeries") - duration = recording.get_total_duration() * pq.s - return duration - - -def view_active_channels(action, sorter): - """ - Returns the active channels for a given action and sorter - - Parameters - ---------- - action: Action - The action - sorter: str - The sorter name - - Returns - ------- - active_channels: list - The active channels - """ - path = _get_data_path(action) - sorter_path = path.parent / "spikeinterface" / sorter - if not sorter_path.is_dir(): - raise ValueError(f"Action {action.id} has not been sorted with {sorter}") - waveforms_folder = sorter_path / "waveforms" - we = si.load_waveforms(waveforms_folder, with_recording=False) - return we.channel_ids - - -def load_leds(data_path): - """ - Returns the positions of the LEDs (red + green) - - Parameters - ---------- - data_path: Path - The action data path - - Returns - ------- - x1, y1, t1, x2, y2, t2, stop_time: tuple - The x and y positions of the red and green LEDs, the timestamps and the stop time - """ - io = NWBHDF5IO(str(data_path), "r") - nwbfile = io.read() - - behavior = nwbfile.processing["behavior"] - - # tracking data - open_field_position = behavior["Open Field Position"] - red_spatial_series = open_field_position["LedRed"] - green_spatial_series = open_field_position["LedGreen"] - red_data = red_spatial_series.data - green_data = green_spatial_series.data - x1, y1 = red_data[:, 0], red_data[:, 1] - x2, y2 = green_data[:, 0], green_data[:, 1] - t1 = red_spatial_series.timestamps - t2 = green_spatial_series.timestamps - stop_time = np.max([t1[-1], t2[-1]]) - - return x1, y1, t1, x2, y2, t2, stop_time - - -def load_lfp(data_path, channel_group=None, lim=None): - """ - Returns the LFP signal - - Parameters - ---------- - data_path: Path - The action data path - channel_group: str, optional - The channel group to load. If None, all channel groups are loaded - lim: list, optional - The time limits to load the LFP signal. If None, the entire signal is loaded - - Returns - ------- - LFP: neo.AnalogSignal - The LFP signal - """ - recording_lfp = se.read_nwb_recording( - str(data_path), electrical_series_path="processing/ecephys/LFP/ElectricalSeriesLFP" - ) - # LFP - units = pq.uV - sampling_rate = recording_lfp.sampling_frequency * pq.Hz - - if channel_group is not None: - available_channel_groups = np.unique(recording_lfp.get_channel_groups()) - assert ( - channel_group in recording_lfp.get_channel_groups() - ), f"Channel group {channel_group} not found in available channel groups: {available_channel_groups}" - # this returns a sub-extractor with the requested channel group - recording_lfp_group = recording_lfp.split_by("group")[channel_group] - (electrode_idx,) = np.nonzero(np.isin(recording_lfp.channel_ids, recording_lfp_group.channel_ids)) - else: - recording_lfp_group = recording_lfp - electrode_idx = np.arange(recording_lfp.get_num_channels()) - - if lim is None: - lfp_traces = recording_lfp_group.get_traces(return_scaled=True) - t_start = recording_lfp.get_times()[0] * pq.s - t_stop = recording_lfp.get_times()[-1] * pq.s - else: - assert len(lim) == 2, "lim must be a list of two elements with t_start and t_stop" - times_all = recording_lfp_group.get_times() - start_frame, end_frame = np.searchsorted(times_all, lim) - times = times_all[start_frame:end_frame] - t_start = times[0] * pq.s - t_stop = times[-1] * pq.s - lfp_traces = recording_lfp_group.get_traces(start_frame=start_frame, end_frame=end_frame, return_scaled=True) - - LFP = neo.AnalogSignal( - lfp_traces, - units=units, - t_start=t_start, - t_stop=t_stop, - sampling_rate=sampling_rate, - **{"electrode_idx": electrode_idx}, - ) - LFP = LFP.rescale("mV") - return LFP - - -def load_epochs(data_path, label_column=None): - """ - Returns the trials as NEO epochs - - Parameters - ---------- - data_path: Path - The action data path - label_column: str, optional - The column name to use as labels - - Returns - ------- - epochs: neo.Epoch - The trials as NEO epochs - """ - with NWBHDF5IO(str(data_path), "r") as io: - nwbfile = io.read() - trials = nwbfile.trials.to_dataframe() - - start_times = trials["start_time"].values * pq.s - stop_times = trials["stop_time"].values * pq.s - durations = stop_times - start_times - - if label_column is not None and label_column in trials.columns: - labels = trials[label_column].values - else: - labels = None - - epochs = neo.Epoch( - times=start_times, - durations=durations, - labels=labels, - ) - return epochs - - -def get_channel_groups(data_path): - """ - Returns channel groups of session - - Parameters - ---------- - data_path: Path - The action data path - - Returns - ------- - channel groups: list - The channel groups - """ - recording = se.read_nwb_recording(str(data_path), electrical_series_path="acquisition/ElectricalSeries") - channel_groups = list(np.unique(recording.get_channel_groups())) - return channel_groups - - -def load_spiketrains(data_path, channel_group=None, lim=None): - """ - Returns the spike trains as a list of NEO spike trains - - Parameters - ---------- - data_path: str / Path - The action data path - channel_group: str, optional - The channel group to load. If None, all channel groups are loaded - lim: list, optional - The time limits to load the spike trains. If None, the entire spike train is loaded - - Returns - ------- - spiketrains: list of NEO spike trains - The spike trains - """ - recording = se.read_nwb_recording(str(data_path), electrical_series_path="acquisition/ElectricalSeries") - sorting = se.read_nwb_sorting(str(data_path), electrical_series_path="acquisition/ElectricalSeries") - - if channel_group is None: - unit_ids = sorting.unit_ids - else: - assert "group" in sorting.get_property_keys(), "group property not found in sorting" - groups = sorting.get_property("group") - unit_ids = [ - unit_id for unit_index, unit_id in enumerate(sorting.unit_ids) if groups[unit_index] == channel_group - ] - sptr = [] - # build neo pbjects - for unit in unit_ids: - times = sorting.get_unit_spike_train(unit, return_times=True) * pq.s - if lim is None: - times = recording.get_times() * pq.s - t_start = times[0] - t_stop = times[-1] - else: - t_start = pq.Quantity(lim[0], "s") - t_stop = pq.Quantity(lim[1], "s") - mask = (times >= t_start) & (times <= t_stop) - times = times[mask] - - st = neo.SpikeTrain( - times=times, t_start=t_start, t_stop=t_stop, sampling_rate=sorting.sampling_frequency * pq.Hz - ) - for p in sorting.get_property_keys(): - st.annotations.update({p: sorting.get_unit_property(unit, p)}) - sptr.append(st) - - return sptr - - -def load_unit_annotations(data_path, channel_group=None): - """ - Returns the annotations of the units - - Parameters - ---------- - data_path: str/Path - The action data path - channel_group: str, optional - The channel group to load. If None, all channel groups are loaded - - Returns - ------- - annotations: list of dicts - The annotations of the units - """ - sorting = se.read_nwb_sorting(str(data_path), electrical_series_path="acquisition/ElectricalSeries") - - units = [] - - if channel_group is None: - unit_ids = sorting.unit_ids - else: - assert "group" in sorting.get_property_keys(), "group property not found in sorting" - groups = sorting.get_property("group") - unit_ids = [ - unit_id for unit_index, unit_id in enumerate(sorting.unit_ids) if groups[unit_index] == channel_group - ] - - for unit in unit_ids: - annotations = {} - for p in sorting.get_property_keys(): - annotations.update({p: sorting.get_unit_property(unit, p)}) - units.append(annotations) - return units - - -# These functions are not relevant anymore -# def get_unit_id(unit): -# try: -# uid = int(unit.annotations['name'].split('#')[-1]) -# except AttributeError: -# uid = int(unit['name'].split('#')[-1]) -# return uid - -# def sort_by_cluster_id(spike_trains): -# if len(spike_trains) == 0: -# return spike_trains -# if "name" not in spike_trains[0].annotations: -# print("Unable to get cluster_id, save with phy to create") -# sorted_sptrs = sorted(spike_trains, key=lambda x: int(x.annotations["name"].lower().replace("unit #", ""))) -# return sorted_sptrs From 467463441cf0f3e91c4e595a8a930017b2cd4e47 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 8 Mar 2024 17:02:09 +0100 Subject: [PATCH 05/47] Add unit name annotation in load_spiketrains --- src/expipe_plugin_cinpla/data_loader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/expipe_plugin_cinpla/data_loader.py b/src/expipe_plugin_cinpla/data_loader.py index 309e4ca..404db1e 100644 --- a/src/expipe_plugin_cinpla/data_loader.py +++ b/src/expipe_plugin_cinpla/data_loader.py @@ -276,6 +276,7 @@ def load_spiketrains(data_path, channel_group=None, lim=None): st = neo.SpikeTrain( times=times, t_start=t_start, t_stop=t_stop, sampling_rate=sorting.sampling_frequency * pq.Hz ) + st.annotations.update({"name": unit}) for p in sorting.get_property_keys(): st.annotations.update({p: sorting.get_unit_property(unit, p)}) sptr.append(st) @@ -313,7 +314,7 @@ def load_unit_annotations(data_path, channel_group=None): ] for unit in unit_ids: - annotations = {} + annotations = {"name": unit} for p in sorting.get_property_keys(): annotations.update({p: sorting.get_unit_property(unit, p)}) units.append(annotations) From 17dc78ddc202262901f85f8ba784f874db87e891 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sat, 9 Mar 2024 12:12:09 +0100 Subject: [PATCH 06/47] Add trial columns only if needed --- .../interfaces/openephystrackinginterface.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/expipe_plugin_cinpla/nwbutils/interfaces/openephystrackinginterface.py b/src/expipe_plugin_cinpla/nwbutils/interfaces/openephystrackinginterface.py index 84b20f6..2a6a0db 100644 --- a/src/expipe_plugin_cinpla/nwbutils/interfaces/openephystrackinginterface.py +++ b/src/expipe_plugin_cinpla/nwbutils/interfaces/openephystrackinginterface.py @@ -78,14 +78,16 @@ def add_to_nwbfile( rising = rising[:-1] if len(rising) == len(falling): - nwbfile.add_trial_column( - name="channel", - description="Open Ephys channel", - ) - nwbfile.add_trial_column( - name="processor", - description="Open Ephys processor that recorded the event", - ) + if "channel" not in nwbfile.trials.colnames: + nwbfile.add_trial_column( + name="channel", + description="Open Ephys channel", + ) + if "processor" not in nwbfile.trials.colnames: + nwbfile.add_trial_column( + name="processor", + description="Open Ephys processor that recorded the event", + ) start_times = times[rising].rescale("s").magnitude stop_times = times[falling].rescale("s").magnitude for start, stop in zip(start_times, stop_times): From 8a32e86798731ccf7e3e1b37e5286db5d44c16ce Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sat, 9 Mar 2024 12:16:06 +0100 Subject: [PATCH 07/47] Add trial columns only if needed2 --- .../nwbutils/interfaces/openephystrackinginterface.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/expipe_plugin_cinpla/nwbutils/interfaces/openephystrackinginterface.py b/src/expipe_plugin_cinpla/nwbutils/interfaces/openephystrackinginterface.py index 2a6a0db..9c406d9 100644 --- a/src/expipe_plugin_cinpla/nwbutils/interfaces/openephystrackinginterface.py +++ b/src/expipe_plugin_cinpla/nwbutils/interfaces/openephystrackinginterface.py @@ -78,12 +78,11 @@ def add_to_nwbfile( rising = rising[:-1] if len(rising) == len(falling): - if "channel" not in nwbfile.trials.colnames: + if nwbfile.trials is None: nwbfile.add_trial_column( name="channel", description="Open Ephys channel", ) - if "processor" not in nwbfile.trials.colnames: nwbfile.add_trial_column( name="processor", description="Open Ephys processor that recorded the event", From 278e086f03536424a3c57517a41fced47a9e8281 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sat, 9 Mar 2024 12:39:40 +0100 Subject: [PATCH 08/47] Re-copy actions after failures --- .../scripts/convert_old_project.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/src/expipe_plugin_cinpla/scripts/convert_old_project.py b/src/expipe_plugin_cinpla/scripts/convert_old_project.py index edb2461..dc9d352 100644 --- a/src/expipe_plugin_cinpla/scripts/convert_old_project.py +++ b/src/expipe_plugin_cinpla/scripts/convert_old_project.py @@ -130,18 +130,23 @@ def convert_old_project( old_action = old_actions[action_id] new_action = new_project.actions[action_id] + old_data_folder = old_project.path / "actions" / action_id / "data" + new_data_folder = new_project.path / "actions" / action_id / "data" + # main.exdir + old_exdir_folder = old_data_folder / "main.exdir" + + if exist_ok and not (new_project.path / "actions" / action_id).is_dir(): + # Copy action that previously failed + print(f">>> Re-copying action {action_id} to new project\n") + shutil.copytree(old_data_folder.parent, new_data_folder.parent, + ignore=shutil.ignore_patterns("main.exdir", ".git")) + # replace file in attributes.yaml attributes_file = new_project.path / "actions" / action_id / "attributes.yaml" attributes_str = attributes_file.read_text() attributes_str = attributes_str.replace("main.exdir", "main.nwb") attributes_file.write_text(attributes_str) - old_data_folder = old_project.path / "actions" / action_id / "data" - new_data_folder = new_project.path / "actions" / action_id / "data" - - # main.exdir - old_exdir_folder = old_data_folder / "main.exdir" - # find open-ephys folder acquisition_folder = old_exdir_folder / "acquisition" openephys_folders = [p for p in acquisition_folder.iterdir() if p.is_dir()] From 370123ecddabc14186df192469ac19b63e5f96da Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sat, 9 Mar 2024 12:43:57 +0100 Subject: [PATCH 09/47] Re-copy actions after failures 2 --- .../scripts/convert_old_project.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/expipe_plugin_cinpla/scripts/convert_old_project.py b/src/expipe_plugin_cinpla/scripts/convert_old_project.py index dc9d352..3cd9200 100644 --- a/src/expipe_plugin_cinpla/scripts/convert_old_project.py +++ b/src/expipe_plugin_cinpla/scripts/convert_old_project.py @@ -128,18 +128,20 @@ def convert_old_project( delimiter = "*" * len(process_msg) print(f"\n{delimiter}\n{process_msg}\n{delimiter}\n") old_action = old_actions[action_id] - new_action = new_project.actions[action_id] - old_data_folder = old_project.path / "actions" / action_id / "data" - new_data_folder = new_project.path / "actions" / action_id / "data" + old_action_folder = old_project.path / "actions" / action_id + new_action_folder = new_project.path / "actions" / action_id + old_data_folder = old_action_folder / "data" + new_data_folder = new_action_folder / "data" # main.exdir old_exdir_folder = old_data_folder / "main.exdir" - if exist_ok and not (new_project.path / "actions" / action_id).is_dir(): + if exist_ok and not new_action_folder.is_dir(): # Copy action that previously failed print(f">>> Re-copying action {action_id} to new project\n") - shutil.copytree(old_data_folder.parent, new_data_folder.parent, + shutil.copytree(old_action_folder, new_action_folder, ignore=shutil.ignore_patterns("main.exdir", ".git")) + new_action = new_project.actions[action_id] # replace file in attributes.yaml attributes_file = new_project.path / "actions" / action_id / "attributes.yaml" From fa7b780ea1ec2c7ca61ec0ad6c74a3a9ca38181a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 13 Mar 2024 18:28:22 +0100 Subject: [PATCH 10/47] Remove excess spikes --- src/expipe_plugin_cinpla/scripts/curation.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/expipe_plugin_cinpla/scripts/curation.py b/src/expipe_plugin_cinpla/scripts/curation.py index f1cf693..1f3c546 100644 --- a/src/expipe_plugin_cinpla/scripts/curation.py +++ b/src/expipe_plugin_cinpla/scripts/curation.py @@ -142,6 +142,9 @@ def apply_curation(self, sorter, curated_sorting): if "group" not in curated_sorting.get_property_keys(): compute_and_set_unit_groups(curated_sorting, recording) + # remove excess spikes + curated_sorting = sc.remove_excess_spikes(curated_sorting, recording=recording) + print("Extracting waveforms on curated sorting") self.curated_we = si.extract_waveforms( recording, From aa3b82869d9d9a9b105b766831dba3c915116531 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 14 Mar 2024 09:54:05 +0100 Subject: [PATCH 11/47] Add print statement --- src/expipe_plugin_cinpla/scripts/curation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/expipe_plugin_cinpla/scripts/curation.py b/src/expipe_plugin_cinpla/scripts/curation.py index 1f3c546..8d3650d 100644 --- a/src/expipe_plugin_cinpla/scripts/curation.py +++ b/src/expipe_plugin_cinpla/scripts/curation.py @@ -143,6 +143,7 @@ def apply_curation(self, sorter, curated_sorting): compute_and_set_unit_groups(curated_sorting, recording) # remove excess spikes + print("Removing excess spikes from curated sorting") curated_sorting = sc.remove_excess_spikes(curated_sorting, recording=recording) print("Extracting waveforms on curated sorting") From aeda5261df0ace6236cecc3ffd482b13839370a1 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 14 Mar 2024 12:54:48 +0100 Subject: [PATCH 12/47] fix registration with depth --- src/expipe_plugin_cinpla/scripts/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/expipe_plugin_cinpla/scripts/utils.py b/src/expipe_plugin_cinpla/scripts/utils.py index c6572e7..4c52258 100644 --- a/src/expipe_plugin_cinpla/scripts/utils.py +++ b/src/expipe_plugin_cinpla/scripts/utils.py @@ -1,6 +1,6 @@ import sys import shutil -from datetime import datetime +from datetime import datetime, timedelta from pathlib import Path, PureWindowsPath import numpy as np @@ -47,7 +47,7 @@ def query_yes_no(question, default="yes", answer=None): def deltadate(adjustdate, regdate): - delta = regdate - adjustdate if regdate > adjustdate else datetime.timedelta.max + delta = regdate - adjustdate if regdate > adjustdate else timedelta.max return delta From 6219b9cd253bfde635a2f76185cb342e849e207f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 14 Mar 2024 12:57:15 +0100 Subject: [PATCH 13/47] Add log for adjustment and annotate tabs --- src/expipe_plugin_cinpla/widgets/register.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/expipe_plugin_cinpla/widgets/register.py b/src/expipe_plugin_cinpla/widgets/register.py index c536b5a..defc915 100644 --- a/src/expipe_plugin_cinpla/widgets/register.py +++ b/src/expipe_plugin_cinpla/widgets/register.py @@ -158,6 +158,9 @@ def on_manual_depth(change): depth_from_surgery.observe(on_manual_depth, names="value") + view = BaseViewWithLog(main_box=main_box, project=project) + + @view.output.capture() def on_register(change): if not required_values_filled(entity_id, user, adjustment): return @@ -172,7 +175,8 @@ def on_register(change): ) register_button.on_click(on_register) - return main_box + + return view ### Annotation ### @@ -201,6 +205,9 @@ def register_annotate_view(project): fields = ipywidgets.VBox([user, date, location, message, action_type, tag, depth, entity_id, register_button]) main_box = ipywidgets.VBox([ipywidgets.HBox([fields, action_id, templates])]) + view = BaseViewWithLog(main_box=main_box, project=project) + + @view.output.capture() def on_register(change): if not required_values_filled(action_id, user): return @@ -222,7 +229,7 @@ def on_register(change): ) register_button.on_click(on_register) - return main_box + return view ### Entity ### From 78898b15d910495b207db0d79c16541c4a290854 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 15 Mar 2024 10:15:45 +0100 Subject: [PATCH 14/47] Remove excess spikes at the right place! --- src/expipe_plugin_cinpla/scripts/curation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/expipe_plugin_cinpla/scripts/curation.py b/src/expipe_plugin_cinpla/scripts/curation.py index 8d3650d..a6bf726 100644 --- a/src/expipe_plugin_cinpla/scripts/curation.py +++ b/src/expipe_plugin_cinpla/scripts/curation.py @@ -138,14 +138,14 @@ def apply_curation(self, sorter, curated_sorting): else: recording = self.load_processed_recording(sorter) - # if not sort by group, extract dense and estimate group - if "group" not in curated_sorting.get_property_keys(): - compute_and_set_unit_groups(curated_sorting, recording) - # remove excess spikes print("Removing excess spikes from curated sorting") curated_sorting = sc.remove_excess_spikes(curated_sorting, recording=recording) + # if not sort by group, extract dense and estimate group + if "group" not in curated_sorting.get_property_keys(): + compute_and_set_unit_groups(curated_sorting, recording) + print("Extracting waveforms on curated sorting") self.curated_we = si.extract_waveforms( recording, From 95d7ae78ff82f6a46768749161a7e05c70a45170 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 15 Mar 2024 10:17:14 +0100 Subject: [PATCH 15/47] Raise error for multiple experiments/openephys folders --- src/expipe_plugin_cinpla/scripts/convert_old_project.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/expipe_plugin_cinpla/scripts/convert_old_project.py b/src/expipe_plugin_cinpla/scripts/convert_old_project.py index 3cd9200..a07fed3 100644 --- a/src/expipe_plugin_cinpla/scripts/convert_old_project.py +++ b/src/expipe_plugin_cinpla/scripts/convert_old_project.py @@ -154,7 +154,7 @@ def convert_old_project( openephys_folders = [p for p in acquisition_folder.iterdir() if p.is_dir()] if len(openephys_folders) != 1: print(f"Found {len(openephys_folders)} openephys folders in {acquisition_folder}!") - continue + raise ValueError("Expected to find exactly one openephys folder") openephys_path = openephys_folders[0] # here we assume the following action name: {entity_id}-{date}-{session} entity_id = action_id.split("-")[0] From 3a0eab4ff7e99a63dc31a7bbe6965f896f2e84b4 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 15 Mar 2024 11:07:34 +0100 Subject: [PATCH 16/47] Sort actions in widget --- src/expipe_plugin_cinpla/widgets/curation.py | 2 +- src/expipe_plugin_cinpla/widgets/process.py | 2 +- src/expipe_plugin_cinpla/widgets/viewer.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/expipe_plugin_cinpla/widgets/curation.py b/src/expipe_plugin_cinpla/widgets/curation.py index 57a05a4..2ef2eb2 100644 --- a/src/expipe_plugin_cinpla/widgets/curation.py +++ b/src/expipe_plugin_cinpla/widgets/curation.py @@ -70,7 +70,7 @@ def __init__(self, project): si_path = data_path.parent / "spikeinterface" if si_path.is_dir(): actions_processed.append(action_name) - + actions_processed = sorted(actions_processed) actions_list = ipywidgets.Select( options=actions_processed, rows=10, description="Actions: ", disabled=False, layout={"width": "300px"} ) diff --git a/src/expipe_plugin_cinpla/widgets/process.py b/src/expipe_plugin_cinpla/widgets/process.py index 4d60634..65473b8 100644 --- a/src/expipe_plugin_cinpla/widgets/process.py +++ b/src/expipe_plugin_cinpla/widgets/process.py @@ -41,7 +41,7 @@ def process_ecephys_view(project): action_names.append(f"{action_name} -- (P)") else: action_names.append(f"{action_name} -- (U)") - + action_names = sorted(action_names) action_ids = SearchSelectMultiple(action_names, description="*Actions") overwrite = ipywidgets.Checkbox(description="Overwrite", value=True) diff --git a/src/expipe_plugin_cinpla/widgets/viewer.py b/src/expipe_plugin_cinpla/widgets/viewer.py index e9937ea..408ca5b 100644 --- a/src/expipe_plugin_cinpla/widgets/viewer.py +++ b/src/expipe_plugin_cinpla/widgets/viewer.py @@ -32,6 +32,7 @@ def get_options(self): data_path = _get_data_path(action) if data_path is not None and data_path.name == "main.nwb": options.append(action_name) + options = sorted(options) return options def on_change(self, change): From ebd4b0217ae51a79111482f5bc86822f87a3117b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 15 Mar 2024 11:14:57 +0100 Subject: [PATCH 17/47] Fix spike train loader --- src/expipe_plugin_cinpla/data_loader.py | 12 ++++++------ .../scripts/convert_old_project.py | 5 +++-- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/expipe_plugin_cinpla/data_loader.py b/src/expipe_plugin_cinpla/data_loader.py index 404db1e..7dc5063 100644 --- a/src/expipe_plugin_cinpla/data_loader.py +++ b/src/expipe_plugin_cinpla/data_loader.py @@ -260,9 +260,9 @@ def load_spiketrains(data_path, channel_group=None, lim=None): unit_id for unit_index, unit_id in enumerate(sorting.unit_ids) if groups[unit_index] == channel_group ] sptr = [] - # build neo pbjects - for unit in unit_ids: - times = sorting.get_unit_spike_train(unit, return_times=True) * pq.s + # build neo objects + for unit in sorting.unit_ids: + spike_times = sorting.get_unit_spike_train(unit, return_times=True) * pq.s if lim is None: times = recording.get_times() * pq.s t_start = times[0] @@ -270,11 +270,11 @@ def load_spiketrains(data_path, channel_group=None, lim=None): else: t_start = pq.Quantity(lim[0], "s") t_stop = pq.Quantity(lim[1], "s") - mask = (times >= t_start) & (times <= t_stop) - times = times[mask] + mask = (spike_times >= t_start) & (spike_times <= t_stop) + spike_times = spike_times[mask] st = neo.SpikeTrain( - times=times, t_start=t_start, t_stop=t_stop, sampling_rate=sorting.sampling_frequency * pq.Hz + times=spike_times, t_start=t_start, t_stop=t_stop, sampling_rate=sorting.sampling_frequency * pq.Hz ) st.annotations.update({"name": unit}) for p in sorting.get_property_keys(): diff --git a/src/expipe_plugin_cinpla/scripts/convert_old_project.py b/src/expipe_plugin_cinpla/scripts/convert_old_project.py index a07fed3..b0899b3 100644 --- a/src/expipe_plugin_cinpla/scripts/convert_old_project.py +++ b/src/expipe_plugin_cinpla/scripts/convert_old_project.py @@ -139,8 +139,9 @@ def convert_old_project( if exist_ok and not new_action_folder.is_dir(): # Copy action that previously failed print(f">>> Re-copying action {action_id} to new project\n") - shutil.copytree(old_action_folder, new_action_folder, - ignore=shutil.ignore_patterns("main.exdir", ".git")) + shutil.copytree( + old_action_folder, new_action_folder, ignore=shutil.ignore_patterns("main.exdir", ".git") + ) new_action = new_project.actions[action_id] # replace file in attributes.yaml From c526f51e9a0ff085509155bf96d95b7e5329175d Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 5 Apr 2024 12:19:17 +0200 Subject: [PATCH 18/47] Set include_events to True and fix load_spiketrain for tetrodes --- src/expipe_plugin_cinpla/data_loader.py | 2 +- src/expipe_plugin_cinpla/widgets/register.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/expipe_plugin_cinpla/data_loader.py b/src/expipe_plugin_cinpla/data_loader.py index 7dc5063..269775f 100644 --- a/src/expipe_plugin_cinpla/data_loader.py +++ b/src/expipe_plugin_cinpla/data_loader.py @@ -261,7 +261,7 @@ def load_spiketrains(data_path, channel_group=None, lim=None): ] sptr = [] # build neo objects - for unit in sorting.unit_ids: + for unit in unit_ids: spike_times = sorting.get_unit_spike_train(unit, return_times=True) * pq.s if lim is None: times = recording.get_times() * pq.s diff --git a/src/expipe_plugin_cinpla/widgets/register.py b/src/expipe_plugin_cinpla/widgets/register.py index defc915..6dec24b 100644 --- a/src/expipe_plugin_cinpla/widgets/register.py +++ b/src/expipe_plugin_cinpla/widgets/register.py @@ -31,7 +31,7 @@ def register_openephys_view(project): # buttons depth = MultiInput(["Key", "Probe", "Depth", "Unit"], "Add depth") register_depth = ipywidgets.Checkbox(description="Register depth", value=False) - include_events = ipywidgets.Checkbox(description="Include events", value=False) + include_events = ipywidgets.Checkbox(description="Include events", value=True) register_depth_from_adjustment = ipywidgets.Checkbox(description="Find adjustments", value=True) register_depth_from_adjustment.layout.visibility = "hidden" From 3ee8a8db93b88f47918db348975ea7aaf47f6616 Mon Sep 17 00:00:00 2001 From: Mikkel Date: Sat, 6 Apr 2024 11:54:52 +0200 Subject: [PATCH 19/47] add tools for processing, unit tracking and notebook registration --- src/expipe_plugin_cinpla/tools/__init__.py | 0 .../tools/data_processing.py | 498 ++++++++++++++++++ .../tools/registration.py | 15 + .../tools/track_units_tools.py | 220 ++++++++ .../tools/trackunitcomparison.py | 172 ++++++ .../tools/trackunitmulticomparison.py | 284 ++++++++++ 6 files changed, 1189 insertions(+) create mode 100644 src/expipe_plugin_cinpla/tools/__init__.py create mode 100644 src/expipe_plugin_cinpla/tools/data_processing.py create mode 100644 src/expipe_plugin_cinpla/tools/registration.py create mode 100644 src/expipe_plugin_cinpla/tools/track_units_tools.py create mode 100644 src/expipe_plugin_cinpla/tools/trackunitcomparison.py create mode 100644 src/expipe_plugin_cinpla/tools/trackunitmulticomparison.py diff --git a/src/expipe_plugin_cinpla/tools/__init__.py b/src/expipe_plugin_cinpla/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/expipe_plugin_cinpla/tools/data_processing.py b/src/expipe_plugin_cinpla/tools/data_processing.py new file mode 100644 index 0000000..3244375 --- /dev/null +++ b/src/expipe_plugin_cinpla/tools/data_processing.py @@ -0,0 +1,498 @@ +# This is work in progress, +import numpy as np +from expipe_plugin_cinpla.data_loader import ( + load_epochs, get_channel_groups, load_spiketrains, load_unit_annotations, + load_leds, get_duration, load_lfp, get_sample_rate, get_data_path +) +import pathlib +import expipe +import spatial_maps as sp +import warnings + + +def view_active_channels(action, sorter): + path = action.data_path() + sorter_path = path / 'spikeinterface' / sorter / 'phy' + return np.load(sorter_path / 'channel_map_si.npy') + + +def _cut_to_same_len(*args): + out = [] + lens = [] + for arg in args: + lens.append(len(arg)) + minlen = min(lens) + for arg in args: + out.append(arg[:minlen]) + return out + + +def velocity_filter(x, y, t, threshold): + """ + Removes values above threshold + Parameters + ---------- + x : quantities.Quantity array in m + 1d vector of x positions + y : quantities.Quantity array in m + 1d vector of y positions + t : quantities.Quantity array in s + 1d vector of times at x, y positions + threshold : float + """ + assert len(x) == len(y) == len(t), 'x, y, t must have same length' + vel = np.gradient([x, y], axis=1) / np.gradient(t) + speed = np.linalg.norm(vel, axis=0) + speed_mask = (speed < threshold) + speed_mask = np.append(speed_mask, 0) + x = x[np.where(speed_mask)] + y = y[np.where(speed_mask)] + t = t[np.where(speed_mask)] + return x, y, t + + +def interp_filt_position(x, y, tm, fs=100 , f_cut=10 ): + """ + rapid head movements will contribute to velocity artifacts, + these can be removed by low-pass filtering + see http://www.ncbi.nlm.nih.gov/pmc/articles/PMC1876586/ + code addapted from Espen Hagen + Parameters + ---------- + x : quantities.Quantity array in m + 1d vector of x positions + y : quantities.Quantity array in m + 1d vector of y positions + tm : quantities.Quantity array in s + 1d vector of times at x, y positions + fs : quantities scalar in Hz + return radians + Returns + ------- + out : angles, resized t + """ + import scipy.signal as ss + assert len(x) == len(y) == len(tm), 'x, y, t must have same length' + t = np.arange(tm.min(), tm.max() + 1. / fs, 1. / fs) + x = np.interp(t, tm, x) + y = np.interp(t, tm, y) + # rapid head movements will contribute to velocity artifacts, + # these can be removed by low-pass filteringpar + # see http://www.ncbi.nlm.nih.gov/pmc/articles/PMC1876586/ + # code addapted from Espen Hagen + b, a = ss.butter(N=1, Wn=f_cut * 2 / fs) + # zero phase shift filter + x = ss.filtfilt(b, a, x) + y = ss.filtfilt(b, a, y) + # we tolerate small interpolation errors + x[(x > -1e-3) & (x < 0.0)] = 0.0 + y[(y > -1e-3) & (y < 0.0)] = 0.0 + + return x, y, t + + +def rm_nans(*args): + """ + Removes nan from all corresponding arrays + Parameters + ---------- + args : arrays, lists or quantities which should have removed nans in + all the same indices + Returns + ------- + out : args with removed nans + """ + nan_indices = [] + for arg in args: + nan_indices.extend(np.where(np.isnan(arg))[0].tolist()) + nan_indices = np.unique(nan_indices) + out = [] + for arg in args: + out.append(np.delete(arg, nan_indices)) + return out + + +def filter_xy_zero(x, y, t): + idxs, = np.where((x == 0) & (y == 0)) + return [np.delete(a, idxs) for a in [x, y, t]] + +def filter_xy_box_size(x, y, t, box_size): + idxs, = np.where((x > box_size[0]) | (x < 0) | (y > box_size[1]) | (y < 0)) + return [np.delete(a, idxs) for a in [x, y, t]] + + +def filter_t_zero_duration(x, y, t, duration): + idxs, = np.where((t < 0) | (t > duration)) + return [np.delete(a, idxs) for a in [x, y, t]] + + +def load_head_direction(data_path, sampling_rate, low_pass_frequency, box_size): + from head_direction.head import head_direction + x1, y1, t1, x2, y2, t2, stop_time = load_leds(data_path) + + x1, y1, t1 = rm_nans(x1, y1, t1) + x2, y2, t2 = rm_nans(x2, y2, t2) + + x1, y1, t1 = filter_t_zero_duration(x1, y1, t1, stop_time) + x2, y2, t2 = filter_t_zero_duration(x2, y2, t2, stop_time) + + # OE saves 0.0 when signal is lost, these can be removed + x1, y1, t1 = filter_xy_zero(x1, y1, t1) + x2, y2, t2 = filter_xy_zero(x2, y2, t2) + + # x1, y1, t1 = filter_xy_box_size(x1, y1, t1, box_size) + # x2, y2, t2 = filter_xy_box_size(x2, y2, t2, box_size) + + x1, y1, t1 = interp_filt_position(x1, y1, t1, + fs=sampling_rate, f_cut=low_pass_frequency) + x2, y2, t2 = interp_filt_position(x2, y2, t2, + fs=sampling_rate, f_cut=low_pass_frequency) + + x1, y1, t1, x2, y2, t2 = _cut_to_same_len(x1, y1, t1, x2, y2, t2) + + check_valid_tracking(x1, y1, box_size) + check_valid_tracking(x2, y2, box_size) + + angles, times = head_direction(x1, y1, x2, y2, t1) + return angles, times + + +def check_valid_tracking(x, y, box_size): + if np.isnan(x).any() and np.isnan(y).any(): + raise ValueError('nans found in position, ' + + 'x nans = %i, y nans = %i' % (sum(np.isnan(x)), sum(np.isnan(y)))) + + if (x.min() < 0 or x.max() > box_size[0] or y.min() < 0 or y.max() > box_size[1]): + warnings.warn( + "Invalid values found " + + "outside box: min [x, y] = [{}, {}], ".format(x.min(), y.min()) + + "max [x, y] = [{}, {}]".format(x.max(), y.max())) + + +def load_tracking(data_path, sampling_rate, low_pass_frequency, box_size, velocity_threshold=5): + x1, y1, t1, x2, y2, t2, stop_time = load_leds(data_path) + x1, y1, t1 = rm_nans(x1, y1, t1) + x2, y2, t2 = rm_nans(x2, y2, t2) + + x1, y1, t1 = filter_t_zero_duration(x1, y1, t1, stop_time) + x2, y2, t2 = filter_t_zero_duration(x2, y2, t2, stop_time) + + # select data with least nan + if len(x1) > len(x2): + x, y, t = x1, y1, t1 + else: + x, y, t = x2, y2, t2 + + # OE saves 0.0 when signal is lost, these can be removed + x, y, t = filter_xy_zero(x, y, t) + + # x, y, t = filter_xy_box_size(x, y, t, box_size) + + # remove velocity artifacts + x, y, t = velocity_filter(x, y, t, velocity_threshold) + + x, y, t = interp_filt_position( + x, y, t, fs=sampling_rate, f_cut=low_pass_frequency) + + check_valid_tracking(x, y, box_size) + + vel = np.gradient([x, y], axis=1) / np.gradient(t) + speed = np.linalg.norm(vel, axis=0) + x, y, t, speed = np.array(x), np.array(y), np.array(t), np.array(speed) + return x, y, t, speed + +def sort_by_cluster_id(spike_trains): + if len(spike_trains) == 0: + return spike_trains + if 'name' not in spike_trains[0].annotations: + print('Unable to get cluster_id, save with phy to create') + sorted_sptrs = sorted( + spike_trains, + key=lambda x: str(x.annotations['name'])) + return sorted_sptrs + + +def get_unit_id(unit): + return str(int(unit.annotations['name'])) + + +class Template: + def __init__(self, sptr): + self.data = np.array(sptr.annotations["waveform_mean"]) + self.sampling_rate = float(sptr.sampling_rate) + + +class Data: + def __init__(self, project, stim_mask=False, baseline_duration=None, stim_channels=None, **kwargs): + self.project_path = project.path + self.params = kwargs + self.project = expipe.get_project(self.project_path) + self.actions = self.project.actions + self._spike_trains = {} + self._templates = {} + self._stim_times = {} + self._unit_names = {} + self._tracking = {} + self._head_direction = {} + self._lfp = {} + self._occupancy = {} + self._rate_maps = {} + self._tracking_split = {} + self._rate_maps_split = {} + self._prob_dist = {} + self._spatial_bins = None + self.stim_mask = stim_mask + self.baseline_duration = baseline_duration + self._channel_groups = {} + self.stim_channels = stim_channels + + def channel_groups(self, action_id): + if action_id not in self._channel_groups: + self._channel_groups[action_id] = get_channel_groups(self.data_path(action_id)) + return self._channel_groups[action_id] + + def data_path(self, action_id): + return pathlib.Path(self.project_path) / "actions" / action_id / "data" / "main.nwb" + + def get_lim(self, action_id): + stim_times = self.stim_times(action_id) + if stim_times is None: + if self.baseline_duration is None: + return [0, float(get_duration(self.data_path(action_id)).magnitude)] + else: + return [0, float(self.baseline_duration)] + stim_times = np.array(stim_times) + return [stim_times.min(), stim_times.max()] + + def duration(self, action_id): + return get_duration(self.data_path(action_id)) + + def tracking(self, action_id): + if action_id not in self._tracking: + x, y, t, speed = load_tracking( + self.data_path(action_id), + sampling_rate=self.params['position_sampling_rate'], + low_pass_frequency=self.params['position_low_pass_frequency'], + box_size=self.params['box_size']) + if self.stim_mask: + t1, t2 = self.get_lim(action_id) + mask = (t >= t1) & (t <= t2) + x = x[mask] + y = y[mask] + t = t[mask] + speed = speed[mask] + self._tracking[action_id] = { + 'x': x, 'y': y, 't': t, 'v': speed + } + return self._tracking[action_id] + + @property + def spatial_bins(self): + if self._spatial_bins is None: + box_size_, bin_size_ = sp.maps._adjust_bin_size( + box_size=self.params['box_size'], + bin_size=self.params['bin_size']) + xbins, ybins = sp.maps._make_bins(box_size_, bin_size_) + self._spatial_bins = (xbins, ybins) + self.box_size_, self.bin_size_ = box_size_, bin_size_ + return self._spatial_bins + + def occupancy(self, action_id): + if action_id not in self._occupancy: + xbins, ybins = self.spatial_bins + + occupancy_map = sp.maps._occupancy_map( + self.tracking(action_id)['x'], + self.tracking(action_id)['y'], + self.tracking(action_id)['t'], xbins, ybins) + threshold = self.params.get('occupancy_threshold') + if threshold is not None: + occupancy_map[occupancy_map <= threshold] = 0 + self._occupancy[action_id] = occupancy_map + return self._occupancy[action_id] + + def prob_dist(self, action_id): + if action_id not in self._prob_dist: + xbins, ybins = xbins, ybins = self.spatial_bins + prob_dist = sp.stats.prob_dist( + self.tracking(action_id)['x'], + self.tracking(action_id)['y'], bins=(xbins, ybins)) + self._prob_dist[action_id] = prob_dist + return self._prob_dist[action_id] + + def tracking_split(self, action_id): + if action_id not in self._tracking_split: + x, y, t, v = map(self.tracking(action_id).get, ['x', 'y', 't', 'v']) + + t_split = t[-1] / 2 + mask_1 = t < t_split + mask_2 = t >= t_split + x1, y1, t1, v1 = x[mask_1], y[mask_1], t[mask_1], v[mask_1] + x2, y2, t2, v2 = x[mask_2], y[mask_2], t[mask_2], v[mask_2] + + + self._tracking_split[action_id] = { + 'x1': x1, 'y1': y1, 't1': t1, 'v1': v1, + 'x2': x2, 'y2': y2, 't2': t2, 'v2': v2 + } + return self._tracking_split[action_id] + + def spike_train_split(self, action_id, channel_group, unit_name): + spikes = self.spike_train(action_id, channel_group, unit_name) + t_split = self.duration(action_id) / 2 + spikes_1 = spikes[spikes < t_split] + spikes_2 = spikes[spikes >= t_split] + return spikes_1, spikes_2, t_split + + def rate_map_split(self, action_id, channel_group, unit_name, smoothing): + make_rate_map = False + if action_id not in self._rate_maps_split: + self._rate_maps_split[action_id] = {} + if channel_group not in self._rate_maps_split[action_id]: + self._rate_maps_split[action_id][channel_group] = {} + if unit_name not in self._rate_maps_split[action_id][channel_group]: + self._rate_maps_split[action_id][channel_group][unit_name] = {} + if smoothing not in self._rate_maps_split[action_id][channel_group][unit_name]: + make_rate_map = True + + + if make_rate_map: + xbins, ybins = self.spatial_bins + x, y, t = map(self.tracking(action_id).get, ['x', 'y', 't']) + spikes = self.spike_train(action_id, channel_group, unit_name) + t_split = t[-1] / 2 + mask_1 = t < t_split + mask_2 = t >= t_split + x_1, y_1, t_1 = x[mask_1], y[mask_1], t[mask_1] + x_2, y_2, t_2 = x[mask_2], y[mask_2], t[mask_2] + spikes_1 = spikes[spikes < t_split] + spikes_2 = spikes[spikes >= t_split] + occupancy_map_1 = sp.maps._occupancy_map( + x_1, y_1, t_1, xbins, ybins) + occupancy_map_2 = sp.maps._occupancy_map( + x_2, y_2, t_2, xbins, ybins) + + spike_map_1 = sp.maps._spike_map( + x_1, y_1, t_1, spikes_1, xbins, ybins) + spike_map_2 = sp.maps._spike_map( + x_2, y_2, t_2, spikes_2, xbins, ybins) + + smooth_spike_map_1 = sp.maps.smooth_map( + spike_map_1, bin_size=self.bin_size_, smoothing=smoothing) + smooth_spike_map_2 = sp.maps.smooth_map( + spike_map_2, bin_size=self.bin_size_, smoothing=smoothing) + smooth_occupancy_map_1 = sp.maps.smooth_map( + occupancy_map_1, bin_size=self.bin_size_, smoothing=smoothing) + smooth_occupancy_map_2 = sp.maps.smooth_map( + occupancy_map_2, bin_size=self.bin_size_, smoothing=smoothing) + + rate_map_1 = smooth_spike_map_1 / smooth_occupancy_map_1 + rate_map_2 = smooth_spike_map_2 / smooth_occupancy_map_2 + self._rate_maps_split[action_id][channel_group][unit_name][smoothing] = [rate_map_1, rate_map_2] + + return self._rate_maps_split[action_id][channel_group][unit_name][smoothing] + + def rate_map(self, action_id, channel_group, unit_name, smoothing): + make_rate_map = False + if action_id not in self._rate_maps: + self._rate_maps[action_id] = {} + if channel_group not in self._rate_maps[action_id]: + self._rate_maps[action_id][channel_group] = {} + if unit_name not in self._rate_maps[action_id][channel_group]: + self._rate_maps[action_id][channel_group][unit_name] = {} + if smoothing not in self._rate_maps[action_id][channel_group][unit_name]: + make_rate_map = True + + if make_rate_map: + xbins, ybins = self.spatial_bins + + spike_map = sp.maps._spike_map( + self.tracking(action_id)['x'], + self.tracking(action_id)['y'], + self.tracking(action_id)['t'], + self.spike_train(action_id, channel_group, unit_name), + xbins, ybins) + + smooth_spike_map = sp.maps.smooth_map( + spike_map, bin_size=self.bin_size_, smoothing=smoothing) + smooth_occupancy_map = sp.maps.smooth_map( + self.occupancy(action_id), bin_size=self.bin_size_, smoothing=smoothing) + rate_map = smooth_spike_map / smooth_occupancy_map + self._rate_maps[action_id][channel_group][unit_name][smoothing] = rate_map + + return self._rate_maps[action_id][channel_group][unit_name][smoothing] + + def head_direction(self, action_id): + if action_id not in self._head_direction: + a, t = load_head_direction( + self.data_path(action_id), + sampling_rate=self.params['position_sampling_rate'], + low_pass_frequency=self.params['position_low_pass_frequency'], + box_size=self.params['box_size']) + if self.stim_mask: + t1, t2 = self.get_lim(action_id) + mask = (t >= t1) & (t <= t2) + a = a[mask] + t = t[mask] + self._head_direction[action_id] = { + 'a': a, 't': t + } + return self._head_direction[action_id] + + def lfp(self, action_id, channel_group, clean_memory=False): + lim = self.get_lim(action_id) if self.stim_mask else None + if clean_memory: + return load_lfp( + self.data_path(action_id), channel_group, lim) + if action_id not in self._lfp: + self._lfp[action_id] = {} + if channel_group not in self._lfp[action_id]: + self._lfp[action_id][channel_group] = load_lfp( + self.data_path(action_id), channel_group, lim) + return self._lfp[action_id][channel_group] + + def template(self, action_id, channel_group, unit_id): + self.spike_trains(action_id) + return Template(self._spike_trains[action_id][channel_group][unit_id]) + + def spike_train(self, action_id, channel_group, unit_id): + self.spike_trains(action_id) + return self._spike_trains[action_id][channel_group][unit_id] + + def spike_trains(self, action_id): + if action_id not in self._spike_trains: + self._spike_trains[action_id] = {} + lim = self.get_lim(action_id) if self.stim_mask else None + + sts = load_spiketrains(self.data_path(action_id), lim=lim) + for st in sts: + channel_group = st.annotations['group'] + unit_id = get_unit_id(st) + self._spike_trains[action_id][channel_group] = {unit_id: st} + + return self._spike_trains[action_id] + + def unit_names(self, action_id, channel_group): + self.spike_trains(action_id) + return list(self._spike_trains[action_id][channel_group].keys()) + + def stim_times(self, action_id): + if action_id not in self._stim_times: + try: + trials = load_epochs( + self.data_path(action_id), label_column='channel') + if len(set(trials.labels)) > 1: + stim_times = trials.times[trials.labels==self.stim_channels[action_id]] + else: + stim_times = trials.times + stim_times = np.sort(np.abs(np.array(stim_times))) + # there are some 0 times and inf times, remove those + # stim_times = stim_times[stim_times >= 1e-20] + self._stim_times[action_id] = stim_times + except AttributeError as e: + if str(e)=="'NoneType' object has no attribute 'to_dataframe'": + self._stim_times[action_id] = None + else: + raise e + + return self._stim_times[action_id] diff --git a/src/expipe_plugin_cinpla/tools/registration.py b/src/expipe_plugin_cinpla/tools/registration.py new file mode 100644 index 0000000..1d3b311 --- /dev/null +++ b/src/expipe_plugin_cinpla/tools/registration.py @@ -0,0 +1,15 @@ +import os +import shutil +import pathlib + +def store_notebook(action, notebook_path): + notebook_path = pathlib.Path(notebook_path) + action.data["notebook"] = notebook_path.name + notebook_output_path = action.data_path('notebook') + shutil.copy(notebook_path, notebook_output_path) + # As HTML + os.system('jupyter nbconvert --to html {}'.format(notebook_path)) + html_path = notebook_path.with_suffix(".html") + action.data["html"] = html_path.name + html_output_path = action.data_path('html') + shutil.copy(html_path, html_output_path) diff --git a/src/expipe_plugin_cinpla/tools/track_units_tools.py b/src/expipe_plugin_cinpla/tools/track_units_tools.py new file mode 100644 index 0000000..0dced8c --- /dev/null +++ b/src/expipe_plugin_cinpla/tools/track_units_tools.py @@ -0,0 +1,220 @@ +import numpy as np +import pandas as pd +from scipy.optimize import linear_sum_assignment +from matplotlib import gridspec +import matplotlib.pyplot as plt + + +def dissimilarity(template_0, template_1): + """ + Returns a value of dissimilarity of the mean between two or more + spike templates. + Parameters + ---------- + templates : list object (see Notes) + List containing the mean waveform over each electrode of spike sorted + spiketrains from at least one electrode. All elements in the list must + be of equal size, that is, the number of electrodes must be equal, and + the number of points on the waveform must be equal. + Returns + ------- + diss : numpy array-like + Returns a matrix containing the computed dissimilarity between the mean + of the spiketrain, for the same channel. + """ + max_val = np.max([np.max(np.abs(template_0)), np.max(np.abs(template_1))]) + + t_i_lin = template_0.ravel() + t_j_lin = template_1.ravel() + + + return np.mean(np.abs(t_i_lin / max_val - t_j_lin / max_val)) + # return np.mean(np.abs(t_i_lin - t_j_lin)) + + +def dissimilarity_weighted(templates_0, templates_1): + """ + Returns a value of dissimilarity of the mean between two or more + spike templates. + Parameters + ---------- + templates : list object (see Notes) + List containing the mean waveform over each electrode of spike sorted + spiketrains from at least one electrode. All elements in the list must + be of equal size, that is, the number of electrodes must be equal, and + the number of points on the waveform must be equal. + Returns + ------- + diss : numpy array-like + Returns a matrix containing the computed dissimilarity between the mean + of the spiketrain, for the same channel. + """ + + max_val = np.max([np.max(np.abs(templates_0)), np.max(np.abs(templates_1))]) + + templates_0 /= max_val + templates_1 /= max_val + + return np.sqrt(np.sum([(templates_0[i] - templates_1[i])**2 for i in range(templates_0.shape[0])], axis=0)).mean() + + +def make_dissimilary_matrix(comp_object, channel_group): + templates_0, templates_1 = comp_object.templates[channel_group] + diss_matrix = np.zeros((len(templates_0), len(templates_1))) + + unit_ids_0, unit_ids_1 = comp_object.unit_ids[channel_group] + + for i, w0 in enumerate(templates_0): + for j, w1 in enumerate(templates_1): + diss_matrix[i, j] = dissimilarity_weighted(w0, w1) + + diss_matrix = pd.DataFrame( + diss_matrix, + index=unit_ids_0, + columns=unit_ids_1) + + return diss_matrix + + +def make_possible_match(dissimilarity_scores, max_dissimilarity): + """ + Given an agreement matrix and a max_dissimilarity threhold. + Return as a dict all possible match for each spiketrain in each side. + + Note : this is symmetric. + + + Parameters + ---------- + dissimilarity_scores: pd.DataFrame + + max_dissimilarity: float + + + Returns + ----------- + best_match_12: pd.Series + + best_match_21: pd.Series + + """ + unit1_ids = np.array(dissimilarity_scores.index) + unit2_ids = np.array(dissimilarity_scores.columns) + + # threhold the matrix + scores = dissimilarity_scores.values.copy() + scores[scores > max_dissimilarity] = np.inf + + possible_match_12 = {} + for i1, u1 in enumerate(unit1_ids): + inds_match = np.isfinite(scores[i1, :]) + possible_match_12[u1] = unit2_ids[inds_match] + + possible_match_21 = {} + for i2, u2 in enumerate(unit2_ids): + inds_match = np.isfinite(scores[:, i2]) + possible_match_21[u2] = unit1_ids[inds_match] + + return possible_match_12, possible_match_21 + + +def make_best_match(dissimilarity_scores, max_dissimilarity): + """ + Given an agreement matrix and a max_dissimilarity threhold. + return a dict a best match for each units independently of others. + + Note : this is symmetric. + + Parameters + ---------- + dissimilarity_scores: pd.DataFrame + + max_dissimilarity: float + + + Returns + ----------- + best_match_12: pd.Series + + best_match_21: pd.Series + + + """ + unit1_ids = np.array(dissimilarity_scores.index) + unit2_ids = np.array(dissimilarity_scores.columns) + + scores = dissimilarity_scores.values.copy() + + best_match_12 = pd.Series(index=unit1_ids, dtype='int64') + for i1, u1 in enumerate(unit1_ids): + ind_min = np.argmin(scores[i1, :]) + if scores[i1, ind_min] <= max_dissimilarity: + best_match_12[u1] = unit2_ids[ind_min] + else: + best_match_12[u1] = -1 + + best_match_21 = pd.Series(index=unit2_ids, dtype='int64') + for i2, u2 in enumerate(unit2_ids): + ind_min = np.argmin(scores[:, i2]) + if scores[ind_min, i2] <= max_dissimilarity: + best_match_21[u2] = unit1_ids[ind_min] + else: + best_match_21[u2] = -1 + + return best_match_12, best_match_21 + + +def make_hungarian_match(dissimilarity_scores, max_dissimilarity): + """ + Given an agreement matrix and a max_dissimilarity threhold. + return the "optimal" match with the "hungarian" algo. + This use internally the scipy.optimze.linear_sum_assignment implementation. + + Parameters + ---------- + dissimilarity_scores: pd.DataFrame + + max_dissimilarity: float + + + Returns + ----------- + hungarian_match_12: pd.Series + + hungarian_match_21: pd.Series + + """ + unit1_ids = np.array(dissimilarity_scores.index) + unit2_ids = np.array(dissimilarity_scores.columns) + + # threhold the matrix + scores = dissimilarity_scores.values.copy() + + [inds1, inds2] = linear_sum_assignment(scores) + + hungarian_match_12 = pd.Series(index=unit1_ids, dtype='int64') + hungarian_match_12[:] = -1 + hungarian_match_21 = pd.Series(index=unit2_ids, dtype='int64') + hungarian_match_21[:] = -1 + + for i1, i2 in zip(inds1, inds2): + u1 = unit1_ids[i1] + u2 = unit2_ids[i2] + if dissimilarity_scores.at[u1, u2] < max_dissimilarity: + hungarian_match_12[u1] = u2 + hungarian_match_21[u2] = u1 + + return hungarian_match_12, hungarian_match_21 + + +def plot_template(template, fig, gs, axs=None, **kwargs): + nrc = template.shape[1] + if axs is None: + gs0 = gridspec.GridSpecFromSubplotSpec(1, nrc, subplot_spec=gs) + axs = [fig.add_subplot(gs0[0])] + axs.extend([fig.add_subplot(gs0[i], sharey=axs[0], sharex=axs[0]) for i in range(1, nrc)]) + for c in range(nrc): + axs[c].plot(template[:, c], **kwargs) + if c > 0: + plt.setp(axs[c].get_yticklabels(), visible=False) + return axs diff --git a/src/expipe_plugin_cinpla/tools/trackunitcomparison.py b/src/expipe_plugin_cinpla/tools/trackunitcomparison.py new file mode 100644 index 0000000..f42cb4b --- /dev/null +++ b/src/expipe_plugin_cinpla/tools/trackunitcomparison.py @@ -0,0 +1,172 @@ +from .track_units_tools import make_dissimilary_matrix, make_possible_match, make_best_match, \ + make_hungarian_match +from expipe_plugin_cinpla.data_loader import get_data_path, load_spiketrains, get_channel_groups +import matplotlib.pylab as plt +import numpy as np +from pathlib import Path + + +class TrackingSession: + """ + Base class shared by SortingComparison and GroundTruthComparison + """ + + def __init__(self, action_id_0, action_id_1, actions, channel_group=None, + max_dissimilarity=10, dissimilarity_function=None, verbose=False): + + data_path_0 = get_data_path(actions[action_id_0]) + data_path_1 = get_data_path(actions[action_id_1]) + + self._actions = actions + self.action_id_0 = action_id_0 + self.action_id_1 = action_id_1 + self._channel_group = channel_group + self.action_ids = [action_id_0, action_id_1] + self.max_dissimilarity = max_dissimilarity + self.dissimilarity_function = dissimilarity_function + self._verbose = verbose + + if channel_group is None: + channel_groups = get_channel_groups(data_path_0) + self.matches = {} + self.templates = {} + self.unit_ids = {} + for chan in channel_groups: + self.matches[chan] = dict() + self.templates[chan] = list() + self.unit_ids[chan] = list() + else: + self.matches = {channel_group: dict()} + self.templates = {channel_group: list()} + self.unit_ids = {channel_group: list()} + + units_0 = load_spiketrains(data_path_0) + units_1 = load_spiketrains(data_path_1) + for channel_group in self.matches.keys(): + units_0 = [unit for unit in units_0 if unit.annotations['group']==channel_group] + units_1 = [unit for unit in units_1 if unit.annotations['group']==channel_group] + + self.unit_ids[channel_group] = [ + np.array([int(st.annotations['name']) for st in units_0]), + np.array([int(st.annotations['name']) for st in units_1]) + ] + self.templates[channel_group] = [ + [st.annotations["waveform_mean"] for st in units_0], + [st.annotations["waveform_mean"] for st in units_1] + ] + if len(units_0) > 0 and len(units_1) > 0: + + self._do_dissimilarity(channel_group) + self._do_matching(channel_group) + + def save_dissimilarity_matrix(self, path=None): + path = path or Path.cwd() + for channel_group in self.matches: + if 'dissimilarity_scores' not in self.matches[channel_group]: + continue + filename = f'{self.action_id_0}_{self.action_id_1}_{channel_group}' + self.matches[channel_group]['dissimilarity_scores'].to_csv( + path / (filename + '.csv')) + + @property + def session_0_name(self): + return self.name_list[0] + + @property + def session_1_name(self): + return self.name_list[1] + + def _do_dissimilarity(self, channel_group): + if self._verbose: + print('Agreement scores...') + + # agreement matrix score for each pair + self.matches[channel_group]['dissimilarity_scores'] = make_dissimilary_matrix( + self, channel_group) + + def _do_matching(self, channel_group): + # must be implemented in subclass + if self._verbose: + print("Matching...") + + self.matches[channel_group]['possible_match_01'], self.matches[channel_group]['possible_match_10'] = \ + make_possible_match(self.matches[channel_group]['dissimilarity_scores'], self.max_dissimilarity) + self.matches[channel_group]['best_match_01'], self.matches[channel_group]['best_match_10'] = \ + make_best_match(self.matches[channel_group]['dissimilarity_scores'], self.max_dissimilarity) + self.matches[channel_group]['hungarian_match_01'], self.matches[channel_group]['hungarian_match_10'] = \ + make_hungarian_match(self.matches[channel_group]['dissimilarity_scores'], self.max_dissimilarity) + + def plot_matched_units(self, match_mode='hungarian', channel_group=None, ylim=[-200, 50], figsize=(15, 15)): + ''' + + Parameters + ---------- + match_mode + + Returns + ------- + + ''' + if channel_group is None: + ch_groups = self.matches.keys() + else: + ch_groups = [channel_group] + + for ch_group in ch_groups: + if 'hungarian_match_01' not in self.matches[ch_group].keys(): + print('Not units for group', ch_group) + continue + + if match_mode == 'hungarian': + match12 = self.matches[ch_group]['hungarian_match_01'] + elif match_mode == 'best': + match12 = self.matches[ch_group]['best_match_01'] + + num_matches = len(np.where(match12 != -1)[0]) + + if num_matches > 0: + + fig, ax_list = plt.subplots(nrows=2, ncols=num_matches, figsize=figsize) + fig.suptitle('Channel group ' + str(ch_group)) + + if num_matches == 1: + i = np.where(match12 != -1)[0][0] + j = match12.iloc[i] + i1 = np.where(self.matches[ch_group]['unit_ids_0'] == match12.index[i]) + i2 = np.where(self.matches[ch_group]['unit_ids_1'] == j) + template1 = np.squeeze( + self.matches[ch_group]['templates_0'][i1]) + ax_list[0].plot(template1, color='C0') + ax_list[0].set_title('Unit ' + str(match12.index[i])) + template2 = np.squeeze( + self.matches[ch_group]['templates_1'][i2]) + ax_list[1].plot(template2, color='C0') + ax_list[1].set_title('Unit ' + str(j)) + ax_list[0].set_ylabel(self.name_list[0]) + ax_list[1].set_ylabel(self.name_list[1]) + ax_list[0].set_ylim(ylim) + ax_list[1].set_ylim(ylim) + else: + id_ax = 0 + for i, j in enumerate(match12): + if j != -1: + i1 = np.where(self.matches[ch_group]['unit_ids_0'] == match12.index[i]) + i2 = np.where(self.matches[ch_group]['unit_ids_1'] == j) + + if id_ax == 0: + ax_list[0, id_ax].set_ylabel(self.name_list[0]) + ax_list[1, id_ax].set_ylabel(self.name_list[1]) + template1 = np.squeeze( + self.matches[ch_group]['templates_0'][i1]) + ax_list[0, id_ax].plot(template1, color='C'+str(id_ax)) + ax_list[0, id_ax].set_title('Unit ' + str(match12.index[i])) + template2 = np.squeeze( + self.matches[ch_group]['templates_1'][i1]) + ax_list[1, id_ax].plot(template2, color='C'+str(id_ax)) + ax_list[1, id_ax].set_title('Unit ' + str(j)) + ax_list[0, id_ax].set_ylim(ylim) + ax_list[1, id_ax].set_ylim(ylim) + id_ax += 1 + else: + print('No matched units for group', ch_group) + continue diff --git a/src/expipe_plugin_cinpla/tools/trackunitmulticomparison.py b/src/expipe_plugin_cinpla/tools/trackunitmulticomparison.py new file mode 100644 index 0000000..172887b --- /dev/null +++ b/src/expipe_plugin_cinpla/tools/trackunitmulticomparison.py @@ -0,0 +1,284 @@ +import numpy as np +import networkx as nx +import yaml +from .trackunitcomparison import TrackingSession +from expipe_plugin_cinpla.data_loader import get_data_path, get_channel_groups, load_spiketrains +from .track_units_tools import plot_template +import matplotlib.pylab as plt +from tqdm import tqdm +import uuid +from matplotlib import gridspec +from collections import defaultdict +from pathlib import Path +import datetime + +class TrackMultipleSessions: + def __init__(self, actions, action_list=None, channel_group=None, + max_dissimilarity=None, max_timedelta=None, verbose=False, + progress_bar=None, data_path=None): + self.data_path = Path.cwd() if data_path is None else Path(data_path) + self.data_path.mkdir(parents=True, exist_ok=True) + self.action_list = [a for a in actions] if action_list is None else action_list + self._actions = actions + self._channel_group = channel_group + self.max_dissimilarity = max_dissimilarity or np.inf + self.max_timedelta = max_timedelta or datetime.MAXYEAR + self._verbose = verbose + self._pbar = tqdm if progress_bar is None else progress_bar + self._templates = {} + if self._channel_group is None: + dp = get_data_path(self._actions[self.action_list[0]]) + self._channel_groups = get_channel_groups(dp) + if len(self._channel_groups) == 0: + print('Unable to locate channel groups, please provide a working action_list') + else: + self._channel_groups = [self._channel_group] + + def do_matching(self): + # do pairwise matching + if self._verbose: + print('Multicomaprison step1: pairwise comparison') + + self.comparisons = [] + N = len(self.action_list) + pbar = self._pbar(total=int((N**2 - N) / 2)) + for i in range(N): + for j in range(i + 1, N): + if self._verbose: + print(" Comparing: ", self.action_list[i], " and ", self.action_list[j]) + comp = TrackingSession( + self.action_list[i], self.action_list[j], + actions=self._actions, + max_dissimilarity=np.inf, + channel_group=self._channel_group, + verbose=self._verbose) + # comp.save_dissimilarity_matrix() + self.comparisons.append(comp) + pbar.update(1) + pbar.close() + + def make_graphs_from_matches(self): + if self._verbose: + print('Multicomaprison step2: make graph') + + self.graphs = {} + + for ch in self._channel_groups: + if self._verbose: + print('Processing channel', ch) + self.graphs[ch] = nx.Graph() + + # nodes + for comp in self.comparisons: + # if same node is added twice it's only created once + for i, action_id in enumerate(comp.action_ids): + for u in comp.unit_ids[ch][i]: + node_name = action_id + '_' + str(int(u)) + self.graphs[ch].add_node( + node_name, action_id=action_id, + unit_id=int(u)) + + # edges + for comp in self.comparisons: + if 'hungarian_match_01' not in comp.matches[ch]: + continue + for u1 in comp.unit_ids[ch][0]: + u2 = comp.matches[ch]['hungarian_match_01'][u1] + if u2 != -1: + score = comp.matches[ch]['dissimilarity_scores'].loc[u1, u2] + node1_name = comp.action_id_0 + '_' + str(int(u1)) + node2_name = comp.action_id_1 + '_' + str(int(u2)) + self.graphs[ch].add_edge( + node1_name, node2_name, weight=float(score)) + + # the graph is symmetrical + self.graphs[ch] = self.graphs[ch].to_undirected() + + def compute_time_delta_edges(self): + ''' + adds a timedelta to each of the edges + ''' + for graph in self.graphs.values(): + for n0, n1 in graph.edges(): + action_id_0 = graph.nodes[n0]['action_id'] + action_id_1 = graph.nodes[n1]['action_id'] + time_delta = abs( + self._actions[action_id_0].datetime - + self._actions[action_id_1].datetime) + graph.add_edge(n0, n1, time_delta=time_delta) + + def compute_depth_delta_edges(self): + ''' + adds a depthdelta to each of the edges + ''' + for ch, graph in self.graphs.items(): + ch_num = int(ch[-1]) + for n0, n1 in graph.edges(): + action_id_0 = graph.nodes[n0]['action_id'] + action_id_1 = graph.nodes[n1]['action_id'] + loc_0 = self._actions[action_id_0].modules['channel_group_location'][ch_num] + loc_1 = self._actions[action_id_1].modules['channel_group_location'][ch_num] + assert loc_0 == loc_1 + depth_0 = self._actions[action_id_0].modules['depth'][loc_0]['probe_0'] + depth_1 = self._actions[action_id_0].modules['depth'][loc_1]['probe_0'] + depth_0 = float(depth_0.rescale('um')) + depth_1 = float(depth_1.rescale('um')) + depth_delta = abs(depth_0 - depth_1) + graph.add_edge(n0, n1, depth_delta=depth_delta) + + def remove_edges_above_threshold(self, key='weight', threshold=0.05): + ''' + key: weight, depth_delta, time_delta + ''' + for ch in self.graphs: + graph = self.graphs[ch] + edges_to_remove = [] + for sub_graph in nx.connected_components(graph): + for node_id in sub_graph: + for n1, n2, d in graph.edges(node_id, data=True): + if d[key] > threshold and n2 in sub_graph: # remove all edges from the subgraph + edge = set((n1, n2)) + if edge not in edges_to_remove: + edges_to_remove.append(edge) + for n1, n2 in edges_to_remove: + graph.remove_edge(n1, n2) + self.graphs[ch] = graph + + def remove_edges_with_duplicate_actions(self): + for graph in self.graphs.values(): + edges_to_remove = [] + for sub_graph in nx.connected_components(graph): + sub_graph_action_ids = {node: graph.nodes[node]['action_id'] for node in sub_graph} + action_ids = np.array(list(sub_graph_action_ids.values())) + node_ids = np.array(list(sub_graph_action_ids.keys())) + unique_action_ids, action_id_counts = np.unique(action_ids, return_counts=True) + if len(unique_action_ids) != len(action_ids): + + duplicates = unique_action_ids[action_id_counts > 1] + + for duplicate in duplicates: + idxs, = np.where(action_ids == duplicate) + weights = {} + for node_id in node_ids[idxs]: + weights[node_id] = np.mean([ + d['weight'] + for n1, n2, d in graph.edges(node_id, data=True) + if n2 in sub_graph_action_ids + ]) + min_weight = np.min(list(weights.values())) + for node_id, weight in weights.items(): + if weight > min_weight: # remove all edges from the subgraph + for n1, n2 in graph.edges(node_id): + if n2 in sub_graph_action_ids: + edge = set((n1, n2)) + if edge not in edges_to_remove: + edges_to_remove.append(edge) + for n1, n2 in edges_to_remove: + graph.remove_edge(n1, n2) + + def save_graphs(self): + for ch, graph in self.graphs.items(): + with open(self.data_path / f'graph-group-{ch}.yaml', "w") as f: + yaml.dump(graph, f) + + def load_graphs(self): + self.graphs = {} + for path in self.data_path.iterdir(): + if path.name.startswith('graph-group') and path.suffix == '.yaml': + ch = path.stem.split('-')[-1] + with open(path, "r") as f: + self.graphs[ch] = yaml.load(f, Loader=yaml.Loader) + + def identify_units(self): + if self._verbose: + print('Multicomaprison step3: extract agreement from graph') + self.identified_units = {} + for ch, graph in self.graphs.items(): + # extract agrrement from graph + self._new_units = {} + for node_set in nx.connected_components(graph): + unit_id = str(uuid.uuid4()) + edges = graph.edges(node_set, data=True) + + if len(edges) == 0: + average_dissimilarity = None + else: + average_dissimilarity = np.mean( + [d['weight'] for _, _, d in edges]) + + original_ids = defaultdict(list) + for node in node_set: + original_ids[graph.nodes[node]['action_id']].append( + graph.nodes[node]['unit_id'] + ) + + self._new_units[unit_id] = { + 'average_dissimilarity': average_dissimilarity, + 'original_unit_ids': original_ids} + + self.identified_units[ch] = self._new_units + + def load_template(self, action_id, channel_group, unit_id): + group_unit_hash = str(channel_group) + '_' + str(unit_id) + if action_id in self._templates: + return self._templates[action_id][group_unit_hash] + + action = self._actions[action_id] + + data_path = get_data_path(action) + + spike_trains = load_spiketrains(data_path) + + self._templates[action_id] = {} + for sptr in spike_trains: + group_unit_hash_ = sptr.annotations['group'] + '_' + str(int(sptr.annotations['name'])) + self._templates[action_id][group_unit_hash_] = sptr.annotations["waveform_mean"] + + return self._templates[action_id][group_unit_hash] + + def plot_matches(self, chan_group=None, figsize=(10, 3), step_color=True): + ''' + + Parameters + ---------- + + + Returns + ------- + + ''' + if chan_group is None: + ch_groups = self.identified_units.keys() + else: + ch_groups = [chan_group] + for ch_group in ch_groups: + identified_units = self.identified_units[ch_group] + units = [ + (unit['original_unit_ids'], unit['average_dissimilarity']) + for unit in identified_units.values() + if len(unit['original_unit_ids']) > 1] + num_units = sum([len(u) for u in units]) + if num_units == 0: + print(f"Zero units found on channel group {ch_group}") + continue + fig = plt.figure(figsize=(figsize[0], figsize[1] * num_units)) + gs = gridspec.GridSpec(num_units, 1) + id_ax = 0 + for unit, avg_dsim in units: + axs = None + for action_id, unit_ids in unit.items(): + for unit_id in unit_ids: + label = f"{action_id} Unit {unit_id} {avg_dsim:.2f}" + template = self.load_template(action_id, ch_group, unit_id) + if template is None: + print(f'Unable to plot "{unit_id}" from action "{action_id}" ch group "{ch_group}"') + continue + # print(f'plotting {action_id}, {ch_group}, {unit_id}') + axs = plot_template( + template, + fig=fig, gs=gs[id_ax], axs=axs, + label=label) + id_ax += 1 + plt.legend(loc='center left', bbox_to_anchor=(1, 0.5)) + fig.suptitle('Channel group ' + str(ch_group)) + plt.tight_layout(rect=[0, 0.03, 1, 0.98]) From 5a449caa8b0e580caa4b7c0c69deaeba425ff7ae Mon Sep 17 00:00:00 2001 From: Mikkel Date: Sat, 6 Apr 2024 12:16:07 +0200 Subject: [PATCH 20/47] correct channel avg --- src/expipe_plugin_cinpla/tools/track_units_tools.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/expipe_plugin_cinpla/tools/track_units_tools.py b/src/expipe_plugin_cinpla/tools/track_units_tools.py index 0dced8c..eb25556 100644 --- a/src/expipe_plugin_cinpla/tools/track_units_tools.py +++ b/src/expipe_plugin_cinpla/tools/track_units_tools.py @@ -54,8 +54,9 @@ def dissimilarity_weighted(templates_0, templates_1): templates_0 /= max_val templates_1 /= max_val - - return np.sqrt(np.sum([(templates_0[i] - templates_1[i])**2 for i in range(templates_0.shape[0])], axis=0)).mean() + # root sum square, averaged over channels + weighted = np.sqrt(np.sum([(templates_0[:,i] - templates_1[:,i])**2 for i in range(templates_0.shape[1])], axis=0)).mean() + return weighted def make_dissimilary_matrix(comp_object, channel_group): From f9caf6347ca6045d83cfda8e64f547ee74036f75 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 8 Apr 2024 15:36:34 +0200 Subject: [PATCH 21/47] Speed up import time --- src/expipe_plugin_cinpla/cli/register.py | 3 - src/expipe_plugin_cinpla/data_loader.py | 5 +- src/expipe_plugin_cinpla/imports.py | 180 ---------------- .../nwbutils/nwbwidgetsunitviewer.py | 17 +- src/expipe_plugin_cinpla/scripts/curation.py | 28 ++- .../tools/data_processing.py | 194 +++++++++--------- .../tools/registration.py | 7 +- .../tools/track_units_tools.py | 20 +- .../tools/trackunitcomparison.py | 117 ++++++----- .../tools/trackunitmulticomparison.py | 159 +++++++------- src/expipe_plugin_cinpla/widgets/viewer.py | 2 +- 11 files changed, 286 insertions(+), 446 deletions(-) diff --git a/src/expipe_plugin_cinpla/cli/register.py b/src/expipe_plugin_cinpla/cli/register.py index 08d11b2..b5d690e 100644 --- a/src/expipe_plugin_cinpla/cli/register.py +++ b/src/expipe_plugin_cinpla/cli/register.py @@ -8,9 +8,6 @@ from expipe_plugin_cinpla.cli.utils import validate_depth, validate_position, validate_angle, validate_adjustment -import spikeinterface.sorters as ss - - def attach_to_register(cli): ### OpenEphys ### @cli.command("openephys", short_help="Register an open-ephys recording-action to database.") diff --git a/src/expipe_plugin_cinpla/data_loader.py b/src/expipe_plugin_cinpla/data_loader.py index 269775f..70a854c 100644 --- a/src/expipe_plugin_cinpla/data_loader.py +++ b/src/expipe_plugin_cinpla/data_loader.py @@ -6,7 +6,6 @@ import spikeinterface as si import spikeinterface.extractors as se -from pynwb import NWBHDF5IO from .scripts.utils import _get_data_path @@ -92,6 +91,8 @@ def load_leds(data_path): x1, y1, t1, x2, y2, t2, stop_time: tuple The x and y positions of the red and green LEDs, the timestamps and the stop time """ + from pynwb import NWBHDF5IO + io = NWBHDF5IO(str(data_path), "r") nwbfile = io.read() @@ -190,6 +191,8 @@ def load_epochs(data_path, label_column=None): epochs: neo.Epoch The trials as NEO epochs """ + from pynwb import NWBHDF5IO + with NWBHDF5IO(str(data_path), "r") as io: nwbfile = io.read() trials = nwbfile.trials.to_dataframe() diff --git a/src/expipe_plugin_cinpla/imports.py b/src/expipe_plugin_cinpla/imports.py index 1038c0e..f5e5bd6 100644 --- a/src/expipe_plugin_cinpla/imports.py +++ b/src/expipe_plugin_cinpla/imports.py @@ -1,18 +1,6 @@ -# import click -# from expipe.cliutils.misc import lazy_import - import expipe from pathlib import Path -# @lazy_import -# def expipe(): -# import expipe -# return expipe - -# @lazy_import -# def pathlib(): -# import pathlib -# return pathlib local_root, _ = expipe.config._load_local_config(Path.cwd()) if local_root is not None: @@ -23,171 +11,3 @@ class P: config = {} project = P - - -# @lazy_import -# def pd(): -# import pandas as pd -# return pd - -# @lazy_import -# def dt(): -# import datetime as dt -# return dt - -# @lazy_import -# def yaml(): -# import yaml -# return yaml - -# @lazy_import -# def ipywidgets(): -# import ipywidgets -# return ipywidgets - -# @lazy_import -# def pyopenephys(): -# import pyopenephys -# return pyopenephys - -# # @lazy_import -# # def openephys_io(): -# # from expipe_io_neuro.openephys import openephys as openephys_io -# # return openephys_io - -# @lazy_import -# def pyintan(): -# import pyintan -# return pyintan - -# @lazy_import -# def pyxona(): -# import pyxona -# return pyxona - -# @lazy_import -# def platform(): -# import platform -# return platform - -# @lazy_import -# def csv(): -# import csv -# return csv - -# @lazy_import -# def json(): -# import json -# return json - -# # @lazy_import -# # def axona(): -# # from expipe_io_neuro import axona -# # return axona - -# @lazy_import -# def os(): -# import os -# return os - -# @lazy_import -# def shutil(): -# import shutil -# return shutil - -# @lazy_import -# def datetime(): -# import datetime -# return datetime - -# @lazy_import -# def subprocess(): -# import subprocess -# return subprocess - -# @lazy_import -# def tarfile(): -# import tarfile -# return tarfile - -# @lazy_import -# def paramiko(): -# import paramiko -# return paramiko - -# @lazy_import -# def getpass(): -# import getpass -# return getpass - -# @lazy_import -# def tqdm(): -# from tqdm import tqdm -# return tqdm - -# @lazy_import -# def scp(): -# import scp -# return scp - -# @lazy_import -# def neo(): -# import neo -# return neo - -# @lazy_import -# def exdir(): -# import exdir -# import exdir.plugins.quantities -# return exdir - -# @lazy_import -# def pq(): -# import quantities as pq -# return pq - -# @lazy_import -# def logging(): -# import logging -# return logging - -# @lazy_import -# def np(): -# import numpy as np -# return np - -# @lazy_import -# def copy(): -# import copy -# return copy - -# @lazy_import -# def scipy(): -# import scipy -# import scipy.io -# return scipy - -# @lazy_import -# def glob(): -# import glob -# return glob - -# @lazy_import -# def el(): -# import elephant as el -# return el - -# @lazy_import -# def sys(): -# import sys -# return sys - -# @lazy_import -# def pprint(): -# import pprint -# return pprint - -# @lazy_import -# def collections(): -# import collections -# return collections diff --git a/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py b/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py index c402c6a..041bbc0 100644 --- a/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py +++ b/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py @@ -2,13 +2,9 @@ import numpy as np import ipywidgets as widgets from ipywidgets import Layout, interactive_output -from pynwb.misc import Units -from pynwb.behavior import SpatialSeries import matplotlib.pyplot as plt -from nwbwidgets.view import default_neurodata_vis_spec - color_wheel = plt.rcParams["axes.prop_cycle"].by_key()["color"] @@ -16,7 +12,7 @@ class UnitWaveformsWidget(widgets.VBox): def __init__( self, - units: Units, + units: "pynwb.misc.Units", ): super().__init__() @@ -55,7 +51,7 @@ def on_unit_change(self, change): self.unit_group_text.value = f"Group: {unit_group}" -def show_unit_waveforms(units: Units, unit_index=None, ax=None): +def show_unit_waveforms(units: "pynwb.mis.Units", unit_index=None, ax=None): """ TODO: add docstring @@ -105,8 +101,8 @@ def show_unit_waveforms(units: Units, unit_index=None, ax=None): class UnitRateMapWidget(widgets.VBox): def __init__( self, - units: Units, - spatial_series: SpatialSeries = None, + units: "pynwb.mis.Units", + spatial_series: "SpatialSeries" = None, ): super().__init__() @@ -186,6 +182,8 @@ def on_unit_change(self, change): self.unit_group_text.value = f"Group: {unit_group}" def get_spatial_series(self): + from pynwb.behavior import SpatialSeries + spatial_series = dict() nwbfile = self.units.get_ancestor("NWBFile") for item in nwbfile.all_children(): @@ -265,6 +263,9 @@ def show_unit_rate_maps(self, unit_index=None, spatial_series_selector=None, num def get_custom_spec(): + from pynwb.misc import Units + from nwbwidgets.view import default_neurodata_vis_spec + custom_neurodata_vis_spec = default_neurodata_vis_spec.copy() # remove irrelevant widgets diff --git a/src/expipe_plugin_cinpla/scripts/curation.py b/src/expipe_plugin_cinpla/scripts/curation.py index a6bf726..f945806 100644 --- a/src/expipe_plugin_cinpla/scripts/curation.py +++ b/src/expipe_plugin_cinpla/scripts/curation.py @@ -2,17 +2,9 @@ import json from pathlib import Path import numpy as np -from pynwb import NWBHDF5IO -from pynwb.testing.mock.file import mock_NWBFile import warnings -import spikeinterface.full as si -import spikeinterface.extractors as se -import spikeinterface.postprocessing as spost -import spikeinterface.qualitymetrics as sqm -import spikeinterface.curation as sc - -from spikeinterface.extractors.nwbextractors import _retrieve_unit_table_pynwb +import spikeinterface as si from .utils import _get_data_path, add_units_from_waveform_extractor, compute_and_set_unit_groups @@ -75,6 +67,8 @@ def check_sortings_equal(self, sorting1, sorting2): return True def load_raw_sorting(self, sorter): + import spikeinterface.extractors as se + raw_units_path = f"processing/ecephys/RawUnits-{sorter}" try: sorting_raw = se.read_nwb_sorting( @@ -88,6 +82,9 @@ def load_raw_sorting(self, sorter): return sorting_raw def load_raw_units(self, sorter): + from pynwb import NWBHDF5IO + from spikeinterface.extractors.nwbextractors import _retrieve_unit_table_pynwb + raw_units_path = f"processing/ecephys/RawUnits-{sorter}" self.io = NWBHDF5IO(self.nwb_path_main, "r") nwbfile = self.io.read() @@ -98,6 +95,8 @@ def load_raw_units(self, sorter): return units def load_main_units(self): + from pynwb import NWBHDF5IO + self.io = NWBHDF5IO(self.nwb_path_main, "r") nwbfile = self.io.read() return nwbfile.units @@ -106,6 +105,8 @@ def construct_curated_units(self): if len(self.curated_we.unit_ids) == 0: print("No units left after curation.") return + from pynwb import NWBHDF5IO + self.io = NWBHDF5IO(self.nwb_path_main, "r") nwbfile = self.io.read() add_units_from_waveform_extractor( @@ -136,6 +137,10 @@ def apply_curation(self, sorter, curated_sorting): print(f"No curation was performed for {sorter}. Using raw sorting") self.curated_we = None else: + import spikeinterface.postprocessing as spost + import spikeinterface.qualitymetrics as sqm + import spikeinterface.curation as sc + recording = self.load_processed_recording(sorter) # remove excess spikes @@ -175,6 +180,8 @@ def apply_curation(self, sorter, curated_sorting): print("Done applying curation") def load_from_phy(self, sorter): + import spikeinterface.extractors as se + phy_path = self.si_path / sorter / "phy" sorting_phy = se.read_phy(phy_path, exclude_cluster_groups=["noise"]) @@ -198,6 +205,8 @@ def get_sortingview_link(self, sorter): return sortingview_links["raw"] def apply_sortingview_curation(self, sorter, curated_link): + import spikeinterface.curation as sc + sorting_raw = self.load_raw_sorting(sorter) assert sorting_raw is not None, f"Could not load raw sorting for {sorter}." sorting_raw = sorting_raw.save(format="memory") @@ -247,6 +256,7 @@ def save_to_nwb(self): if self.curated_we is None: print("No curation was performed.") return + from pynwb import NWBHDF5IO # trick to get rid of Units first with NWBHDF5IO(self.nwb_path_main, mode="r") as read_io: diff --git a/src/expipe_plugin_cinpla/tools/data_processing.py b/src/expipe_plugin_cinpla/tools/data_processing.py index 3244375..2e975b0 100644 --- a/src/expipe_plugin_cinpla/tools/data_processing.py +++ b/src/expipe_plugin_cinpla/tools/data_processing.py @@ -1,8 +1,15 @@ # This is work in progress, import numpy as np from expipe_plugin_cinpla.data_loader import ( - load_epochs, get_channel_groups, load_spiketrains, load_unit_annotations, - load_leds, get_duration, load_lfp, get_sample_rate, get_data_path + load_epochs, + get_channel_groups, + load_spiketrains, + load_unit_annotations, + load_leds, + get_duration, + load_lfp, + get_sample_rate, + get_data_path, ) import pathlib import expipe @@ -12,8 +19,8 @@ def view_active_channels(action, sorter): path = action.data_path() - sorter_path = path / 'spikeinterface' / sorter / 'phy' - return np.load(sorter_path / 'channel_map_si.npy') + sorter_path = path / "spikeinterface" / sorter / "phy" + return np.load(sorter_path / "channel_map_si.npy") def _cut_to_same_len(*args): @@ -40,10 +47,10 @@ def velocity_filter(x, y, t, threshold): 1d vector of times at x, y positions threshold : float """ - assert len(x) == len(y) == len(t), 'x, y, t must have same length' + assert len(x) == len(y) == len(t), "x, y, t must have same length" vel = np.gradient([x, y], axis=1) / np.gradient(t) speed = np.linalg.norm(vel, axis=0) - speed_mask = (speed < threshold) + speed_mask = speed < threshold speed_mask = np.append(speed_mask, 0) x = x[np.where(speed_mask)] y = y[np.where(speed_mask)] @@ -51,7 +58,7 @@ def velocity_filter(x, y, t, threshold): return x, y, t -def interp_filt_position(x, y, tm, fs=100 , f_cut=10 ): +def interp_filt_position(x, y, tm, fs=100, f_cut=10): """ rapid head movements will contribute to velocity artifacts, these can be removed by low-pass filtering @@ -72,8 +79,9 @@ def interp_filt_position(x, y, tm, fs=100 , f_cut=10 ): out : angles, resized t """ import scipy.signal as ss - assert len(x) == len(y) == len(tm), 'x, y, t must have same length' - t = np.arange(tm.min(), tm.max() + 1. / fs, 1. / fs) + + assert len(x) == len(y) == len(tm), "x, y, t must have same length" + t = np.arange(tm.min(), tm.max() + 1.0 / fs, 1.0 / fs) x = np.interp(t, tm, x) y = np.interp(t, tm, y) # rapid head movements will contribute to velocity artifacts, @@ -113,21 +121,23 @@ def rm_nans(*args): def filter_xy_zero(x, y, t): - idxs, = np.where((x == 0) & (y == 0)) + (idxs,) = np.where((x == 0) & (y == 0)) return [np.delete(a, idxs) for a in [x, y, t]] + def filter_xy_box_size(x, y, t, box_size): - idxs, = np.where((x > box_size[0]) | (x < 0) | (y > box_size[1]) | (y < 0)) + (idxs,) = np.where((x > box_size[0]) | (x < 0) | (y > box_size[1]) | (y < 0)) return [np.delete(a, idxs) for a in [x, y, t]] def filter_t_zero_duration(x, y, t, duration): - idxs, = np.where((t < 0) | (t > duration)) + (idxs,) = np.where((t < 0) | (t > duration)) return [np.delete(a, idxs) for a in [x, y, t]] def load_head_direction(data_path, sampling_rate, low_pass_frequency, box_size): from head_direction.head import head_direction + x1, y1, t1, x2, y2, t2, stop_time = load_leds(data_path) x1, y1, t1 = rm_nans(x1, y1, t1) @@ -143,10 +153,8 @@ def load_head_direction(data_path, sampling_rate, low_pass_frequency, box_size): # x1, y1, t1 = filter_xy_box_size(x1, y1, t1, box_size) # x2, y2, t2 = filter_xy_box_size(x2, y2, t2, box_size) - x1, y1, t1 = interp_filt_position(x1, y1, t1, - fs=sampling_rate, f_cut=low_pass_frequency) - x2, y2, t2 = interp_filt_position(x2, y2, t2, - fs=sampling_rate, f_cut=low_pass_frequency) + x1, y1, t1 = interp_filt_position(x1, y1, t1, fs=sampling_rate, f_cut=low_pass_frequency) + x2, y2, t2 = interp_filt_position(x2, y2, t2, fs=sampling_rate, f_cut=low_pass_frequency) x1, y1, t1, x2, y2, t2 = _cut_to_same_len(x1, y1, t1, x2, y2, t2) @@ -159,14 +167,16 @@ def load_head_direction(data_path, sampling_rate, low_pass_frequency, box_size): def check_valid_tracking(x, y, box_size): if np.isnan(x).any() and np.isnan(y).any(): - raise ValueError('nans found in position, ' + - 'x nans = %i, y nans = %i' % (sum(np.isnan(x)), sum(np.isnan(y)))) + raise ValueError( + "nans found in position, " + "x nans = %i, y nans = %i" % (sum(np.isnan(x)), sum(np.isnan(y))) + ) - if (x.min() < 0 or x.max() > box_size[0] or y.min() < 0 or y.max() > box_size[1]): + if x.min() < 0 or x.max() > box_size[0] or y.min() < 0 or y.max() > box_size[1]: warnings.warn( - "Invalid values found " + - "outside box: min [x, y] = [{}, {}], ".format(x.min(), y.min()) + - "max [x, y] = [{}, {}]".format(x.max(), y.max())) + "Invalid values found " + + "outside box: min [x, y] = [{}, {}], ".format(x.min(), y.min()) + + "max [x, y] = [{}, {}]".format(x.max(), y.max()) + ) def load_tracking(data_path, sampling_rate, low_pass_frequency, box_size, velocity_threshold=5): @@ -191,8 +201,7 @@ def load_tracking(data_path, sampling_rate, low_pass_frequency, box_size, veloci # remove velocity artifacts x, y, t = velocity_filter(x, y, t, velocity_threshold) - x, y, t = interp_filt_position( - x, y, t, fs=sampling_rate, f_cut=low_pass_frequency) + x, y, t = interp_filt_position(x, y, t, fs=sampling_rate, f_cut=low_pass_frequency) check_valid_tracking(x, y, box_size) @@ -201,19 +210,18 @@ def load_tracking(data_path, sampling_rate, low_pass_frequency, box_size, veloci x, y, t, speed = np.array(x), np.array(y), np.array(t), np.array(speed) return x, y, t, speed + def sort_by_cluster_id(spike_trains): if len(spike_trains) == 0: return spike_trains - if 'name' not in spike_trains[0].annotations: - print('Unable to get cluster_id, save with phy to create') - sorted_sptrs = sorted( - spike_trains, - key=lambda x: str(x.annotations['name'])) + if "name" not in spike_trains[0].annotations: + print("Unable to get cluster_id, save with phy to create") + sorted_sptrs = sorted(spike_trains, key=lambda x: str(x.annotations["name"])) return sorted_sptrs def get_unit_id(unit): - return str(int(unit.annotations['name'])) + return str(int(unit.annotations["name"])) class Template: @@ -271,9 +279,10 @@ def tracking(self, action_id): if action_id not in self._tracking: x, y, t, speed = load_tracking( self.data_path(action_id), - sampling_rate=self.params['position_sampling_rate'], - low_pass_frequency=self.params['position_low_pass_frequency'], - box_size=self.params['box_size']) + sampling_rate=self.params["position_sampling_rate"], + low_pass_frequency=self.params["position_low_pass_frequency"], + box_size=self.params["box_size"], + ) if self.stim_mask: t1, t2 = self.get_lim(action_id) mask = (t >= t1) & (t <= t2) @@ -281,17 +290,15 @@ def tracking(self, action_id): y = y[mask] t = t[mask] speed = speed[mask] - self._tracking[action_id] = { - 'x': x, 'y': y, 't': t, 'v': speed - } + self._tracking[action_id] = {"x": x, "y": y, "t": t, "v": speed} return self._tracking[action_id] @property def spatial_bins(self): if self._spatial_bins is None: box_size_, bin_size_ = sp.maps._adjust_bin_size( - box_size=self.params['box_size'], - bin_size=self.params['bin_size']) + box_size=self.params["box_size"], bin_size=self.params["bin_size"] + ) xbins, ybins = sp.maps._make_bins(box_size_, bin_size_) self._spatial_bins = (xbins, ybins) self.box_size_, self.bin_size_ = box_size_, bin_size_ @@ -302,10 +309,13 @@ def occupancy(self, action_id): xbins, ybins = self.spatial_bins occupancy_map = sp.maps._occupancy_map( - self.tracking(action_id)['x'], - self.tracking(action_id)['y'], - self.tracking(action_id)['t'], xbins, ybins) - threshold = self.params.get('occupancy_threshold') + self.tracking(action_id)["x"], + self.tracking(action_id)["y"], + self.tracking(action_id)["t"], + xbins, + ybins, + ) + threshold = self.params.get("occupancy_threshold") if threshold is not None: occupancy_map[occupancy_map <= threshold] = 0 self._occupancy[action_id] = occupancy_map @@ -315,14 +325,14 @@ def prob_dist(self, action_id): if action_id not in self._prob_dist: xbins, ybins = xbins, ybins = self.spatial_bins prob_dist = sp.stats.prob_dist( - self.tracking(action_id)['x'], - self.tracking(action_id)['y'], bins=(xbins, ybins)) + self.tracking(action_id)["x"], self.tracking(action_id)["y"], bins=(xbins, ybins) + ) self._prob_dist[action_id] = prob_dist return self._prob_dist[action_id] - + def tracking_split(self, action_id): if action_id not in self._tracking_split: - x, y, t, v = map(self.tracking(action_id).get, ['x', 'y', 't', 'v']) + x, y, t, v = map(self.tracking(action_id).get, ["x", "y", "t", "v"]) t_split = t[-1] / 2 mask_1 = t < t_split @@ -330,13 +340,18 @@ def tracking_split(self, action_id): x1, y1, t1, v1 = x[mask_1], y[mask_1], t[mask_1], v[mask_1] x2, y2, t2, v2 = x[mask_2], y[mask_2], t[mask_2], v[mask_2] - self._tracking_split[action_id] = { - 'x1': x1, 'y1': y1, 't1': t1, 'v1': v1, - 'x2': x2, 'y2': y2, 't2': t2, 'v2': v2 + "x1": x1, + "y1": y1, + "t1": t1, + "v1": v1, + "x2": x2, + "y2": y2, + "t2": t2, + "v2": v2, } return self._tracking_split[action_id] - + def spike_train_split(self, action_id, channel_group, unit_name): spikes = self.spike_train(action_id, channel_group, unit_name) t_split = self.duration(action_id) / 2 @@ -355,10 +370,9 @@ def rate_map_split(self, action_id, channel_group, unit_name, smoothing): if smoothing not in self._rate_maps_split[action_id][channel_group][unit_name]: make_rate_map = True - if make_rate_map: xbins, ybins = self.spatial_bins - x, y, t = map(self.tracking(action_id).get, ['x', 'y', 't']) + x, y, t = map(self.tracking(action_id).get, ["x", "y", "t"]) spikes = self.spike_train(action_id, channel_group, unit_name) t_split = t[-1] / 2 mask_1 = t < t_split @@ -367,24 +381,16 @@ def rate_map_split(self, action_id, channel_group, unit_name, smoothing): x_2, y_2, t_2 = x[mask_2], y[mask_2], t[mask_2] spikes_1 = spikes[spikes < t_split] spikes_2 = spikes[spikes >= t_split] - occupancy_map_1 = sp.maps._occupancy_map( - x_1, y_1, t_1, xbins, ybins) - occupancy_map_2 = sp.maps._occupancy_map( - x_2, y_2, t_2, xbins, ybins) - - spike_map_1 = sp.maps._spike_map( - x_1, y_1, t_1, spikes_1, xbins, ybins) - spike_map_2 = sp.maps._spike_map( - x_2, y_2, t_2, spikes_2, xbins, ybins) - - smooth_spike_map_1 = sp.maps.smooth_map( - spike_map_1, bin_size=self.bin_size_, smoothing=smoothing) - smooth_spike_map_2 = sp.maps.smooth_map( - spike_map_2, bin_size=self.bin_size_, smoothing=smoothing) - smooth_occupancy_map_1 = sp.maps.smooth_map( - occupancy_map_1, bin_size=self.bin_size_, smoothing=smoothing) - smooth_occupancy_map_2 = sp.maps.smooth_map( - occupancy_map_2, bin_size=self.bin_size_, smoothing=smoothing) + occupancy_map_1 = sp.maps._occupancy_map(x_1, y_1, t_1, xbins, ybins) + occupancy_map_2 = sp.maps._occupancy_map(x_2, y_2, t_2, xbins, ybins) + + spike_map_1 = sp.maps._spike_map(x_1, y_1, t_1, spikes_1, xbins, ybins) + spike_map_2 = sp.maps._spike_map(x_2, y_2, t_2, spikes_2, xbins, ybins) + + smooth_spike_map_1 = sp.maps.smooth_map(spike_map_1, bin_size=self.bin_size_, smoothing=smoothing) + smooth_spike_map_2 = sp.maps.smooth_map(spike_map_2, bin_size=self.bin_size_, smoothing=smoothing) + smooth_occupancy_map_1 = sp.maps.smooth_map(occupancy_map_1, bin_size=self.bin_size_, smoothing=smoothing) + smooth_occupancy_map_2 = sp.maps.smooth_map(occupancy_map_2, bin_size=self.bin_size_, smoothing=smoothing) rate_map_1 = smooth_spike_map_1 / smooth_occupancy_map_1 rate_map_2 = smooth_spike_map_2 / smooth_occupancy_map_2 @@ -407,16 +413,18 @@ def rate_map(self, action_id, channel_group, unit_name, smoothing): xbins, ybins = self.spatial_bins spike_map = sp.maps._spike_map( - self.tracking(action_id)['x'], - self.tracking(action_id)['y'], - self.tracking(action_id)['t'], + self.tracking(action_id)["x"], + self.tracking(action_id)["y"], + self.tracking(action_id)["t"], self.spike_train(action_id, channel_group, unit_name), - xbins, ybins) + xbins, + ybins, + ) - smooth_spike_map = sp.maps.smooth_map( - spike_map, bin_size=self.bin_size_, smoothing=smoothing) + smooth_spike_map = sp.maps.smooth_map(spike_map, bin_size=self.bin_size_, smoothing=smoothing) smooth_occupancy_map = sp.maps.smooth_map( - self.occupancy(action_id), bin_size=self.bin_size_, smoothing=smoothing) + self.occupancy(action_id), bin_size=self.bin_size_, smoothing=smoothing + ) rate_map = smooth_spike_map / smooth_occupancy_map self._rate_maps[action_id][channel_group][unit_name][smoothing] = rate_map @@ -426,29 +434,26 @@ def head_direction(self, action_id): if action_id not in self._head_direction: a, t = load_head_direction( self.data_path(action_id), - sampling_rate=self.params['position_sampling_rate'], - low_pass_frequency=self.params['position_low_pass_frequency'], - box_size=self.params['box_size']) + sampling_rate=self.params["position_sampling_rate"], + low_pass_frequency=self.params["position_low_pass_frequency"], + box_size=self.params["box_size"], + ) if self.stim_mask: t1, t2 = self.get_lim(action_id) mask = (t >= t1) & (t <= t2) a = a[mask] t = t[mask] - self._head_direction[action_id] = { - 'a': a, 't': t - } + self._head_direction[action_id] = {"a": a, "t": t} return self._head_direction[action_id] def lfp(self, action_id, channel_group, clean_memory=False): lim = self.get_lim(action_id) if self.stim_mask else None if clean_memory: - return load_lfp( - self.data_path(action_id), channel_group, lim) + return load_lfp(self.data_path(action_id), channel_group, lim) if action_id not in self._lfp: self._lfp[action_id] = {} if channel_group not in self._lfp[action_id]: - self._lfp[action_id][channel_group] = load_lfp( - self.data_path(action_id), channel_group, lim) + self._lfp[action_id][channel_group] = load_lfp(self.data_path(action_id), channel_group, lim) return self._lfp[action_id][channel_group] def template(self, action_id, channel_group, unit_id): @@ -463,13 +468,13 @@ def spike_trains(self, action_id): if action_id not in self._spike_trains: self._spike_trains[action_id] = {} lim = self.get_lim(action_id) if self.stim_mask else None - + sts = load_spiketrains(self.data_path(action_id), lim=lim) for st in sts: - channel_group = st.annotations['group'] + channel_group = st.annotations["group"] unit_id = get_unit_id(st) self._spike_trains[action_id][channel_group] = {unit_id: st} - + return self._spike_trains[action_id] def unit_names(self, action_id, channel_group): @@ -479,10 +484,9 @@ def unit_names(self, action_id, channel_group): def stim_times(self, action_id): if action_id not in self._stim_times: try: - trials = load_epochs( - self.data_path(action_id), label_column='channel') + trials = load_epochs(self.data_path(action_id), label_column="channel") if len(set(trials.labels)) > 1: - stim_times = trials.times[trials.labels==self.stim_channels[action_id]] + stim_times = trials.times[trials.labels == self.stim_channels[action_id]] else: stim_times = trials.times stim_times = np.sort(np.abs(np.array(stim_times))) @@ -490,9 +494,9 @@ def stim_times(self, action_id): # stim_times = stim_times[stim_times >= 1e-20] self._stim_times[action_id] = stim_times except AttributeError as e: - if str(e)=="'NoneType' object has no attribute 'to_dataframe'": + if str(e) == "'NoneType' object has no attribute 'to_dataframe'": self._stim_times[action_id] = None else: raise e - + return self._stim_times[action_id] diff --git a/src/expipe_plugin_cinpla/tools/registration.py b/src/expipe_plugin_cinpla/tools/registration.py index 1d3b311..dbff094 100644 --- a/src/expipe_plugin_cinpla/tools/registration.py +++ b/src/expipe_plugin_cinpla/tools/registration.py @@ -2,14 +2,15 @@ import shutil import pathlib + def store_notebook(action, notebook_path): notebook_path = pathlib.Path(notebook_path) action.data["notebook"] = notebook_path.name - notebook_output_path = action.data_path('notebook') + notebook_output_path = action.data_path("notebook") shutil.copy(notebook_path, notebook_output_path) # As HTML - os.system('jupyter nbconvert --to html {}'.format(notebook_path)) + os.system("jupyter nbconvert --to html {}".format(notebook_path)) html_path = notebook_path.with_suffix(".html") action.data["html"] = html_path.name - html_output_path = action.data_path('html') + html_output_path = action.data_path("html") shutil.copy(html_path, html_output_path) diff --git a/src/expipe_plugin_cinpla/tools/track_units_tools.py b/src/expipe_plugin_cinpla/tools/track_units_tools.py index eb25556..20a3748 100644 --- a/src/expipe_plugin_cinpla/tools/track_units_tools.py +++ b/src/expipe_plugin_cinpla/tools/track_units_tools.py @@ -27,7 +27,6 @@ def dissimilarity(template_0, template_1): t_i_lin = template_0.ravel() t_j_lin = template_1.ravel() - return np.mean(np.abs(t_i_lin / max_val - t_j_lin / max_val)) # return np.mean(np.abs(t_i_lin - t_j_lin)) @@ -55,24 +54,23 @@ def dissimilarity_weighted(templates_0, templates_1): templates_0 /= max_val templates_1 /= max_val # root sum square, averaged over channels - weighted = np.sqrt(np.sum([(templates_0[:,i] - templates_1[:,i])**2 for i in range(templates_0.shape[1])], axis=0)).mean() + weighted = np.sqrt( + np.sum([(templates_0[:, i] - templates_1[:, i]) ** 2 for i in range(templates_0.shape[1])], axis=0) + ).mean() return weighted def make_dissimilary_matrix(comp_object, channel_group): templates_0, templates_1 = comp_object.templates[channel_group] diss_matrix = np.zeros((len(templates_0), len(templates_1))) - + unit_ids_0, unit_ids_1 = comp_object.unit_ids[channel_group] for i, w0 in enumerate(templates_0): for j, w1 in enumerate(templates_1): diss_matrix[i, j] = dissimilarity_weighted(w0, w1) - diss_matrix = pd.DataFrame( - diss_matrix, - index=unit_ids_0, - columns=unit_ids_1) + diss_matrix = pd.DataFrame(diss_matrix, index=unit_ids_0, columns=unit_ids_1) return diss_matrix @@ -146,7 +144,7 @@ def make_best_match(dissimilarity_scores, max_dissimilarity): scores = dissimilarity_scores.values.copy() - best_match_12 = pd.Series(index=unit1_ids, dtype='int64') + best_match_12 = pd.Series(index=unit1_ids, dtype="int64") for i1, u1 in enumerate(unit1_ids): ind_min = np.argmin(scores[i1, :]) if scores[i1, ind_min] <= max_dissimilarity: @@ -154,7 +152,7 @@ def make_best_match(dissimilarity_scores, max_dissimilarity): else: best_match_12[u1] = -1 - best_match_21 = pd.Series(index=unit2_ids, dtype='int64') + best_match_21 = pd.Series(index=unit2_ids, dtype="int64") for i2, u2 in enumerate(unit2_ids): ind_min = np.argmin(scores[:, i2]) if scores[ind_min, i2] <= max_dissimilarity: @@ -193,9 +191,9 @@ def make_hungarian_match(dissimilarity_scores, max_dissimilarity): [inds1, inds2] = linear_sum_assignment(scores) - hungarian_match_12 = pd.Series(index=unit1_ids, dtype='int64') + hungarian_match_12 = pd.Series(index=unit1_ids, dtype="int64") hungarian_match_12[:] = -1 - hungarian_match_21 = pd.Series(index=unit2_ids, dtype='int64') + hungarian_match_21 = pd.Series(index=unit2_ids, dtype="int64") hungarian_match_21[:] = -1 for i1, i2 in zip(inds1, inds2): diff --git a/src/expipe_plugin_cinpla/tools/trackunitcomparison.py b/src/expipe_plugin_cinpla/tools/trackunitcomparison.py index f42cb4b..ae11e53 100644 --- a/src/expipe_plugin_cinpla/tools/trackunitcomparison.py +++ b/src/expipe_plugin_cinpla/tools/trackunitcomparison.py @@ -1,5 +1,4 @@ -from .track_units_tools import make_dissimilary_matrix, make_possible_match, make_best_match, \ - make_hungarian_match +from .track_units_tools import make_dissimilary_matrix, make_possible_match, make_best_match, make_hungarian_match from expipe_plugin_cinpla.data_loader import get_data_path, load_spiketrains, get_channel_groups import matplotlib.pylab as plt import numpy as np @@ -11,9 +10,16 @@ class TrackingSession: Base class shared by SortingComparison and GroundTruthComparison """ - def __init__(self, action_id_0, action_id_1, actions, channel_group=None, - max_dissimilarity=10, dissimilarity_function=None, verbose=False): - + def __init__( + self, + action_id_0, + action_id_1, + actions, + channel_group=None, + max_dissimilarity=10, + dissimilarity_function=None, + verbose=False, + ): data_path_0 = get_data_path(actions[action_id_0]) data_path_1 = get_data_path(actions[action_id_1]) @@ -43,30 +49,28 @@ def __init__(self, action_id_0, action_id_1, actions, channel_group=None, units_0 = load_spiketrains(data_path_0) units_1 = load_spiketrains(data_path_1) for channel_group in self.matches.keys(): - units_0 = [unit for unit in units_0 if unit.annotations['group']==channel_group] - units_1 = [unit for unit in units_1 if unit.annotations['group']==channel_group] + units_0 = [unit for unit in units_0 if unit.annotations["group"] == channel_group] + units_1 = [unit for unit in units_1 if unit.annotations["group"] == channel_group] self.unit_ids[channel_group] = [ - np.array([int(st.annotations['name']) for st in units_0]), - np.array([int(st.annotations['name']) for st in units_1]) + np.array([int(st.annotations["name"]) for st in units_0]), + np.array([int(st.annotations["name"]) for st in units_1]), ] self.templates[channel_group] = [ [st.annotations["waveform_mean"] for st in units_0], - [st.annotations["waveform_mean"] for st in units_1] + [st.annotations["waveform_mean"] for st in units_1], ] if len(units_0) > 0 and len(units_1) > 0: - self._do_dissimilarity(channel_group) self._do_matching(channel_group) def save_dissimilarity_matrix(self, path=None): path = path or Path.cwd() for channel_group in self.matches: - if 'dissimilarity_scores' not in self.matches[channel_group]: + if "dissimilarity_scores" not in self.matches[channel_group]: continue - filename = f'{self.action_id_0}_{self.action_id_1}_{channel_group}' - self.matches[channel_group]['dissimilarity_scores'].to_csv( - path / (filename + '.csv')) + filename = f"{self.action_id_0}_{self.action_id_1}_{channel_group}" + self.matches[channel_group]["dissimilarity_scores"].to_csv(path / (filename + ".csv")) @property def session_0_name(self): @@ -78,26 +82,30 @@ def session_1_name(self): def _do_dissimilarity(self, channel_group): if self._verbose: - print('Agreement scores...') + print("Agreement scores...") # agreement matrix score for each pair - self.matches[channel_group]['dissimilarity_scores'] = make_dissimilary_matrix( - self, channel_group) + self.matches[channel_group]["dissimilarity_scores"] = make_dissimilary_matrix(self, channel_group) def _do_matching(self, channel_group): # must be implemented in subclass if self._verbose: print("Matching...") - self.matches[channel_group]['possible_match_01'], self.matches[channel_group]['possible_match_10'] = \ - make_possible_match(self.matches[channel_group]['dissimilarity_scores'], self.max_dissimilarity) - self.matches[channel_group]['best_match_01'], self.matches[channel_group]['best_match_10'] = \ - make_best_match(self.matches[channel_group]['dissimilarity_scores'], self.max_dissimilarity) - self.matches[channel_group]['hungarian_match_01'], self.matches[channel_group]['hungarian_match_10'] = \ - make_hungarian_match(self.matches[channel_group]['dissimilarity_scores'], self.max_dissimilarity) - - def plot_matched_units(self, match_mode='hungarian', channel_group=None, ylim=[-200, 50], figsize=(15, 15)): - ''' + ( + self.matches[channel_group]["possible_match_01"], + self.matches[channel_group]["possible_match_10"], + ) = make_possible_match(self.matches[channel_group]["dissimilarity_scores"], self.max_dissimilarity) + self.matches[channel_group]["best_match_01"], self.matches[channel_group]["best_match_10"] = make_best_match( + self.matches[channel_group]["dissimilarity_scores"], self.max_dissimilarity + ) + ( + self.matches[channel_group]["hungarian_match_01"], + self.matches[channel_group]["hungarian_match_10"], + ) = make_hungarian_match(self.matches[channel_group]["dissimilarity_scores"], self.max_dissimilarity) + + def plot_matched_units(self, match_mode="hungarian", channel_group=None, ylim=[-200, 50], figsize=(15, 15)): + """ Parameters ---------- @@ -106,42 +114,39 @@ def plot_matched_units(self, match_mode='hungarian', channel_group=None, ylim=[- Returns ------- - ''' + """ if channel_group is None: ch_groups = self.matches.keys() else: ch_groups = [channel_group] for ch_group in ch_groups: - if 'hungarian_match_01' not in self.matches[ch_group].keys(): - print('Not units for group', ch_group) + if "hungarian_match_01" not in self.matches[ch_group].keys(): + print("Not units for group", ch_group) continue - if match_mode == 'hungarian': - match12 = self.matches[ch_group]['hungarian_match_01'] - elif match_mode == 'best': - match12 = self.matches[ch_group]['best_match_01'] + if match_mode == "hungarian": + match12 = self.matches[ch_group]["hungarian_match_01"] + elif match_mode == "best": + match12 = self.matches[ch_group]["best_match_01"] num_matches = len(np.where(match12 != -1)[0]) if num_matches > 0: - fig, ax_list = plt.subplots(nrows=2, ncols=num_matches, figsize=figsize) - fig.suptitle('Channel group ' + str(ch_group)) + fig.suptitle("Channel group " + str(ch_group)) if num_matches == 1: i = np.where(match12 != -1)[0][0] j = match12.iloc[i] - i1 = np.where(self.matches[ch_group]['unit_ids_0'] == match12.index[i]) - i2 = np.where(self.matches[ch_group]['unit_ids_1'] == j) - template1 = np.squeeze( - self.matches[ch_group]['templates_0'][i1]) - ax_list[0].plot(template1, color='C0') - ax_list[0].set_title('Unit ' + str(match12.index[i])) - template2 = np.squeeze( - self.matches[ch_group]['templates_1'][i2]) - ax_list[1].plot(template2, color='C0') - ax_list[1].set_title('Unit ' + str(j)) + i1 = np.where(self.matches[ch_group]["unit_ids_0"] == match12.index[i]) + i2 = np.where(self.matches[ch_group]["unit_ids_1"] == j) + template1 = np.squeeze(self.matches[ch_group]["templates_0"][i1]) + ax_list[0].plot(template1, color="C0") + ax_list[0].set_title("Unit " + str(match12.index[i])) + template2 = np.squeeze(self.matches[ch_group]["templates_1"][i2]) + ax_list[1].plot(template2, color="C0") + ax_list[1].set_title("Unit " + str(j)) ax_list[0].set_ylabel(self.name_list[0]) ax_list[1].set_ylabel(self.name_list[1]) ax_list[0].set_ylim(ylim) @@ -150,23 +155,21 @@ def plot_matched_units(self, match_mode='hungarian', channel_group=None, ylim=[- id_ax = 0 for i, j in enumerate(match12): if j != -1: - i1 = np.where(self.matches[ch_group]['unit_ids_0'] == match12.index[i]) - i2 = np.where(self.matches[ch_group]['unit_ids_1'] == j) + i1 = np.where(self.matches[ch_group]["unit_ids_0"] == match12.index[i]) + i2 = np.where(self.matches[ch_group]["unit_ids_1"] == j) if id_ax == 0: ax_list[0, id_ax].set_ylabel(self.name_list[0]) ax_list[1, id_ax].set_ylabel(self.name_list[1]) - template1 = np.squeeze( - self.matches[ch_group]['templates_0'][i1]) - ax_list[0, id_ax].plot(template1, color='C'+str(id_ax)) - ax_list[0, id_ax].set_title('Unit ' + str(match12.index[i])) - template2 = np.squeeze( - self.matches[ch_group]['templates_1'][i1]) - ax_list[1, id_ax].plot(template2, color='C'+str(id_ax)) - ax_list[1, id_ax].set_title('Unit ' + str(j)) + template1 = np.squeeze(self.matches[ch_group]["templates_0"][i1]) + ax_list[0, id_ax].plot(template1, color="C" + str(id_ax)) + ax_list[0, id_ax].set_title("Unit " + str(match12.index[i])) + template2 = np.squeeze(self.matches[ch_group]["templates_1"][i1]) + ax_list[1, id_ax].plot(template2, color="C" + str(id_ax)) + ax_list[1, id_ax].set_title("Unit " + str(j)) ax_list[0, id_ax].set_ylim(ylim) ax_list[1, id_ax].set_ylim(ylim) id_ax += 1 else: - print('No matched units for group', ch_group) + print("No matched units for group", ch_group) continue diff --git a/src/expipe_plugin_cinpla/tools/trackunitmulticomparison.py b/src/expipe_plugin_cinpla/tools/trackunitmulticomparison.py index 172887b..9ceb1c0 100644 --- a/src/expipe_plugin_cinpla/tools/trackunitmulticomparison.py +++ b/src/expipe_plugin_cinpla/tools/trackunitmulticomparison.py @@ -12,10 +12,19 @@ from pathlib import Path import datetime + class TrackMultipleSessions: - def __init__(self, actions, action_list=None, channel_group=None, - max_dissimilarity=None, max_timedelta=None, verbose=False, - progress_bar=None, data_path=None): + def __init__( + self, + actions, + action_list=None, + channel_group=None, + max_dissimilarity=None, + max_timedelta=None, + verbose=False, + progress_bar=None, + data_path=None, + ): self.data_path = Path.cwd() if data_path is None else Path(data_path) self.data_path.mkdir(parents=True, exist_ok=True) self.action_list = [a for a in actions] if action_list is None else action_list @@ -30,14 +39,14 @@ def __init__(self, actions, action_list=None, channel_group=None, dp = get_data_path(self._actions[self.action_list[0]]) self._channel_groups = get_channel_groups(dp) if len(self._channel_groups) == 0: - print('Unable to locate channel groups, please provide a working action_list') + print("Unable to locate channel groups, please provide a working action_list") else: self._channel_groups = [self._channel_group] def do_matching(self): # do pairwise matching if self._verbose: - print('Multicomaprison step1: pairwise comparison') + print("Multicomaprison step1: pairwise comparison") self.comparisons = [] N = len(self.action_list) @@ -47,11 +56,13 @@ def do_matching(self): if self._verbose: print(" Comparing: ", self.action_list[i], " and ", self.action_list[j]) comp = TrackingSession( - self.action_list[i], self.action_list[j], + self.action_list[i], + self.action_list[j], actions=self._actions, max_dissimilarity=np.inf, channel_group=self._channel_group, - verbose=self._verbose) + verbose=self._verbose, + ) # comp.save_dissimilarity_matrix() self.comparisons.append(comp) pbar.update(1) @@ -59,13 +70,13 @@ def do_matching(self): def make_graphs_from_matches(self): if self._verbose: - print('Multicomaprison step2: make graph') + print("Multicomaprison step2: make graph") self.graphs = {} for ch in self._channel_groups: if self._verbose: - print('Processing channel', ch) + print("Processing channel", ch) self.graphs[ch] = nx.Graph() # nodes @@ -73,70 +84,65 @@ def make_graphs_from_matches(self): # if same node is added twice it's only created once for i, action_id in enumerate(comp.action_ids): for u in comp.unit_ids[ch][i]: - node_name = action_id + '_' + str(int(u)) - self.graphs[ch].add_node( - node_name, action_id=action_id, - unit_id=int(u)) + node_name = action_id + "_" + str(int(u)) + self.graphs[ch].add_node(node_name, action_id=action_id, unit_id=int(u)) # edges for comp in self.comparisons: - if 'hungarian_match_01' not in comp.matches[ch]: + if "hungarian_match_01" not in comp.matches[ch]: continue for u1 in comp.unit_ids[ch][0]: - u2 = comp.matches[ch]['hungarian_match_01'][u1] + u2 = comp.matches[ch]["hungarian_match_01"][u1] if u2 != -1: - score = comp.matches[ch]['dissimilarity_scores'].loc[u1, u2] - node1_name = comp.action_id_0 + '_' + str(int(u1)) - node2_name = comp.action_id_1 + '_' + str(int(u2)) - self.graphs[ch].add_edge( - node1_name, node2_name, weight=float(score)) + score = comp.matches[ch]["dissimilarity_scores"].loc[u1, u2] + node1_name = comp.action_id_0 + "_" + str(int(u1)) + node2_name = comp.action_id_1 + "_" + str(int(u2)) + self.graphs[ch].add_edge(node1_name, node2_name, weight=float(score)) # the graph is symmetrical self.graphs[ch] = self.graphs[ch].to_undirected() def compute_time_delta_edges(self): - ''' + """ adds a timedelta to each of the edges - ''' + """ for graph in self.graphs.values(): for n0, n1 in graph.edges(): - action_id_0 = graph.nodes[n0]['action_id'] - action_id_1 = graph.nodes[n1]['action_id'] - time_delta = abs( - self._actions[action_id_0].datetime - - self._actions[action_id_1].datetime) + action_id_0 = graph.nodes[n0]["action_id"] + action_id_1 = graph.nodes[n1]["action_id"] + time_delta = abs(self._actions[action_id_0].datetime - self._actions[action_id_1].datetime) graph.add_edge(n0, n1, time_delta=time_delta) def compute_depth_delta_edges(self): - ''' + """ adds a depthdelta to each of the edges - ''' + """ for ch, graph in self.graphs.items(): ch_num = int(ch[-1]) for n0, n1 in graph.edges(): - action_id_0 = graph.nodes[n0]['action_id'] - action_id_1 = graph.nodes[n1]['action_id'] - loc_0 = self._actions[action_id_0].modules['channel_group_location'][ch_num] - loc_1 = self._actions[action_id_1].modules['channel_group_location'][ch_num] + action_id_0 = graph.nodes[n0]["action_id"] + action_id_1 = graph.nodes[n1]["action_id"] + loc_0 = self._actions[action_id_0].modules["channel_group_location"][ch_num] + loc_1 = self._actions[action_id_1].modules["channel_group_location"][ch_num] assert loc_0 == loc_1 - depth_0 = self._actions[action_id_0].modules['depth'][loc_0]['probe_0'] - depth_1 = self._actions[action_id_0].modules['depth'][loc_1]['probe_0'] - depth_0 = float(depth_0.rescale('um')) - depth_1 = float(depth_1.rescale('um')) + depth_0 = self._actions[action_id_0].modules["depth"][loc_0]["probe_0"] + depth_1 = self._actions[action_id_0].modules["depth"][loc_1]["probe_0"] + depth_0 = float(depth_0.rescale("um")) + depth_1 = float(depth_1.rescale("um")) depth_delta = abs(depth_0 - depth_1) graph.add_edge(n0, n1, depth_delta=depth_delta) - def remove_edges_above_threshold(self, key='weight', threshold=0.05): - ''' + def remove_edges_above_threshold(self, key="weight", threshold=0.05): + """ key: weight, depth_delta, time_delta - ''' + """ for ch in self.graphs: graph = self.graphs[ch] edges_to_remove = [] for sub_graph in nx.connected_components(graph): for node_id in sub_graph: for n1, n2, d in graph.edges(node_id, data=True): - if d[key] > threshold and n2 in sub_graph: # remove all edges from the subgraph + if d[key] > threshold and n2 in sub_graph: # remove all edges from the subgraph edge = set((n1, n2)) if edge not in edges_to_remove: edges_to_remove.append(edge) @@ -148,26 +154,27 @@ def remove_edges_with_duplicate_actions(self): for graph in self.graphs.values(): edges_to_remove = [] for sub_graph in nx.connected_components(graph): - sub_graph_action_ids = {node: graph.nodes[node]['action_id'] for node in sub_graph} + sub_graph_action_ids = {node: graph.nodes[node]["action_id"] for node in sub_graph} action_ids = np.array(list(sub_graph_action_ids.values())) node_ids = np.array(list(sub_graph_action_ids.keys())) unique_action_ids, action_id_counts = np.unique(action_ids, return_counts=True) if len(unique_action_ids) != len(action_ids): - duplicates = unique_action_ids[action_id_counts > 1] for duplicate in duplicates: - idxs, = np.where(action_ids == duplicate) + (idxs,) = np.where(action_ids == duplicate) weights = {} for node_id in node_ids[idxs]: - weights[node_id] = np.mean([ - d['weight'] - for n1, n2, d in graph.edges(node_id, data=True) - if n2 in sub_graph_action_ids - ]) + weights[node_id] = np.mean( + [ + d["weight"] + for n1, n2, d in graph.edges(node_id, data=True) + if n2 in sub_graph_action_ids + ] + ) min_weight = np.min(list(weights.values())) for node_id, weight in weights.items(): - if weight > min_weight: # remove all edges from the subgraph + if weight > min_weight: # remove all edges from the subgraph for n1, n2 in graph.edges(node_id): if n2 in sub_graph_action_ids: edge = set((n1, n2)) @@ -178,22 +185,22 @@ def remove_edges_with_duplicate_actions(self): def save_graphs(self): for ch, graph in self.graphs.items(): - with open(self.data_path / f'graph-group-{ch}.yaml', "w") as f: + with open(self.data_path / f"graph-group-{ch}.yaml", "w") as f: yaml.dump(graph, f) def load_graphs(self): self.graphs = {} for path in self.data_path.iterdir(): - if path.name.startswith('graph-group') and path.suffix == '.yaml': - ch = path.stem.split('-')[-1] + if path.name.startswith("graph-group") and path.suffix == ".yaml": + ch = path.stem.split("-")[-1] with open(path, "r") as f: self.graphs[ch] = yaml.load(f, Loader=yaml.Loader) def identify_units(self): - if self._verbose: - print('Multicomaprison step3: extract agreement from graph') - self.identified_units = {} - for ch, graph in self.graphs.items(): + if self._verbose: + print("Multicomaprison step3: extract agreement from graph") + self.identified_units = {} + for ch, graph in self.graphs.items(): # extract agrrement from graph self._new_units = {} for node_set in nx.connected_components(graph): @@ -203,26 +210,24 @@ def identify_units(self): if len(edges) == 0: average_dissimilarity = None else: - average_dissimilarity = np.mean( - [d['weight'] for _, _, d in edges]) + average_dissimilarity = np.mean([d["weight"] for _, _, d in edges]) original_ids = defaultdict(list) for node in node_set: - original_ids[graph.nodes[node]['action_id']].append( - graph.nodes[node]['unit_id'] - ) + original_ids[graph.nodes[node]["action_id"]].append(graph.nodes[node]["unit_id"]) self._new_units[unit_id] = { - 'average_dissimilarity': average_dissimilarity, - 'original_unit_ids': original_ids} + "average_dissimilarity": average_dissimilarity, + "original_unit_ids": original_ids, + } self.identified_units[ch] = self._new_units def load_template(self, action_id, channel_group, unit_id): - group_unit_hash = str(channel_group) + '_' + str(unit_id) + group_unit_hash = str(channel_group) + "_" + str(unit_id) if action_id in self._templates: return self._templates[action_id][group_unit_hash] - + action = self._actions[action_id] data_path = get_data_path(action) @@ -231,22 +236,22 @@ def load_template(self, action_id, channel_group, unit_id): self._templates[action_id] = {} for sptr in spike_trains: - group_unit_hash_ = sptr.annotations['group'] + '_' + str(int(sptr.annotations['name'])) + group_unit_hash_ = sptr.annotations["group"] + "_" + str(int(sptr.annotations["name"])) self._templates[action_id][group_unit_hash_] = sptr.annotations["waveform_mean"] return self._templates[action_id][group_unit_hash] def plot_matches(self, chan_group=None, figsize=(10, 3), step_color=True): - ''' + """ Parameters ---------- - + Returns ------- - ''' + """ if chan_group is None: ch_groups = self.identified_units.keys() else: @@ -254,9 +259,10 @@ def plot_matches(self, chan_group=None, figsize=(10, 3), step_color=True): for ch_group in ch_groups: identified_units = self.identified_units[ch_group] units = [ - (unit['original_unit_ids'], unit['average_dissimilarity']) + (unit["original_unit_ids"], unit["average_dissimilarity"]) for unit in identified_units.values() - if len(unit['original_unit_ids']) > 1] + if len(unit["original_unit_ids"]) > 1 + ] num_units = sum([len(u) for u in units]) if num_units == 0: print(f"Zero units found on channel group {ch_group}") @@ -274,11 +280,8 @@ def plot_matches(self, chan_group=None, figsize=(10, 3), step_color=True): print(f'Unable to plot "{unit_id}" from action "{action_id}" ch group "{ch_group}"') continue # print(f'plotting {action_id}, {ch_group}, {unit_id}') - axs = plot_template( - template, - fig=fig, gs=gs[id_ax], axs=axs, - label=label) + axs = plot_template(template, fig=fig, gs=gs[id_ax], axs=axs, label=label) id_ax += 1 - plt.legend(loc='center left', bbox_to_anchor=(1, 0.5)) - fig.suptitle('Channel group ' + str(ch_group)) + plt.legend(loc="center left", bbox_to_anchor=(1, 0.5)) + fig.suptitle("Channel group " + str(ch_group)) plt.tight_layout(rect=[0, 0.03, 1, 0.98]) diff --git a/src/expipe_plugin_cinpla/widgets/viewer.py b/src/expipe_plugin_cinpla/widgets/viewer.py index 408ca5b..1324068 100644 --- a/src/expipe_plugin_cinpla/widgets/viewer.py +++ b/src/expipe_plugin_cinpla/widgets/viewer.py @@ -1,5 +1,4 @@ import ipywidgets -from pynwb import NWBHDF5IO import expipe @@ -38,6 +37,7 @@ def get_options(self): def on_change(self, change): if change["type"] == "change" and change["name"] == "value": from nwbwidgets import nwb2widget + from pynwb import NWBHDF5IO action_id = change["new"] if action_id is None: From e924703d61e4ddf5153ea2fd5224329e5280fd73 Mon Sep 17 00:00:00 2001 From: Mikkel Date: Mon, 15 Apr 2024 22:29:47 +0200 Subject: [PATCH 22/47] bug in unit retrieval pluss some cleanup --- requirements.txt | 2 +- .../tools/track_units_tools.py | 15 -- .../tools/trackunitcomparison.py | 142 ++++++------------ .../tools/trackunitmulticomparison.py | 16 +- 4 files changed, 50 insertions(+), 125 deletions(-) diff --git a/requirements.txt b/requirements.txt index bd0846c..3f35eac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,4 +7,4 @@ pyopenephys>=1.2.0 tbb>=2021.11.0 spikeinterface[full,widgets]>=0.100.0 pynwb>=2.6.0 -neuroconv>=0.4.6 \ No newline at end of file +neuroconv>=0.4.6 diff --git a/src/expipe_plugin_cinpla/tools/track_units_tools.py b/src/expipe_plugin_cinpla/tools/track_units_tools.py index 20a3748..fc1ce64 100644 --- a/src/expipe_plugin_cinpla/tools/track_units_tools.py +++ b/src/expipe_plugin_cinpla/tools/track_units_tools.py @@ -60,21 +60,6 @@ def dissimilarity_weighted(templates_0, templates_1): return weighted -def make_dissimilary_matrix(comp_object, channel_group): - templates_0, templates_1 = comp_object.templates[channel_group] - diss_matrix = np.zeros((len(templates_0), len(templates_1))) - - unit_ids_0, unit_ids_1 = comp_object.unit_ids[channel_group] - - for i, w0 in enumerate(templates_0): - for j, w1 in enumerate(templates_1): - diss_matrix[i, j] = dissimilarity_weighted(w0, w1) - - diss_matrix = pd.DataFrame(diss_matrix, index=unit_ids_0, columns=unit_ids_1) - - return diss_matrix - - def make_possible_match(dissimilarity_scores, max_dissimilarity): """ Given an agreement matrix and a max_dissimilarity threhold. diff --git a/src/expipe_plugin_cinpla/tools/trackunitcomparison.py b/src/expipe_plugin_cinpla/tools/trackunitcomparison.py index ae11e53..980a33f 100644 --- a/src/expipe_plugin_cinpla/tools/trackunitcomparison.py +++ b/src/expipe_plugin_cinpla/tools/trackunitcomparison.py @@ -1,9 +1,9 @@ -from .track_units_tools import make_dissimilary_matrix, make_possible_match, make_best_match, make_hungarian_match +from .track_units_tools import dissimilarity_weighted, make_possible_match, make_best_match, make_hungarian_match from expipe_plugin_cinpla.data_loader import get_data_path, load_spiketrains, get_channel_groups import matplotlib.pylab as plt import numpy as np from pathlib import Path - +import pandas as pd class TrackingSession: """ @@ -15,7 +15,7 @@ def __init__( action_id_0, action_id_1, actions, - channel_group=None, + channel_groups=None, max_dissimilarity=10, dissimilarity_function=None, verbose=False, @@ -26,47 +26,45 @@ def __init__( self._actions = actions self.action_id_0 = action_id_0 self.action_id_1 = action_id_1 - self._channel_group = channel_group + self.channel_groups = channel_groups self.action_ids = [action_id_0, action_id_1] self.max_dissimilarity = max_dissimilarity self.dissimilarity_function = dissimilarity_function self._verbose = verbose - if channel_group is None: - channel_groups = get_channel_groups(data_path_0) - self.matches = {} - self.templates = {} - self.unit_ids = {} - for chan in channel_groups: - self.matches[chan] = dict() - self.templates[chan] = list() - self.unit_ids[chan] = list() - else: - self.matches = {channel_group: dict()} - self.templates = {channel_group: list()} - self.unit_ids = {channel_group: list()} - - units_0 = load_spiketrains(data_path_0) - units_1 = load_spiketrains(data_path_1) - for channel_group in self.matches.keys(): - units_0 = [unit for unit in units_0 if unit.annotations["group"] == channel_group] - units_1 = [unit for unit in units_1 if unit.annotations["group"] == channel_group] + if self.channel_groups is None: + self.channel_groups = get_channel_groups(data_path_0) + self.matches = {} + self.templates = {} + self.unit_ids = {} + for chan in self.channel_groups: + self.matches[chan] = dict() + self.templates[chan] = list() + self.unit_ids[chan] = list() + + self.units_0 = load_spiketrains(data_path_0) + self.units_1 = load_spiketrains(data_path_1) + for channel_group in self.channel_groups: + us_0 = [u for u in self.units_0 if u.annotations["group"] == channel_group] + us_1 = [u for u in self.units_1 if u.annotations["group"] == channel_group] self.unit_ids[channel_group] = [ - np.array([int(st.annotations["name"]) for st in units_0]), - np.array([int(st.annotations["name"]) for st in units_1]), + [int(u.annotations["name"]) for u in us_0], + [int(u.annotations["name"]) for u in us_1], ] self.templates[channel_group] = [ - [st.annotations["waveform_mean"] for st in units_0], - [st.annotations["waveform_mean"] for st in units_1], + [u.annotations["waveform_mean"] for u in us_0], + [u.annotations["waveform_mean"] for u in us_1], ] - if len(units_0) > 0 and len(units_1) > 0: + if len(us_0) > 0 and len(us_1) > 0: self._do_dissimilarity(channel_group) self._do_matching(channel_group) + elif self._verbose: + print(f'Found no units in {channel_group}') def save_dissimilarity_matrix(self, path=None): path = path or Path.cwd() - for channel_group in self.matches: + for channel_group in self.channel_groups: if "dissimilarity_scores" not in self.matches[channel_group]: continue filename = f"{self.action_id_0}_{self.action_id_1}_{channel_group}" @@ -79,13 +77,27 @@ def session_0_name(self): @property def session_1_name(self): return self.name_list[1] + + def make_dissimilary_matrix(self, channel_group): + templates_0, templates_1 = self.templates[channel_group] + diss_matrix = np.zeros((len(templates_0), len(templates_1))) + + unit_ids_0, unit_ids_1 = self.unit_ids[channel_group] + + for i, w0 in enumerate(templates_0): + for j, w1 in enumerate(templates_1): + diss_matrix[i, j] = dissimilarity_weighted(w0, w1) + + diss_matrix = pd.DataFrame(diss_matrix, index=unit_ids_0, columns=unit_ids_1) + + return diss_matrix def _do_dissimilarity(self, channel_group): if self._verbose: print("Agreement scores...") # agreement matrix score for each pair - self.matches[channel_group]["dissimilarity_scores"] = make_dissimilary_matrix(self, channel_group) + self.matches[channel_group]["dissimilarity_scores"] = self.make_dissimilary_matrix(channel_group) def _do_matching(self, channel_group): # must be implemented in subclass @@ -103,73 +115,3 @@ def _do_matching(self, channel_group): self.matches[channel_group]["hungarian_match_01"], self.matches[channel_group]["hungarian_match_10"], ) = make_hungarian_match(self.matches[channel_group]["dissimilarity_scores"], self.max_dissimilarity) - - def plot_matched_units(self, match_mode="hungarian", channel_group=None, ylim=[-200, 50], figsize=(15, 15)): - """ - - Parameters - ---------- - match_mode - - Returns - ------- - - """ - if channel_group is None: - ch_groups = self.matches.keys() - else: - ch_groups = [channel_group] - - for ch_group in ch_groups: - if "hungarian_match_01" not in self.matches[ch_group].keys(): - print("Not units for group", ch_group) - continue - - if match_mode == "hungarian": - match12 = self.matches[ch_group]["hungarian_match_01"] - elif match_mode == "best": - match12 = self.matches[ch_group]["best_match_01"] - - num_matches = len(np.where(match12 != -1)[0]) - - if num_matches > 0: - fig, ax_list = plt.subplots(nrows=2, ncols=num_matches, figsize=figsize) - fig.suptitle("Channel group " + str(ch_group)) - - if num_matches == 1: - i = np.where(match12 != -1)[0][0] - j = match12.iloc[i] - i1 = np.where(self.matches[ch_group]["unit_ids_0"] == match12.index[i]) - i2 = np.where(self.matches[ch_group]["unit_ids_1"] == j) - template1 = np.squeeze(self.matches[ch_group]["templates_0"][i1]) - ax_list[0].plot(template1, color="C0") - ax_list[0].set_title("Unit " + str(match12.index[i])) - template2 = np.squeeze(self.matches[ch_group]["templates_1"][i2]) - ax_list[1].plot(template2, color="C0") - ax_list[1].set_title("Unit " + str(j)) - ax_list[0].set_ylabel(self.name_list[0]) - ax_list[1].set_ylabel(self.name_list[1]) - ax_list[0].set_ylim(ylim) - ax_list[1].set_ylim(ylim) - else: - id_ax = 0 - for i, j in enumerate(match12): - if j != -1: - i1 = np.where(self.matches[ch_group]["unit_ids_0"] == match12.index[i]) - i2 = np.where(self.matches[ch_group]["unit_ids_1"] == j) - - if id_ax == 0: - ax_list[0, id_ax].set_ylabel(self.name_list[0]) - ax_list[1, id_ax].set_ylabel(self.name_list[1]) - template1 = np.squeeze(self.matches[ch_group]["templates_0"][i1]) - ax_list[0, id_ax].plot(template1, color="C" + str(id_ax)) - ax_list[0, id_ax].set_title("Unit " + str(match12.index[i])) - template2 = np.squeeze(self.matches[ch_group]["templates_1"][i1]) - ax_list[1, id_ax].plot(template2, color="C" + str(id_ax)) - ax_list[1, id_ax].set_title("Unit " + str(j)) - ax_list[0, id_ax].set_ylim(ylim) - ax_list[1, id_ax].set_ylim(ylim) - id_ax += 1 - else: - print("No matched units for group", ch_group) - continue diff --git a/src/expipe_plugin_cinpla/tools/trackunitmulticomparison.py b/src/expipe_plugin_cinpla/tools/trackunitmulticomparison.py index 9ceb1c0..e24eab9 100644 --- a/src/expipe_plugin_cinpla/tools/trackunitmulticomparison.py +++ b/src/expipe_plugin_cinpla/tools/trackunitmulticomparison.py @@ -18,7 +18,7 @@ def __init__( self, actions, action_list=None, - channel_group=None, + channel_groups=None, max_dissimilarity=None, max_timedelta=None, verbose=False, @@ -29,19 +29,17 @@ def __init__( self.data_path.mkdir(parents=True, exist_ok=True) self.action_list = [a for a in actions] if action_list is None else action_list self._actions = actions - self._channel_group = channel_group + self.channel_groups = channel_groups self.max_dissimilarity = max_dissimilarity or np.inf self.max_timedelta = max_timedelta or datetime.MAXYEAR self._verbose = verbose self._pbar = tqdm if progress_bar is None else progress_bar self._templates = {} - if self._channel_group is None: + if self.channel_groups is None: dp = get_data_path(self._actions[self.action_list[0]]) - self._channel_groups = get_channel_groups(dp) - if len(self._channel_groups) == 0: + self.channel_groups = get_channel_groups(dp) + if len(self.channel_groups) == 0: print("Unable to locate channel groups, please provide a working action_list") - else: - self._channel_groups = [self._channel_group] def do_matching(self): # do pairwise matching @@ -60,7 +58,7 @@ def do_matching(self): self.action_list[j], actions=self._actions, max_dissimilarity=np.inf, - channel_group=self._channel_group, + channel_groups=self.channel_groups, verbose=self._verbose, ) # comp.save_dissimilarity_matrix() @@ -74,7 +72,7 @@ def make_graphs_from_matches(self): self.graphs = {} - for ch in self._channel_groups: + for ch in self.channel_groups: if self._verbose: print("Processing channel", ch) self.graphs[ch] = nx.Graph() From 65417598fe6be39024782374c84aedfacaf4ea5e Mon Sep 17 00:00:00 2001 From: Mikkel Date: Tue, 16 Apr 2024 22:10:11 +0200 Subject: [PATCH 23/47] bugfix --- src/expipe_plugin_cinpla/tools/data_processing.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/expipe_plugin_cinpla/tools/data_processing.py b/src/expipe_plugin_cinpla/tools/data_processing.py index 2e975b0..efcb403 100644 --- a/src/expipe_plugin_cinpla/tools/data_processing.py +++ b/src/expipe_plugin_cinpla/tools/data_processing.py @@ -472,8 +472,9 @@ def spike_trains(self, action_id): sts = load_spiketrains(self.data_path(action_id), lim=lim) for st in sts: channel_group = st.annotations["group"] - unit_id = get_unit_id(st) - self._spike_trains[action_id][channel_group] = {unit_id: st} + if channel_group not in self._spike_trains[action_id]: + self._spike_trains[action_id][channel_group] = {} + self._spike_trains[action_id][channel_group][int(get_unit_id(st))] = st return self._spike_trains[action_id] From afc15b02f463cd7fe43c50e431ff3db761dd83dd Mon Sep 17 00:00:00 2001 From: Nicolai Haug Date: Thu, 16 May 2024 10:40:21 +0200 Subject: [PATCH 24/47] More comprehensive .gitignore --- .gitignore | 948 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 936 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index 5be366e..fb19775 100644 --- a/.gitignore +++ b/.gitignore @@ -1,14 +1,793 @@ + +### Custom ### + +### C ### +# Prerequisites +*.d + +# Object files +*.o +*.ko +*.obj +*.elf + +# Linker output +*.ilk +*.map +*.exp + +# Precompiled Headers +*.gch +*.pch + +# Libraries +*.lib +*.a +*.la +*.lo + +# Shared objects (inc. Windows DLLs) +*.dll +*.so +*.so.* +*.dylib + +# Executables +*.exe +*.out +*.app +*.i*86 +*.x86_64 +*.hex + +# Debug files +*.dSYM/ +*.su +*.idb +*.pdb + +# Kernel Module Compile Results +*.mod* +*.cmd +.tmp_versions/ +modules.order +Module.symvers +Mkfile.old +dkms.conf + +### C++ ### +# Prerequisites + +# Compiled Object files +*.slo + +# Fortran module files +*.mod +*.smod + +# Compiled Static libraries +*.lai + +### CMake ### +CMakeLists.txt.user +CMakeCache.txt +CMakeFiles +CMakeScripts +Testing +Makefile +cmake_install.cmake +install_manifest.txt +compile_commands.json +CTestTestfile.cmake +_deps + +### CMake Patch ### +# External projects +*-prefix/ + +### Java ### +# Compiled class file +*.class + +# Log file +*.log + +# BlueJ files +*.ctxt + +# Mobile Tools for Java (J2ME) +.mtj.tmp/ + +# Package Files # +*.jar +*.war +*.nar +*.ear +*.zip +*.tar.gz +*.rar + +# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml +hs_err_pid* +replay_pid* + +### Julia ### +# Files generated by invoking Julia with --code-coverage +*.jl.cov +*.jl.*.cov + +# Files generated by invoking Julia with --track-allocation +*.jl.mem + +# System-specific files and directories generated by the BinaryProvider and BinDeps packages +# They contain absolute paths specific to the host computer, and so should not be committed +deps/deps.jl +deps/build.log +deps/downloads/ +deps/usr/ +deps/src/ + +# Build artifacts for creating documentation generated by the Documenter package +docs/build/ +docs/site/ + +# File generated by Pkg, the package manager, based on a corresponding Project.toml +# It records a fixed state of all packages used by the project. As such, it should not be +# committed for packages, but should be committed for applications that require a static +# environment. +Manifest.toml + +### JupyterNotebooks ### +# gitignore template for Jupyter Notebooks +# website: http://jupyter.org/ + +.ipynb_checkpoints +*/.ipynb_checkpoints/* + +# IPython +profile_default/ +ipython_config.py + +# Remove previous ipynb_checkpoints +# git rm -r .ipynb_checkpoints/ + +### LaTeX ### +## Core latex/pdflatex auxiliary files: +*.aux +*.lof +*.lot +*.fls +*.toc +*.fmt +*.fot +*.cb +*.cb2 +.*.lb + +## Intermediate documents: +*.dvi +*.xdv +*-converted-to.* +# these rules might exclude image files for figures etc. +# *.ps +# *.eps +# *.pdf + +## Generated if empty string is given at "Please type another file name for output:" +.pdf + +## Bibliography auxiliary files (bibtex/biblatex/biber): +*.bbl +*.bcf +*.blg +*-blx.aux +*-blx.bib +*.run.xml + +## Build tool auxiliary files: +*.fdb_latexmk +*.synctex +*.synctex(busy) +*.synctex.gz +*.synctex.gz(busy) +*.pdfsync + +## Build tool directories for auxiliary files +# latexrun +latex.out/ + +## Auxiliary and intermediate files from other packages: +# algorithms +*.alg +*.loa + +# achemso +acs-*.bib + +# amsthm +*.thm + +# beamer +*.nav +*.pre +*.snm +*.vrb + +# changes +*.soc + +# comment +*.cut + +# cprotect +*.cpt + +# elsarticle (documentclass of Elsevier journals) +*.spl + +# endnotes +*.ent + +# fixme +*.lox + +# feynmf/feynmp +*.mf +*.mp +*.t[1-9] +*.t[1-9][0-9] +*.tfm + +#(r)(e)ledmac/(r)(e)ledpar +*.end +*.?end +*.[1-9] +*.[1-9][0-9] +*.[1-9][0-9][0-9] +*.[1-9]R +*.[1-9][0-9]R +*.[1-9][0-9][0-9]R +*.eledsec[1-9] +*.eledsec[1-9]R +*.eledsec[1-9][0-9] +*.eledsec[1-9][0-9]R +*.eledsec[1-9][0-9][0-9] +*.eledsec[1-9][0-9][0-9]R + +# glossaries +*.acn +*.acr +*.glg +*.glo +*.gls +*.glsdefs +*.lzo +*.lzs +*.slg +*.sls + +# uncomment this for glossaries-extra (will ignore makeindex's style files!) +# *.ist + +# gnuplot +*.gnuplot +*.table + +# gnuplottex +*-gnuplottex-* + +# gregoriotex +*.gaux +*.glog +*.gtex + +# htlatex +*.4ct +*.4tc +*.idv +*.lg +*.trc +*.xref + +# hyperref +*.brf + +# knitr +*-concordance.tex +# TODO Uncomment the next line if you use knitr and want to ignore its generated tikz files +# *.tikz +*-tikzDictionary + +# listings +*.lol + +# luatexja-ruby +*.ltjruby + +# makeidx +*.idx +*.ilg +*.ind + +# minitoc +*.maf +*.mlf +*.mlt +*.mtc[0-9]* +*.slf[0-9]* +*.slt[0-9]* +*.stc[0-9]* + +# minted +_minted* +*.pyg + +# morewrites +*.mw + +# newpax +*.newpax + +# nomencl +*.nlg +*.nlo +*.nls + +# pax +*.pax + +# pdfpcnotes +*.pdfpc + +# sagetex +*.sagetex.sage +*.sagetex.py +*.sagetex.scmd + +# scrwfile +*.wrt + +# svg +svg-inkscape/ + +# sympy +*.sout +*.sympy +sympy-plots-for-*.tex/ + +# pdfcomment +*.upa +*.upb + +# pythontex +*.pytxcode +pythontex-files-*/ + +# tcolorbox +*.listing + +# thmtools +*.loe + +# TikZ & PGF +*.dpth +*.md5 +*.auxlock + +# titletoc +*.ptc + +# todonotes +*.tdo + +# vhistory +*.hst +*.ver + +# easy-todo +*.lod + +# xcolor +*.xcp + +# xmpincl +*.xmpi + +# xindy +*.xdy + +# xypic precompiled matrices and outlines +*.xyc +*.xyd + +# endfloat +*.ttt +*.fff + +# Latexian +TSWLatexianTemp* + +## Editors: +# WinEdt +*.bak +*.sav + +# Texpad +.texpadtmp + +# LyX +*.lyx~ + +# Kile +*.backup + +# gummi +.*.swp + +# KBibTeX +*~[0-9]* + +# TeXnicCenter +*.tps + +# auto folder when using emacs and auctex +./auto/* +*.el + +# expex forward references with \gathertags +*-tags.tex + +# standalone packages +*.sta + +# Makeindex log files +*.lpz + +# xwatermark package +*.xwm + +# REVTeX puts footnotes in the bibliography by default, unless the nofootinbib +# option is specified. Footnotes are the stored in a file with suffix Notes.bib. +# Uncomment the next line to have this generated file ignored. +#*Notes.bib + +### LaTeX Patch ### +# LIPIcs / OASIcs +*.vtc + +# glossaries +*.glstex + +### Linux ### +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +# .nfs files are created when an open file is removed but is still being accessed +.nfs* + +### macOS ### +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +### macOS Patch ### +# iCloud generated files +*.icloud + +### MATLAB ### +# Windows default autosave extension +*.asv + +# OSX / *nix default autosave extension +*.m~ + +# Compiled MEX binaries (all platforms) +*.mex* + +# Packaged app and toolbox files +*.mlappinstall +*.mltbx + +# Generated helpsearch folders +helpsearch*/ + +# Simulink code generation folders +slprj/ +sccprj/ + +# Matlab code generation folders +codegen/ + +# Simulink autosave extension +*.autosave + +# Simulink cache files +*.slxc + +# Octave session info +octave-workspace + +### Node ### +# Logs +logs +npm-debug.log* +yarn-debug.log* +yarn-error.log* +lerna-debug.log* +.pnpm-debug.log* + +# Diagnostic reports (https://nodejs.org/api/report.html) +report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json + +# Runtime data +pids +*.pid +*.seed +*.pid.lock + +# Directory for instrumented libs generated by jscoverage/JSCover +lib-cov + +# Coverage directory used by tools like istanbul +coverage +*.lcov + +# nyc test coverage +.nyc_output + +# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files) +.grunt + +# Bower dependency directory (https://bower.io/) +bower_components + +# node-waf configuration +.lock-wscript + +# Compiled binary addons (https://nodejs.org/api/addons.html) +build/Release + +# Dependency directories +node_modules/ +jspm_packages/ + +# Snowpack dependency directory (https://snowpack.dev/) +web_modules/ + +# TypeScript cache +*.tsbuildinfo + +# Optional npm cache directory +.npm + +# Optional eslint cache +.eslintcache + +# Optional stylelint cache +.stylelintcache + +# Microbundle cache +.rpt2_cache/ +.rts2_cache_cjs/ +.rts2_cache_es/ +.rts2_cache_umd/ + +# Optional REPL history +.node_repl_history + +# Output of 'npm pack' +*.tgz + +# Yarn Integrity file +.yarn-integrity + +# dotenv environment variable files +.env +.env.development.local +.env.test.local +.env.production.local +.env.local + +# parcel-bundler cache (https://parceljs.org/) +.cache +.parcel-cache + +# Next.js build output +.next +out + +# Nuxt.js build / generate output +.nuxt +dist + +# Gatsby files +.cache/ +# Comment in the public line in if your project uses Gatsby and not Next.js +# https://nextjs.org/blog/next-9-1#public-directory-support +# public + +# vuepress build output +.vuepress/dist + +# vuepress v2.x temp and cache directory +.temp + +# Docusaurus cache and generated files +.docusaurus + +# Serverless directories +.serverless/ + +# FuseBox cache +.fusebox/ + +# DynamoDB Local files +.dynamodb/ + +# TernJS port file +.tern-port + +# Stores VSCode versions used for testing VSCode extensions +.vscode-test + +# yarn v2 +.yarn/cache +.yarn/unplugged +.yarn/build-state.yml +.yarn/install-state.gz +.pnp.* + +### Node Patch ### +# Serverless Webpack directories +.webpack/ + +# Optional stylelint cache + +# SvelteKit build / generate output +.svelte-kit + +### PyCharm ### +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/**/usage.statistics.xml +.idea/**/dictionaries +.idea/**/shelf + +# AWS User-specific +.idea/**/aws.xml + +# Generated files +.idea/**/contentModel.xml + +# Sensitive or high-churn files +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml +.idea/**/dbnavigator.xml + +# Gradle +.idea/**/gradle.xml +.idea/**/libraries + +# CMake +cmake-build-*/ + +# Mongo Explorer plugin +.idea/**/mongoSettings.xml + +# File-based project format +*.iws + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# SonarLint plugin +.idea/sonarlint/ + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties + +# Editor-based Rest Client +.idea/httpRequests + +# Android studio 3.1+ serialized cache file +.idea/caches/build_file_checksums.ser + +### PyCharm Patch ### +# Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 + +# *.iml +# modules.xml +# .idea/misc.xml +# *.ipr + +# Sonarlint plugin +# https://plugins.jetbrains.com/plugin/7973-sonarlint +.idea/**/sonarlint/ + +# SonarQube Plugin +# https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin +.idea/**/sonarIssues.xml + +# Markdown Navigator plugin +# https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced +.idea/**/markdown-navigator.xml +.idea/**/markdown-navigator-enh.xml +.idea/**/markdown-navigator/ + +# Cache file creation bug +# See https://youtrack.jetbrains.com/issue/JBR-2257 +.idea/$CACHE_FILE$ + +# CodeStream plugin +# https://plugins.jetbrains.com/plugin/12206-codestream +.idea/codestream.xml + +# Azure Toolkit for IntelliJ plugin +# https://plugins.jetbrains.com/plugin/8053-azure-toolkit-for-intellij +.idea/**/azureSettings.xml + +### Python ### # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions -*.so # Distribution / packaging .Python -env/ build/ develop-eggs/ dist/ @@ -20,9 +799,12 @@ lib64/ parts/ sdist/ var/ +wheels/ +share/python-wheels/ *.egg-info/ .installed.cfg *.egg +MANIFEST # PyInstaller # Usually these files are written by a python script from a template @@ -37,21 +819,25 @@ pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ +.nox/ .coverage .coverage.* -.cache nosetests.xml coverage.xml -*,cover +*.cover +*.py,cover .hypothesis/ +.pytest_cache/ +cover/ # Translations *.mo *.pot # Django stuff: -*.log local_settings.py +db.sqlite3 +db.sqlite3-journal # Flask stuff: instance/ @@ -64,33 +850,171 @@ instance/ docs/_build/ # PyBuilder +.pybuilder/ target/ -# IPython Notebook -.ipynb_checkpoints +# Jupyter Notebook + +# IPython # pyenv -.python-version +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock -# celery beat schedule file +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff celerybeat-schedule +celerybeat.pid -# dotenv -.env +# SageMath parsed files +*.sage.py -# virtualenv +# Environments +.venv +env/ venv/ ENV/ +env.bak/ +venv.bak/ # Spyder project settings .spyderproject +.spyproject # Rope project settings .ropeproject +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +### Rust ### +# Generated by Cargo +# will have compiled files and executables +debug/ + +# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries +# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html +Cargo.lock + +# These are backup files generated by rustfmt +**/*.rs.bk + +# MSVC Windows builds of rustc generate these, which store debugging information + ### VisualStudioCode ### .vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +!.vscode/*.code-snippets + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix ### VisualStudioCode Patch ### # Ignore all local history of files .history +.ionide + +### Windows ### +# Windows thumbnail cache files +Thumbs.db +Thumbs.db:encryptable +ehthumbs.db +ehthumbs_vista.db + +# Dump file +*.stackdump + +# Folder config file +[Dd]esktop.ini + +# Recycle Bin used on file shares +$RECYCLE.BIN/ + +# Windows Installer files +*.cab +*.msi +*.msix +*.msm +*.msp + +# Windows shortcuts +*.lnk + +### Xcode ### +## User settings +xcuserdata/ + +## Xcode 8 and earlier +*.xcscmblueprint +*.xccheckout + +### Xcode Patch ### +*.xcodeproj/* +!*.xcodeproj/project.pbxproj +!*.xcodeproj/xcshareddata/ +!*.xcodeproj/project.xcworkspace/ +!*.xcworkspace/contents.xcworkspacedata +/*.gcno +**/xcshareddata/WorkspaceSettings.xcsettings From 579de7402c397480d8c441f9e69259dd26bd92e3 Mon Sep 17 00:00:00 2001 From: Nicolai Haug Date: Thu, 16 May 2024 10:43:00 +0200 Subject: [PATCH 25/47] Bump Python version from 3.10 to 3.11 --- environment.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/environment.yml b/environment.yml index e6c8fa7..39966db 100644 --- a/environment.yml +++ b/environment.yml @@ -3,7 +3,7 @@ name: cinpla channels: - defaults dependencies: - - python=3.10 + - python=3.11 - pip - pip: - expipe-plugin-cinpla From 4d373f048c946868521acd3c119ea0a63798969e Mon Sep 17 00:00:00 2001 From: Nicolai Haug Date: Thu, 16 May 2024 10:43:58 +0200 Subject: [PATCH 26/47] Configure pre-commit hooks --- .pre-commit-config.yaml | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..7132741 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,31 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: fix-encoding-pragma + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-docstring-first + - id: debug-statements + - id: check-toml + - id: check-yaml + - id: requirements-txt-fixer + - id: detect-private-key + - id: check-merge-conflict + + - repo: https://github.com/psf/black + rev: 24.4.2 + hooks: + - id: black + - id: black-jupyter + + - repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + args: ["--profile", "black"] + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.4.4 + hooks: + - id: ruff From aac8d4f744af1361ba0504d983c103657e1f991c Mon Sep 17 00:00:00 2001 From: Nicolai Haug Date: Thu, 16 May 2024 10:48:56 +0200 Subject: [PATCH 27/47] Fix formatting, add ruff specs and new deps --- pyproject.toml | 95 +++++++++++++++++++++++++++++++++++--------------- 1 file changed, 66 insertions(+), 29 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6673ad2..9711e54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,13 +2,13 @@ name = "expipe_plugin_cinpla" version = "0.1.5" authors = [ - { name="Mikkel Lepperod", email="mikkel@simula.no" }, - { name="Alessio Buccino", email="alessiop.buccino@gmail.com" }, + { name = "Mikkel Lepperod", email = "mikkel@simula.no" }, + { name = "Alessio Buccino", email = "alessiop.buccino@gmail.com" }, ] description = "Expipe plugins for the CINPLA lab." readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.10" classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", @@ -16,29 +16,26 @@ classifiers = [ ] dependencies = [ - "expipe>=0.6.0", + "expipe>=0.6.0", "neuroconv>=0.4.6", "pyopenephys>=1.2.0", "spikeinterface[full,widgets]>=0.100.0", "pynwb>=2.5.0", - "neuroconv>=0.4.6", "ipywidgets>=8.1.1", "nwbwidgets>=0.11.3", - "tbb>=2021.11.0", + "tbb>=2021.11.0", # TODO: pip can't find tbb or tbb4py (at least on macOS). Is it needed? "pynapple>=0.5.1", - "spython>=0.3.13", + "spython>=0.3.13", # TODO: is this needed? ] [project.urls] homepage = "https://github.com/CINPLA/expipe-plugin-cinpla" repository = "https://github.com/CINPLA/expipe-plugin-cinpla" - [build-system] requires = ["setuptools>=62.0"] build-backend = "setuptools.build_meta" - [tool.setuptools] include-package-data = true @@ -48,30 +45,70 @@ include = ["expipe_plugin_cinpla*"] namespaces = false [project.optional-dependencies] - -dev = [ - "pytest", - "pytest-cov", - "pytest-dependency", - "black" +dev = ["pre-commit", "black[jupyter]", "isort", "ruff"] +test = ["pytest", "pytest-cov", "pytest-dependency", "mountainsort5"] +docs = ["sphinx-gallery", "sphinx_rtd_theme"] +full = [ + "expipe_plugin_cinpla[dev]", + "expipe_plugin_cinpla[test]", + "expipe_plugin_cinpla[docs]", ] -test = [ - "pytest", - "pytest-cov", - "pytest-dependency", - "mountainsort5" -] +[tool.coverage.run] +omit = ["tests/*"] -docs = [ - "sphinx-gallery", - "sphinx_rtd_theme", -] +[tool.black] +line-length = 120 -[tool.coverage.run] -omit = [ - "tests/*", +[tool.ruff] +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", ] -[tool.black] +# Same as Black. line-length = 120 +indent-width = 4 + +# Assume Python 3.11. +target-version = "py311" + +[tool.ruff.lint] +# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. +# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or +# McCabe complexity (`C901`) by default. +select = ["E4", "E7", "E9", "F"] +ignore = [] + +# Allow fix for all enabled rules (when `--fix`) is provided. +fixable = ["ALL"] +unfixable = [] + +# Allow unused variables when underscore-prefixed. +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +[tool.ruff.lint.per-file-ignores] +"src/expipe_plugin_cinpla/cli/utils.py" = ["F403"] +"src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py" = ["F821"] +"src/expipe_plugin_cinpla/widgets/utils.py" = ["F841"] # TODO: fix warning +"tests/test_cli.py" = ["F841"] # TODO: fix warning +"tests/test_script.py" = ["F841"] # TODO: fix warning From 6d575411e4fde514cb74d9f489ab3ce174551f33 Mon Sep 17 00:00:00 2001 From: Nicolai Haug Date: Thu, 16 May 2024 10:52:55 +0200 Subject: [PATCH 28/47] Update installation description and add 'How to contribute' --- README.md | 62 +++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 47 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index 5417d71..5946af0 100644 --- a/README.md +++ b/README.md @@ -1,47 +1,42 @@ # Expipe plugin CINPLA -Expipe plugin for CINPLA laboratory +Expipe plugin for the CINPLA laboratory. ## Installation -You can install the package with pip: +`expipe-plugin-cinpla` can be installed by running -```bash ->>> pip install expipe-plugin-cinpla -``` + $ pip install expipe-plugin-cinpla -or from sources: +It requires Python 3.10+ to run. + +If you want the latest features and can't wait for the next release, install from GitHub: + + $ pip install git+https://github.com/CINPLA/expipe-plugin-cinpla.git -```bash -git clone -cd expipe-plugin-cinpla -pip install -e . -``` ## Usage -The starting point is a valid `expipe` project. Refer to the [expipe docs]() to read more on how -to create one. +The starting point is a valid `expipe` project. Refer to the [expipe docs](https://expipe.readthedocs.io/en/latest/) to read more on how to create one. The recommended usage is via Jupyter Notebook / Lab, using the interactive widgets to Register, Process, Curate, and View your actions. To launch the interactive browser, you can run: + ```python from expipe_plugin_cinpla import display_browser project_path = "path-to-my-project" display_browser(project_path) - ``` ![alt text](docs/images/browser.png) - ## Updating old projects The current version uses Neurodata Without Borders as backend instead of Exdir. If you have an existing @@ -60,3 +55,40 @@ convert_old_project(old_project_path, new_project_path, probe_path) To check out other options, use `convert_old_project?` + +## How to contribute + +### Set up development environment + +First, we recommend to create a virtual environment and install `pip`; + +* Using [venv](https://packaging.python.org/en/latest/key_projects/#venv): + + $ python3.11 -m venv + $ source /bin/activate + $ python3 -m pip install --upgrade pip + +* Using [conda](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html): + + $ conda create -n python=3.11 pip + $ conda activate + +Then install `expipe-plugin-cinpla` in editable mode from source: + + $ git clone https://github.com/CINPLA/expipe-plugin-cinpla.git + $ cd expipe_plugin_cinpla + $ python3 -m pip install -e ".[full]" + + +### pre-commit +We use [pre-commit](https://pre-commit.com/) to run Git hooks on every commit to identify simple issues such as trailing whitespace or not complying with the required formatting. Our pre-commit configuration is specified in the `.pre-commit-config.yml` file. + +To set up the Git hook scripts specified in `.pre-commit-config.yml`, run + + $ pre-commit install + +> **NOTE:** If `pre-commit` identifies formatting issues in the commited code, the pre-commit Git hooks will reformat the code. If code is reformatted, it will show up in your unstaged changes. Stage them and recommit to successfully commit your changes. + +It is also possible to run the pre-commit hooks without attempting a commit: + + $ pre-commit run --all-files From 07209df9dcf38e94f353e125a374212188c6c6a7 Mon Sep 17 00:00:00 2001 From: Nicolai Haug Date: Thu, 16 May 2024 10:54:38 +0200 Subject: [PATCH 29/47] Remove old, commented out code --- setup.py | 39 +-------------------------------------- 1 file changed, 1 insertion(+), 38 deletions(-) diff --git a/setup.py b/setup.py index d0453b9..7a58127 100644 --- a/setup.py +++ b/setup.py @@ -1,43 +1,6 @@ -# # -*- coding: utf-8 -*- -# from setuptools import setup, find_packages - -# long_description = open("README.md").read() - -# with open("requirements.txt", mode='r') as f: -# install_requires = f.read().split('\n') - -# install_requires = [e for e in install_requires if len(e) > 0] - -# d = {} -# exec(open("expipe_plugin_cinpla/version.py").read(), None, d) -# version = d['version'] -# pkg_name = "expipe-pligin-cinpla" - -# setup( -# name=pkg_name, -# packages=find_packages(), -# version=version, -# include_package_data=True, -# author="CINPLA", -# author_email="", -# maintainer="Mikkel Elle Lepperød, Alessio Buccino", -# maintainer_email="mikkel@simula.no", -# platforms=["Linux", "Windows"], -# description="Expipe plugins for the CINPLA lab", -# url="https://github.com/CINPLA/expipe-plugin-cinpla", -# long_description_content_type="text/markdown", -# install_requires=install_requires, -# long_description=long_description, -# classifiers=['Intended Audience :: Science/Research', -# 'License :: OSI Approved :: GNU General Public License v2 (GPLv2)', -# 'Natural Language :: English', -# 'Programming Language :: Python :: 3', -# 'Topic :: Scientific/Engineering'], -# python_requires='>=3.9', -# ) +# -*- coding: utf-8 -*- import setuptools - if __name__ == "__main__": setuptools.setup() From fa8a44e3a4e3ab083d6ad420fe2acfcfafea0343 Mon Sep 17 00:00:00 2001 From: Nicolai Haug Date: Thu, 16 May 2024 10:55:23 +0200 Subject: [PATCH 30/47] Add formatting and linter check to CI --- .github/workflows/check_formatting.yml | 44 ++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 .github/workflows/check_formatting.yml diff --git a/.github/workflows/check_formatting.yml b/.github/workflows/check_formatting.yml new file mode 100644 index 0000000..c946b9f --- /dev/null +++ b/.github/workflows/check_formatting.yml @@ -0,0 +1,44 @@ +name: Check formatting + +on: [push, pull_request] + +permissions: + contents: read + +jobs: + main: + runs-on: ubuntu-latest + steps: + + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Check EOF + uses: pre-commit/action@v3.0.0 + with: + extra_args: end-of-file-fixer + + - name: Check trailing whitespace + uses: pre-commit/action@v3.0.0 + with: + extra_args: trailing-whitespace + + - name: Black + uses: psf/black@stable + with: + options: "--check" + src: "./code" + jupyter: true + + - name: isort + uses: isort/isort-action@v1 + with: + configuration: --profile=black --check-only --diff + + - name: Ruff + uses: chartboost/ruff-action@v1 + with: + args: "check" From 46bc0674c6a7f5fa2d9bc69e9e3389c54ff416ca Mon Sep 17 00:00:00 2001 From: Nicolai Haug Date: Thu, 16 May 2024 10:56:07 +0200 Subject: [PATCH 31/47] Remove duplicate pytest install --- .github/workflows/full_tests.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/full_tests.yml b/.github/workflows/full_tests.yml index c74c8ae..4906926 100644 --- a/.github/workflows/full_tests.yml +++ b/.github/workflows/full_tests.yml @@ -21,7 +21,6 @@ jobs: run: | python -m pip install --upgrade pip pip install .[test] - pip install pytest - name: Pytest run: | pytest -v From 5f4c52845e8d39547e41ff931678d8e96f4804f0 Mon Sep 17 00:00:00 2001 From: Nicolai Haug Date: Thu, 16 May 2024 10:58:29 +0200 Subject: [PATCH 32/47] =?UTF-8?q?Reformat=20the=20codebase=20=E2=9C=A8=20?= =?UTF-8?q?=F0=9F=8D=B0=20=E2=9C=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/conf.py | 1 - notebooks/convert_project.py | 1 + notebooks/expipe-plugin-cinpla-demo.ipynb | 4 +- probes/tetrode_32_openephys.json | 2 +- requirements.txt | 8 +- src/expipe_plugin_cinpla/__init__.py | 9 +- src/expipe_plugin_cinpla/cli/__init__.py | 3 +- src/expipe_plugin_cinpla/cli/main.py | 4 +- src/expipe_plugin_cinpla/cli/process.py | 4 +- src/expipe_plugin_cinpla/cli/register.py | 12 +- src/expipe_plugin_cinpla/cli/utils.py | 11 +- src/expipe_plugin_cinpla/data_loader.py | 5 +- src/expipe_plugin_cinpla/imports.py | 3 +- .../nwbutils/cinplanwbconverter.py | 3 +- .../nwbutils/interfaces/__init__.py | 3 +- .../interfaces/openephystrackinginterface.py | 11 +- .../nwbutils/nwbwidgetsunitviewer.py | 10 +- src/expipe_plugin_cinpla/scripts/__init__.py | 3 +- .../scripts/convert_old_project.py | 10 +- src/expipe_plugin_cinpla/scripts/curation.py | 29 ++-- src/expipe_plugin_cinpla/scripts/process.py | 34 ++-- src/expipe_plugin_cinpla/scripts/register.py | 62 +++---- src/expipe_plugin_cinpla/scripts/utils.py | 23 ++- .../tools/data_processing.py | 24 +-- .../tools/registration.py | 5 +- .../tools/track_units_tools.py | 7 +- .../tools/trackunitcomparison.py | 25 ++- .../tools/trackunitmulticomparison.py | 30 ++-- src/expipe_plugin_cinpla/utils.py | 1 + src/expipe_plugin_cinpla/widgets/__init__.py | 3 +- src/expipe_plugin_cinpla/widgets/browser.py | 11 +- src/expipe_plugin_cinpla/widgets/curation.py | 17 +- src/expipe_plugin_cinpla/widgets/process.py | 6 +- src/expipe_plugin_cinpla/widgets/register.py | 40 ++--- src/expipe_plugin_cinpla/widgets/utils.py | 23 ++- src/expipe_plugin_cinpla/widgets/viewer.py | 4 +- tests/conftest.py | 6 +- tests/test_cli.py | 13 +- tests/test_convert_old_project.py | 6 +- tests/test_data/axona/DVH_2013103103.set | 154 +++++++++--------- tests/test_data/axona/DVH_2013103103_2.cut | 2 +- tests/test_data/axona/DVH_2013103103_3.cut | 2 +- tests/test_data/axona/DVH_2013103103_4.cut | 2 +- tests/test_data/axona/DVH_2013103103_8.cut | 2 +- .../experiment1_all_channels_0.events | Bin 19664 -> 19662 bytes .../experiment1_binarymsg_0.eventsbinary | Bin 30074 -> 30070 bytes .../experiment1_all_channels_0.events | Bin 25936 -> 25934 bytes .../experiment1_binarymsg_0.eventsbinary | Bin 39874 -> 39871 bytes .../2/experiment1/recording1/structure.oebin | 2 +- .../spikesorting/mountainsort4/phy/params.py | 1 + .../experiment1/recording1/structure.oebin | 2 +- .../spikesorting/mountainsort4/phy/params.py | 1 + .../experiment1/recording1/structure.oebin | 2 +- .../experiment1/recording1/structure.oebin | 2 +- tests/test_data/tetrode_32_openephys.json | 2 +- tests/test_script.py | 12 +- 56 files changed, 361 insertions(+), 301 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 7f625c8..70a3c88 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -22,7 +22,6 @@ # sys.path.insert(0, os.path.abspath('.')) import os -import re # import expipe_plugin_cinpla diff --git a/notebooks/convert_project.py b/notebooks/convert_project.py index ec98fe8..4b7ed65 100644 --- a/notebooks/convert_project.py +++ b/notebooks/convert_project.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- from pathlib import Path import expipe_plugin_cinpla diff --git a/notebooks/expipe-plugin-cinpla-demo.ipynb b/notebooks/expipe-plugin-cinpla-demo.ipynb index b37a790..177c469 100644 --- a/notebooks/expipe-plugin-cinpla-demo.ipynb +++ b/notebooks/expipe-plugin-cinpla-demo.ipynb @@ -50,7 +50,9 @@ "metadata": {}, "outputs": [], "source": [ - "expipe_plugin_cinpla.convert_old_project(old_project_path, new_project_path, probe_path=\"tetrode_32_openephys.json\", debug_n_actions=5, overwrite=True)" + "expipe_plugin_cinpla.convert_old_project(\n", + " old_project_path, new_project_path, probe_path=\"tetrode_32_openephys.json\", debug_n_actions=5, overwrite=True\n", + ")" ] }, { diff --git a/probes/tetrode_32_openephys.json b/probes/tetrode_32_openephys.json index fc1b2b2..37d723f 100644 --- a/probes/tetrode_32_openephys.json +++ b/probes/tetrode_32_openephys.json @@ -787,4 +787,4 @@ ] } ] -} \ No newline at end of file +} diff --git a/requirements.txt b/requirements.txt index 3f35eac..35d5ce0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ expipe>=0.5.1 -neuroconv>=0.4.6 ipywidgets>=8.1.1 jupyter_contrib_nbextensions>=0.7.0 jupyterlab==3.6.5 +neuroconv>=0.4.6 +neuroconv>=0.4.6 +pynwb>=2.6.0 pyopenephys>=1.2.0 -tbb>=2021.11.0 spikeinterface[full,widgets]>=0.100.0 -pynwb>=2.6.0 -neuroconv>=0.4.6 +tbb>=2021.11.0 diff --git a/src/expipe_plugin_cinpla/__init__.py b/src/expipe_plugin_cinpla/__init__.py index 0e0c0fb..63fbebe 100644 --- a/src/expipe_plugin_cinpla/__init__.py +++ b/src/expipe_plugin_cinpla/__init__.py @@ -1,7 +1,8 @@ -from .cli import CinplaPlugin -from .widgets import display_browser -from .scripts import convert_old_project - +# -*- coding: utf-8 -*- import importlib.metadata +from .cli import CinplaPlugin # noqa +from .scripts import convert_old_project # noqa +from .widgets import display_browser # noqa + __version__ = importlib.metadata.version("expipe_plugin_cinpla") diff --git a/src/expipe_plugin_cinpla/cli/__init__.py b/src/expipe_plugin_cinpla/cli/__init__.py index c526aa8..6fca81e 100644 --- a/src/expipe_plugin_cinpla/cli/__init__.py +++ b/src/expipe_plugin_cinpla/cli/__init__.py @@ -1 +1,2 @@ -from .main import CinplaPlugin +# -*- coding: utf-8 -*- +from .main import CinplaPlugin # noqa diff --git a/src/expipe_plugin_cinpla/cli/main.py b/src/expipe_plugin_cinpla/cli/main.py index cb699f7..cf2b0a3 100644 --- a/src/expipe_plugin_cinpla/cli/main.py +++ b/src/expipe_plugin_cinpla/cli/main.py @@ -1,9 +1,9 @@ +# -*- coding: utf-8 -*- import click - from expipe.cliutils.plugin import IPlugin -from .register import attach_to_register from .process import attach_to_process +from .register import attach_to_register class CinplaPlugin(IPlugin): diff --git a/src/expipe_plugin_cinpla/cli/process.py b/src/expipe_plugin_cinpla/cli/process.py index 97b3108..04cca4b 100644 --- a/src/expipe_plugin_cinpla/cli/process.py +++ b/src/expipe_plugin_cinpla/cli/process.py @@ -1,5 +1,7 @@ -import click +# -*- coding: utf-8 -*- from pathlib import Path + +import click import ruamel.yaml as yaml from expipe_plugin_cinpla.imports import project diff --git a/src/expipe_plugin_cinpla/cli/register.py b/src/expipe_plugin_cinpla/cli/register.py index b5d690e..43dae56 100644 --- a/src/expipe_plugin_cinpla/cli/register.py +++ b/src/expipe_plugin_cinpla/cli/register.py @@ -1,11 +1,17 @@ -import click -from pathlib import Path +# -*- coding: utf-8 -*- from datetime import datetime +from pathlib import Path +import click import expipe +from expipe_plugin_cinpla.cli.utils import ( + validate_adjustment, + validate_angle, + validate_depth, + validate_position, +) from expipe_plugin_cinpla.scripts import register -from expipe_plugin_cinpla.cli.utils import validate_depth, validate_position, validate_angle, validate_adjustment def attach_to_register(cli): diff --git a/src/expipe_plugin_cinpla/cli/utils.py b/src/expipe_plugin_cinpla/cli/utils.py index 08b2ec8..d982f3d 100644 --- a/src/expipe_plugin_cinpla/cli/utils.py +++ b/src/expipe_plugin_cinpla/cli/utils.py @@ -1,4 +1,7 @@ +# -*- coding: utf-8 -*- import collections +import copy + import click from expipe_plugin_cinpla.imports import * @@ -98,12 +101,12 @@ def optional_choice(ctx, param, value): assert isinstance(options, list) if value is None: if param.required: - raise ValueError('Missing option "{}"'.format(param.opts)) + raise ValueError(f'Missing option "{param.opts}"') return value if param.multiple: if len(value) == 0: if param.required: - raise ValueError('Missing option "{}"'.format(param.opts)) + raise ValueError(f'Missing option "{param.opts}"') return value if len(options) == 0: return value @@ -113,8 +116,8 @@ def optional_choice(ctx, param, value): value, ] for val in value: - if not val in options: - raise ValueError('Value "{}" not in "{}".'.format(val, options)) + if val not in options: + raise ValueError(f'Value "{val}" not in "{options}".') else: if param.multiple: return value diff --git a/src/expipe_plugin_cinpla/data_loader.py b/src/expipe_plugin_cinpla/data_loader.py index 70a854c..3324be5 100644 --- a/src/expipe_plugin_cinpla/data_loader.py +++ b/src/expipe_plugin_cinpla/data_loader.py @@ -1,8 +1,9 @@ +# -*- coding: utf-8 -*- """Utils for loading data from NWB files""" -import numpy as np -import quantities as pq import neo +import numpy as np +import quantities as pq import spikeinterface as si import spikeinterface.extractors as se diff --git a/src/expipe_plugin_cinpla/imports.py b/src/expipe_plugin_cinpla/imports.py index f5e5bd6..3b4af5e 100644 --- a/src/expipe_plugin_cinpla/imports.py +++ b/src/expipe_plugin_cinpla/imports.py @@ -1,6 +1,7 @@ -import expipe +# -*- coding: utf-8 -*- from pathlib import Path +import expipe local_root, _ = expipe.config._load_local_config(Path.cwd()) if local_root is not None: diff --git a/src/expipe_plugin_cinpla/nwbutils/cinplanwbconverter.py b/src/expipe_plugin_cinpla/nwbutils/cinplanwbconverter.py index 090bb2e..1eea5a6 100644 --- a/src/expipe_plugin_cinpla/nwbutils/cinplanwbconverter.py +++ b/src/expipe_plugin_cinpla/nwbutils/cinplanwbconverter.py @@ -1,6 +1,7 @@ -from probeinterface import ProbeGroup +# -*- coding: utf-8 -*- from neuroconv import NWBConverter from neuroconv.datainterfaces import OpenEphysRecordingInterface +from probeinterface import ProbeGroup from .interfaces.openephystrackinginterface import OpenEphysTrackingInterface diff --git a/src/expipe_plugin_cinpla/nwbutils/interfaces/__init__.py b/src/expipe_plugin_cinpla/nwbutils/interfaces/__init__.py index ab3bebe..4f94197 100644 --- a/src/expipe_plugin_cinpla/nwbutils/interfaces/__init__.py +++ b/src/expipe_plugin_cinpla/nwbutils/interfaces/__init__.py @@ -1 +1,2 @@ -from .openephystrackinginterface import OpenEphysTrackingInterface +# -*- coding: utf-8 -*- +from .openephystrackinginterface import OpenEphysTrackingInterface # noqa diff --git a/src/expipe_plugin_cinpla/nwbutils/interfaces/openephystrackinginterface.py b/src/expipe_plugin_cinpla/nwbutils/interfaces/openephystrackinginterface.py index 9c406d9..f11cb0a 100644 --- a/src/expipe_plugin_cinpla/nwbutils/interfaces/openephystrackinginterface.py +++ b/src/expipe_plugin_cinpla/nwbutils/interfaces/openephystrackinginterface.py @@ -1,14 +1,11 @@ +# -*- coding: utf-8 -*- import warnings + import numpy as np import pyopenephys - -from pynwb.behavior import ( - Position, - SpatialSeries, -) - from neuroconv import BaseDataInterface from neuroconv.utils import FolderPathType +from pynwb.behavior import Position, SpatialSeries class OpenEphysTrackingInterface(BaseDataInterface): @@ -89,7 +86,7 @@ def add_to_nwbfile( ) start_times = times[rising].rescale("s").magnitude stop_times = times[falling].rescale("s").magnitude - for start, stop in zip(start_times, stop_times): + for start, stop in zip(start_times, stop_times, strict=False): nwbfile.add_trial( start_time=start, stop_time=stop, diff --git a/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py b/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py index 041bbc0..eba833e 100644 --- a/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py +++ b/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py @@ -1,10 +1,10 @@ +# -*- coding: utf-8 -*- from functools import partial -import numpy as np -import ipywidgets as widgets -from ipywidgets import Layout, interactive_output +import ipywidgets as widgets import matplotlib.pyplot as plt - +import numpy as np +from ipywidgets import Layout, interactive_output color_wheel = plt.rcParams["axes.prop_cycle"].by_key()["color"] @@ -263,8 +263,8 @@ def show_unit_rate_maps(self, unit_index=None, spatial_series_selector=None, num def get_custom_spec(): - from pynwb.misc import Units from nwbwidgets.view import default_neurodata_vis_spec + from pynwb.misc import Units custom_neurodata_vis_spec = default_neurodata_vis_spec.copy() diff --git a/src/expipe_plugin_cinpla/scripts/__init__.py b/src/expipe_plugin_cinpla/scripts/__init__.py index 3ca085f..085d970 100644 --- a/src/expipe_plugin_cinpla/scripts/__init__.py +++ b/src/expipe_plugin_cinpla/scripts/__init__.py @@ -1 +1,2 @@ -from .convert_old_project import convert_old_project +# -*- coding: utf-8 -*- +from .convert_old_project import convert_old_project # noqa diff --git a/src/expipe_plugin_cinpla/scripts/convert_old_project.py b/src/expipe_plugin_cinpla/scripts/convert_old_project.py index b0899b3..a0884d1 100644 --- a/src/expipe_plugin_cinpla/scripts/convert_old_project.py +++ b/src/expipe_plugin_cinpla/scripts/convert_old_project.py @@ -1,14 +1,14 @@ +# -*- coding: utf-8 -*- import shutil +import time from datetime import datetime, timedelta from pathlib import Path -import time import expipe -from .utils import _get_data_path -from .register import convert_to_nwb, register_entity -from .process import process_ecephys from .curation import SortingCurator +from .process import process_ecephys +from .register import convert_to_nwb, register_entity def convert_old_project( @@ -244,7 +244,7 @@ def convert_old_project( t_stop_all = time.perf_counter() print(f"\nTotal time: {t_stop_all - t_start_all:.2f} s") - done_msg = f"ALL DONE!" + done_msg = "ALL DONE!" delimeter = "*" * len(done_msg) print(f"\n{delimeter}\n{done_msg}\n{delimeter}\n") print(f"Successful: {len(actions_to_convert) - len(actions_failed)}\n") diff --git a/src/expipe_plugin_cinpla/scripts/curation.py b/src/expipe_plugin_cinpla/scripts/curation.py index f945806..874af48 100644 --- a/src/expipe_plugin_cinpla/scripts/curation.py +++ b/src/expipe_plugin_cinpla/scripts/curation.py @@ -1,12 +1,16 @@ -import shutil +# -*- coding: utf-8 -*- import json -from pathlib import Path -import numpy as np +import shutil import warnings +import numpy as np import spikeinterface as si -from .utils import _get_data_path, add_units_from_waveform_extractor, compute_and_set_unit_groups +from .utils import ( + _get_data_path, + add_units_from_waveform_extractor, + compute_and_set_unit_groups, +) warnings.filterwarnings("ignore", category=ResourceWarning) warnings.filterwarnings("ignore", category=DeprecationWarning) @@ -76,10 +80,10 @@ def load_raw_sorting(self, sorter): unit_table_path=raw_units_path, electrical_series_path="acquisition/ElectricalSeries", ) + return sorting_raw except Exception as e: print(f"Could not load raw sorting for {sorter}. Using None: {e}") - sorting_raw = None - return sorting_raw + return None def load_raw_units(self, sorter): from pynwb import NWBHDF5IO @@ -90,9 +94,10 @@ def load_raw_units(self, sorter): nwbfile = self.io.read() try: units = _retrieve_unit_table_pynwb(nwbfile, raw_units_path) - except: - units = None - return units + return units + except Exception as e: + print(f"Could not load raw units for {sorter}. Using None: {e}") + return None def load_main_units(self): from pynwb import NWBHDF5IO @@ -137,9 +142,9 @@ def apply_curation(self, sorter, curated_sorting): print(f"No curation was performed for {sorter}. Using raw sorting") self.curated_we = None else: + import spikeinterface.curation as sc import spikeinterface.postprocessing as spost import spikeinterface.qualitymetrics as sqm - import spikeinterface.curation as sc recording = self.load_processed_recording(sorter) @@ -200,7 +205,7 @@ def get_sortingview_link(self, sorter): visualization_json = self.si_path / sorter / "sortingview_links.json" if not visualization_json.is_file(): return "Sorting view link not found." - with open(visualization_json, "r") as f: + with open(visualization_json) as f: sortingview_links = json.load(f) return sortingview_links["raw"] @@ -221,7 +226,7 @@ def apply_sortingview_curation(self, sorter, curated_link): uri = curation_str[curation_str.find("sha1://") : -2] sorting_curated = sc.apply_sortingview_curation(sorting_raw, uri_or_json=uri) # exclude noise - good_units = sorting_curated.unit_ids[sorting_curated.get_property("noise") == False] + good_units = sorting_curated.unit_ids[sorting_curated.get_property("noise") == False] # noqa E712 # create single property for SUA and MUA sorting_curated = sorting_curated.select_units(good_units) self.apply_curation(sorter, sorting_curated) diff --git a/src/expipe_plugin_cinpla/scripts/process.py b/src/expipe_plugin_cinpla/scripts/process.py index e293397..3b941e0 100644 --- a/src/expipe_plugin_cinpla/scripts/process.py +++ b/src/expipe_plugin_cinpla/scripts/process.py @@ -1,14 +1,13 @@ -import shutil +# -*- coding: utf-8 -*- import contextlib -import time import json -import os +import shutil +import time + import numpy as np from expipe_plugin_cinpla.scripts import utils -from ..nwbutils.cinplanwbconverter import CinplaNWBConverter - def process_ecephys( project, @@ -33,18 +32,17 @@ def process_ecephys( verbose=True, ): import warnings + import spikeinterface as si + import spikeinterface.exporters as sexp import spikeinterface.extractors as se - import spikeinterface.preprocessing as spre - import spikeinterface.sorters as ss import spikeinterface.postprocessing as spost + import spikeinterface.preprocessing as spre import spikeinterface.qualitymetrics as sqm - import spikeinterface.exporters as sexp + import spikeinterface.sorters as ss import spikeinterface.widgets as sw - - from pynwb import NWBHDF5IO - from neuroconv.tools.spikeinterface import add_recording + from pynwb import NWBHDF5IO from .utils import add_units_from_waveform_extractor, compute_and_set_unit_groups @@ -226,11 +224,7 @@ def process_ecephys( **spikesorter_params, ) except Exception as e: - try: - shutil.rmtree(output_folder) - except: - if verbose: - print(f"\tCould not tmp processing folder: {output_folder}") + shutil.rmtree(output_folder) raise Exception(f"Spike sorting failed:\n\n{e}") if verbose: print(f"\tFound {len(sorting.get_unit_ids())} units!") @@ -269,7 +263,7 @@ def process_ecephys( if verbose: print("\tExporting to phy") - phy_folder = output_base_folder / f"phy" + phy_folder = output_base_folder / "phy" if phy_folder.is_dir(): shutil.rmtree(phy_folder) sexp.export_to_phy( @@ -383,7 +377,7 @@ def process_ecephys( if not provenance_file.is_file(): (output_base_folder / "recording_cmr").mkdir(parents=True, exist_ok=True) recording_cmr.dump_to_json(output_base_folder / "recording_cmr" / "provenance.json") - with open(output_base_folder / "recording_cmr" / "provenance.json", "r") as f: + with open(output_base_folder / "recording_cmr" / "provenance.json") as f: provenance = json.load(f) provenance_str = json.dumps(provenance) provenance_str = provenance_str.replace("main_tmp.nwb", "main.nwb") @@ -393,9 +387,9 @@ def process_ecephys( shutil.rmtree(output_base_folder / "recording_cmr") try: nwb_path_tmp.unlink() - except: + except Exception as e: print(f"Could not remove: {nwb_path_tmp}") - raise Exception + raise e if verbose: print("\tSaved to NWB: ", nwb_path) diff --git a/src/expipe_plugin_cinpla/scripts/register.py b/src/expipe_plugin_cinpla/scripts/register.py index 3440992..e186ca7 100644 --- a/src/expipe_plugin_cinpla/scripts/register.py +++ b/src/expipe_plugin_cinpla/scripts/register.py @@ -1,21 +1,21 @@ +# -*- coding: utf-8 -*- import shutil -import warnings import time -import numpy as np -from pathlib import Path +import warnings from datetime import datetime -import pytz -import quantities as pq - -import pyopenephys -import probeinterface as pi +from pathlib import Path import expipe +import numpy as np +import probeinterface as pi +import pyopenephys +import pytz +import quantities as pq def convert_to_nwb(project, action, openephys_path, probe_path, entity_id, user, include_events, overwrite): - from .utils import _make_data_path from ..nwbutils.cinplanwbconverter import CinplaNWBConverter + from .utils import _make_data_path nwb_path = _make_data_path(action, overwrite) @@ -162,8 +162,9 @@ def register_openephys_recording( if delete_raw_data: try: shutil.rmtree(openephys_path) - except: + except Exception as e: print("Could not remove: ", openephys_path) + raise e ### Adjustment ### @@ -181,7 +182,11 @@ def register_openephys_recording( def register_adjustment(project, entity_id, date, adjustment, user, depth, yes): - from expipe_plugin_cinpla.scripts.utils import position_to_dict, get_depth_from_surgery, query_yes_no + from expipe_plugin_cinpla.scripts.utils import ( + get_depth_from_surgery, + position_to_dict, + query_yes_no, + ) user = user or project.config.get("username") if user is None: @@ -203,7 +208,7 @@ def register_adjustment(project, entity_id, date, adjustment, user, depth, yes): try: action = project.actions[action_id] init = False - except KeyError as e: + except KeyError: action = project.create_action(action_id) init = True @@ -213,7 +218,7 @@ def register_adjustment(project, entity_id, date, adjustment, user, depth, yes): if name.endswith("adjustment"): deltas.append(int(name.split("_")[0])) index = max(deltas) + 1 - prev_depth = action.modules["{:03d}_adjustment".format(max(deltas))].contents["depth"] + prev_depth = action.modules[f"{max(deltas):03d}_adjustment"].contents["depth"] if init: if len(depth) > 0: prev_depth = position_to_dict(depth) @@ -221,14 +226,14 @@ def register_adjustment(project, entity_id, date, adjustment, user, depth, yes): prev_depth = get_depth_from_surgery(project=project, entity_id=entity_id) index = 0 - name = "{:03d}_adjustment".format(index) + name = f"{index:03d}_adjustment" if not isinstance(prev_depth, dict): print("Unable to retrieve previous depth.") return adjustment_dict = {key: dict() for key in prev_depth} current = {key: dict() for key in prev_depth} for key, probe, val, unit in adjustment: - pos_key = "probe_{}".format(probe) + pos_key = f"probe_{probe}" adjustment_dict[key][pos_key] = pq.Quantity(val, unit) for key, val in prev_depth.items(): for pos_key in prev_depth[key]: @@ -243,13 +248,13 @@ def last_probe(x): correct = query_yes_no( "Correct adjustment?: \n" + " ".join( - "{} {} = {}\n".format(key, pos_key, val[pos_key]) + f"{key} {pos_key} = {val[pos_key]}\n" for key, val in adjustment_dict.items() for pos_key in sorted(val, key=lambda x: last_probe(x)) ) + "New depth: \n" + " ".join( - "{} {} = {}\n".format(key, pos_key, val[pos_key]) + f"{key} {pos_key} = {val[pos_key]}\n" for key, val in current.items() for pos_key in sorted(val, key=lambda x: last_probe(x)) ), @@ -262,13 +267,13 @@ def last_probe(x): print( "Registering adjustment: \n" + " ".join( - "{} {} = {}\n".format(key, pos_key, val[pos_key]) + f"{key} {pos_key} = {val[pos_key]}\n" for key, val in adjustment_dict.items() for pos_key in sorted(val, key=lambda x: last_probe(x)) ) + " New depth: \n" + " ".join( - "{} {} = {}\n".format(key, pos_key, val[pos_key]) + f"{key} {pos_key} = {val[pos_key]}\n" for key, val in current.items() for pos_key in sorted(val, key=lambda x: last_probe(x)) ) @@ -300,10 +305,7 @@ def register_annotation( templates, correct_depth_answer, ): - from expipe_plugin_cinpla.scripts.utils import ( - register_templates, - register_depth, - ) + from expipe_plugin_cinpla.scripts.utils import register_depth, register_templates user = user or project.config.get("username") action = project.actions[action_id] @@ -333,9 +335,7 @@ def register_annotation( print("Registering message", message) action.create_message(text=message, user=user, datetime=datetime.now()) if depth: - correct_depth = register_depth( - project=project, action=action, depth=depth, answer=correct_depth_answer, overwrite=True - ) + _ = register_depth(project=project, action=action, depth=depth, answer=correct_depth_answer, overwrite=True) ### Entity ### @@ -389,7 +389,7 @@ def register_entity( if isinstance(val, (str, float, int)): entity.modules["register"][key]["value"] = val elif isinstance(val, tuple): - if not None in val: + if None not in val: entity.modules["register"][key] = pq.Quantity(val[0], val[1]) elif isinstance(val, type(None)): pass @@ -456,15 +456,15 @@ def register_surgery( for key, probe, x, y, z, unit in position: action.modules[key] = {} - probe_key = "probe_{}".format(probe) + probe_key = f"probe_{probe}" action.modules[key][probe_key] = {} - print("Registering position " + "{} {}: x={}, y={}, z={} {}".format(key, probe, x, y, z, unit)) + print("Registering position " + f"{key} {probe}: x={x}, y={y}, z={z} {unit}") action.modules[key][probe_key]["position"] = pq.Quantity([x, y, z], unit) for key, probe, ang, unit in angle: - probe_key = "probe_{}".format(probe) + probe_key = f"probe_{probe}" if probe_key not in action.modules[key]: action.modules[key][probe_key] = {} - print("Registering angle " + "{} {}: angle={} {}".format(key, probe, ang, unit)) + print("Registering angle " + f"{key} {probe}: angle={ang} {unit}") action.modules[key][probe_key]["angle"] = pq.Quantity(ang, unit) diff --git a/src/expipe_plugin_cinpla/scripts/utils.py b/src/expipe_plugin_cinpla/scripts/utils.py index 4c52258..7cdc835 100644 --- a/src/expipe_plugin_cinpla/scripts/utils.py +++ b/src/expipe_plugin_cinpla/scripts/utils.py @@ -1,12 +1,12 @@ -import sys +# -*- coding: utf-8 -*- import shutil +import sys from datetime import datetime, timedelta from pathlib import Path, PureWindowsPath -import numpy as np - -import quantities as pq import expipe +import numpy as np +import quantities as pq nwb_main_groups = ["acquisition", "analysis", "processing", "epochs", "general"] tmp_phy_folders = [".klustakwik2", ".phy", ".spikedetect"] @@ -54,7 +54,7 @@ def deltadate(adjustdate, regdate): def position_to_dict(depth): position = {d[0]: dict() for d in depth} for key, num, val, unit in depth: - probe_key = "probe_{}".format(num) + probe_key = f"probe_{num}" position[key][probe_key] = pq.Quantity(val, unit) return position @@ -85,7 +85,6 @@ def write_python(path, dict): def get_depth_from_surgery(project, entity_id): - index = 0 surgery = project.actions[entity_id + "-surgery-implantation"] position = {} for key, module in surgery.modules.items(): @@ -97,7 +96,7 @@ def get_depth_from_surgery(project, entity_id): for key, groups in position.items(): for group, pos in groups.items(): if not isinstance(pos, pq.Quantity): - raise ValueError("Depth of implant " + '"{} {} = {}"'.format(key, group, pos) + " not recognized") + raise ValueError("Depth of implant " + f'"{key} {group} = {pos}"' + " not recognized") position[key][group] = pos.astype(float)[2] # index 2 = z return position @@ -106,7 +105,7 @@ def get_depth_from_adjustment(project, action, entity_id): DTIME_FORMAT = expipe.core.datetime_format try: adjustments = project.actions[entity_id + "-adjustment"] - except KeyError as e: + except KeyError: return None, None adjusts = {} for adjust in adjustments.modules.values(): @@ -130,7 +129,7 @@ def register_depth(project, action, depth=None, answer=None, overwrite=False): adjustdate = None else: curr_depth, adjustdate = get_depth_from_adjustment(project, action, action.entities[0]) - print("Adjust date time: {}\n".format(adjustdate)) + print(f"Adjust date time: {adjustdate}\n") if curr_depth is None: print("Cannot find current depth from adjustments.") return False @@ -140,7 +139,7 @@ def last_num(x): print( "".join( - "Depth: {} {} = {}\n".format(key, probe_key, val[probe_key]) + f"Depth: {key} {probe_key} = {val[probe_key]}\n" for key, val in curr_depth.items() for probe_key in sorted(val, key=lambda x: last_num(x)) ) @@ -186,8 +185,8 @@ def _get_data_path(action): # data_path = action.data['main'] data_path = project_path / str(Path(PureWindowsPath(action.data["main"]))) return data_path - except: - return + except Exception: + return None def register_templates(action, templates, overwrite=False): diff --git a/src/expipe_plugin_cinpla/tools/data_processing.py b/src/expipe_plugin_cinpla/tools/data_processing.py index efcb403..04a0d95 100644 --- a/src/expipe_plugin_cinpla/tools/data_processing.py +++ b/src/expipe_plugin_cinpla/tools/data_processing.py @@ -1,20 +1,20 @@ +# -*- coding: utf-8 -*- # This is work in progress, +import pathlib +import warnings + +import expipe import numpy as np +import spatial_maps as sp + from expipe_plugin_cinpla.data_loader import ( - load_epochs, get_channel_groups, - load_spiketrains, - load_unit_annotations, - load_leds, get_duration, + load_epochs, + load_leds, load_lfp, - get_sample_rate, - get_data_path, + load_spiketrains, ) -import pathlib -import expipe -import spatial_maps as sp -import warnings def view_active_channels(action, sorter): @@ -174,8 +174,8 @@ def check_valid_tracking(x, y, box_size): if x.min() < 0 or x.max() > box_size[0] or y.min() < 0 or y.max() > box_size[1]: warnings.warn( "Invalid values found " - + "outside box: min [x, y] = [{}, {}], ".format(x.min(), y.min()) - + "max [x, y] = [{}, {}]".format(x.max(), y.max()) + + f"outside box: min [x, y] = [{x.min()}, {y.min()}], " + + f"max [x, y] = [{x.max()}, {y.max()}]" ) diff --git a/src/expipe_plugin_cinpla/tools/registration.py b/src/expipe_plugin_cinpla/tools/registration.py index dbff094..80bd395 100644 --- a/src/expipe_plugin_cinpla/tools/registration.py +++ b/src/expipe_plugin_cinpla/tools/registration.py @@ -1,6 +1,7 @@ +# -*- coding: utf-8 -*- import os -import shutil import pathlib +import shutil def store_notebook(action, notebook_path): @@ -9,7 +10,7 @@ def store_notebook(action, notebook_path): notebook_output_path = action.data_path("notebook") shutil.copy(notebook_path, notebook_output_path) # As HTML - os.system("jupyter nbconvert --to html {}".format(notebook_path)) + os.system(f"jupyter nbconvert --to html {notebook_path}") html_path = notebook_path.with_suffix(".html") action.data["html"] = html_path.name html_output_path = action.data_path("html") diff --git a/src/expipe_plugin_cinpla/tools/track_units_tools.py b/src/expipe_plugin_cinpla/tools/track_units_tools.py index fc1ce64..b512d4e 100644 --- a/src/expipe_plugin_cinpla/tools/track_units_tools.py +++ b/src/expipe_plugin_cinpla/tools/track_units_tools.py @@ -1,8 +1,9 @@ +# -*- coding: utf-8 -*- +import matplotlib.pyplot as plt import numpy as np import pandas as pd -from scipy.optimize import linear_sum_assignment from matplotlib import gridspec -import matplotlib.pyplot as plt +from scipy.optimize import linear_sum_assignment def dissimilarity(template_0, template_1): @@ -181,7 +182,7 @@ def make_hungarian_match(dissimilarity_scores, max_dissimilarity): hungarian_match_21 = pd.Series(index=unit2_ids, dtype="int64") hungarian_match_21[:] = -1 - for i1, i2 in zip(inds1, inds2): + for i1, i2 in zip(inds1, inds2, strict=False): u1 = unit1_ids[i1] u2 = unit2_ids[i2] if dissimilarity_scores.at[u1, u2] < max_dissimilarity: diff --git a/src/expipe_plugin_cinpla/tools/trackunitcomparison.py b/src/expipe_plugin_cinpla/tools/trackunitcomparison.py index 980a33f..1997b0a 100644 --- a/src/expipe_plugin_cinpla/tools/trackunitcomparison.py +++ b/src/expipe_plugin_cinpla/tools/trackunitcomparison.py @@ -1,10 +1,23 @@ -from .track_units_tools import dissimilarity_weighted, make_possible_match, make_best_match, make_hungarian_match -from expipe_plugin_cinpla.data_loader import get_data_path, load_spiketrains, get_channel_groups -import matplotlib.pylab as plt -import numpy as np +# -*- coding: utf-8 -*- from pathlib import Path + +import numpy as np import pandas as pd +from expipe_plugin_cinpla.data_loader import ( + get_channel_groups, + get_data_path, + load_spiketrains, +) + +from .track_units_tools import ( + dissimilarity_weighted, + make_best_match, + make_hungarian_match, + make_possible_match, +) + + class TrackingSession: """ Base class shared by SortingComparison and GroundTruthComparison @@ -60,7 +73,7 @@ def __init__( self._do_dissimilarity(channel_group) self._do_matching(channel_group) elif self._verbose: - print(f'Found no units in {channel_group}') + print(f"Found no units in {channel_group}") def save_dissimilarity_matrix(self, path=None): path = path or Path.cwd() @@ -77,7 +90,7 @@ def session_0_name(self): @property def session_1_name(self): return self.name_list[1] - + def make_dissimilary_matrix(self, channel_group): templates_0, templates_1 = self.templates[channel_group] diss_matrix = np.zeros((len(templates_0), len(templates_1))) diff --git a/src/expipe_plugin_cinpla/tools/trackunitmulticomparison.py b/src/expipe_plugin_cinpla/tools/trackunitmulticomparison.py index e24eab9..7ba2ec5 100644 --- a/src/expipe_plugin_cinpla/tools/trackunitmulticomparison.py +++ b/src/expipe_plugin_cinpla/tools/trackunitmulticomparison.py @@ -1,16 +1,24 @@ -import numpy as np -import networkx as nx -import yaml -from .trackunitcomparison import TrackingSession -from expipe_plugin_cinpla.data_loader import get_data_path, get_channel_groups, load_spiketrains -from .track_units_tools import plot_template -import matplotlib.pylab as plt -from tqdm import tqdm +# -*- coding: utf-8 -*- +import datetime import uuid -from matplotlib import gridspec from collections import defaultdict from pathlib import Path -import datetime + +import matplotlib.pylab as plt +import networkx as nx +import numpy as np +import yaml +from matplotlib import gridspec +from tqdm import tqdm + +from expipe_plugin_cinpla.data_loader import ( + get_channel_groups, + get_data_path, + load_spiketrains, +) + +from .track_units_tools import plot_template +from .trackunitcomparison import TrackingSession class TrackMultipleSessions: @@ -191,7 +199,7 @@ def load_graphs(self): for path in self.data_path.iterdir(): if path.name.startswith("graph-group") and path.suffix == ".yaml": ch = path.stem.split("-")[-1] - with open(path, "r") as f: + with open(path) as f: self.graphs[ch] = yaml.load(f, Loader=yaml.Loader) def identify_units(self): diff --git a/src/expipe_plugin_cinpla/utils.py b/src/expipe_plugin_cinpla/utils.py index e346d67..ec9311d 100644 --- a/src/expipe_plugin_cinpla/utils.py +++ b/src/expipe_plugin_cinpla/utils.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- from expipe.backends.filesystem import yaml_dump diff --git a/src/expipe_plugin_cinpla/widgets/__init__.py b/src/expipe_plugin_cinpla/widgets/__init__.py index f386856..fb60c64 100644 --- a/src/expipe_plugin_cinpla/widgets/__init__.py +++ b/src/expipe_plugin_cinpla/widgets/__init__.py @@ -1 +1,2 @@ -from .browser import display_browser +# -*- coding: utf-8 -*- +from .browser import display_browser # noqa diff --git a/src/expipe_plugin_cinpla/widgets/browser.py b/src/expipe_plugin_cinpla/widgets/browser.py index b70db22..f689321 100644 --- a/src/expipe_plugin_cinpla/widgets/browser.py +++ b/src/expipe_plugin_cinpla/widgets/browser.py @@ -1,16 +1,17 @@ -import IPython.display as ipd +# -*- coding: utf-8 -*- import expipe +import IPython.display as ipd +from .curation import CurationView +from .process import process_ecephys_view from .register import ( - register_openephys_view, register_adjustment_view, register_annotate_view, register_entity_view, - register_surgery_view, + register_openephys_view, register_perfuse_view, + register_surgery_view, ) -from .process import process_ecephys_view -from .curation import CurationView from .viewer import NwbViewer diff --git a/src/expipe_plugin_cinpla/widgets/curation.py b/src/expipe_plugin_cinpla/widgets/curation.py index 2ef2eb2..b6b4a43 100644 --- a/src/expipe_plugin_cinpla/widgets/curation.py +++ b/src/expipe_plugin_cinpla/widgets/curation.py @@ -1,15 +1,14 @@ -import ipywidgets -import pandas as pd +# -*- coding: utf-8 -*- from collections import OrderedDict -import expipe -import expipe.config +import ipywidgets +import pandas as pd from expipe_plugin_cinpla.scripts import curation from expipe_plugin_cinpla.scripts.utils import _get_data_path -from .utils import BaseViewWithLog, required_values_filled -from ..utils import dump_project_config +from ..utils import dump_project_config +from .utils import BaseViewWithLog, required_values_filled default_qms = [ dict(name="isi_violations_ratio", sign="<", threshold=0.5), @@ -47,7 +46,11 @@ class CurationView(BaseViewWithLog): def __init__(self, project): from nwbwidgets import nwb2widget from pynwb.misc import Units - from ..nwbutils.nwbwidgetsunitviewer import UnitWaveformsWidget, UnitRateMapWidget + + from ..nwbutils.nwbwidgetsunitviewer import ( + UnitRateMapWidget, + UnitWaveformsWidget, + ) custom_raw_unit_vis = { Units: OrderedDict({"Raw Waveforms": UnitWaveformsWidget, "Rate Maps": UnitRateMapWidget}) diff --git a/src/expipe_plugin_cinpla/widgets/process.py b/src/expipe_plugin_cinpla/widgets/process.py index 65473b8..6897fef 100644 --- a/src/expipe_plugin_cinpla/widgets/process.py +++ b/src/expipe_plugin_cinpla/widgets/process.py @@ -1,5 +1,8 @@ +# -*- coding: utf-8 -*- import ast + from expipe_plugin_cinpla.scripts import process + from .utils import BaseViewWithLog metric_names = [ @@ -25,8 +28,9 @@ def process_ecephys_view(project): import ipywidgets import spikeinterface.sorters as ss - from .utils import SearchSelectMultiple, required_values_filled, ParameterSelectList + from ..scripts.utils import _get_data_path + from .utils import ParameterSelectList, SearchSelectMultiple, required_values_filled all_actions = project.actions diff --git a/src/expipe_plugin_cinpla/widgets/register.py b/src/expipe_plugin_cinpla/widgets/register.py index 6dec24b..7b7aa75 100644 --- a/src/expipe_plugin_cinpla/widgets/register.py +++ b/src/expipe_plugin_cinpla/widgets/register.py @@ -1,18 +1,17 @@ +# -*- coding: utf-8 -*- from pathlib import Path + from expipe_plugin_cinpla.scripts import register -from .utils import BaseViewWithLog + from ..utils import dump_project_config +from .utils import BaseViewWithLog ### Open Ephys recording ### def register_openephys_view(project): import ipywidgets - from .utils import ( - MultiInput, - required_values_filled, - none_if_empty, - split_tags, - ) + + from .utils import MultiInput, none_if_empty, required_values_filled, split_tags # left column layout_auto = ipywidgets.Layout(width="300px") @@ -127,12 +126,8 @@ def on_register(change): ### Adjustment ### def register_adjustment_view(project): import ipywidgets - from .utils import ( - DateTimePicker, - MultiInput, - required_values_filled, - SearchSelect, - ) + + from .utils import DateTimePicker, MultiInput, SearchSelect, required_values_filled entity_id = SearchSelect(options=project.entities, description="*Entities") user = ipywidgets.Text(placeholder="*User", value=project.config.get("username")) @@ -182,11 +177,12 @@ def on_register(change): ### Annotation ### def register_annotate_view(project): import ipywidgets + from .utils import ( DateTimePicker, MultiInput, - required_values_filled, SearchSelectMultiple, + required_values_filled, split_tags, ) @@ -235,13 +231,13 @@ def on_register(change): ### Entity ### def register_entity_view(project): import ipywidgets + from .utils import ( DatePicker, SearchSelectMultiple, - required_values_filled, none_if_empty, + required_values_filled, split_tags, - make_output_and_show, ) entity_id = ipywidgets.Text(placeholder="*Entity id") @@ -287,14 +283,15 @@ def on_register(change): ### Surgery ### def register_surgery_view(project): import ipywidgets + from .utils import ( DatePicker, MultiInput, + SearchSelect, SearchSelectMultiple, - required_values_filled, none_if_empty, + required_values_filled, split_tags, - SearchSelect, ) entity_id = SearchSelect(options=project.entities, description="*Entities") @@ -350,14 +347,13 @@ def on_register(change): ### PERFUSION ### def register_perfuse_view(project): import ipywidgets + from .utils import ( DatePicker, - MultiInput, + SearchSelect, SearchSelectMultiple, - required_values_filled, none_if_empty, - split_tags, - SearchSelect, + required_values_filled, ) entity_id = SearchSelect(options=project.entities, description="*Entities") diff --git a/src/expipe_plugin_cinpla/widgets/utils.py b/src/expipe_plugin_cinpla/widgets/utils.py index 35d1e0c..39f9ceb 100644 --- a/src/expipe_plugin_cinpla/widgets/utils.py +++ b/src/expipe_plugin_cinpla/widgets/utils.py @@ -1,9 +1,11 @@ -import ipywidgets -import numpy as np +# -*- coding: utf-8 -*- import datetime as dt -import expipe import warnings +import expipe +import ipywidgets +import numpy as np + warnings.filterwarnings("ignore", category=DeprecationWarning) @@ -217,7 +219,7 @@ def value(self): for ch in self.children: keys.append(ch.description) values.append(ch.value) - return dict(zip(keys, values)) + return dict(zip(keys, values, strict=False)) class DateTimePicker(ipywidgets.HBox): @@ -277,9 +279,10 @@ def __init__(self, filetype=None, initialdir=None, *args, **kwargs): @staticmethod def select_file(self): - from tkfilebrowser import askopenfilename from tkinter import Tk + from tkfilebrowser import askopenfilename + # Create Tk root root = Tk() # Hide the main window @@ -294,7 +297,7 @@ def select_file(self): name = ft[1:].capitalize() result = askopenfilename( defaultextension=ft, - filetypes=[("{} file".format(name), "*{}".format(ft)), ("All files", "*.*")], + filetypes=[(f"{name} file", f"*{ft}"), ("All files", "*.*")], initialdir=self.initialdir, ) self.file = result if len(result) > 0 else "" @@ -347,9 +350,10 @@ def on_text_change(change): @staticmethod def select_file(self): - from tkfilebrowser import askopenfilenames from tkinter import Tk + from tkfilebrowser import askopenfilenames + # Create Tk root root = Tk() # Hide the main window @@ -364,7 +368,7 @@ def select_file(self): name = ft[1:].capitalize() self.files = askopenfilenames( defaultextension=ft, - filetypes=[("{} file".format(name), "*{}".format(ft)), ("All files", "*.*")], + filetypes=[(f"{name} file", f"*{ft}"), ("All files", "*.*")], initialdir=self.initialdir, ) else: @@ -399,9 +403,10 @@ def __init__(self, initialdir=None, *args, **kwargs): @staticmethod def select_directories(self): - from tkfilebrowser import askopendirnames from tkinter import Tk + from tkfilebrowser import askopendirnames + # Create Tk root root = Tk() # Hide the main window diff --git a/src/expipe_plugin_cinpla/widgets/viewer.py b/src/expipe_plugin_cinpla/widgets/viewer.py index 1324068..2cc710b 100644 --- a/src/expipe_plugin_cinpla/widgets/viewer.py +++ b/src/expipe_plugin_cinpla/widgets/viewer.py @@ -1,6 +1,6 @@ -import ipywidgets - +# -*- coding: utf-8 -*- import expipe +import ipywidgets from ..nwbutils.nwbwidgetsunitviewer import get_custom_spec from ..scripts.utils import _get_data_path diff --git a/tests/conftest.py b/tests/conftest.py index 8e45c3e..c8c4092 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,10 @@ -import pytest -import expipe +# -*- coding: utf-8 -*- import shutil from pathlib import Path +import expipe +import pytest + from expipe_plugin_cinpla.utils import dump_project_config TEST_DATA_PATH = Path(__file__).parent / "test_data" diff --git a/tests/test_cli.py b/tests/test_cli.py index 50377ff..839d4d6 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,15 +1,15 @@ -import pytest +# -*- coding: utf-8 -*- import time + import click -from pathlib import Path -from click.testing import CliRunner -import quantities as pq import numpy as np +import pytest +import quantities as pq +import spikeinterface.extractors as se +from click.testing import CliRunner from expipe_plugin_cinpla.cli import CinplaPlugin -import spikeinterface.extractors as se - @click.group() @click.pass_context @@ -23,7 +23,6 @@ def cli(ctx): def run_command(command_list, inp=None): runner = CliRunner() command_list = [str(c) for c in command_list] - # print(" ".join(command_list)) result = runner.invoke(cli, command_list, input=inp) if result.exit_code != 0: print(result.output) diff --git a/tests/test_convert_old_project.py b/tests/test_convert_old_project.py index a9307d3..9e5ab99 100644 --- a/tests/test_convert_old_project.py +++ b/tests/test_convert_old_project.py @@ -1,9 +1,11 @@ +# -*- coding: utf-8 -*- from pathlib import Path -from pynwb import NWBHDF5IO import expipe -from expipe_plugin_cinpla.scripts.utils import _get_data_path +from pynwb import NWBHDF5IO + from expipe_plugin_cinpla import convert_old_project +from expipe_plugin_cinpla.scripts.utils import _get_data_path test_folder = Path(__file__).parent old_project_path = test_folder / "test_data" / "old_project" diff --git a/tests/test_data/axona/DVH_2013103103.set b/tests/test_data/axona/DVH_2013103103.set index 36a0596..6cbbde8 100644 --- a/tests/test_data/axona/DVH_2013103103.set +++ b/tests/test_data/axona/DVH_2013103103.set @@ -1,8 +1,8 @@ trial_date Thursday, 31 Oct 2013 trial_time 17:20:11 -experimenter -comments -duration 394 +experimenter +comments +duration 394 sw_version 1.2.2.7 ADC_fullscale_mv 1500 tracker_version 0 @@ -20,7 +20,7 @@ filtfreq2_ch_0 0 filtripple_ch_0 0.00 filtdcblock_ch_0 0 dispmode_ch_0 1 -channame_ch_0 +channame_ch_0 gain_ch_1 10000 filter_ch_1 2 a_in_ch_1 1 @@ -33,7 +33,7 @@ filtfreq2_ch_1 0 filtripple_ch_1 0.00 filtdcblock_ch_1 0 dispmode_ch_1 1 -channame_ch_1 +channame_ch_1 gain_ch_2 10000 filter_ch_2 2 a_in_ch_2 2 @@ -46,7 +46,7 @@ filtfreq2_ch_2 0 filtripple_ch_2 0.00 filtdcblock_ch_2 0 dispmode_ch_2 1 -channame_ch_2 +channame_ch_2 gain_ch_3 10000 filter_ch_3 2 a_in_ch_3 3 @@ -59,7 +59,7 @@ filtfreq2_ch_3 0 filtripple_ch_3 0.00 filtdcblock_ch_3 0 dispmode_ch_3 1 -channame_ch_3 +channame_ch_3 gain_ch_4 10000 filter_ch_4 2 a_in_ch_4 4 @@ -72,7 +72,7 @@ filtfreq2_ch_4 0 filtripple_ch_4 0.00 filtdcblock_ch_4 0 dispmode_ch_4 1 -channame_ch_4 +channame_ch_4 gain_ch_5 10000 filter_ch_5 2 a_in_ch_5 5 @@ -85,7 +85,7 @@ filtfreq2_ch_5 0 filtripple_ch_5 0.00 filtdcblock_ch_5 0 dispmode_ch_5 1 -channame_ch_5 +channame_ch_5 gain_ch_6 10000 filter_ch_6 2 a_in_ch_6 6 @@ -98,7 +98,7 @@ filtfreq2_ch_6 0 filtripple_ch_6 0.00 filtdcblock_ch_6 0 dispmode_ch_6 1 -channame_ch_6 +channame_ch_6 gain_ch_7 10000 filter_ch_7 2 a_in_ch_7 7 @@ -111,7 +111,7 @@ filtfreq2_ch_7 0 filtripple_ch_7 0.00 filtdcblock_ch_7 0 dispmode_ch_7 1 -channame_ch_7 +channame_ch_7 gain_ch_8 10000 filter_ch_8 2 a_in_ch_8 8 @@ -124,7 +124,7 @@ filtfreq2_ch_8 0 filtripple_ch_8 0.00 filtdcblock_ch_8 0 dispmode_ch_8 1 -channame_ch_8 +channame_ch_8 gain_ch_9 10000 filter_ch_9 2 a_in_ch_9 9 @@ -137,7 +137,7 @@ filtfreq2_ch_9 0 filtripple_ch_9 0.00 filtdcblock_ch_9 0 dispmode_ch_9 1 -channame_ch_9 +channame_ch_9 gain_ch_10 10000 filter_ch_10 2 a_in_ch_10 10 @@ -150,7 +150,7 @@ filtfreq2_ch_10 0 filtripple_ch_10 0.00 filtdcblock_ch_10 0 dispmode_ch_10 1 -channame_ch_10 +channame_ch_10 gain_ch_11 10000 filter_ch_11 2 a_in_ch_11 11 @@ -163,7 +163,7 @@ filtfreq2_ch_11 0 filtripple_ch_11 0.00 filtdcblock_ch_11 0 dispmode_ch_11 1 -channame_ch_11 +channame_ch_11 gain_ch_12 10000 filter_ch_12 2 a_in_ch_12 12 @@ -176,7 +176,7 @@ filtfreq2_ch_12 0 filtripple_ch_12 0.00 filtdcblock_ch_12 0 dispmode_ch_12 1 -channame_ch_12 +channame_ch_12 gain_ch_13 10000 filter_ch_13 2 a_in_ch_13 13 @@ -189,7 +189,7 @@ filtfreq2_ch_13 0 filtripple_ch_13 0.00 filtdcblock_ch_13 0 dispmode_ch_13 1 -channame_ch_13 +channame_ch_13 gain_ch_14 10000 filter_ch_14 2 a_in_ch_14 14 @@ -202,7 +202,7 @@ filtfreq2_ch_14 0 filtripple_ch_14 0.00 filtdcblock_ch_14 0 dispmode_ch_14 1 -channame_ch_14 +channame_ch_14 gain_ch_15 10000 filter_ch_15 2 a_in_ch_15 15 @@ -215,7 +215,7 @@ filtfreq2_ch_15 0 filtripple_ch_15 0.00 filtdcblock_ch_15 0 dispmode_ch_15 1 -channame_ch_15 +channame_ch_15 gain_ch_16 8000 filter_ch_16 2 a_in_ch_16 16 @@ -228,7 +228,7 @@ filtfreq2_ch_16 0 filtripple_ch_16 0.00 filtdcblock_ch_16 0 dispmode_ch_16 1 -channame_ch_16 +channame_ch_16 gain_ch_17 8000 filter_ch_17 2 a_in_ch_17 17 @@ -241,7 +241,7 @@ filtfreq2_ch_17 0 filtripple_ch_17 0.00 filtdcblock_ch_17 0 dispmode_ch_17 1 -channame_ch_17 +channame_ch_17 gain_ch_18 8000 filter_ch_18 2 a_in_ch_18 18 @@ -254,7 +254,7 @@ filtfreq2_ch_18 0 filtripple_ch_18 0.00 filtdcblock_ch_18 0 dispmode_ch_18 1 -channame_ch_18 +channame_ch_18 gain_ch_19 8000 filter_ch_19 2 a_in_ch_19 19 @@ -267,7 +267,7 @@ filtfreq2_ch_19 0 filtripple_ch_19 0.00 filtdcblock_ch_19 0 dispmode_ch_19 1 -channame_ch_19 +channame_ch_19 gain_ch_20 6000 filter_ch_20 2 a_in_ch_20 20 @@ -280,7 +280,7 @@ filtfreq2_ch_20 0 filtripple_ch_20 0.00 filtdcblock_ch_20 0 dispmode_ch_20 1 -channame_ch_20 +channame_ch_20 gain_ch_21 8000 filter_ch_21 2 a_in_ch_21 21 @@ -293,7 +293,7 @@ filtfreq2_ch_21 0 filtripple_ch_21 0.00 filtdcblock_ch_21 0 dispmode_ch_21 1 -channame_ch_21 +channame_ch_21 gain_ch_22 8000 filter_ch_22 2 a_in_ch_22 22 @@ -306,7 +306,7 @@ filtfreq2_ch_22 0 filtripple_ch_22 0.00 filtdcblock_ch_22 0 dispmode_ch_22 1 -channame_ch_22 +channame_ch_22 gain_ch_23 8000 filter_ch_23 2 a_in_ch_23 23 @@ -319,7 +319,7 @@ filtfreq2_ch_23 0 filtripple_ch_23 0.00 filtdcblock_ch_23 0 dispmode_ch_23 1 -channame_ch_23 +channame_ch_23 gain_ch_24 6000 filter_ch_24 2 a_in_ch_24 24 @@ -332,7 +332,7 @@ filtfreq2_ch_24 0 filtripple_ch_24 0.00 filtdcblock_ch_24 0 dispmode_ch_24 1 -channame_ch_24 +channame_ch_24 gain_ch_25 8000 filter_ch_25 2 a_in_ch_25 25 @@ -345,7 +345,7 @@ filtfreq2_ch_25 0 filtripple_ch_25 0.00 filtdcblock_ch_25 0 dispmode_ch_25 1 -channame_ch_25 +channame_ch_25 gain_ch_26 8000 filter_ch_26 2 a_in_ch_26 26 @@ -358,7 +358,7 @@ filtfreq2_ch_26 0 filtripple_ch_26 0.00 filtdcblock_ch_26 0 dispmode_ch_26 1 -channame_ch_26 +channame_ch_26 gain_ch_27 8000 filter_ch_27 2 a_in_ch_27 27 @@ -371,7 +371,7 @@ filtfreq2_ch_27 0 filtripple_ch_27 0.00 filtdcblock_ch_27 0 dispmode_ch_27 1 -channame_ch_27 +channame_ch_27 gain_ch_28 8000 filter_ch_28 2 a_in_ch_28 28 @@ -384,7 +384,7 @@ filtfreq2_ch_28 0 filtripple_ch_28 0.00 filtdcblock_ch_28 0 dispmode_ch_28 1 -channame_ch_28 +channame_ch_28 gain_ch_29 8000 filter_ch_29 2 a_in_ch_29 29 @@ -397,7 +397,7 @@ filtfreq2_ch_29 0 filtripple_ch_29 0.00 filtdcblock_ch_29 0 dispmode_ch_29 1 -channame_ch_29 +channame_ch_29 gain_ch_30 8000 filter_ch_30 2 a_in_ch_30 30 @@ -410,7 +410,7 @@ filtfreq2_ch_30 0 filtripple_ch_30 0.00 filtdcblock_ch_30 0 dispmode_ch_30 1 -channame_ch_30 +channame_ch_30 gain_ch_31 8000 filter_ch_31 2 a_in_ch_31 31 @@ -423,7 +423,7 @@ filtfreq2_ch_31 0 filtripple_ch_31 0.00 filtdcblock_ch_31 0 dispmode_ch_31 1 -channame_ch_31 +channame_ch_31 gain_ch_32 2000 filter_ch_32 3 a_in_ch_32 32 @@ -436,7 +436,7 @@ filtfreq2_ch_32 7000 filtripple_ch_32 0.10 filtdcblock_ch_32 1 dispmode_ch_32 1 -channame_ch_32 +channame_ch_32 gain_ch_33 10000 filter_ch_33 2 a_in_ch_33 33 @@ -449,7 +449,7 @@ filtfreq2_ch_33 7000 filtripple_ch_33 0.10 filtdcblock_ch_33 1 dispmode_ch_33 1 -channame_ch_33 +channame_ch_33 gain_ch_34 10000 filter_ch_34 2 a_in_ch_34 34 @@ -462,7 +462,7 @@ filtfreq2_ch_34 7000 filtripple_ch_34 0.10 filtdcblock_ch_34 1 dispmode_ch_34 1 -channame_ch_34 +channame_ch_34 gain_ch_35 2000 filter_ch_35 3 a_in_ch_35 35 @@ -475,7 +475,7 @@ filtfreq2_ch_35 7000 filtripple_ch_35 0.10 filtdcblock_ch_35 1 dispmode_ch_35 1 -channame_ch_35 +channame_ch_35 gain_ch_36 2000 filter_ch_36 3 a_in_ch_36 36 @@ -488,7 +488,7 @@ filtfreq2_ch_36 0 filtripple_ch_36 0.00 filtdcblock_ch_36 0 dispmode_ch_36 1 -channame_ch_36 +channame_ch_36 gain_ch_37 10000 filter_ch_37 3 a_in_ch_37 37 @@ -501,7 +501,7 @@ filtfreq2_ch_37 0 filtripple_ch_37 0.00 filtdcblock_ch_37 0 dispmode_ch_37 1 -channame_ch_37 +channame_ch_37 gain_ch_38 10000 filter_ch_38 3 a_in_ch_38 38 @@ -514,7 +514,7 @@ filtfreq2_ch_38 0 filtripple_ch_38 0.00 filtdcblock_ch_38 0 dispmode_ch_38 1 -channame_ch_38 +channame_ch_38 gain_ch_39 2000 filter_ch_39 3 a_in_ch_39 39 @@ -527,7 +527,7 @@ filtfreq2_ch_39 7000 filtripple_ch_39 0.10 filtdcblock_ch_39 1 dispmode_ch_39 1 -channame_ch_39 +channame_ch_39 gain_ch_40 2000 filter_ch_40 3 a_in_ch_40 40 @@ -540,7 +540,7 @@ filtfreq2_ch_40 7000 filtripple_ch_40 0.10 filtdcblock_ch_40 1 dispmode_ch_40 1 -channame_ch_40 +channame_ch_40 gain_ch_41 10000 filter_ch_41 2 a_in_ch_41 41 @@ -553,7 +553,7 @@ filtfreq2_ch_41 7000 filtripple_ch_41 0.10 filtdcblock_ch_41 1 dispmode_ch_41 1 -channame_ch_41 +channame_ch_41 gain_ch_42 10000 filter_ch_42 2 a_in_ch_42 42 @@ -566,7 +566,7 @@ filtfreq2_ch_42 7000 filtripple_ch_42 0.10 filtdcblock_ch_42 1 dispmode_ch_42 1 -channame_ch_42 +channame_ch_42 gain_ch_43 2000 filter_ch_43 3 a_in_ch_43 43 @@ -579,7 +579,7 @@ filtfreq2_ch_43 7000 filtripple_ch_43 0.10 filtdcblock_ch_43 1 dispmode_ch_43 1 -channame_ch_43 +channame_ch_43 gain_ch_44 2000 filter_ch_44 3 a_in_ch_44 44 @@ -592,7 +592,7 @@ filtfreq2_ch_44 7000 filtripple_ch_44 0.10 filtdcblock_ch_44 1 dispmode_ch_44 1 -channame_ch_44 +channame_ch_44 gain_ch_45 10000 filter_ch_45 2 a_in_ch_45 45 @@ -605,7 +605,7 @@ filtfreq2_ch_45 7000 filtripple_ch_45 0.10 filtdcblock_ch_45 1 dispmode_ch_45 1 -channame_ch_45 +channame_ch_45 gain_ch_46 10000 filter_ch_46 2 a_in_ch_46 46 @@ -618,7 +618,7 @@ filtfreq2_ch_46 7000 filtripple_ch_46 0.10 filtdcblock_ch_46 1 dispmode_ch_46 1 -channame_ch_46 +channame_ch_46 gain_ch_47 2000 filter_ch_47 3 a_in_ch_47 47 @@ -631,7 +631,7 @@ filtfreq2_ch_47 7000 filtripple_ch_47 0.10 filtdcblock_ch_47 1 dispmode_ch_47 1 -channame_ch_47 +channame_ch_47 gain_ch_48 10000 filter_ch_48 2 a_in_ch_48 48 @@ -644,7 +644,7 @@ filtfreq2_ch_48 7000 filtripple_ch_48 0.10 filtdcblock_ch_48 1 dispmode_ch_48 1 -channame_ch_48 +channame_ch_48 gain_ch_49 10000 filter_ch_49 2 a_in_ch_49 49 @@ -657,7 +657,7 @@ filtfreq2_ch_49 7000 filtripple_ch_49 0.10 filtdcblock_ch_49 1 dispmode_ch_49 1 -channame_ch_49 +channame_ch_49 gain_ch_50 10000 filter_ch_50 2 a_in_ch_50 50 @@ -670,7 +670,7 @@ filtfreq2_ch_50 7000 filtripple_ch_50 0.10 filtdcblock_ch_50 1 dispmode_ch_50 1 -channame_ch_50 +channame_ch_50 gain_ch_51 10000 filter_ch_51 2 a_in_ch_51 51 @@ -683,7 +683,7 @@ filtfreq2_ch_51 7000 filtripple_ch_51 0.10 filtdcblock_ch_51 1 dispmode_ch_51 1 -channame_ch_51 +channame_ch_51 gain_ch_52 10000 filter_ch_52 2 a_in_ch_52 52 @@ -696,7 +696,7 @@ filtfreq2_ch_52 7000 filtripple_ch_52 0.10 filtdcblock_ch_52 1 dispmode_ch_52 1 -channame_ch_52 +channame_ch_52 gain_ch_53 10000 filter_ch_53 2 a_in_ch_53 53 @@ -709,7 +709,7 @@ filtfreq2_ch_53 7000 filtripple_ch_53 0.10 filtdcblock_ch_53 1 dispmode_ch_53 1 -channame_ch_53 +channame_ch_53 gain_ch_54 10000 filter_ch_54 2 a_in_ch_54 54 @@ -722,7 +722,7 @@ filtfreq2_ch_54 7000 filtripple_ch_54 0.10 filtdcblock_ch_54 1 dispmode_ch_54 1 -channame_ch_54 +channame_ch_54 gain_ch_55 10000 filter_ch_55 2 a_in_ch_55 55 @@ -735,7 +735,7 @@ filtfreq2_ch_55 7000 filtripple_ch_55 0.10 filtdcblock_ch_55 1 dispmode_ch_55 1 -channame_ch_55 +channame_ch_55 gain_ch_56 10000 filter_ch_56 2 a_in_ch_56 56 @@ -748,7 +748,7 @@ filtfreq2_ch_56 7000 filtripple_ch_56 0.10 filtdcblock_ch_56 1 dispmode_ch_56 1 -channame_ch_56 +channame_ch_56 gain_ch_57 10000 filter_ch_57 2 a_in_ch_57 57 @@ -761,7 +761,7 @@ filtfreq2_ch_57 7000 filtripple_ch_57 0.10 filtdcblock_ch_57 1 dispmode_ch_57 1 -channame_ch_57 +channame_ch_57 gain_ch_58 10000 filter_ch_58 2 a_in_ch_58 58 @@ -774,7 +774,7 @@ filtfreq2_ch_58 7000 filtripple_ch_58 0.10 filtdcblock_ch_58 1 dispmode_ch_58 1 -channame_ch_58 +channame_ch_58 gain_ch_59 10000 filter_ch_59 2 a_in_ch_59 59 @@ -787,7 +787,7 @@ filtfreq2_ch_59 7000 filtripple_ch_59 0.10 filtdcblock_ch_59 1 dispmode_ch_59 1 -channame_ch_59 +channame_ch_59 gain_ch_60 10000 filter_ch_60 2 a_in_ch_60 60 @@ -800,7 +800,7 @@ filtfreq2_ch_60 7000 filtripple_ch_60 0.10 filtdcblock_ch_60 1 dispmode_ch_60 1 -channame_ch_60 +channame_ch_60 gain_ch_61 10000 filter_ch_61 2 a_in_ch_61 61 @@ -813,7 +813,7 @@ filtfreq2_ch_61 7000 filtripple_ch_61 0.10 filtdcblock_ch_61 1 dispmode_ch_61 1 -channame_ch_61 +channame_ch_61 gain_ch_62 10000 filter_ch_62 2 a_in_ch_62 62 @@ -826,7 +826,7 @@ filtfreq2_ch_62 7000 filtripple_ch_62 0.10 filtdcblock_ch_62 1 dispmode_ch_62 1 -channame_ch_62 +channame_ch_62 gain_ch_63 10000 filter_ch_63 2 a_in_ch_63 63 @@ -839,7 +839,7 @@ filtfreq2_ch_63 7000 filtripple_ch_63 0.10 filtdcblock_ch_63 1 dispmode_ch_63 1 -channame_ch_63 +channame_ch_63 second_audio -1 default_filtresp_hp 2 default_filtkind_hp 0 @@ -1136,14 +1136,14 @@ stim_start_delay 1 biphasic 0 use_dacstim 0 stimscript 0 -stimfile +stimfile numPatterns 1 stim_patt_1 "One 100 us pulse every 30 s" 100 100 30000000 0 1 1000 0 1 1000000 0 1 numProtocols 1 stim_prot_1 "Ten minutes of 30 s pulse baseline" 1 600 "One 100 us pulse every 30 s" 0 0 "Pause (no stimulation)" 0 0 "Pause (no stimulation)" 0 0 "Pause (no stimulation)" 0 0 "Pause (no stimulation)" stim_during_rec 0 -info_subject -info_trial +info_subject +info_trial waveform_period 32 pretrig_period 4 deadzone_period 500 @@ -1375,8 +1375,8 @@ recordSerial 0 useScript 0 script M:\pc\Desktop\1302\1302_0323.bas postProcess 0 -postProcessor -postProcessorParams +postProcessor +postProcessorParams sync_out 0 syncRate 25 autoTrial 0 @@ -1393,8 +1393,8 @@ rawGateChan 0 rawGatePol 1 defaultTime 600 defaultMode 0 -trial_comment -experimenter +trial_comment +experimenter digout_state 0 stim_phase 90 stim_period 100 @@ -1495,9 +1495,9 @@ lastfileext set lasttrialdatetime 1383240011 lastupdatecheck 1358640000 useupdateproxy 0 -updateproxy -updateproxyid -updateproxypw +updateproxy +updateproxyid +updateproxypw contaudio 0 mode128channels 0 modeanalog32 0 diff --git a/tests/test_data/axona/DVH_2013103103_2.cut b/tests/test_data/axona/DVH_2013103103_2.cut index 5f8de34..3fe4e29 100644 --- a/tests/test_data/axona/DVH_2013103103_2.cut +++ b/tests/test_data/axona/DVH_2013103103_2.cut @@ -155,4 +155,4 @@ Exact_cut_for: DVH_2013103103 spikes: 1466 0 0 0 0 0 0 1 0 1 1 0 1 1 0 0 1 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 1 0 0 0 0 0 0 0 0 1 0 1 0 1 0 0 0 0 0 0 0 0 0 0 1 1 0 1 1 1 1 1 1 0 0 1 0 - 0 0 0 0 0 1 1 1 0 0 0 1 0 0 0 0 \ No newline at end of file + 0 0 0 0 0 1 1 1 0 0 0 1 0 0 0 0 diff --git a/tests/test_data/axona/DVH_2013103103_3.cut b/tests/test_data/axona/DVH_2013103103_3.cut index 022b696..cda8b18 100644 --- a/tests/test_data/axona/DVH_2013103103_3.cut +++ b/tests/test_data/axona/DVH_2013103103_3.cut @@ -322,4 +322,4 @@ Exact_cut_for: DVH_2013103103 spikes: 5648 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 3 0 0 0 0 0 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 - 0 4 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \ No newline at end of file + 0 4 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 diff --git a/tests/test_data/axona/DVH_2013103103_4.cut b/tests/test_data/axona/DVH_2013103103_4.cut index 8ed9dd0..266f0fc 100644 --- a/tests/test_data/axona/DVH_2013103103_4.cut +++ b/tests/test_data/axona/DVH_2013103103_4.cut @@ -141,4 +141,4 @@ Exact_cut_for: DVH_2013103103 spikes: 1103 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 - 0 0 0 \ No newline at end of file + 0 0 0 diff --git a/tests/test_data/axona/DVH_2013103103_8.cut b/tests/test_data/axona/DVH_2013103103_8.cut index 1da61a0..cfd3526 100644 --- a/tests/test_data/axona/DVH_2013103103_8.cut +++ b/tests/test_data/axona/DVH_2013103103_8.cut @@ -509,4 +509,4 @@ Exact_cut_for: DVH_2013103103 spikes: 10311 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 3 0 0 0 0 0 3 0 0 0 0 0 0 0 0 0 0 0 3 0 0 0 0 0 0 0 0 0 0 0 0 3 0 0 0 3 0 - 3 2 0 0 3 0 0 0 0 0 0 \ No newline at end of file + 3 2 0 0 3 0 0 0 0 0 0 diff --git a/tests/test_data/intan/test-rat_2017-06-23_11-15-46_1/experiment1_all_channels_0.events b/tests/test_data/intan/test-rat_2017-06-23_11-15-46_1/experiment1_all_channels_0.events index 7fd646dc02cfdc93b921b4414a35205d859a8f1a..bbcc7090aaf8140fbddf2aa346e0e8d8feefcb60 100755 GIT binary patch delta 41 xcmcaGlkwb4#tE8YTp6i}DXB$zWvNBQnfZANwh9J%Ce{-jziu{XoaxQT1pq%H4m1D& delta 43 zcmX>%lkvh##tE7d3S1efi7BZ?dS$6a#hLke3bqOcdM4Hi6Yak;Dr`1koaqe!P&N*S diff --git a/tests/test_data/intan/test-rat_2017-06-23_11-15-46_1/experiment1_binarymsg_0.eventsbinary b/tests/test_data/intan/test-rat_2017-06-23_11-15-46_1/experiment1_binarymsg_0.eventsbinary index fe5da80446574657d7c89a0ee32204fdffe7bd7f..e421a4aa2d60672029ab18f71c0e28101beaca6e 100755 GIT binary patch delta 71 zcmezMit*bk#tE8YTp6i}DXB$zWvNBQnfZANwh9J%Ce{-jziu{X?2(n=VgQ5F8p(DA Y4fgg74O%V?4%#j$oA1ia4P)d20E(d&vj6}9 delta 56 zcmezNit*Pg#tE7d3S1efi7BZ?dS$6a#hLke3bqOcdM4Hi6Yak;Dr`1k?2%>UoP0n| Mf{}Cc4Y|2t09Hg3c>n+a diff --git a/tests/test_data/intan/test-rat_2017-06-23_11-16-45_2/experiment1_all_channels_0.events b/tests/test_data/intan/test-rat_2017-06-23_11-16-45_2/experiment1_all_channels_0.events index 41f4ac19573e4279cc5dc3c71522a9c98de75b7f..4e5a09807f378db0b7cdd28f8caba9afa6ddeb41 100755 GIT binary patch delta 41 xcmca`it*ej#tE8YTp6i}DXB$zWvNBQnfZANwh9J%Ce{-jziu{XluTjd0su@Z4le)z delta 43 zcmX?iit)lJ#tE7d3S1efi7BZ?dS$6a#hLke3bqOcdM4Hi6Yak;Dr`1kluQ8tUeFGL diff --git a/tests/test_data/intan/test-rat_2017-06-23_11-16-45_2/experiment1_binarymsg_0.eventsbinary b/tests/test_data/intan/test-rat_2017-06-23_11-16-45_2/experiment1_binarymsg_0.eventsbinary index ef3bf5367fa42ba9cd7eb57c233448f4cce9e12e..59afc7e7ebea4a7a1a274e4d374b1443f2c32a00 100755 GIT binary patch delta 46 zcmX@KooWAerU{y2Tp6i}DXB$zWvNBQnfZANwh9J%Ce{-jziu{X+~c|Vi`TgTMlJxH C&k= Date: Thu, 16 May 2024 11:01:48 +0200 Subject: [PATCH 33/47] Correct the source path --- .github/workflows/check_formatting.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/check_formatting.yml b/.github/workflows/check_formatting.yml index c946b9f..3f106e8 100644 --- a/.github/workflows/check_formatting.yml +++ b/.github/workflows/check_formatting.yml @@ -30,7 +30,7 @@ jobs: uses: psf/black@stable with: options: "--check" - src: "./code" + src: "./src" jupyter: true - name: isort From 41e8eb7f7f4f0b13e63478fc8ef749f0085863df Mon Sep 17 00:00:00 2001 From: Nicolai Haug Date: Thu, 16 May 2024 12:24:08 +0200 Subject: [PATCH 34/47] Exclude test_data datasets from formatter and linter --- .pre-commit-config.yaml | 5 +++++ pyproject.toml | 3 +++ 2 files changed, 8 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7132741..7be81c7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,12 +3,16 @@ repos: rev: v4.5.0 hooks: - id: fix-encoding-pragma + exclude: tests/test_data - id: trailing-whitespace + exclude: tests/test_data - id: end-of-file-fixer + exclude: tests/test_data - id: check-docstring-first - id: debug-statements - id: check-toml - id: check-yaml + exclude: tests/test_data - id: requirements-txt-fixer - id: detect-private-key - id: check-merge-conflict @@ -17,6 +21,7 @@ repos: rev: 24.4.2 hooks: - id: black + exclude: tests/test_data - id: black-jupyter - repo: https://github.com/pycqa/isort diff --git a/pyproject.toml b/pyproject.toml index 9711e54..62ab6ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -85,6 +85,9 @@ exclude = [ "venv", ] +# In addition to the standard set of exclusions, omit: +extend-exclude = ["tests/test_data"] + # Same as Black. line-length = 120 indent-width = 4 From cb2c9ff5297cc73d96eccec752b4df02defd849b Mon Sep 17 00:00:00 2001 From: Nicolai Haug Date: Thu, 16 May 2024 12:24:58 +0200 Subject: [PATCH 35/47] Revert formating and linting changes --- tests/test_data/axona/DVH_2013103103.set | 154 +++++++++--------- tests/test_data/axona/DVH_2013103103_2.cut | 2 +- tests/test_data/axona/DVH_2013103103_3.cut | 2 +- tests/test_data/axona/DVH_2013103103_4.cut | 2 +- tests/test_data/axona/DVH_2013103103_8.cut | 2 +- .../experiment1_all_channels_0.events | Bin 19662 -> 19664 bytes .../experiment1_binarymsg_0.eventsbinary | Bin 30070 -> 30074 bytes .../experiment1_all_channels_0.events | Bin 25934 -> 25936 bytes .../experiment1_binarymsg_0.eventsbinary | Bin 39871 -> 39874 bytes .../2/experiment1/recording1/structure.oebin | 2 +- .../spikesorting/mountainsort4/phy/params.py | 1 - .../experiment1/recording1/structure.oebin | 2 +- .../spikesorting/mountainsort4/phy/params.py | 1 - .../experiment1/recording1/structure.oebin | 2 +- .../experiment1/recording1/structure.oebin | 2 +- tests/test_data/tetrode_32_openephys.json | 2 +- 16 files changed, 86 insertions(+), 88 deletions(-) diff --git a/tests/test_data/axona/DVH_2013103103.set b/tests/test_data/axona/DVH_2013103103.set index 6cbbde8..36a0596 100644 --- a/tests/test_data/axona/DVH_2013103103.set +++ b/tests/test_data/axona/DVH_2013103103.set @@ -1,8 +1,8 @@ trial_date Thursday, 31 Oct 2013 trial_time 17:20:11 -experimenter -comments -duration 394 +experimenter +comments +duration 394 sw_version 1.2.2.7 ADC_fullscale_mv 1500 tracker_version 0 @@ -20,7 +20,7 @@ filtfreq2_ch_0 0 filtripple_ch_0 0.00 filtdcblock_ch_0 0 dispmode_ch_0 1 -channame_ch_0 +channame_ch_0 gain_ch_1 10000 filter_ch_1 2 a_in_ch_1 1 @@ -33,7 +33,7 @@ filtfreq2_ch_1 0 filtripple_ch_1 0.00 filtdcblock_ch_1 0 dispmode_ch_1 1 -channame_ch_1 +channame_ch_1 gain_ch_2 10000 filter_ch_2 2 a_in_ch_2 2 @@ -46,7 +46,7 @@ filtfreq2_ch_2 0 filtripple_ch_2 0.00 filtdcblock_ch_2 0 dispmode_ch_2 1 -channame_ch_2 +channame_ch_2 gain_ch_3 10000 filter_ch_3 2 a_in_ch_3 3 @@ -59,7 +59,7 @@ filtfreq2_ch_3 0 filtripple_ch_3 0.00 filtdcblock_ch_3 0 dispmode_ch_3 1 -channame_ch_3 +channame_ch_3 gain_ch_4 10000 filter_ch_4 2 a_in_ch_4 4 @@ -72,7 +72,7 @@ filtfreq2_ch_4 0 filtripple_ch_4 0.00 filtdcblock_ch_4 0 dispmode_ch_4 1 -channame_ch_4 +channame_ch_4 gain_ch_5 10000 filter_ch_5 2 a_in_ch_5 5 @@ -85,7 +85,7 @@ filtfreq2_ch_5 0 filtripple_ch_5 0.00 filtdcblock_ch_5 0 dispmode_ch_5 1 -channame_ch_5 +channame_ch_5 gain_ch_6 10000 filter_ch_6 2 a_in_ch_6 6 @@ -98,7 +98,7 @@ filtfreq2_ch_6 0 filtripple_ch_6 0.00 filtdcblock_ch_6 0 dispmode_ch_6 1 -channame_ch_6 +channame_ch_6 gain_ch_7 10000 filter_ch_7 2 a_in_ch_7 7 @@ -111,7 +111,7 @@ filtfreq2_ch_7 0 filtripple_ch_7 0.00 filtdcblock_ch_7 0 dispmode_ch_7 1 -channame_ch_7 +channame_ch_7 gain_ch_8 10000 filter_ch_8 2 a_in_ch_8 8 @@ -124,7 +124,7 @@ filtfreq2_ch_8 0 filtripple_ch_8 0.00 filtdcblock_ch_8 0 dispmode_ch_8 1 -channame_ch_8 +channame_ch_8 gain_ch_9 10000 filter_ch_9 2 a_in_ch_9 9 @@ -137,7 +137,7 @@ filtfreq2_ch_9 0 filtripple_ch_9 0.00 filtdcblock_ch_9 0 dispmode_ch_9 1 -channame_ch_9 +channame_ch_9 gain_ch_10 10000 filter_ch_10 2 a_in_ch_10 10 @@ -150,7 +150,7 @@ filtfreq2_ch_10 0 filtripple_ch_10 0.00 filtdcblock_ch_10 0 dispmode_ch_10 1 -channame_ch_10 +channame_ch_10 gain_ch_11 10000 filter_ch_11 2 a_in_ch_11 11 @@ -163,7 +163,7 @@ filtfreq2_ch_11 0 filtripple_ch_11 0.00 filtdcblock_ch_11 0 dispmode_ch_11 1 -channame_ch_11 +channame_ch_11 gain_ch_12 10000 filter_ch_12 2 a_in_ch_12 12 @@ -176,7 +176,7 @@ filtfreq2_ch_12 0 filtripple_ch_12 0.00 filtdcblock_ch_12 0 dispmode_ch_12 1 -channame_ch_12 +channame_ch_12 gain_ch_13 10000 filter_ch_13 2 a_in_ch_13 13 @@ -189,7 +189,7 @@ filtfreq2_ch_13 0 filtripple_ch_13 0.00 filtdcblock_ch_13 0 dispmode_ch_13 1 -channame_ch_13 +channame_ch_13 gain_ch_14 10000 filter_ch_14 2 a_in_ch_14 14 @@ -202,7 +202,7 @@ filtfreq2_ch_14 0 filtripple_ch_14 0.00 filtdcblock_ch_14 0 dispmode_ch_14 1 -channame_ch_14 +channame_ch_14 gain_ch_15 10000 filter_ch_15 2 a_in_ch_15 15 @@ -215,7 +215,7 @@ filtfreq2_ch_15 0 filtripple_ch_15 0.00 filtdcblock_ch_15 0 dispmode_ch_15 1 -channame_ch_15 +channame_ch_15 gain_ch_16 8000 filter_ch_16 2 a_in_ch_16 16 @@ -228,7 +228,7 @@ filtfreq2_ch_16 0 filtripple_ch_16 0.00 filtdcblock_ch_16 0 dispmode_ch_16 1 -channame_ch_16 +channame_ch_16 gain_ch_17 8000 filter_ch_17 2 a_in_ch_17 17 @@ -241,7 +241,7 @@ filtfreq2_ch_17 0 filtripple_ch_17 0.00 filtdcblock_ch_17 0 dispmode_ch_17 1 -channame_ch_17 +channame_ch_17 gain_ch_18 8000 filter_ch_18 2 a_in_ch_18 18 @@ -254,7 +254,7 @@ filtfreq2_ch_18 0 filtripple_ch_18 0.00 filtdcblock_ch_18 0 dispmode_ch_18 1 -channame_ch_18 +channame_ch_18 gain_ch_19 8000 filter_ch_19 2 a_in_ch_19 19 @@ -267,7 +267,7 @@ filtfreq2_ch_19 0 filtripple_ch_19 0.00 filtdcblock_ch_19 0 dispmode_ch_19 1 -channame_ch_19 +channame_ch_19 gain_ch_20 6000 filter_ch_20 2 a_in_ch_20 20 @@ -280,7 +280,7 @@ filtfreq2_ch_20 0 filtripple_ch_20 0.00 filtdcblock_ch_20 0 dispmode_ch_20 1 -channame_ch_20 +channame_ch_20 gain_ch_21 8000 filter_ch_21 2 a_in_ch_21 21 @@ -293,7 +293,7 @@ filtfreq2_ch_21 0 filtripple_ch_21 0.00 filtdcblock_ch_21 0 dispmode_ch_21 1 -channame_ch_21 +channame_ch_21 gain_ch_22 8000 filter_ch_22 2 a_in_ch_22 22 @@ -306,7 +306,7 @@ filtfreq2_ch_22 0 filtripple_ch_22 0.00 filtdcblock_ch_22 0 dispmode_ch_22 1 -channame_ch_22 +channame_ch_22 gain_ch_23 8000 filter_ch_23 2 a_in_ch_23 23 @@ -319,7 +319,7 @@ filtfreq2_ch_23 0 filtripple_ch_23 0.00 filtdcblock_ch_23 0 dispmode_ch_23 1 -channame_ch_23 +channame_ch_23 gain_ch_24 6000 filter_ch_24 2 a_in_ch_24 24 @@ -332,7 +332,7 @@ filtfreq2_ch_24 0 filtripple_ch_24 0.00 filtdcblock_ch_24 0 dispmode_ch_24 1 -channame_ch_24 +channame_ch_24 gain_ch_25 8000 filter_ch_25 2 a_in_ch_25 25 @@ -345,7 +345,7 @@ filtfreq2_ch_25 0 filtripple_ch_25 0.00 filtdcblock_ch_25 0 dispmode_ch_25 1 -channame_ch_25 +channame_ch_25 gain_ch_26 8000 filter_ch_26 2 a_in_ch_26 26 @@ -358,7 +358,7 @@ filtfreq2_ch_26 0 filtripple_ch_26 0.00 filtdcblock_ch_26 0 dispmode_ch_26 1 -channame_ch_26 +channame_ch_26 gain_ch_27 8000 filter_ch_27 2 a_in_ch_27 27 @@ -371,7 +371,7 @@ filtfreq2_ch_27 0 filtripple_ch_27 0.00 filtdcblock_ch_27 0 dispmode_ch_27 1 -channame_ch_27 +channame_ch_27 gain_ch_28 8000 filter_ch_28 2 a_in_ch_28 28 @@ -384,7 +384,7 @@ filtfreq2_ch_28 0 filtripple_ch_28 0.00 filtdcblock_ch_28 0 dispmode_ch_28 1 -channame_ch_28 +channame_ch_28 gain_ch_29 8000 filter_ch_29 2 a_in_ch_29 29 @@ -397,7 +397,7 @@ filtfreq2_ch_29 0 filtripple_ch_29 0.00 filtdcblock_ch_29 0 dispmode_ch_29 1 -channame_ch_29 +channame_ch_29 gain_ch_30 8000 filter_ch_30 2 a_in_ch_30 30 @@ -410,7 +410,7 @@ filtfreq2_ch_30 0 filtripple_ch_30 0.00 filtdcblock_ch_30 0 dispmode_ch_30 1 -channame_ch_30 +channame_ch_30 gain_ch_31 8000 filter_ch_31 2 a_in_ch_31 31 @@ -423,7 +423,7 @@ filtfreq2_ch_31 0 filtripple_ch_31 0.00 filtdcblock_ch_31 0 dispmode_ch_31 1 -channame_ch_31 +channame_ch_31 gain_ch_32 2000 filter_ch_32 3 a_in_ch_32 32 @@ -436,7 +436,7 @@ filtfreq2_ch_32 7000 filtripple_ch_32 0.10 filtdcblock_ch_32 1 dispmode_ch_32 1 -channame_ch_32 +channame_ch_32 gain_ch_33 10000 filter_ch_33 2 a_in_ch_33 33 @@ -449,7 +449,7 @@ filtfreq2_ch_33 7000 filtripple_ch_33 0.10 filtdcblock_ch_33 1 dispmode_ch_33 1 -channame_ch_33 +channame_ch_33 gain_ch_34 10000 filter_ch_34 2 a_in_ch_34 34 @@ -462,7 +462,7 @@ filtfreq2_ch_34 7000 filtripple_ch_34 0.10 filtdcblock_ch_34 1 dispmode_ch_34 1 -channame_ch_34 +channame_ch_34 gain_ch_35 2000 filter_ch_35 3 a_in_ch_35 35 @@ -475,7 +475,7 @@ filtfreq2_ch_35 7000 filtripple_ch_35 0.10 filtdcblock_ch_35 1 dispmode_ch_35 1 -channame_ch_35 +channame_ch_35 gain_ch_36 2000 filter_ch_36 3 a_in_ch_36 36 @@ -488,7 +488,7 @@ filtfreq2_ch_36 0 filtripple_ch_36 0.00 filtdcblock_ch_36 0 dispmode_ch_36 1 -channame_ch_36 +channame_ch_36 gain_ch_37 10000 filter_ch_37 3 a_in_ch_37 37 @@ -501,7 +501,7 @@ filtfreq2_ch_37 0 filtripple_ch_37 0.00 filtdcblock_ch_37 0 dispmode_ch_37 1 -channame_ch_37 +channame_ch_37 gain_ch_38 10000 filter_ch_38 3 a_in_ch_38 38 @@ -514,7 +514,7 @@ filtfreq2_ch_38 0 filtripple_ch_38 0.00 filtdcblock_ch_38 0 dispmode_ch_38 1 -channame_ch_38 +channame_ch_38 gain_ch_39 2000 filter_ch_39 3 a_in_ch_39 39 @@ -527,7 +527,7 @@ filtfreq2_ch_39 7000 filtripple_ch_39 0.10 filtdcblock_ch_39 1 dispmode_ch_39 1 -channame_ch_39 +channame_ch_39 gain_ch_40 2000 filter_ch_40 3 a_in_ch_40 40 @@ -540,7 +540,7 @@ filtfreq2_ch_40 7000 filtripple_ch_40 0.10 filtdcblock_ch_40 1 dispmode_ch_40 1 -channame_ch_40 +channame_ch_40 gain_ch_41 10000 filter_ch_41 2 a_in_ch_41 41 @@ -553,7 +553,7 @@ filtfreq2_ch_41 7000 filtripple_ch_41 0.10 filtdcblock_ch_41 1 dispmode_ch_41 1 -channame_ch_41 +channame_ch_41 gain_ch_42 10000 filter_ch_42 2 a_in_ch_42 42 @@ -566,7 +566,7 @@ filtfreq2_ch_42 7000 filtripple_ch_42 0.10 filtdcblock_ch_42 1 dispmode_ch_42 1 -channame_ch_42 +channame_ch_42 gain_ch_43 2000 filter_ch_43 3 a_in_ch_43 43 @@ -579,7 +579,7 @@ filtfreq2_ch_43 7000 filtripple_ch_43 0.10 filtdcblock_ch_43 1 dispmode_ch_43 1 -channame_ch_43 +channame_ch_43 gain_ch_44 2000 filter_ch_44 3 a_in_ch_44 44 @@ -592,7 +592,7 @@ filtfreq2_ch_44 7000 filtripple_ch_44 0.10 filtdcblock_ch_44 1 dispmode_ch_44 1 -channame_ch_44 +channame_ch_44 gain_ch_45 10000 filter_ch_45 2 a_in_ch_45 45 @@ -605,7 +605,7 @@ filtfreq2_ch_45 7000 filtripple_ch_45 0.10 filtdcblock_ch_45 1 dispmode_ch_45 1 -channame_ch_45 +channame_ch_45 gain_ch_46 10000 filter_ch_46 2 a_in_ch_46 46 @@ -618,7 +618,7 @@ filtfreq2_ch_46 7000 filtripple_ch_46 0.10 filtdcblock_ch_46 1 dispmode_ch_46 1 -channame_ch_46 +channame_ch_46 gain_ch_47 2000 filter_ch_47 3 a_in_ch_47 47 @@ -631,7 +631,7 @@ filtfreq2_ch_47 7000 filtripple_ch_47 0.10 filtdcblock_ch_47 1 dispmode_ch_47 1 -channame_ch_47 +channame_ch_47 gain_ch_48 10000 filter_ch_48 2 a_in_ch_48 48 @@ -644,7 +644,7 @@ filtfreq2_ch_48 7000 filtripple_ch_48 0.10 filtdcblock_ch_48 1 dispmode_ch_48 1 -channame_ch_48 +channame_ch_48 gain_ch_49 10000 filter_ch_49 2 a_in_ch_49 49 @@ -657,7 +657,7 @@ filtfreq2_ch_49 7000 filtripple_ch_49 0.10 filtdcblock_ch_49 1 dispmode_ch_49 1 -channame_ch_49 +channame_ch_49 gain_ch_50 10000 filter_ch_50 2 a_in_ch_50 50 @@ -670,7 +670,7 @@ filtfreq2_ch_50 7000 filtripple_ch_50 0.10 filtdcblock_ch_50 1 dispmode_ch_50 1 -channame_ch_50 +channame_ch_50 gain_ch_51 10000 filter_ch_51 2 a_in_ch_51 51 @@ -683,7 +683,7 @@ filtfreq2_ch_51 7000 filtripple_ch_51 0.10 filtdcblock_ch_51 1 dispmode_ch_51 1 -channame_ch_51 +channame_ch_51 gain_ch_52 10000 filter_ch_52 2 a_in_ch_52 52 @@ -696,7 +696,7 @@ filtfreq2_ch_52 7000 filtripple_ch_52 0.10 filtdcblock_ch_52 1 dispmode_ch_52 1 -channame_ch_52 +channame_ch_52 gain_ch_53 10000 filter_ch_53 2 a_in_ch_53 53 @@ -709,7 +709,7 @@ filtfreq2_ch_53 7000 filtripple_ch_53 0.10 filtdcblock_ch_53 1 dispmode_ch_53 1 -channame_ch_53 +channame_ch_53 gain_ch_54 10000 filter_ch_54 2 a_in_ch_54 54 @@ -722,7 +722,7 @@ filtfreq2_ch_54 7000 filtripple_ch_54 0.10 filtdcblock_ch_54 1 dispmode_ch_54 1 -channame_ch_54 +channame_ch_54 gain_ch_55 10000 filter_ch_55 2 a_in_ch_55 55 @@ -735,7 +735,7 @@ filtfreq2_ch_55 7000 filtripple_ch_55 0.10 filtdcblock_ch_55 1 dispmode_ch_55 1 -channame_ch_55 +channame_ch_55 gain_ch_56 10000 filter_ch_56 2 a_in_ch_56 56 @@ -748,7 +748,7 @@ filtfreq2_ch_56 7000 filtripple_ch_56 0.10 filtdcblock_ch_56 1 dispmode_ch_56 1 -channame_ch_56 +channame_ch_56 gain_ch_57 10000 filter_ch_57 2 a_in_ch_57 57 @@ -761,7 +761,7 @@ filtfreq2_ch_57 7000 filtripple_ch_57 0.10 filtdcblock_ch_57 1 dispmode_ch_57 1 -channame_ch_57 +channame_ch_57 gain_ch_58 10000 filter_ch_58 2 a_in_ch_58 58 @@ -774,7 +774,7 @@ filtfreq2_ch_58 7000 filtripple_ch_58 0.10 filtdcblock_ch_58 1 dispmode_ch_58 1 -channame_ch_58 +channame_ch_58 gain_ch_59 10000 filter_ch_59 2 a_in_ch_59 59 @@ -787,7 +787,7 @@ filtfreq2_ch_59 7000 filtripple_ch_59 0.10 filtdcblock_ch_59 1 dispmode_ch_59 1 -channame_ch_59 +channame_ch_59 gain_ch_60 10000 filter_ch_60 2 a_in_ch_60 60 @@ -800,7 +800,7 @@ filtfreq2_ch_60 7000 filtripple_ch_60 0.10 filtdcblock_ch_60 1 dispmode_ch_60 1 -channame_ch_60 +channame_ch_60 gain_ch_61 10000 filter_ch_61 2 a_in_ch_61 61 @@ -813,7 +813,7 @@ filtfreq2_ch_61 7000 filtripple_ch_61 0.10 filtdcblock_ch_61 1 dispmode_ch_61 1 -channame_ch_61 +channame_ch_61 gain_ch_62 10000 filter_ch_62 2 a_in_ch_62 62 @@ -826,7 +826,7 @@ filtfreq2_ch_62 7000 filtripple_ch_62 0.10 filtdcblock_ch_62 1 dispmode_ch_62 1 -channame_ch_62 +channame_ch_62 gain_ch_63 10000 filter_ch_63 2 a_in_ch_63 63 @@ -839,7 +839,7 @@ filtfreq2_ch_63 7000 filtripple_ch_63 0.10 filtdcblock_ch_63 1 dispmode_ch_63 1 -channame_ch_63 +channame_ch_63 second_audio -1 default_filtresp_hp 2 default_filtkind_hp 0 @@ -1136,14 +1136,14 @@ stim_start_delay 1 biphasic 0 use_dacstim 0 stimscript 0 -stimfile +stimfile numPatterns 1 stim_patt_1 "One 100 us pulse every 30 s" 100 100 30000000 0 1 1000 0 1 1000000 0 1 numProtocols 1 stim_prot_1 "Ten minutes of 30 s pulse baseline" 1 600 "One 100 us pulse every 30 s" 0 0 "Pause (no stimulation)" 0 0 "Pause (no stimulation)" 0 0 "Pause (no stimulation)" 0 0 "Pause (no stimulation)" stim_during_rec 0 -info_subject -info_trial +info_subject +info_trial waveform_period 32 pretrig_period 4 deadzone_period 500 @@ -1375,8 +1375,8 @@ recordSerial 0 useScript 0 script M:\pc\Desktop\1302\1302_0323.bas postProcess 0 -postProcessor -postProcessorParams +postProcessor +postProcessorParams sync_out 0 syncRate 25 autoTrial 0 @@ -1393,8 +1393,8 @@ rawGateChan 0 rawGatePol 1 defaultTime 600 defaultMode 0 -trial_comment -experimenter +trial_comment +experimenter digout_state 0 stim_phase 90 stim_period 100 @@ -1495,9 +1495,9 @@ lastfileext set lasttrialdatetime 1383240011 lastupdatecheck 1358640000 useupdateproxy 0 -updateproxy -updateproxyid -updateproxypw +updateproxy +updateproxyid +updateproxypw contaudio 0 mode128channels 0 modeanalog32 0 diff --git a/tests/test_data/axona/DVH_2013103103_2.cut b/tests/test_data/axona/DVH_2013103103_2.cut index 3fe4e29..5f8de34 100644 --- a/tests/test_data/axona/DVH_2013103103_2.cut +++ b/tests/test_data/axona/DVH_2013103103_2.cut @@ -155,4 +155,4 @@ Exact_cut_for: DVH_2013103103 spikes: 1466 0 0 0 0 0 0 1 0 1 1 0 1 1 0 0 1 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 1 0 0 0 0 0 0 0 0 1 0 1 0 1 0 0 0 0 0 0 0 0 0 0 1 1 0 1 1 1 1 1 1 0 0 1 0 - 0 0 0 0 0 1 1 1 0 0 0 1 0 0 0 0 + 0 0 0 0 0 1 1 1 0 0 0 1 0 0 0 0 \ No newline at end of file diff --git a/tests/test_data/axona/DVH_2013103103_3.cut b/tests/test_data/axona/DVH_2013103103_3.cut index cda8b18..022b696 100644 --- a/tests/test_data/axona/DVH_2013103103_3.cut +++ b/tests/test_data/axona/DVH_2013103103_3.cut @@ -322,4 +322,4 @@ Exact_cut_for: DVH_2013103103 spikes: 5648 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 3 0 0 0 0 0 3 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 - 0 4 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 + 0 4 0 0 0 0 0 0 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 \ No newline at end of file diff --git a/tests/test_data/axona/DVH_2013103103_4.cut b/tests/test_data/axona/DVH_2013103103_4.cut index 266f0fc..8ed9dd0 100644 --- a/tests/test_data/axona/DVH_2013103103_4.cut +++ b/tests/test_data/axona/DVH_2013103103_4.cut @@ -141,4 +141,4 @@ Exact_cut_for: DVH_2013103103 spikes: 1103 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 - 0 0 0 + 0 0 0 \ No newline at end of file diff --git a/tests/test_data/axona/DVH_2013103103_8.cut b/tests/test_data/axona/DVH_2013103103_8.cut index cfd3526..1da61a0 100644 --- a/tests/test_data/axona/DVH_2013103103_8.cut +++ b/tests/test_data/axona/DVH_2013103103_8.cut @@ -509,4 +509,4 @@ Exact_cut_for: DVH_2013103103 spikes: 10311 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 3 0 0 0 0 0 3 0 0 0 0 0 0 0 0 0 0 0 3 0 0 0 0 0 0 0 0 0 0 0 0 3 0 0 0 3 0 - 3 2 0 0 3 0 0 0 0 0 0 + 3 2 0 0 3 0 0 0 0 0 0 \ No newline at end of file diff --git a/tests/test_data/intan/test-rat_2017-06-23_11-15-46_1/experiment1_all_channels_0.events b/tests/test_data/intan/test-rat_2017-06-23_11-15-46_1/experiment1_all_channels_0.events index bbcc7090aaf8140fbddf2aa346e0e8d8feefcb60..7fd646dc02cfdc93b921b4414a35205d859a8f1a 100755 GIT binary patch delta 43 zcmX>%lkvh##tE7d3S1efi7BZ?dS$6a#hLke3bqOcdM4Hi6Yak;Dr`1koaqe!P&N*S delta 41 xcmcaGlkwb4#tE8YTp6i}DXB$zWvNBQnfZANwh9J%Ce{-jziu{XoaxQT1pq%H4m1D& diff --git a/tests/test_data/intan/test-rat_2017-06-23_11-15-46_1/experiment1_binarymsg_0.eventsbinary b/tests/test_data/intan/test-rat_2017-06-23_11-15-46_1/experiment1_binarymsg_0.eventsbinary index e421a4aa2d60672029ab18f71c0e28101beaca6e..fe5da80446574657d7c89a0ee32204fdffe7bd7f 100755 GIT binary patch delta 56 zcmezNit*Pg#tE7d3S1efi7BZ?dS$6a#hLke3bqOcdM4Hi6Yak;Dr`1k?2%>UoP0n| Mf{}Cc4Y|2t09Hg3c>n+a delta 71 zcmezMit*bk#tE8YTp6i}DXB$zWvNBQnfZANwh9J%Ce{-jziu{X?2(n=VgQ5F8p(DA Y4fgg74O%V?4%#j$oA1ia4P)d20E(d&vj6}9 diff --git a/tests/test_data/intan/test-rat_2017-06-23_11-16-45_2/experiment1_all_channels_0.events b/tests/test_data/intan/test-rat_2017-06-23_11-16-45_2/experiment1_all_channels_0.events index 4e5a09807f378db0b7cdd28f8caba9afa6ddeb41..41f4ac19573e4279cc5dc3c71522a9c98de75b7f 100755 GIT binary patch delta 43 zcmX?iit)lJ#tE7d3S1efi7BZ?dS$6a#hLke3bqOcdM4Hi6Yak;Dr`1kluQ8tUeFGL delta 41 xcmca`it*ej#tE8YTp6i}DXB$zWvNBQnfZANwh9J%Ce{-jziu{XluTjd0su@Z4le)z diff --git a/tests/test_data/intan/test-rat_2017-06-23_11-16-45_2/experiment1_binarymsg_0.eventsbinary b/tests/test_data/intan/test-rat_2017-06-23_11-16-45_2/experiment1_binarymsg_0.eventsbinary index 59afc7e7ebea4a7a1a274e4d374b1443f2c32a00..ef3bf5367fa42ba9cd7eb57c233448f4cce9e12e 100755 GIT binary patch delta 50 zcmdnLo$1harU{x73S1efi7BZ?dS$6a#hLke3bqOcdM4Hi6Yak;Dr`1k+~di}z4?RJ Gxc~siV-es0 delta 46 zcmX@KooWAerU{y2Tp6i}DXB$zWvNBQnfZANwh9J%Ce{-jziu{X+~c|Vi`TgTMlJxH C&k= Date: Thu, 16 May 2024 19:07:58 +0200 Subject: [PATCH 36/47] Install tbb only on non-Darwin platforms and remove spython for now --- pyproject.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 62ab6ee..85f7502 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,9 +23,8 @@ dependencies = [ "pynwb>=2.5.0", "ipywidgets>=8.1.1", "nwbwidgets>=0.11.3", - "tbb>=2021.11.0", # TODO: pip can't find tbb or tbb4py (at least on macOS). Is it needed? + "tbb>=2021.11.0; platform_system != 'Darwin'", "pynapple>=0.5.1", - "spython>=0.3.13", # TODO: is this needed? ] [project.urls] From 367f33f37b48a6c45d7a55315e93652738669591 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 6 Jun 2024 10:32:02 +0200 Subject: [PATCH 37/47] Processing+Curation: extract waveforms from all spikes --- src/expipe_plugin_cinpla/scripts/curation.py | 2 +- src/expipe_plugin_cinpla/scripts/process.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/expipe_plugin_cinpla/scripts/curation.py b/src/expipe_plugin_cinpla/scripts/curation.py index f945806..5f7e584 100644 --- a/src/expipe_plugin_cinpla/scripts/curation.py +++ b/src/expipe_plugin_cinpla/scripts/curation.py @@ -157,7 +157,7 @@ def apply_curation(self, sorter, curated_sorting): curated_sorting, folder=None, mode="memory", - max_spikes_per_unit=100, + max_spikes_per_unit=None, sparse=True, method="by_property", by_property="group", diff --git a/src/expipe_plugin_cinpla/scripts/process.py b/src/expipe_plugin_cinpla/scripts/process.py index e293397..fa9cc96 100644 --- a/src/expipe_plugin_cinpla/scripts/process.py +++ b/src/expipe_plugin_cinpla/scripts/process.py @@ -252,6 +252,7 @@ def process_ecephys( ms_after=ms_after, sparsity_temp_folder=si_folder / "tmp", sparse=True, + max_spikes_per_unit=None, method="by_property", by_property="group", ) From 1898ca4a983faefa0bc0c97f26129069fb0ef70d Mon Sep 17 00:00:00 2001 From: Mikkel Date: Tue, 11 Jun 2024 09:31:33 +0200 Subject: [PATCH 38/47] set t[0] in tracking to zero --- src/expipe_plugin_cinpla/tools/data_processing.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/expipe_plugin_cinpla/tools/data_processing.py b/src/expipe_plugin_cinpla/tools/data_processing.py index efcb403..6df0e22 100644 --- a/src/expipe_plugin_cinpla/tools/data_processing.py +++ b/src/expipe_plugin_cinpla/tools/data_processing.py @@ -162,6 +162,8 @@ def load_head_direction(data_path, sampling_rate, low_pass_frequency, box_size): check_valid_tracking(x2, y2, box_size) angles, times = head_direction(x1, y1, x2, y2, t1) + # set t[0] to zero + times = times - times.min() return angles, times @@ -208,6 +210,8 @@ def load_tracking(data_path, sampling_rate, low_pass_frequency, box_size, veloci vel = np.gradient([x, y], axis=1) / np.gradient(t) speed = np.linalg.norm(vel, axis=0) x, y, t, speed = np.array(x), np.array(y), np.array(t), np.array(speed) + # set t[0] to zero + t = t - t.min() return x, y, t, speed From b3aad4ebae4e1fc72e9dae3478b7ba44daa1f6e8 Mon Sep 17 00:00:00 2001 From: Mikkel Date: Tue, 11 Jun 2024 11:48:19 +0200 Subject: [PATCH 39/47] redo last commit and make read unit id faster --- src/expipe_plugin_cinpla/tools/data_processing.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/expipe_plugin_cinpla/tools/data_processing.py b/src/expipe_plugin_cinpla/tools/data_processing.py index 6df0e22..86fd9c1 100644 --- a/src/expipe_plugin_cinpla/tools/data_processing.py +++ b/src/expipe_plugin_cinpla/tools/data_processing.py @@ -162,8 +162,7 @@ def load_head_direction(data_path, sampling_rate, low_pass_frequency, box_size): check_valid_tracking(x2, y2, box_size) angles, times = head_direction(x1, y1, x2, y2, t1) - # set t[0] to zero - times = times - times.min() + return angles, times @@ -210,8 +209,6 @@ def load_tracking(data_path, sampling_rate, low_pass_frequency, box_size, veloci vel = np.gradient([x, y], axis=1) / np.gradient(t) speed = np.linalg.norm(vel, axis=0) x, y, t, speed = np.array(x), np.array(y), np.array(t), np.array(speed) - # set t[0] to zero - t = t - t.min() return x, y, t, speed @@ -483,8 +480,8 @@ def spike_trains(self, action_id): return self._spike_trains[action_id] def unit_names(self, action_id, channel_group): - self.spike_trains(action_id) - return list(self._spike_trains[action_id][channel_group].keys()) + units = load_unit_annotations(self.data_path(action_id), channel_group=channel_group) + return [u['name'] for u in units] def stim_times(self, action_id): if action_id not in self._stim_times: From 87e12aa9a664f0112b67d46d8a233f5c723f751f Mon Sep 17 00:00:00 2001 From: Mikkel Date: Tue, 11 Jun 2024 14:34:02 +0200 Subject: [PATCH 40/47] possibility to get specific group --- .../tools/data_processing.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/expipe_plugin_cinpla/tools/data_processing.py b/src/expipe_plugin_cinpla/tools/data_processing.py index 86fd9c1..56766a8 100644 --- a/src/expipe_plugin_cinpla/tools/data_processing.py +++ b/src/expipe_plugin_cinpla/tools/data_processing.py @@ -465,19 +465,21 @@ def spike_train(self, action_id, channel_group, unit_id): self.spike_trains(action_id) return self._spike_trains[action_id][channel_group][unit_id] - def spike_trains(self, action_id): + def spike_trains(self, action_id, channel_group=None): if action_id not in self._spike_trains: self._spike_trains[action_id] = {} lim = self.get_lim(action_id) if self.stim_mask else None sts = load_spiketrains(self.data_path(action_id), lim=lim) for st in sts: - channel_group = st.annotations["group"] - if channel_group not in self._spike_trains[action_id]: - self._spike_trains[action_id][channel_group] = {} - self._spike_trains[action_id][channel_group][int(get_unit_id(st))] = st - - return self._spike_trains[action_id] + group = st.annotations["group"] + if group not in self._spike_trains[action_id]: + self._spike_trains[action_id][group] = {} + self._spike_trains[action_id][group][int(get_unit_id(st))] = st + if channel_group is None: + return self._spike_trains[action_id] + else: + return self._spike_trains[action_id][channel_group] def unit_names(self, action_id, channel_group): units = load_unit_annotations(self.data_path(action_id), channel_group=channel_group) From 991f7fdeadcf806a17e6e194eb1c4410c947f0c9 Mon Sep 17 00:00:00 2001 From: Mikkel Date: Fri, 14 Jun 2024 06:52:34 +0200 Subject: [PATCH 41/47] subtract relative session start time for all timestamps (session start is now zero) --- src/expipe_plugin_cinpla/data_loader.py | 30 ++++++++++++++++--------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/src/expipe_plugin_cinpla/data_loader.py b/src/expipe_plugin_cinpla/data_loader.py index 70a854c..3b1ffd6 100644 --- a/src/expipe_plugin_cinpla/data_loader.py +++ b/src/expipe_plugin_cinpla/data_loader.py @@ -106,8 +106,8 @@ def load_leds(data_path): green_data = green_spatial_series.data x1, y1 = red_data[:, 0], red_data[:, 1] x2, y2 = green_data[:, 0], green_data[:, 1] - t1 = red_spatial_series.timestamps - t2 = green_spatial_series.timestamps + t1 = red_spatial_series.timestamps - nwbfile.session_start_time + t2 = green_spatial_series.timestamps - nwbfile.session_start_time stop_time = np.max([t1[-1], t2[-1]]) return x1, y1, t1, x2, y2, t2, stop_time @@ -131,6 +131,10 @@ def load_lfp(data_path, channel_group=None, lim=None): LFP: neo.AnalogSignal The LFP signal """ + from pynwb import NWBHDF5IO + # get the session start time + io = NWBHDF5IO(str(data_path), "r") + nwbfile = io.read() recording_lfp = se.read_nwb_recording( str(data_path), electrical_series_path="processing/ecephys/LFP/ElectricalSeriesLFP" ) @@ -152,11 +156,11 @@ def load_lfp(data_path, channel_group=None, lim=None): if lim is None: lfp_traces = recording_lfp_group.get_traces(return_scaled=True) - t_start = recording_lfp.get_times()[0] * pq.s - t_stop = recording_lfp.get_times()[-1] * pq.s + t_start = (recording_lfp.get_times()[0] - nwbfile.session_start_time) * pq.s + t_stop = (recording_lfp.get_times()[-1] - nwbfile.session_start_time) * pq.s else: assert len(lim) == 2, "lim must be a list of two elements with t_start and t_stop" - times_all = recording_lfp_group.get_times() + times_all = recording_lfp_group.get_times() - nwbfile.session_start_time start_frame, end_frame = np.searchsorted(times_all, lim) times = times_all[start_frame:end_frame] t_start = times[0] * pq.s @@ -196,9 +200,9 @@ def load_epochs(data_path, label_column=None): with NWBHDF5IO(str(data_path), "r") as io: nwbfile = io.read() trials = nwbfile.trials.to_dataframe() - - start_times = trials["start_time"].values * pq.s - stop_times = trials["stop_time"].values * pq.s + nwbfile.session_start_time + start_times = (trials["start_time"].values - nwbfile.session_start_time) * pq.s + stop_times = (trials["stop_time"].values - nwbfile.session_start_time) * pq.s durations = stop_times - start_times if label_column is not None and label_column in trials.columns: @@ -251,6 +255,10 @@ def load_spiketrains(data_path, channel_group=None, lim=None): spiketrains: list of NEO spike trains The spike trains """ + from pynwb import NWBHDF5IO + # get the session start time + io = NWBHDF5IO(str(data_path), "r") + nwbfile = io.read() recording = se.read_nwb_recording(str(data_path), electrical_series_path="acquisition/ElectricalSeries") sorting = se.read_nwb_sorting(str(data_path), electrical_series_path="acquisition/ElectricalSeries") @@ -265,9 +273,11 @@ def load_spiketrains(data_path, channel_group=None, lim=None): sptr = [] # build neo objects for unit in unit_ids: - spike_times = sorting.get_unit_spike_train(unit, return_times=True) * pq.s + spike_times = sorting.get_unit_spike_train(unit, return_times=True) + # subtract the session start time + spike_times = (spike_times - nwbfile.session_start_time) * pq.s if lim is None: - times = recording.get_times() * pq.s + times = (recording.get_times - nwbfile.session_start_time) * pq.s t_start = times[0] t_stop = times[-1] else: From 9b1637e8b58b7657e4eefa9e1e2b1608acb56afa Mon Sep 17 00:00:00 2001 From: Mikkel Date: Fri, 14 Jun 2024 11:19:28 +0200 Subject: [PATCH 42/47] remove session time subtraction --- src/expipe_plugin_cinpla/data_loader.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/expipe_plugin_cinpla/data_loader.py b/src/expipe_plugin_cinpla/data_loader.py index 3b1ffd6..40cdf19 100644 --- a/src/expipe_plugin_cinpla/data_loader.py +++ b/src/expipe_plugin_cinpla/data_loader.py @@ -106,8 +106,8 @@ def load_leds(data_path): green_data = green_spatial_series.data x1, y1 = red_data[:, 0], red_data[:, 1] x2, y2 = green_data[:, 0], green_data[:, 1] - t1 = red_spatial_series.timestamps - nwbfile.session_start_time - t2 = green_spatial_series.timestamps - nwbfile.session_start_time + t1 = red_spatial_series.timestamps + t2 = green_spatial_series.timestamps stop_time = np.max([t1[-1], t2[-1]]) return x1, y1, t1, x2, y2, t2, stop_time @@ -156,11 +156,11 @@ def load_lfp(data_path, channel_group=None, lim=None): if lim is None: lfp_traces = recording_lfp_group.get_traces(return_scaled=True) - t_start = (recording_lfp.get_times()[0] - nwbfile.session_start_time) * pq.s - t_stop = (recording_lfp.get_times()[-1] - nwbfile.session_start_time) * pq.s + t_start = recording_lfp.get_times()[0] * pq.s + t_stop = recording_lfp.get_times()[-1] * pq.s else: assert len(lim) == 2, "lim must be a list of two elements with t_start and t_stop" - times_all = recording_lfp_group.get_times() - nwbfile.session_start_time + times_all = recording_lfp_group.get_times() start_frame, end_frame = np.searchsorted(times_all, lim) times = times_all[start_frame:end_frame] t_start = times[0] * pq.s @@ -200,9 +200,8 @@ def load_epochs(data_path, label_column=None): with NWBHDF5IO(str(data_path), "r") as io: nwbfile = io.read() trials = nwbfile.trials.to_dataframe() - nwbfile.session_start_time - start_times = (trials["start_time"].values - nwbfile.session_start_time) * pq.s - stop_times = (trials["stop_time"].values - nwbfile.session_start_time) * pq.s + start_times = trials["start_time"].values * pq.s + stop_times = trials["stop_time"].values * pq.s durations = stop_times - start_times if label_column is not None and label_column in trials.columns: @@ -275,9 +274,9 @@ def load_spiketrains(data_path, channel_group=None, lim=None): for unit in unit_ids: spike_times = sorting.get_unit_spike_train(unit, return_times=True) # subtract the session start time - spike_times = (spike_times - nwbfile.session_start_time) * pq.s + spike_times = spike_times * pq.s if lim is None: - times = (recording.get_times - nwbfile.session_start_time) * pq.s + times = recording.get_times * pq.s t_start = times[0] t_stop = times[-1] else: From 66447c703e6378b8c3658a302216b87ff095dfab Mon Sep 17 00:00:00 2001 From: Mikkel Date: Mon, 17 Jun 2024 20:45:06 +0200 Subject: [PATCH 43/47] bugfix --- src/expipe_plugin_cinpla/data_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/expipe_plugin_cinpla/data_loader.py b/src/expipe_plugin_cinpla/data_loader.py index 40cdf19..aaf4e5e 100644 --- a/src/expipe_plugin_cinpla/data_loader.py +++ b/src/expipe_plugin_cinpla/data_loader.py @@ -276,7 +276,7 @@ def load_spiketrains(data_path, channel_group=None, lim=None): # subtract the session start time spike_times = spike_times * pq.s if lim is None: - times = recording.get_times * pq.s + times = recording.get_times() * pq.s t_start = times[0] t_stop = times[-1] else: From 3497a51f7f901efbc3c5dfdab6dee08b20411ba6 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 20 Jun 2024 12:27:08 +0200 Subject: [PATCH 44/47] Remove units with less than n_components spikes --- src/expipe_plugin_cinpla/scripts/process.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/expipe_plugin_cinpla/scripts/process.py b/src/expipe_plugin_cinpla/scripts/process.py index fa9cc96..e05bdd2 100644 --- a/src/expipe_plugin_cinpla/scripts/process.py +++ b/src/expipe_plugin_cinpla/scripts/process.py @@ -235,6 +235,13 @@ def process_ecephys( if verbose: print(f"\tFound {len(sorting.get_unit_ids())} units!") + # remove units with less than n_components spikes + num_spikes = sorting.count_num_spikes_per_unit() + selected_units = sorting.unit_ids[np.array(list(num_spikes.values())) >= n_components] + n_too_few_spikes = int(len(sorting.unit_ids) - len(selected_units)) + print(f"\tRemoved {n_too_few_spikes} units with less than {n_components} spikes") + sorting = sorting.select_units(selected_units) + # extract waveforms if verbose: print("\nPostprocessing") From dc8c5b619db2ab1a6a8b13caa6c70c1580f80b53 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 20 Jun 2024 12:34:55 +0200 Subject: [PATCH 45/47] Pin scikit-learn and fix tests --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 6673ad2..683c855 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "neuroconv>=0.4.6", "pyopenephys>=1.2.0", "spikeinterface[full,widgets]>=0.100.0", + "scikit-learn<1.5.0", "pynwb>=2.5.0", "neuroconv>=0.4.6", "ipywidgets>=8.1.1", From b5424ffb9ce01128c0b1997dcf5a728eca1c6adb Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 20 Aug 2024 18:16:04 +0200 Subject: [PATCH 46/47] Pin SI version --- pyproject.toml | 2 +- src/expipe_plugin_cinpla/scripts/curation.py | 5 ++--- src/expipe_plugin_cinpla/scripts/utils.py | 22 ++++++++++++++++---- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 683c855..22f5e8a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ dependencies = [ "expipe>=0.6.0", "neuroconv>=0.4.6", "pyopenephys>=1.2.0", - "spikeinterface[full,widgets]>=0.100.0", + "spikeinterface[full,widgets]>=0.100.0,<0.101.0", "scikit-learn<1.5.0", "pynwb>=2.5.0", "neuroconv>=0.4.6", diff --git a/src/expipe_plugin_cinpla/scripts/curation.py b/src/expipe_plugin_cinpla/scripts/curation.py index 5f7e584..2ebb5f6 100644 --- a/src/expipe_plugin_cinpla/scripts/curation.py +++ b/src/expipe_plugin_cinpla/scripts/curation.py @@ -147,9 +147,8 @@ def apply_curation(self, sorter, curated_sorting): print("Removing excess spikes from curated sorting") curated_sorting = sc.remove_excess_spikes(curated_sorting, recording=recording) - # if not sort by group, extract dense and estimate group - if "group" not in curated_sorting.get_property_keys(): - compute_and_set_unit_groups(curated_sorting, recording) + # if "group" is not available or some missing groups, extract dense and estimate group + compute_and_set_unit_groups(curated_sorting, recording) print("Extracting waveforms on curated sorting") self.curated_we = si.extract_waveforms( diff --git a/src/expipe_plugin_cinpla/scripts/utils.py b/src/expipe_plugin_cinpla/scripts/utils.py index 4c52258..d7bf380 100644 --- a/src/expipe_plugin_cinpla/scripts/utils.py +++ b/src/expipe_plugin_cinpla/scripts/utils.py @@ -289,7 +289,21 @@ def generate_phy_restore_files(phy_folder): def compute_and_set_unit_groups(sorting, recording): import spikeinterface as si - we_mem = si.extract_waveforms(recording, sorting, folder=None, mode="memory", sparse=False) - extremum_channel_indices = si.get_template_extremum_channel(we_mem, outputs="index") - unit_groups = recording.get_channel_groups()[np.array(list(extremum_channel_indices.values()))] - sorting.set_property("group", unit_groups) + if len(np.unique(recording.get_channel_groups())) == 1: + sorting.set_property("group", np.zeros(len(sorting.unit_ids), dtype="int64")) + else: + if "group" not in sorting.get_property_keys(): + we_mem = si.extract_waveforms(recording, sorting, folder=None, mode="memory", sparse=False) + extremum_channel_indices = si.get_template_extremum_channel(we_mem, outputs="index") + unit_groups = recording.get_channel_groups()[np.array(list(extremum_channel_indices.values()))] + sorting.set_property("group", unit_groups) + else: + unit_groups = sorting.get_property("group") + # if there are units without group, we need to compute them + unit_ids_without_group = np.array(sorting.unit_ids)[np.where(unit_groups == "nan")[0]] + if len(unit_ids_without_group) > 0: + sorting_no_group = sorting.select_units(unit_ids=unit_ids_without_group) + we_mem = si.extract_waveforms(recording, sorting_no_group, folder=None, mode="memory", sparse=False) + extremum_channel_indices = si.get_template_extremum_channel(we_mem, outputs="index") + unit_groups[unit_ids_without_group] = recording.get_channel_groups()[np.array(list(extremum_channel_indices.values()))] + sorting.set_property("group", unit_groups) From 9bc0a2a27e676279968033ab36ee01c76c10f4ef Mon Sep 17 00:00:00 2001 From: Nicolai Haug Date: Wed, 21 Aug 2024 15:46:31 +0200 Subject: [PATCH 47/47] Fix formatting --- src/expipe_plugin_cinpla/data_loader.py | 18 +++++++++++------- src/expipe_plugin_cinpla/scripts/utils.py | 4 +++- .../tools/data_processing.py | 6 ++++-- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/src/expipe_plugin_cinpla/data_loader.py b/src/expipe_plugin_cinpla/data_loader.py index 34badee..3b65564 100644 --- a/src/expipe_plugin_cinpla/data_loader.py +++ b/src/expipe_plugin_cinpla/data_loader.py @@ -132,10 +132,12 @@ def load_lfp(data_path, channel_group=None, lim=None): LFP: neo.AnalogSignal The LFP signal """ - from pynwb import NWBHDF5IO + # from pynwb import NWBHDF5IO + # get the session start time - io = NWBHDF5IO(str(data_path), "r") - nwbfile = io.read() + # TODO: are io and nwbfile needed? + # io = NWBHDF5IO(str(data_path), "r") + # nwbfile = io.read() recording_lfp = se.read_nwb_recording( str(data_path), electrical_series_path="processing/ecephys/LFP/ElectricalSeriesLFP" ) @@ -255,10 +257,12 @@ def load_spiketrains(data_path, channel_group=None, lim=None): spiketrains: list of NEO spike trains The spike trains """ - from pynwb import NWBHDF5IO + # from pynwb import NWBHDF5IO + # get the session start time - io = NWBHDF5IO(str(data_path), "r") - nwbfile = io.read() + # TODO: are io and nwbfile needed? + # io = NWBHDF5IO(str(data_path), "r") + # nwbfile = io.read() recording = se.read_nwb_recording(str(data_path), electrical_series_path="acquisition/ElectricalSeries") sorting = se.read_nwb_sorting(str(data_path), electrical_series_path="acquisition/ElectricalSeries") @@ -273,7 +277,7 @@ def load_spiketrains(data_path, channel_group=None, lim=None): sptr = [] # build neo objects for unit in unit_ids: - spike_times = sorting.get_unit_spike_train(unit, return_times=True) + spike_times = sorting.get_unit_spike_train(unit, return_times=True) # subtract the session start time spike_times = spike_times * pq.s if lim is None: diff --git a/src/expipe_plugin_cinpla/scripts/utils.py b/src/expipe_plugin_cinpla/scripts/utils.py index 4b61696..b3c8a55 100644 --- a/src/expipe_plugin_cinpla/scripts/utils.py +++ b/src/expipe_plugin_cinpla/scripts/utils.py @@ -304,5 +304,7 @@ def compute_and_set_unit_groups(sorting, recording): sorting_no_group = sorting.select_units(unit_ids=unit_ids_without_group) we_mem = si.extract_waveforms(recording, sorting_no_group, folder=None, mode="memory", sparse=False) extremum_channel_indices = si.get_template_extremum_channel(we_mem, outputs="index") - unit_groups[unit_ids_without_group] = recording.get_channel_groups()[np.array(list(extremum_channel_indices.values()))] + unit_groups[unit_ids_without_group] = recording.get_channel_groups()[ + np.array(list(extremum_channel_indices.values())) + ] sorting.set_property("group", unit_groups) diff --git a/src/expipe_plugin_cinpla/tools/data_processing.py b/src/expipe_plugin_cinpla/tools/data_processing.py index dec7167..eca4ee6 100644 --- a/src/expipe_plugin_cinpla/tools/data_processing.py +++ b/src/expipe_plugin_cinpla/tools/data_processing.py @@ -482,8 +482,10 @@ def spike_trains(self, action_id, channel_group=None): return self._spike_trains[action_id][channel_group] def unit_names(self, action_id, channel_group): - units = load_unit_annotations(self.data_path(action_id), channel_group=channel_group) - return [u['name'] for u in units] + # TODO + # units = load_unit_annotations(self.data_path(action_id), channel_group=channel_group) + units = None + return [u["name"] for u in units] def stim_times(self, action_id): if action_id not in self._stim_times: