diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index e5e6a07a..19c91e36 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -1,7 +1,11 @@ name: Test on: push: + branches: + - main pull_request: + branches: + - main workflow_dispatch: jobs: devcontainer-build: @@ -31,4 +35,3 @@ jobs: run: | python_version=${{matrix.py_ver}} black element_array_ephys --check --verbose --target-version py${python_version//.} - diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0d513df7..6d28ef11 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ exclude: (^.github/|^docs/|^images/) repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -16,7 +16,7 @@ repos: # black - repo: https://github.com/psf/black - rev: 22.12.0 + rev: 24.2.0 hooks: - id: black - id: black-jupyter @@ -25,7 +25,7 @@ repos: # isort - repo: https://github.com/pycqa/isort - rev: 5.11.2 + rev: 5.13.2 hooks: - id: isort args: ["--profile", "black"] @@ -33,7 +33,7 @@ repos: # flake8 - repo: https://github.com/pycqa/flake8 - rev: 4.0.1 + rev: 7.0.0 hooks: - id: flake8 args: # arguments to configure flake8 diff --git a/CHANGELOG.md b/CHANGELOG.md index d2e48eaa..34d1a2e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,21 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention. + +## [1.0.0] - 2024-09-10 + ++ Update - No longer support multiple variation of ephys module, keep only `ephys_no_curation` module, renamed to `ephys` ++ Update - Remove other ephys modules (e.g. `ephys_acute`, `ephys_chronic`) (moved to different branches) ++ Update - Add support for `SpikeInterface` ++ Update - Remove support for `ecephys_spike_sorting` (moved to a different branch) ++ Update - Simplify the "activate" mechanism + +## [0.4.0] - 2024-08-16 + ++ Add - support for SpikeInterface version >= 0.101.0 (updated API) ++ Add - feature for memoization of spike sorting results (prevent duplicated runs) + + ## [0.3.5] - 2024-08-16 + Fix - Improve `spikeglx` loader in extracting neuropixels probe type from the meta file diff --git a/element_array_ephys/__init__.py b/element_array_ephys/__init__.py index 1c0c7285..079950b4 100644 --- a/element_array_ephys/__init__.py +++ b/element_array_ephys/__init__.py @@ -1 +1,3 @@ -from . import ephys_acute as ephys +from . import ephys + +ephys_no_curation = ephys # alias for backward compatibility diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys.py similarity index 66% rename from element_array_ephys/ephys_no_curation.py rename to element_array_ephys/ephys.py index cd0909c9..ad9bb8d7 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys.py @@ -10,10 +10,10 @@ import pandas as pd from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory -from . import ephys_report, probe +from . import probe from .readers import kilosort, openephys, spikeglx -log = dj.logger +logger = dj.logger schema = dj.schema() @@ -22,7 +22,6 @@ def activate( ephys_schema_name: str, - probe_schema_name: str = None, *, create_schema: bool = True, create_tables: bool = True, @@ -32,7 +31,6 @@ def activate( Args: ephys_schema_name (str): A string containing the name of the ephys schema. - probe_schema_name (str): A string containing the name of the probe schema. create_schema (bool): If True, schema will be created in the database. create_tables (bool): If True, tables related to the schema will be created in the database. linking_module (str): A string containing the module name or module containing the required dependencies to activate the schema. @@ -46,7 +44,6 @@ def activate( get_ephys_root_data_dir(): Returns absolute path for root data director(y/ies) with all electrophysiological recording sessions, as a list of string(s). get_session_direction(session_key: dict): Returns path to electrophysiology data for the a particular session as a list of strings. get_processed_data_dir(): Optional. Returns absolute path for processed data. Defaults to root directory. - """ if isinstance(linking_module, str): @@ -58,17 +55,15 @@ def activate( global _linking_module _linking_module = linking_module - # activate - probe.activate( - probe_schema_name, create_schema=create_schema, create_tables=create_tables - ) + if not probe.schema.is_activated(): + raise RuntimeError("Please activate the `probe` schema first.") + schema.activate( ephys_schema_name, create_schema=create_schema, create_tables=create_tables, add_objects=_linking_module.__dict__, ) - ephys_report.activate(f"{ephys_schema_name}_report", ephys_schema_name) # -------------- Functions required by the elements-ephys --------------- @@ -129,7 +124,7 @@ class AcquisitionSoftware(dj.Lookup): """ definition = """ # Name of software used for recording of neuropixels probes - SpikeGLX or Open Ephys - acq_software: varchar(24) + acq_software: varchar(24) """ contents = zip(["SpikeGLX", "Open Ephys"]) @@ -272,15 +267,24 @@ class EphysRecording(dj.Imported): definition = """ # Ephys recording from a probe insertion for a given session. - -> ProbeInsertion + -> ProbeInsertion --- -> probe.ElectrodeConfig -> AcquisitionSoftware - sampling_rate: float # (Hz) + sampling_rate: float # (Hz) recording_datetime: datetime # datetime of the recording from this probe recording_duration: float # (seconds) duration of the recording from this probe """ + class Channel(dj.Part): + definition = """ + -> master + channel_idx: int # channel index (index of the raw data) + --- + -> probe.ElectrodeConfig.Electrode + channel_name="": varchar(64) # alias of the channel + """ + class EphysFile(dj.Part): """Paths of electrophysiology recording files for each insertion. @@ -304,7 +308,7 @@ def make(self, key): "probe" ) - # search session dir and determine acquisition software + # Search session dir and determine acquisition software for ephys_pattern, ephys_acq_type in ( ("*.ap.meta", "SpikeGLX"), ("*.oebin", "Open Ephys"), @@ -315,8 +319,13 @@ def make(self, key): break else: raise FileNotFoundError( - "Ephys recording data not found!" - " Neither SpikeGLX nor Open Ephys recording files found" + f"Ephys recording data not found in {session_dir}." + "Neither SpikeGLX nor Open Ephys recording files found" + ) + + if acq_software not in AcquisitionSoftware.fetch("acq_software"): + raise NotImplementedError( + f"Processing ephys files from acquisition software of type {acq_software} is not yet implemented." ) supported_probe_types = probe.ProbeType.fetch("probe_type") @@ -325,51 +334,79 @@ def make(self, key): for meta_filepath in ephys_meta_filepaths: spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: + spikeglx_meta_filepath = meta_filepath break else: raise FileNotFoundError( "No SpikeGLX data found for probe insertion: {}".format(key) ) - if spikeglx_meta.probe_model in supported_probe_types: - probe_type = spikeglx_meta.probe_model - electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} + if spikeglx_meta.probe_model not in supported_probe_types: + raise NotImplementedError( + f"Processing for neuropixels probe model {spikeglx_meta.probe_model} not yet implemented." + ) - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } + probe_type = spikeglx_meta.probe_model + electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} - electrode_group_members = [ - probe_electrodes[(shank, shank_col, shank_row)] - for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"] - ] - else: - raise NotImplementedError( - "Processing for neuropixels probe model" - " {} not yet implemented".format(spikeglx_meta.probe_model) + probe_electrodes = { + (shank, shank_col, shank_row): key + for key, shank, shank_col, shank_row in zip( + *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") ) + } # electrode configuration + electrode_group_members = [ + probe_electrodes[(shank, shank_col, shank_row)] + for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"] + ] # recording session-specific electrode configuration + + econfig_entry, econfig_electrodes = generate_electrode_config_entry( + probe_type, electrode_group_members + ) - self.insert1( + ephys_recording_entry = { + **key, + "electrode_config_hash": econfig_entry["electrode_config_hash"], + "acq_software": acq_software, + "sampling_rate": spikeglx_meta.meta["imSampRate"], + "recording_datetime": spikeglx_meta.recording_time, + "recording_duration": ( + spikeglx_meta.recording_duration + or spikeglx.retrieve_recording_duration(spikeglx_meta_filepath) + ), + } + + root_dir = find_root_directory( + get_ephys_root_data_dir(), spikeglx_meta_filepath + ) + + ephys_file_entries = [ { **key, - **generate_electrode_config(probe_type, electrode_group_members), - "acq_software": acq_software, - "sampling_rate": spikeglx_meta.meta["imSampRate"], - "recording_datetime": spikeglx_meta.recording_time, - "recording_duration": ( - spikeglx_meta.recording_duration - or spikeglx.retrieve_recording_duration(meta_filepath) - ), + "file_path": spikeglx_meta_filepath.relative_to( + root_dir + ).as_posix(), } - ) + ] - root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath) - self.EphysFile.insert1( - {**key, "file_path": meta_filepath.relative_to(root_dir).as_posix()} - ) + # Insert channel information + # Get channel and electrode-site mapping + channel2electrode_map = { + recorded_site: probe_electrodes[(shank, shank_col, shank_row)] + for recorded_site, (shank, shank_col, shank_row, _) in enumerate( + spikeglx_meta.shankmap["data"] + ) + } + + ephys_channel_entries = [ + { + **key, + "electrode_config_hash": econfig_entry["electrode_config_hash"], + "channel_idx": channel_idx, + **channel_info, + } + for channel_idx, channel_info in channel2electrode_map.items() + ] elif acq_software == "Open Ephys": dataset = openephys.OpenEphys(session_dir) for serial_number, probe_data in dataset.probes.items(): @@ -385,60 +422,84 @@ def make(self, key): 'No analog signals found - check "structure.oebin" file or "continuous" directory' ) - if probe_data.probe_model in supported_probe_types: - probe_type = probe_data.probe_model - electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} - - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - electrode_group_members = [ - probe_electrodes[channel_idx] - for channel_idx in probe_data.ap_meta["channels_indices"] - ] - else: + if probe_data.probe_model not in supported_probe_types: raise NotImplementedError( - "Processing for neuropixels" - " probe model {} not yet implemented".format(probe_data.probe_model) + f"Processing for neuropixels probe model {probe_data.probe_model} not yet implemented." ) - self.insert1( - { - **key, - **generate_electrode_config(probe_type, electrode_group_members), - "acq_software": acq_software, - "sampling_rate": probe_data.ap_meta["sample_rate"], - "recording_datetime": probe_data.recording_info[ - "recording_datetimes" - ][0], - "recording_duration": np.sum( - probe_data.recording_info["recording_durations"] - ), - } + probe_type = probe_data.probe_model + electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} + + probe_electrodes = { + key["electrode"]: key for key in electrode_query.fetch("KEY") + } # electrode configuration + + electrode_group_members = [ + probe_electrodes[channel_idx] + for channel_idx in probe_data.ap_meta["channels_indices"] + ] # recording session-specific electrode configuration + + econfig_entry, econfig_electrodes = generate_electrode_config_entry( + probe_type, electrode_group_members ) + ephys_recording_entry = { + **key, + "electrode_config_hash": econfig_entry["electrode_config_hash"], + "acq_software": acq_software, + "sampling_rate": probe_data.ap_meta["sample_rate"], + "recording_datetime": probe_data.recording_info["recording_datetimes"][ + 0 + ], + "recording_duration": np.sum( + probe_data.recording_info["recording_durations"] + ), + } + root_dir = find_root_directory( get_ephys_root_data_dir(), probe_data.recording_info["recording_files"][0], ) - self.EphysFile.insert( - [ - {**key, "file_path": fp.relative_to(root_dir).as_posix()} - for fp in probe_data.recording_info["recording_files"] - ] - ) - # explicitly garbage collect "dataset" - # as these may have large memory footprint and may not be cleared fast enough + + ephys_file_entries = [ + {**key, "file_path": fp.relative_to(root_dir).as_posix()} + for fp in probe_data.recording_info["recording_files"] + ] + + channel2electrode_map = { + channel_idx: probe_electrodes[channel_idx] + for channel_idx in probe_data.ap_meta["channels_indices"] + } + + ephys_channel_entries = [ + { + **key, + "electrode_config_hash": econfig_entry["electrode_config_hash"], + "channel_idx": channel_idx, + **channel_info, + } + for channel_idx, channel_info in channel2electrode_map.items() + ] + + # Explicitly garbage collect "dataset" as these may have large memory footprint and may not be cleared fast enough del probe_data, dataset gc.collect() else: raise NotImplementedError( - f"Processing ephys files from" - f" acquisition software of type {acq_software} is" - f" not yet implemented" + f"Processing ephys files from acquisition software of type {acq_software} is not yet implemented." ) + # Insert into probe.ElectrodeConfig (recording configuration) + if not probe.ElectrodeConfig & { + "electrode_config_hash": econfig_entry["electrode_config_hash"] + }: + probe.ElectrodeConfig.insert1(econfig_entry) + probe.ElectrodeConfig.Electrode.insert(econfig_electrodes) + + self.insert1(ephys_recording_entry) + self.EphysFile.insert(ephys_file_entries) + self.Channel.insert(ephys_channel_entries) + @schema class LFP(dj.Imported): @@ -471,9 +532,9 @@ class Electrode(dj.Part): definition = """ -> master - -> probe.ElectrodeConfig.Electrode + -> probe.ElectrodeConfig.Electrode --- - lfp: longblob # (uV) recorded lfp at this electrode + lfp: longblob # (uV) recorded lfp at this electrode """ # Only store LFP for every 9th channel, due to high channel density, @@ -614,14 +675,14 @@ class ClusteringParamSet(dj.Lookup): ClusteringMethod (dict): ClusteringMethod primary key. paramset_desc (varchar(128) ): Description of the clustering parameter set. param_set_hash (uuid): UUID hash for the parameter set. - params (longblob): Set of clustering parameters + params (longblob): Set of clustering parameters. """ definition = """ # Parameter set to be used in a clustering procedure paramset_idx: smallint --- - -> ClusteringMethod + -> ClusteringMethod paramset_desc: varchar(128) param_set_hash: uuid unique index (param_set_hash) @@ -700,6 +761,7 @@ class ClusterQualityLabel(dj.Lookup): ("ok", "probably a single unit, but could be contaminated"), ("mua", "multi-unit activity"), ("noise", "bad unit"), + ("n.a.", "not available"), ] @@ -724,18 +786,15 @@ class ClusteringTask(dj.Manual): """ @classmethod - def infer_output_dir( - cls, key, relative: bool = False, mkdir: bool = False - ) -> pathlib.Path: + def infer_output_dir(cls, key, relative: bool = False, mkdir: bool = False): """Infer output directory if it is not provided. Args: key (dict): ClusteringTask primary key. Returns: - Expected clustering_output_dir based on the following convention: - processed_dir / session_dir / probe_{insertion_number} / {clustering_method}_{paramset_idx} - e.g.: sub4/sess1/probe_2/kilosort2_0 + Pathlib.Path: Expected clustering_output_dir based on the following convention: processed_dir / session_dir / probe_{insertion_number} / {clustering_method}_{paramset_idx} + e.g.: sub4/sess1/probe_2/kilosort2_0 """ processed_dir = pathlib.Path(get_processed_root_data_dir()) session_dir = find_full_path( @@ -758,7 +817,7 @@ def infer_output_dir( if mkdir: output_dir.mkdir(parents=True, exist_ok=True) - log.info(f"{output_dir} created!") + logger.info(f"{output_dir} created!") return output_dir.relative_to(processed_dir) if relative else output_dir @@ -809,7 +868,7 @@ class Clustering(dj.Imported): # Clustering Procedure -> ClusteringTask --- - clustering_time: datetime # time of generation of this set of clustering results + clustering_time: datetime # time of generation of this set of clustering results package_version='': varchar(16) """ @@ -838,7 +897,7 @@ def make(self, key): ).fetch1("acq_software", "clustering_method", "params") if "kilosort" in clustering_method: - from element_array_ephys.readers import kilosort_triggering + from .spike_sorting import kilosort_triggering # add additional probe-recording and channels details into `params` params = {**params, **get_recording_channels_details(key)} @@ -850,10 +909,6 @@ def make(self, key): spikeglx_meta_filepath.parent ) spikeglx_recording.validate_file("ap") - run_CatGT = ( - params.pop("run_CatGT", True) - and "_tcat." not in spikeglx_meta_filepath.stem - ) if clustering_method.startswith("pykilosort"): kilosort_triggering.run_pykilosort( @@ -874,7 +929,7 @@ def make(self, key): ks_output_dir=kilosort_dir, params=params, KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - run_CatGT=run_CatGT, + run_CatGT=True, ) run_kilosort.run_modules() elif acq_software == "Open Ephys": @@ -929,7 +984,7 @@ class CuratedClustering(dj.Imported): definition = """ # Clustering results of the spike sorting step. - -> Clustering + -> Clustering """ class Unit(dj.Part): @@ -946,7 +1001,7 @@ class Unit(dj.Part): spike_depths (longblob): Array of depths associated with each spike, relative to each spike. """ - definition = """ + definition = """ # Properties of a given unit from a round of clustering (and curation) -> master unit: int @@ -956,85 +1011,175 @@ class Unit(dj.Part): spike_count: int # how many spikes in this recording for this unit spike_times: longblob # (s) spike times of this unit, relative to the start of the EphysRecording spike_sites : longblob # array of electrode associated with each spike - spike_depths=null : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe + spike_depths=null : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe """ def make(self, key): """Automated population of Unit information.""" - output_dir = (ClusteringTask & key).fetch1("clustering_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) + clustering_method, output_dir = ( + ClusteringTask * ClusteringParamSet & key + ).fetch1("clustering_method", "clustering_output_dir") + output_dir = find_full_path(get_ephys_root_data_dir(), output_dir) + + # Get channel and electrode-site mapping + electrode_query = (EphysRecording.Channel & key).proj(..., "-channel_name") + channel2electrode_map: dict[int, dict] = { + chn.pop("channel_idx"): chn for chn in electrode_query.fetch(as_dict=True) + } - kilosort_dataset = kilosort.Kilosort(kilosort_dir) - acq_software, sample_rate = (EphysRecording & key).fetch1( - "acq_software", "sampling_rate" - ) + # Get sorter method and create output directory. + sorter_name = clustering_method.replace(".", "_") + si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" - sample_rate = kilosort_dataset.data["params"].get("sample_rate", sample_rate) + if si_sorting_analyzer_dir.exists(): # Read from spikeinterface outputs + import spikeinterface as si - # ---------- Unit ---------- - # -- Remove 0-spike units - withspike_idx = [ - i - for i, u in enumerate(kilosort_dataset.data["cluster_ids"]) - if (kilosort_dataset.data["spike_clusters"] == u).any() - ] - valid_units = kilosort_dataset.data["cluster_ids"][withspike_idx] - valid_unit_labels = kilosort_dataset.data["cluster_groups"][withspike_idx] - # -- Get channel and electrode-site mapping - channel2electrodes = get_neuropixels_channel2electrode_map(key, acq_software) - - # -- Spike-times -- - # spike_times_sec_adj > spike_times_sec > spike_times - spike_time_key = ( - "spike_times_sec_adj" - if "spike_times_sec_adj" in kilosort_dataset.data - else ( - "spike_times_sec" - if "spike_times_sec" in kilosort_dataset.data - else "spike_times" + sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) + si_sorting = sorting_analyzer.sorting + + # Find representative channel for each unit + unit_peak_channel: dict[int, np.ndarray] = ( + si.ChannelSparsity.from_best_channels( + sorting_analyzer, + 1, + ).unit_id_to_channel_indices ) - ) - spike_times = kilosort_dataset.data[spike_time_key] - kilosort_dataset.extract_spike_depths() + unit_peak_channel: dict[int, int] = { + u: chn[0] for u, chn in unit_peak_channel.items() + } - # -- Spike-sites and Spike-depths -- - spike_sites = np.array( - [ - channel2electrodes[s]["electrode"] - for s in kilosort_dataset.data["spike_sites"] - ] - ) - spike_depths = kilosort_dataset.data["spike_depths"] - - # -- Insert unit, label, peak-chn - units = [] - for unit, unit_lbl in zip(valid_units, valid_unit_labels): - if (kilosort_dataset.data["spike_clusters"] == unit).any(): - unit_channel, _ = kilosort_dataset.get_best_channel(unit) - unit_spike_times = ( - spike_times[kilosort_dataset.data["spike_clusters"] == unit] - / sample_rate + spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit() + # {unit: spike_count} + + # update channel2electrode_map to match with probe's channel index + channel2electrode_map = { + idx: channel2electrode_map[int(chn_idx)] + for idx, chn_idx in enumerate(sorting_analyzer.get_probe().contact_ids) + } + + # Get unit id to quality label mapping + cluster_quality_label_map = { + int(unit_id): ( + si_sorting.get_unit_property(unit_id, "KSLabel") + if "KSLabel" in si_sorting.get_property_keys() + else "n.a." ) - spike_count = len(unit_spike_times) + for unit_id in si_sorting.unit_ids + } + + spike_locations = sorting_analyzer.get_extension("spike_locations") + extremum_channel_inds = si.template_tools.get_template_extremum_channel( + sorting_analyzer, outputs="index" + ) + spikes_df = pd.DataFrame( + sorting_analyzer.sorting.to_spike_vector( + extremum_channel_inds=extremum_channel_inds + ) + ) + + units = [] + for unit_idx, unit_id in enumerate(si_sorting.unit_ids): + unit_id = int(unit_id) + unit_spikes_df = spikes_df[spikes_df.unit_index == unit_idx] + spike_sites = np.array( + [ + channel2electrode_map[chn_idx]["electrode"] + for chn_idx in unit_spikes_df.channel_index + ] + ) + unit_spikes_loc = spike_locations.get_data()[unit_spikes_df.index] + _, spike_depths = zip(*unit_spikes_loc) # x-coordinates, y-coordinates + spike_times = si_sorting.get_unit_spike_train( + unit_id, return_times=True + ) + + assert len(spike_times) == len(spike_sites) == len(spike_depths) units.append( { - "unit": unit, - "cluster_quality_label": unit_lbl, - **channel2electrodes[unit_channel], - "spike_times": unit_spike_times, - "spike_count": spike_count, - "spike_sites": spike_sites[ - kilosort_dataset.data["spike_clusters"] == unit - ], - "spike_depths": spike_depths[ - kilosort_dataset.data["spike_clusters"] == unit - ], + **key, + **channel2electrode_map[unit_peak_channel[unit_id]], + "unit": unit_id, + "cluster_quality_label": cluster_quality_label_map[unit_id], + "spike_times": spike_times, + "spike_count": spike_count_dict[unit_id], + "spike_sites": spike_sites, + "spike_depths": spike_depths, } ) + else: # read from kilosort outputs + kilosort_dataset = kilosort.Kilosort(output_dir) + acq_software, sample_rate = (EphysRecording & key).fetch1( + "acq_software", "sampling_rate" + ) + + sample_rate = kilosort_dataset.data["params"].get( + "sample_rate", sample_rate + ) + + # ---------- Unit ---------- + # -- Remove 0-spike units + withspike_idx = [ + i + for i, u in enumerate(kilosort_dataset.data["cluster_ids"]) + if (kilosort_dataset.data["spike_clusters"] == u).any() + ] + valid_units = kilosort_dataset.data["cluster_ids"][withspike_idx] + valid_unit_labels = kilosort_dataset.data["cluster_groups"][withspike_idx] + + # -- Spike-times -- + # spike_times_sec_adj > spike_times_sec > spike_times + spike_time_key = ( + "spike_times_sec_adj" + if "spike_times_sec_adj" in kilosort_dataset.data + else ( + "spike_times_sec" + if "spike_times_sec" in kilosort_dataset.data + else "spike_times" + ) + ) + spike_times = kilosort_dataset.data[spike_time_key] + kilosort_dataset.extract_spike_depths() + + # -- Spike-sites and Spike-depths -- + spike_sites = np.array( + [ + channel2electrode_map[s]["electrode"] + for s in kilosort_dataset.data["spike_sites"] + ] + ) + spike_depths = kilosort_dataset.data["spike_depths"] + + # -- Insert unit, label, peak-chn + units = [] + for unit, unit_lbl in zip(valid_units, valid_unit_labels): + if (kilosort_dataset.data["spike_clusters"] == unit).any(): + unit_channel, _ = kilosort_dataset.get_best_channel(unit) + unit_spike_times = ( + spike_times[kilosort_dataset.data["spike_clusters"] == unit] + / sample_rate + ) + spike_count = len(unit_spike_times) + + units.append( + { + **key, + "unit": unit, + "cluster_quality_label": unit_lbl, + **channel2electrode_map[unit_channel], + "spike_times": unit_spike_times, + "spike_count": spike_count, + "spike_sites": spike_sites[ + kilosort_dataset.data["spike_clusters"] == unit + ], + "spike_depths": spike_depths[ + kilosort_dataset.data["spike_clusters"] == unit + ], + } + ) self.insert1(key) - self.Unit.insert([{**key, **u} for u in units]) + self.Unit.insert(units, ignore_extra_fields=True) @schema @@ -1082,113 +1227,171 @@ class Waveform(dj.Part): # Spike waveforms and their mean across spikes for the given unit -> master -> CuratedClustering.Unit - -> probe.ElectrodeConfig.Electrode - --- + -> probe.ElectrodeConfig.Electrode + --- waveform_mean: longblob # (uV) mean waveform across spikes of the given unit waveforms=null: longblob # (uV) (spike x sample) waveforms of a sampling of spikes at the given electrode for the given unit """ def make(self, key): """Populates waveform tables.""" - output_dir = (ClusteringTask & key).fetch1("clustering_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) + clustering_method, output_dir = ( + ClusteringTask * ClusteringParamSet & key + ).fetch1("clustering_method", "clustering_output_dir") + output_dir = find_full_path(get_ephys_root_data_dir(), output_dir) + sorter_name = clustering_method.replace(".", "_") + + # Get channel and electrode-site mapping + electrode_query = (EphysRecording.Channel & key).proj(..., "-channel_name") + channel2electrode_map: dict[int, dict] = { + chn.pop("channel_idx"): chn for chn in electrode_query.fetch(as_dict=True) + } - kilosort_dataset = kilosort.Kilosort(kilosort_dir) + si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" + if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs + import spikeinterface as si - acq_software, probe_serial_number = ( - EphysRecording * ProbeInsertion & key - ).fetch1("acq_software", "probe") + sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) - # -- Get channel and electrode-site mapping - recording_key = (EphysRecording & key).fetch1("KEY") - channel2electrodes = get_neuropixels_channel2electrode_map( - recording_key, acq_software - ) + # Find representative channel for each unit + unit_peak_channel: dict[int, np.ndarray] = ( + si.ChannelSparsity.from_best_channels( + sorting_analyzer, 1 + ).unit_id_to_channel_indices + ) # {unit: peak_channel_index} + unit_peak_channel = {u: chn[0] for u, chn in unit_peak_channel.items()} - # Get all units - units = { - u["unit"]: u - for u in (CuratedClustering.Unit & key).fetch(as_dict=True, order_by="unit") - } + # update channel2electrode_map to match with probe's channel index + channel2electrode_map = { + idx: channel2electrode_map[int(chn_idx)] + for idx, chn_idx in enumerate(sorting_analyzer.get_probe().contact_ids) + } - if (kilosort_dir / "mean_waveforms.npy").exists(): - unit_waveforms = np.load( - kilosort_dir / "mean_waveforms.npy" - ) # unit x channel x sample + templates = sorting_analyzer.get_extension("templates") def yield_unit_waveforms(): - for unit_no, unit_waveform in zip( - kilosort_dataset.data["cluster_ids"], unit_waveforms + for unit in (CuratedClustering.Unit & key).fetch( + "KEY", order_by="unit" ): - unit_peak_waveform = {} - unit_electrode_waveforms = [] - if unit_no in units: + # Get mean waveform for this unit from all channels - (sample x channel) + unit_waveforms = templates.get_unit_template( + unit_id=unit["unit"], operator="average" + ) + unit_peak_waveform = { + **unit, + "peak_electrode_waveform": unit_waveforms[ + :, unit_peak_channel[unit["unit"]] + ], + } + + unit_electrode_waveforms = [ + { + **unit, + **channel2electrode_map[chn_idx], + "waveform_mean": unit_waveforms[:, chn_idx], + } + for chn_idx in channel2electrode_map + ] + + yield unit_peak_waveform, unit_electrode_waveforms + + else: # read from kilosort outputs (ecephys pipeline) + kilosort_dataset = kilosort.Kilosort(output_dir) + + acq_software, probe_serial_number = ( + EphysRecording * ProbeInsertion & key + ).fetch1("acq_software", "probe") + + # Get all units + units = { + u["unit"]: u + for u in (CuratedClustering.Unit & key).fetch( + as_dict=True, order_by="unit" + ) + } + + if (output_dir / "mean_waveforms.npy").exists(): + unit_waveforms = np.load( + output_dir / "mean_waveforms.npy" + ) # unit x channel x sample + + def yield_unit_waveforms(): + for unit_no, unit_waveform in zip( + kilosort_dataset.data["cluster_ids"], unit_waveforms + ): + unit_peak_waveform = {} + unit_electrode_waveforms = [] + if unit_no in units: + for channel, channel_waveform in zip( + kilosort_dataset.data["channel_map"], unit_waveform + ): + unit_electrode_waveforms.append( + { + **units[unit_no], + **channel2electrode_map[channel], + "waveform_mean": channel_waveform, + } + ) + if ( + channel2electrode_map[channel]["electrode"] + == units[unit_no]["electrode"] + ): + unit_peak_waveform = { + **units[unit_no], + "peak_electrode_waveform": channel_waveform, + } + yield unit_peak_waveform, unit_electrode_waveforms + + else: + if acq_software == "SpikeGLX": + spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) + neuropixels_recording = spikeglx.SpikeGLX( + spikeglx_meta_filepath.parent + ) + elif acq_software == "Open Ephys": + session_dir = find_full_path( + get_ephys_root_data_dir(), get_session_directory(key) + ) + openephys_dataset = openephys.OpenEphys(session_dir) + neuropixels_recording = openephys_dataset.probes[ + probe_serial_number + ] + + def yield_unit_waveforms(): + for unit_dict in units.values(): + unit_peak_waveform = {} + unit_electrode_waveforms = [] + + spikes = unit_dict["spike_times"] + waveforms = neuropixels_recording.extract_spike_waveforms( + spikes, kilosort_dataset.data["channel_map"] + ) # (sample x channel x spike) + waveforms = waveforms.transpose( + (1, 2, 0) + ) # (channel x spike x sample) for channel, channel_waveform in zip( - kilosort_dataset.data["channel_map"], unit_waveform + kilosort_dataset.data["channel_map"], waveforms ): unit_electrode_waveforms.append( { - **units[unit_no], - **channel2electrodes[channel], - "waveform_mean": channel_waveform, + **unit_dict, + **channel2electrode_map[channel], + "waveform_mean": channel_waveform.mean(axis=0), + "waveforms": channel_waveform, } ) if ( - channel2electrodes[channel]["electrode"] - == units[unit_no]["electrode"] + channel2electrode_map[channel]["electrode"] + == unit_dict["electrode"] ): unit_peak_waveform = { - **units[unit_no], - "peak_electrode_waveform": channel_waveform, + **unit_dict, + "peak_electrode_waveform": channel_waveform.mean( + axis=0 + ), } - yield unit_peak_waveform, unit_electrode_waveforms - - else: - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) - neuropixels_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - elif acq_software == "Open Ephys": - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(key) - ) - openephys_dataset = openephys.OpenEphys(session_dir) - neuropixels_recording = openephys_dataset.probes[probe_serial_number] - - def yield_unit_waveforms(): - for unit_dict in units.values(): - unit_peak_waveform = {} - unit_electrode_waveforms = [] - - spikes = unit_dict["spike_times"] - waveforms = neuropixels_recording.extract_spike_waveforms( - spikes, kilosort_dataset.data["channel_map"] - ) # (sample x channel x spike) - waveforms = waveforms.transpose( - (1, 2, 0) - ) # (channel x spike x sample) - for channel, channel_waveform in zip( - kilosort_dataset.data["channel_map"], waveforms - ): - unit_electrode_waveforms.append( - { - **unit_dict, - **channel2electrodes[channel], - "waveform_mean": channel_waveform.mean(axis=0), - "waveforms": channel_waveform, - } - ) - if ( - channel2electrodes[channel]["electrode"] - == unit_dict["electrode"] - ): - unit_peak_waveform = { - **unit_dict, - "peak_electrode_waveform": channel_waveform.mean( - axis=0 - ), - } - yield unit_peak_waveform, unit_electrode_waveforms + yield unit_peak_waveform, unit_electrode_waveforms # insert waveform on a per-unit basis to mitigate potential memory issue self.insert1(key) @@ -1209,7 +1412,7 @@ class QualityMetrics(dj.Imported): definition = """ # Clusters and waveforms metrics - -> CuratedClustering + -> CuratedClustering """ class Cluster(dj.Part): @@ -1234,26 +1437,26 @@ class Cluster(dj.Part): contamination_rate (float): Frequency of spikes in the refractory period. """ - definition = """ + definition = """ # Cluster metrics for a particular unit -> master -> CuratedClustering.Unit --- - firing_rate=null: float # (Hz) firing rate for a unit + firing_rate=null: float # (Hz) firing rate for a unit snr=null: float # signal-to-noise ratio for a unit presence_ratio=null: float # fraction of time in which spikes are present isi_violation=null: float # rate of ISI violation as a fraction of overall rate number_violation=null: int # total number of ISI violations amplitude_cutoff=null: float # estimate of miss rate based on amplitude histogram isolation_distance=null: float # distance to nearest cluster in Mahalanobis space - l_ratio=null: float # + l_ratio=null: float # d_prime=null: float # Classification accuracy based on LDA nn_hit_rate=null: float # Fraction of neighbors for target cluster that are also in target cluster nn_miss_rate=null: float # Fraction of neighbors outside target cluster that are in target cluster silhouette_score=null: float # Standard metric for cluster overlap max_drift=null: float # Maximum change in spike depth throughout recording - cumulative_drift=null: float # Cumulative change in spike depth throughout recording - contamination_rate=null: float # + cumulative_drift=null: float # Cumulative change in spike depth throughout recording + contamination_rate=null: float # """ class Waveform(dj.Part): @@ -1273,13 +1476,13 @@ class Waveform(dj.Part): velocity_below (float): inverse velocity of waveform propagation from soma toward the bottom of the probe. """ - definition = """ + definition = """ # Waveform metrics for a particular unit -> master -> CuratedClustering.Unit --- - amplitude: float # (uV) absolute difference between waveform peak and trough - duration: float # (ms) time between waveform peak and trough + amplitude=null: float # (uV) absolute difference between waveform peak and trough + duration=null: float # (ms) time between waveform peak and trough halfwidth=null: float # (ms) spike width at half max amplitude pt_ratio=null: float # absolute amplitude of peak divided by absolute amplitude of trough relative to 0 repolarization_slope=null: float # the repolarization slope was defined by fitting a regression line to the first 30us from trough to peak @@ -1291,24 +1494,63 @@ class Waveform(dj.Part): def make(self, key): """Populates tables with quality metrics data.""" - output_dir = (ClusteringTask & key).fetch1("clustering_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) + # Load metrics.csv + clustering_method, output_dir = ( + ClusteringTask * ClusteringParamSet & key + ).fetch1("clustering_method", "clustering_output_dir") + output_dir = find_full_path(get_ephys_root_data_dir(), output_dir) + sorter_name = clustering_method.replace(".", "_") + + si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" + if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs + import spikeinterface as si + + sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) + qc_metrics = sorting_analyzer.get_extension("quality_metrics").get_data() + template_metrics = sorting_analyzer.get_extension( + "template_metrics" + ).get_data() + metrics_df = pd.concat([qc_metrics, template_metrics], axis=1) + + metrics_df.rename( + columns={ + "amplitude_median": "amplitude", + "isi_violations_ratio": "isi_violation", + "isi_violations_count": "number_violation", + "silhouette": "silhouette_score", + "rp_contamination": "contamination_rate", + "drift_ptp": "max_drift", + "drift_mad": "cumulative_drift", + "half_width": "halfwidth", + "peak_trough_ratio": "pt_ratio", + "peak_to_valley": "duration", + }, + inplace=True, + ) + else: # read from kilosort outputs (ecephys pipeline) + # find metric_fp + for metric_fp in [ + output_dir / "metrics.csv", + ]: + if metric_fp.exists(): + break + else: + raise FileNotFoundError(f"QC metrics file not found in: {output_dir}") - metric_fp = kilosort_dir / "metrics.csv" - rename_dict = { - "isi_viol": "isi_violation", - "num_viol": "number_violation", - "contam_rate": "contamination_rate", - } + metrics_df = pd.read_csv(metric_fp) - if not metric_fp.exists(): - raise FileNotFoundError(f"QC metrics file not found: {metric_fp}") + # Conform the dataframe to match the table definition + if "cluster_id" in metrics_df.columns: + metrics_df.set_index("cluster_id", inplace=True) + else: + metrics_df.rename( + columns={metrics_df.columns[0]: "cluster_id"}, inplace=True + ) + metrics_df.set_index("cluster_id", inplace=True) + + metrics_df.columns = metrics_df.columns.str.lower() - metrics_df = pd.read_csv(metric_fp) - metrics_df.set_index("cluster_id", inplace=True) metrics_df.replace([np.inf, -np.inf], np.nan, inplace=True) - metrics_df.columns = metrics_df.columns.str.lower() - metrics_df.rename(columns=rename_dict, inplace=True) metrics_list = [ dict(metrics_df.loc[unit_key["unit"]], **unit_key) for unit_key in (CuratedClustering.Unit & key).fetch("KEY") @@ -1382,99 +1624,6 @@ def get_openephys_probe_data(ephys_recording_key: dict) -> list: return probe_data -def get_neuropixels_channel2electrode_map( - ephys_recording_key: dict, acq_software: str -) -> dict: - """Get the channel map for neuropixels probe.""" - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key) - spikeglx_meta = spikeglx.SpikeGLXMeta(spikeglx_meta_filepath) - electrode_config_key = ( - EphysRecording * probe.ElectrodeConfig & ephys_recording_key - ).fetch1("KEY") - - electrode_query = ( - probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode - & electrode_config_key - ) - - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - - channel2electrode_map = { - recorded_site: probe_electrodes[(shank, shank_col, shank_row)] - for recorded_site, (shank, shank_col, shank_row, _) in enumerate( - spikeglx_meta.shankmap["data"] - ) - } - elif acq_software == "Open Ephys": - probe_dataset = get_openephys_probe_data(ephys_recording_key) - - electrode_query = ( - probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode * EphysRecording - & ephys_recording_key - ) - - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - channel2electrode_map = { - channel_idx: probe_electrodes[channel_idx] - for channel_idx in probe_dataset.ap_meta["channels_indices"] - } - - return channel2electrode_map - - -def generate_electrode_config(probe_type: str, electrode_keys: list) -> dict: - """Generate and insert new ElectrodeConfig - - Args: - probe_type (str): probe type (e.g. neuropixels 2.0 - SS) - electrode_keys (list): list of keys of the probe.ProbeType.Electrode table - - Returns: - dict: representing a key of the probe.ElectrodeConfig table - """ - # compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode) - electrode_config_hash = dict_to_uuid({k["electrode"]: k for k in electrode_keys}) - - electrode_list = sorted([k["electrode"] for k in electrode_keys]) - electrode_gaps = ( - [-1] - + np.where(np.diff(electrode_list) > 1)[0].tolist() - + [len(electrode_list) - 1] - ) - electrode_config_name = "; ".join( - [ - f"{electrode_list[start + 1]}-{electrode_list[end]}" - for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:]) - ] - ) - - electrode_config_key = {"electrode_config_hash": electrode_config_hash} - - # ---- make new ElectrodeConfig if needed ---- - if not probe.ElectrodeConfig & electrode_config_key: - probe.ElectrodeConfig.insert1( - { - **electrode_config_key, - "probe_type": probe_type, - "electrode_config_name": electrode_config_name, - } - ) - probe.ElectrodeConfig.Electrode.insert( - {**electrode_config_key, **electrode} for electrode in electrode_keys - ) - - return electrode_config_key - - def get_recording_channels_details(ephys_recording_key: dict) -> np.array: """Get details of recording channels for a given recording.""" channels_details = {} @@ -1530,3 +1679,41 @@ def get_recording_channels_details(ephys_recording_key: dict) -> np.array: ) return channels_details + + +def generate_electrode_config_entry(probe_type: str, electrode_keys: list) -> dict: + """Generate and insert new ElectrodeConfig + + Args: + probe_type (str): probe type (e.g. neuropixels 2.0 - SS) + electrode_keys (list): list of keys of the probe.ProbeType.Electrode table + + Returns: + dict: representing a key of the probe.ElectrodeConfig table + """ + # compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode) + electrode_config_hash = dict_to_uuid({k["electrode"]: k for k in electrode_keys}) + + electrode_list = sorted([k["electrode"] for k in electrode_keys]) + electrode_gaps = ( + [-1] + + np.where(np.diff(electrode_list) > 1)[0].tolist() + + [len(electrode_list) - 1] + ) + electrode_config_name = "; ".join( + [ + f"{electrode_list[start + 1]}-{electrode_list[end]}" + for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:]) + ] + ) + electrode_config_key = {"electrode_config_hash": electrode_config_hash} + econfig_entry = { + **electrode_config_key, + "probe_type": probe_type, + "electrode_config_name": electrode_config_name, + } + econfig_electrodes = [ + {**electrode, **electrode_config_key} for electrode in electrode_keys + ] + + return econfig_entry, econfig_electrodes diff --git a/element_array_ephys/ephys_acute.py b/element_array_ephys/ephys_acute.py deleted file mode 100644 index 50371104..00000000 --- a/element_array_ephys/ephys_acute.py +++ /dev/null @@ -1,1594 +0,0 @@ -import gc -import importlib -import inspect -import pathlib -import re -from decimal import Decimal - -import datajoint as dj -import numpy as np -import pandas as pd -from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory - -from . import ephys_report, probe -from .readers import kilosort, openephys, spikeglx - -log = dj.logger - -schema = dj.schema() - -_linking_module = None - - -def activate( - ephys_schema_name: str, - probe_schema_name: str = None, - *, - create_schema: bool = True, - create_tables: bool = True, - linking_module: str = None, -): - """Activates the `ephys` and `probe` schemas. - - Args: - ephys_schema_name (str): A string containing the name of the ephys schema. - probe_schema_name (str): A string containing the name of the probe schema. - create_schema (bool): If True, schema will be created in the database. - create_tables (bool): If True, tables related to the schema will be created in the database. - linking_module (str): A string containing the module name or module containing the required dependencies to activate the schema. - - Dependencies: - Upstream tables: - Session: A parent table to ProbeInsertion - Probe: A parent table to EphysRecording. Probe information is required before electrophysiology data is imported. - - Functions: - get_ephys_root_data_dir(): Returns absolute path for root data director(y/ies) with all electrophysiological recording sessions, as a list of string(s). - get_session_direction(session_key: dict): Returns path to electrophysiology data for the a particular session as a list of strings. - get_processed_data_dir(): Optional. Returns absolute path for processed data. Defaults to root directory. - """ - - if isinstance(linking_module, str): - linking_module = importlib.import_module(linking_module) - assert inspect.ismodule( - linking_module - ), "The argument 'dependency' must be a module's name or a module" - - global _linking_module - _linking_module = linking_module - - probe.activate( - probe_schema_name, create_schema=create_schema, create_tables=create_tables - ) - schema.activate( - ephys_schema_name, - create_schema=create_schema, - create_tables=create_tables, - add_objects=_linking_module.__dict__, - ) - ephys_report.activate(f"{ephys_schema_name}_report", ephys_schema_name) - - -# -------------- Functions required by the elements-ephys --------------- - - -def get_ephys_root_data_dir() -> list: - """Fetches absolute data path to ephys data directories. - - The absolute path here is used as a reference for all downstream relative paths used in DataJoint. - - Returns: - A list of the absolute path(s) to ephys data directories. - """ - root_directories = _linking_module.get_ephys_root_data_dir() - if isinstance(root_directories, (str, pathlib.Path)): - root_directories = [root_directories] - - if hasattr(_linking_module, "get_processed_root_data_dir"): - root_directories.append(_linking_module.get_processed_root_data_dir()) - - return root_directories - - -def get_session_directory(session_key: dict) -> str: - """Retrieve the session directory with Neuropixels for the given session. - - Args: - session_key (dict): A dictionary mapping subject to an entry in the subject table, and session_datetime corresponding to a session in the database. - - Returns: - A string for the path to the session directory. - """ - return _linking_module.get_session_directory(session_key) - - -def get_processed_root_data_dir() -> str: - """Retrieve the root directory for all processed data. - - Returns: - A string for the full path to the root directory for processed data. - """ - - if hasattr(_linking_module, "get_processed_root_data_dir"): - return _linking_module.get_processed_root_data_dir() - else: - return get_ephys_root_data_dir()[0] - - -# ----------------------------- Table declarations ---------------------- - - -@schema -class AcquisitionSoftware(dj.Lookup): - """Name of software used for recording electrophysiological data. - - Attributes: - acq_software ( varchar(24) ): Acquisition software, e.g,. SpikeGLX, OpenEphys - """ - - definition = """ # Software used for recording of neuropixels probes - acq_software: varchar(24) - """ - contents = zip(["SpikeGLX", "Open Ephys"]) - - -@schema -class ProbeInsertion(dj.Manual): - """Information about probe insertion across subjects and sessions. - - Attributes: - Session (foreign key): Session primary key. - insertion_number (foreign key, str): Unique insertion number for each probe insertion for a given session. - probe.Probe (str): probe.Probe primary key. - """ - - definition = """ - # Probe insertion implanted into an animal for a given session. - -> Session - insertion_number: tinyint unsigned - --- - -> probe.Probe - """ - - @classmethod - def auto_generate_entries(cls, session_key): - """Automatically populate entries in ProbeInsertion table for a session.""" - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(session_key) - ) - # search session dir and determine acquisition software - for ephys_pattern, ephys_acq_type in ( - ("*.ap.meta", "SpikeGLX"), - ("*.oebin", "Open Ephys"), - ): - ephys_meta_filepaths = list(session_dir.rglob(ephys_pattern)) - if ephys_meta_filepaths: - acq_software = ephys_acq_type - break - else: - raise FileNotFoundError( - f"Ephys recording data not found!" - f" Neither SpikeGLX nor Open Ephys recording files found in: {session_dir}" - ) - - probe_list, probe_insertion_list = [], [] - if acq_software == "SpikeGLX": - for meta_fp_idx, meta_filepath in enumerate(ephys_meta_filepaths): - spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) - - probe_key = { - "probe_type": spikeglx_meta.probe_model, - "probe": spikeglx_meta.probe_SN, - } - if probe_key["probe"] not in [p["probe"] for p in probe_list]: - probe_list.append(probe_key) - - probe_dir = meta_filepath.parent - try: - probe_number = re.search(r"(imec)?\d{1}$", probe_dir.name).group() - probe_number = int(probe_number.replace("imec", "")) - except AttributeError: - probe_number = meta_fp_idx - - probe_insertion_list.append( - { - **session_key, - "probe": spikeglx_meta.probe_SN, - "insertion_number": int(probe_number), - } - ) - elif acq_software == "Open Ephys": - loaded_oe = openephys.OpenEphys(session_dir) - for probe_idx, oe_probe in enumerate(loaded_oe.probes.values()): - probe_key = { - "probe_type": oe_probe.probe_model, - "probe": oe_probe.probe_SN, - } - if probe_key["probe"] not in [p["probe"] for p in probe_list]: - probe_list.append(probe_key) - probe_insertion_list.append( - { - **session_key, - "probe": oe_probe.probe_SN, - "insertion_number": probe_idx, - } - ) - else: - raise NotImplementedError(f"Unknown acquisition software: {acq_software}") - - probe.Probe.insert(probe_list, skip_duplicates=True) - cls.insert(probe_insertion_list, skip_duplicates=True) - - -@schema -class InsertionLocation(dj.Manual): - """Stereotaxic location information for each probe insertion. - - Attributes: - ProbeInsertion (foreign key): ProbeInsertion primary key. - SkullReference (dict): SkullReference primary key. - ap_location (decimal (6, 2) ): Anterior-posterior location in micrometers. Reference is 0 with anterior values positive. - ml_location (decimal (6, 2) ): Medial-lateral location in micrometers. Reference is zero with right side values positive. - depth (decimal (6, 2) ): Manipulator depth relative to the surface of the brain at zero. Ventral is negative. - Theta (decimal (5, 2) ): elevation - rotation about the ml-axis in degrees relative to positive z-axis. - phi (decimal (5, 2) ): azimuth - rotation about the dv-axis in degrees relative to the positive x-axis. - """ - - definition = """ - # Brain Location of a given probe insertion. - -> ProbeInsertion - --- - -> SkullReference - ap_location: decimal(6, 2) # (um) anterior-posterior; ref is 0; more anterior is more positive - ml_location: decimal(6, 2) # (um) medial axis; ref is 0 ; more right is more positive - depth: decimal(6, 2) # (um) manipulator depth relative to surface of the brain (0); more ventral is more negative - theta=null: decimal(5, 2) # (deg) - elevation - rotation about the ml-axis [0, 180] - w.r.t the z+ axis - phi=null: decimal(5, 2) # (deg) - azimuth - rotation about the dv-axis [0, 360] - w.r.t the x+ axis - beta=null: decimal(5, 2) # (deg) rotation about the shank of the probe [-180, 180] - clockwise is increasing in degree - 0 is the probe-front facing anterior - """ - - -@schema -class EphysRecording(dj.Imported): - """Automated table with electrophysiology recording information for each probe inserted during an experimental session. - - Attributes: - ProbeInsertion (foreign key): ProbeInsertion primary key. - probe.ElectrodeConfig (dict): probe.ElectrodeConfig primary key. - AcquisitionSoftware (dict): AcquisitionSoftware primary key. - sampling_rate (float): sampling rate of the recording in Hertz (Hz). - recording_datetime (datetime): datetime of the recording from this probe. - recording_duration (float): duration of the entire recording from this probe in seconds. - """ - - definition = """ - # Ephys recording from a probe insertion for a given session. - -> ProbeInsertion - --- - -> probe.ElectrodeConfig - -> AcquisitionSoftware - sampling_rate: float # (Hz) - recording_datetime: datetime # datetime of the recording from this probe - recording_duration: float # (seconds) duration of the recording from this probe - """ - - class EphysFile(dj.Part): - """Paths of electrophysiology recording files for each insertion. - - Attributes: - EphysRecording (foreign key): EphysRecording primary key. - file_path (varchar(255) ): relative file path for electrophysiology recording. - """ - - definition = """ - # Paths of files of a given EphysRecording round. - -> master - file_path: varchar(255) # filepath relative to root data directory - """ - - def make(self, key): - """Populates table with electrophysiology recording information.""" - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(key) - ) - - inserted_probe_serial_number = (ProbeInsertion * probe.Probe & key).fetch1( - "probe" - ) - - # search session dir and determine acquisition software - for ephys_pattern, ephys_acq_type in ( - ("*.ap.meta", "SpikeGLX"), - ("*.oebin", "Open Ephys"), - ): - ephys_meta_filepaths = list(session_dir.rglob(ephys_pattern)) - if ephys_meta_filepaths: - acq_software = ephys_acq_type - break - else: - raise FileNotFoundError( - f"Ephys recording data not found!" - f" Neither SpikeGLX nor Open Ephys recording files found" - f" in {session_dir}" - ) - - supported_probe_types = probe.ProbeType.fetch("probe_type") - - if acq_software == "SpikeGLX": - for meta_filepath in ephys_meta_filepaths: - spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) - if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: - break - else: - raise FileNotFoundError( - "No SpikeGLX data found for probe insertion: {}".format(key) - ) - - if spikeglx_meta.probe_model in supported_probe_types: - probe_type = spikeglx_meta.probe_model - electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} - - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - - electrode_group_members = [ - probe_electrodes[(shank, shank_col, shank_row)] - for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"] - ] - else: - raise NotImplementedError( - "Processing for neuropixels probe model" - " {} not yet implemented".format(spikeglx_meta.probe_model) - ) - - self.insert1( - { - **key, - **generate_electrode_config(probe_type, electrode_group_members), - "acq_software": acq_software, - "sampling_rate": spikeglx_meta.meta["imSampRate"], - "recording_datetime": spikeglx_meta.recording_time, - "recording_duration": ( - spikeglx_meta.recording_duration - or spikeglx.retrieve_recording_duration(meta_filepath) - ), - } - ) - - root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath) - self.EphysFile.insert1( - {**key, "file_path": meta_filepath.relative_to(root_dir).as_posix()} - ) - elif acq_software == "Open Ephys": - dataset = openephys.OpenEphys(session_dir) - for serial_number, probe_data in dataset.probes.items(): - if str(serial_number) == inserted_probe_serial_number: - break - else: - raise FileNotFoundError( - "No Open Ephys data found for probe insertion: {}".format(key) - ) - - if not probe_data.ap_meta: - raise IOError( - 'No analog signals found - check "structure.oebin" file or "continuous" directory' - ) - - if probe_data.probe_model in supported_probe_types: - probe_type = probe_data.probe_model - electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} - - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - electrode_group_members = [ - probe_electrodes[channel_idx] - for channel_idx in probe_data.ap_meta["channels_indices"] - ] - else: - raise NotImplementedError( - "Processing for neuropixels" - " probe model {} not yet implemented".format(probe_data.probe_model) - ) - - self.insert1( - { - **key, - **generate_electrode_config(probe_type, electrode_group_members), - "acq_software": acq_software, - "sampling_rate": probe_data.ap_meta["sample_rate"], - "recording_datetime": probe_data.recording_info[ - "recording_datetimes" - ][0], - "recording_duration": np.sum( - probe_data.recording_info["recording_durations"] - ), - } - ) - - root_dir = find_root_directory( - get_ephys_root_data_dir(), - probe_data.recording_info["recording_files"][0], - ) - self.EphysFile.insert( - [ - {**key, "file_path": fp.relative_to(root_dir).as_posix()} - for fp in probe_data.recording_info["recording_files"] - ] - ) - # explicitly garbage collect "dataset" - # as these may have large memory footprint and may not be cleared fast enough - del probe_data, dataset - gc.collect() - else: - raise NotImplementedError( - f"Processing ephys files from" - f" acquisition software of type {acq_software} is" - f" not yet implemented" - ) - - -@schema -class LFP(dj.Imported): - """Extracts local field potentials (LFP) from an electrophysiology recording. - - Attributes: - EphysRecording (foreign key): EphysRecording primary key. - lfp_sampling_rate (float): Sampling rate for LFPs in Hz. - lfp_time_stamps (longblob): Time stamps with respect to the start of the recording. - lfp_mean (longblob): Overall mean LFP across electrodes. - """ - - definition = """ - # Acquired local field potential (LFP) from a given Ephys recording. - -> EphysRecording - --- - lfp_sampling_rate: float # (Hz) - lfp_time_stamps: longblob # (s) timestamps with respect to the start of the recording (recording_timestamp) - lfp_mean: longblob # (uV) mean of LFP across electrodes - shape (time,) - """ - - class Electrode(dj.Part): - """Saves local field potential data for each electrode. - - Attributes: - LFP (foreign key): LFP primary key. - probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key. - lfp (longblob): LFP recording at this electrode in microvolts. - """ - - definition = """ - -> master - -> probe.ElectrodeConfig.Electrode - --- - lfp: longblob # (uV) recorded lfp at this electrode - """ - - # Only store LFP for every 9th channel, due to high channel density, - # close-by channels exhibit highly similar LFP - _skip_channel_counts = 9 - - def make(self, key): - """Populates the LFP tables.""" - acq_software = (EphysRecording * ProbeInsertion & key).fetch1("acq_software") - - electrode_keys, lfp = [], [] - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - - lfp_channel_ind = spikeglx_recording.lfmeta.recording_channels[ - -1 :: -self._skip_channel_counts - ] - - # Extract LFP data at specified channels and convert to uV - lfp = spikeglx_recording.lf_timeseries[ - :, lfp_channel_ind - ] # (sample x channel) - lfp = ( - lfp * spikeglx_recording.get_channel_bit_volts("lf")[lfp_channel_ind] - ).T # (channel x sample) - - self.insert1( - dict( - key, - lfp_sampling_rate=spikeglx_recording.lfmeta.meta["imSampRate"], - lfp_time_stamps=( - np.arange(lfp.shape[1]) - / spikeglx_recording.lfmeta.meta["imSampRate"] - ), - lfp_mean=lfp.mean(axis=0), - ) - ) - - electrode_query = ( - probe.ProbeType.Electrode - * probe.ElectrodeConfig.Electrode - * EphysRecording - & key - ) - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - - for recorded_site in lfp_channel_ind: - shank, shank_col, shank_row, _ = spikeglx_recording.apmeta.shankmap[ - "data" - ][recorded_site] - electrode_keys.append(probe_electrodes[(shank, shank_col, shank_row)]) - elif acq_software == "Open Ephys": - oe_probe = get_openephys_probe_data(key) - - lfp_channel_ind = np.r_[ - len(oe_probe.lfp_meta["channels_indices"]) - - 1 : 0 : -self._skip_channel_counts - ] - - # (sample x channel) - lfp = oe_probe.lfp_timeseries[:, lfp_channel_ind] - lfp = ( - lfp * np.array(oe_probe.lfp_meta["channels_gains"])[lfp_channel_ind] - ).T # (channel x sample) - lfp_timestamps = oe_probe.lfp_timestamps - - self.insert1( - dict( - key, - lfp_sampling_rate=oe_probe.lfp_meta["sample_rate"], - lfp_time_stamps=lfp_timestamps, - lfp_mean=lfp.mean(axis=0), - ) - ) - - electrode_query = ( - probe.ProbeType.Electrode - * probe.ElectrodeConfig.Electrode - * EphysRecording - & key - ) - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - electrode_keys.extend( - probe_electrodes[channel_idx] for channel_idx in lfp_channel_ind - ) - else: - raise NotImplementedError( - f"LFP extraction from acquisition software" - f" of type {acq_software} is not yet implemented" - ) - - # single insert in loop to mitigate potential memory issue - for electrode_key, lfp_trace in zip(electrode_keys, lfp): - self.Electrode.insert1({**key, **electrode_key, "lfp": lfp_trace}) - - -# ------------ Clustering -------------- - - -@schema -class ClusteringMethod(dj.Lookup): - """Kilosort clustering method. - - Attributes: - clustering_method (foreign key, varchar(16) ): Kilosort clustering method. - clustering_methods_desc (varchar(1000) ): Additional description of the clustering method. - """ - - definition = """ - # Method for clustering - clustering_method: varchar(16) - --- - clustering_method_desc: varchar(1000) - """ - - contents = [ - ("kilosort2", "kilosort2 clustering method"), - ("kilosort2.5", "kilosort2.5 clustering method"), - ("kilosort3", "kilosort3 clustering method"), - ] - - -@schema -class ClusteringParamSet(dj.Lookup): - """Parameters to be used in clustering procedure for spike sorting. - - Attributes: - paramset_idx (foreign key): Unique ID for the clustering parameter set. - ClusteringMethod (dict): ClusteringMethod primary key. - paramset_desc (varchar(128) ): Description of the clustering parameter set. - param_set_hash (uuid): UUID hash for the parameter set. - params (longblob): Parameters for clustering with Kilosort. - """ - - definition = """ - # Parameter set to be used in a clustering procedure - paramset_idx: smallint - --- - -> ClusteringMethod - paramset_desc: varchar(128) - param_set_hash: uuid - unique index (param_set_hash) - params: longblob # dictionary of all applicable parameters - """ - - @classmethod - def insert_new_params( - cls, - clustering_method: str, - paramset_desc: str, - params: dict, - paramset_idx: int = None, - ): - """Inserts new parameters into the ClusteringParamSet table. - - Args: - clustering_method (str): name of the clustering method. - paramset_desc (str): description of the parameter set - params (dict): clustering parameters - paramset_idx (int, optional): Unique parameter set ID. Defaults to None. - """ - if paramset_idx is None: - paramset_idx = ( - dj.U().aggr(cls, n="max(paramset_idx)").fetch1("n") or 0 - ) + 1 - - param_dict = { - "clustering_method": clustering_method, - "paramset_idx": paramset_idx, - "paramset_desc": paramset_desc, - "params": params, - "param_set_hash": dict_to_uuid( - {**params, "clustering_method": clustering_method} - ), - } - param_query = cls & {"param_set_hash": param_dict["param_set_hash"]} - - if param_query: # If the specified param-set already exists - existing_paramset_idx = param_query.fetch1("paramset_idx") - if ( - existing_paramset_idx == paramset_idx - ): # If the existing set has the same paramset_idx: job done - return - else: # If not same name: human error, trying to add the same paramset with different name - raise dj.DataJointError( - f"The specified param-set already exists" - f" - with paramset_idx: {existing_paramset_idx}" - ) - else: - if {"paramset_idx": paramset_idx} in cls.proj(): - raise dj.DataJointError( - f"The specified paramset_idx {paramset_idx} already exists," - f" please pick a different one." - ) - cls.insert1(param_dict) - - -@schema -class ClusterQualityLabel(dj.Lookup): - """Quality label for each spike sorted cluster. - - Attributes: - cluster_quality_label (foreign key, varchar(100) ): Cluster quality type. - cluster_quality_description ( varchar(4000) ): Description of the cluster quality type. - """ - - definition = """ - # Quality - cluster_quality_label: varchar(100) # cluster quality type - e.g. 'good', 'MUA', 'noise', etc. - --- - cluster_quality_description: varchar(4000) - """ - contents = [ - ("good", "single unit"), - ("ok", "probably a single unit, but could be contaminated"), - ("mua", "multi-unit activity"), - ("noise", "bad unit"), - ] - - -@schema -class ClusteringTask(dj.Manual): - """A clustering task to spike sort electrophysiology datasets. - - Attributes: - EphysRecording (foreign key): EphysRecording primary key. - ClusteringParamSet (foreign key): ClusteringParamSet primary key. - clustering_output_dir ( varchar (255) ): Relative path to output clustering results. - task_mode (enum): `Trigger` computes clustering or and `load` imports existing data. - """ - - definition = """ - # Manual table for defining a clustering task ready to be run - -> EphysRecording - -> ClusteringParamSet - --- - clustering_output_dir='': varchar(255) # clustering output directory relative to the clustering root data directory - task_mode='load': enum('load', 'trigger') # 'load': load computed analysis results, 'trigger': trigger computation - """ - - @classmethod - def infer_output_dir( - cls, key: dict, relative: bool = False, mkdir: bool = False - ) -> pathlib.Path: - """Infer output directory if it is not provided. - - Args: - key (dict): ClusteringTask primary key. - - Returns: - Expected clustering_output_dir based on the following convention: - processed_dir / session_dir / probe_{insertion_number} / {clustering_method}_{paramset_idx} - e.g.: sub4/sess1/probe_2/kilosort2_0 - """ - processed_dir = pathlib.Path(get_processed_root_data_dir()) - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(key) - ) - root_dir = find_root_directory(get_ephys_root_data_dir(), session_dir) - - method = ( - (ClusteringParamSet * ClusteringMethod & key) - .fetch1("clustering_method") - .replace(".", "-") - ) - - output_dir = ( - processed_dir - / session_dir.relative_to(root_dir) - / f'probe_{key["insertion_number"]}' - / f'{method}_{key["paramset_idx"]}' - ) - - if mkdir: - output_dir.mkdir(parents=True, exist_ok=True) - log.info(f"{output_dir} created!") - - return output_dir.relative_to(processed_dir) if relative else output_dir - - @classmethod - def auto_generate_entries(cls, ephys_recording_key: dict, paramset_idx: int = 0): - """Autogenerate entries based on a particular ephys recording. - - Args: - ephys_recording_key (dict): EphysRecording primary key. - paramset_idx (int, optional): Parameter index to use for clustering task. Defaults to 0. - """ - key = {**ephys_recording_key, "paramset_idx": paramset_idx} - - processed_dir = get_processed_root_data_dir() - output_dir = ClusteringTask.infer_output_dir(key, relative=False, mkdir=True) - - try: - kilosort.Kilosort( - output_dir - ) # check if the directory is a valid Kilosort output - except FileNotFoundError: - task_mode = "trigger" - else: - task_mode = "load" - - cls.insert1( - { - **key, - "clustering_output_dir": output_dir.relative_to( - processed_dir - ).as_posix(), - "task_mode": task_mode, - } - ) - - -@schema -class Clustering(dj.Imported): - """A processing table to handle each clustering task. - - Attributes: - ClusteringTask (foreign key): ClusteringTask primary key. - clustering_time (datetime): Time when clustering results are generated. - package_version ( varchar(16) ): Package version used for a clustering analysis. - """ - - definition = """ - # Clustering Procedure - -> ClusteringTask - --- - clustering_time: datetime # time of generation of this set of clustering results - package_version='': varchar(16) - """ - - def make(self, key): - """Triggers or imports clustering analysis.""" - task_mode, output_dir = (ClusteringTask & key).fetch1( - "task_mode", "clustering_output_dir" - ) - - if not output_dir: - output_dir = ClusteringTask.infer_output_dir(key, relative=True, mkdir=True) - # update clustering_output_dir - ClusteringTask.update1( - {**key, "clustering_output_dir": output_dir.as_posix()} - ) - - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - if task_mode == "load": - kilosort.Kilosort( - kilosort_dir - ) # check if the directory is a valid Kilosort output - elif task_mode == "trigger": - acq_software, clustering_method, params = ( - ClusteringTask * EphysRecording * ClusteringParamSet & key - ).fetch1("acq_software", "clustering_method", "params") - - if "kilosort" in clustering_method: - from element_array_ephys.readers import kilosort_triggering - - # add additional probe-recording and channels details into `params` - params = {**params, **get_recording_channels_details(key)} - params["fs"] = params["sample_rate"] - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX( - spikeglx_meta_filepath.parent - ) - spikeglx_recording.validate_file("ap") - run_CatGT = ( - params.pop("run_CatGT", True) - and "_tcat." not in spikeglx_meta_filepath.stem - ) - - if clustering_method.startswith("pykilosort"): - kilosort_triggering.run_pykilosort( - continuous_file=spikeglx_recording.root_dir - / (spikeglx_recording.root_name + ".ap.bin"), - kilosort_output_directory=kilosort_dir, - channel_ind=params.pop("channel_ind"), - x_coords=params.pop("x_coords"), - y_coords=params.pop("y_coords"), - shank_ind=params.pop("shank_ind"), - connected=params.pop("connected"), - sample_rate=params.pop("sample_rate"), - params=params, - ) - else: - run_kilosort = kilosort_triggering.SGLXKilosortPipeline( - npx_input_dir=spikeglx_meta_filepath.parent, - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - run_CatGT=run_CatGT, - ) - run_kilosort.run_modules() - elif acq_software == "Open Ephys": - oe_probe = get_openephys_probe_data(key) - - assert len(oe_probe.recording_info["recording_files"]) == 1 - - # run kilosort - if clustering_method.startswith("pykilosort"): - kilosort_triggering.run_pykilosort( - continuous_file=pathlib.Path( - oe_probe.recording_info["recording_files"][0] - ) - / "continuous.dat", - kilosort_output_directory=kilosort_dir, - channel_ind=params.pop("channel_ind"), - x_coords=params.pop("x_coords"), - y_coords=params.pop("y_coords"), - shank_ind=params.pop("shank_ind"), - connected=params.pop("connected"), - sample_rate=params.pop("sample_rate"), - params=params, - ) - else: - run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( - npx_input_dir=oe_probe.recording_info["recording_files"][0], - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - ) - run_kilosort.run_modules() - else: - raise NotImplementedError( - f"Automatic triggering of {clustering_method}" - f" clustering analysis is not yet supported" - ) - - else: - raise ValueError(f"Unknown task mode: {task_mode}") - - creation_time, _, _ = kilosort.extract_clustering_info(kilosort_dir) - self.insert1({**key, "clustering_time": creation_time, "package_version": ""}) - - -@schema -class Curation(dj.Manual): - """Curation procedure table. - - Attributes: - Clustering (foreign key): Clustering primary key. - curation_id (foreign key, int): Unique curation ID. - curation_time (datetime): Time when curation results are generated. - curation_output_dir ( varchar(255) ): Output directory of the curated results. - quality_control (bool): If True, this clustering result has undergone quality control. - manual_curation (bool): If True, manual curation has been performed on this clustering result. - curation_note ( varchar(2000) ): Notes about the curation task. - """ - - definition = """ - # Manual curation procedure - -> Clustering - curation_id: int - --- - curation_time: datetime # time of generation of this set of curated clustering results - curation_output_dir: varchar(255) # output directory of the curated results, relative to root data directory - quality_control: bool # has this clustering result undergone quality control? - manual_curation: bool # has manual curation been performed on this clustering result? - curation_note='': varchar(2000) - """ - - def create1_from_clustering_task(self, key, curation_note=""): - """ - A function to create a new corresponding "Curation" for a particular - "ClusteringTask" - """ - if key not in Clustering(): - raise ValueError( - f"No corresponding entry in Clustering available" - f" for: {key}; do `Clustering.populate(key)`" - ) - - task_mode, output_dir = (ClusteringTask & key).fetch1( - "task_mode", "clustering_output_dir" - ) - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - creation_time, is_curated, is_qc = kilosort.extract_clustering_info( - kilosort_dir - ) - # Synthesize curation_id - curation_id = ( - dj.U().aggr(self & key, n="ifnull(max(curation_id)+1,1)").fetch1("n") - ) - self.insert1( - { - **key, - "curation_id": curation_id, - "curation_time": creation_time, - "curation_output_dir": output_dir, - "quality_control": is_qc, - "manual_curation": is_curated, - "curation_note": curation_note, - } - ) - - -@schema -class CuratedClustering(dj.Imported): - """Clustering results after curation. - - Attributes: - Curation (foreign key): Curation primary key. - """ - - definition = """ - # Clustering results of a curation. - -> Curation - """ - - class Unit(dj.Part): - """Single unit properties after clustering and curation. - - Attributes: - CuratedClustering (foreign key): CuratedClustering primary key. - unit (foreign key, int): Unique integer identifying a single unit. - probe.ElectrodeConfig.Electrode (dict): probe.ElectrodeConfig.Electrode primary key. - ClusteringQualityLabel (dict): CLusteringQualityLabel primary key. - spike_count (int): Number of spikes in this recording for this unit. - spike_times (longblob): Spike times of this unit, relative to start time of EphysRecording. - spike_sites (longblob): Array of electrode associated with each spike. - spike_depths (longblob): Array of depths associated with each spike, relative to each spike. - """ - - definition = """ - # Properties of a given unit from a round of clustering (and curation) - -> master - unit: int - --- - -> probe.ElectrodeConfig.Electrode # electrode with highest waveform amplitude for this unit - -> ClusterQualityLabel - spike_count: int # how many spikes in this recording for this unit - spike_times: longblob # (s) spike times of this unit, relative to the start of the EphysRecording - spike_sites : longblob # array of electrode associated with each spike - spike_depths=null : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe - """ - - def make(self, key): - """Automated population of Unit information.""" - output_dir = (Curation & key).fetch1("curation_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - kilosort_dataset = kilosort.Kilosort(kilosort_dir) - acq_software, sample_rate = (EphysRecording & key).fetch1( - "acq_software", "sampling_rate" - ) - - sample_rate = kilosort_dataset.data["params"].get("sample_rate", sample_rate) - - # ---------- Unit ---------- - # -- Remove 0-spike units - withspike_idx = [ - i - for i, u in enumerate(kilosort_dataset.data["cluster_ids"]) - if (kilosort_dataset.data["spike_clusters"] == u).any() - ] - valid_units = kilosort_dataset.data["cluster_ids"][withspike_idx] - valid_unit_labels = kilosort_dataset.data["cluster_groups"][withspike_idx] - # -- Get channel and electrode-site mapping - channel2electrodes = get_neuropixels_channel2electrode_map(key, acq_software) - - # -- Spike-times -- - # spike_times_sec_adj > spike_times_sec > spike_times - spike_time_key = ( - "spike_times_sec_adj" - if "spike_times_sec_adj" in kilosort_dataset.data - else ( - "spike_times_sec" - if "spike_times_sec" in kilosort_dataset.data - else "spike_times" - ) - ) - spike_times = kilosort_dataset.data[spike_time_key] - kilosort_dataset.extract_spike_depths() - - # -- Spike-sites and Spike-depths -- - spike_sites = np.array( - [ - channel2electrodes[s]["electrode"] - for s in kilosort_dataset.data["spike_sites"] - ] - ) - spike_depths = kilosort_dataset.data["spike_depths"] - - # -- Insert unit, label, peak-chn - units = [] - for unit, unit_lbl in zip(valid_units, valid_unit_labels): - if (kilosort_dataset.data["spike_clusters"] == unit).any(): - unit_channel, _ = kilosort_dataset.get_best_channel(unit) - unit_spike_times = ( - spike_times[kilosort_dataset.data["spike_clusters"] == unit] - / sample_rate - ) - spike_count = len(unit_spike_times) - - units.append( - { - "unit": unit, - "cluster_quality_label": unit_lbl, - **channel2electrodes[unit_channel], - "spike_times": unit_spike_times, - "spike_count": spike_count, - "spike_sites": spike_sites[ - kilosort_dataset.data["spike_clusters"] == unit - ], - "spike_depths": ( - spike_depths[ - kilosort_dataset.data["spike_clusters"] == unit - ] - if spike_depths is not None - else None - ), - } - ) - - self.insert1(key) - self.Unit.insert([{**key, **u} for u in units]) - - -@schema -class WaveformSet(dj.Imported): - """A set of spike waveforms for units out of a given CuratedClustering. - - Attributes: - CuratedClustering (foreign key): CuratedClustering primary key. - """ - - definition = """ - # A set of spike waveforms for units out of a given CuratedClustering - -> CuratedClustering - """ - - class PeakWaveform(dj.Part): - """Mean waveform across spikes for a given unit. - - Attributes: - WaveformSet (foreign key): WaveformSet primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - peak_electrode_waveform (longblob): Mean waveform for a given unit at its representative electrode. - """ - - definition = """ - # Mean waveform across spikes for a given unit at its representative electrode - -> master - -> CuratedClustering.Unit - --- - peak_electrode_waveform: longblob # (uV) mean waveform for a given unit at its representative electrode - """ - - class Waveform(dj.Part): - """Spike waveforms for a given unit. - - Attributes: - WaveformSet (foreign key): WaveformSet primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key. - waveform_mean (longblob): mean waveform across spikes of the unit in microvolts. - waveforms (longblob): waveforms of a sampling of spikes at the given electrode and unit. - """ - - definition = """ - # Spike waveforms and their mean across spikes for the given unit - -> master - -> CuratedClustering.Unit - -> probe.ElectrodeConfig.Electrode - --- - waveform_mean: longblob # (uV) mean waveform across spikes of the given unit - waveforms=null: longblob # (uV) (spike x sample) waveforms of a sampling of spikes at the given electrode for the given unit - """ - - def make(self, key): - """Populates waveform tables.""" - output_dir = (Curation & key).fetch1("curation_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - kilosort_dataset = kilosort.Kilosort(kilosort_dir) - - acq_software, probe_serial_number = ( - EphysRecording * ProbeInsertion & key - ).fetch1("acq_software", "probe") - - # -- Get channel and electrode-site mapping - recording_key = (EphysRecording & key).fetch1("KEY") - channel2electrodes = get_neuropixels_channel2electrode_map( - recording_key, acq_software - ) - - is_qc = (Curation & key).fetch1("quality_control") - - # Get all units - units = { - u["unit"]: u - for u in (CuratedClustering.Unit & key).fetch(as_dict=True, order_by="unit") - } - - if is_qc: - unit_waveforms = np.load( - kilosort_dir / "mean_waveforms.npy" - ) # unit x channel x sample - - def yield_unit_waveforms(): - for unit_no, unit_waveform in zip( - kilosort_dataset.data["cluster_ids"], unit_waveforms - ): - unit_peak_waveform = {} - unit_electrode_waveforms = [] - if unit_no in units: - for channel, channel_waveform in zip( - kilosort_dataset.data["channel_map"], unit_waveform - ): - unit_electrode_waveforms.append( - { - **units[unit_no], - **channel2electrodes[channel], - "waveform_mean": channel_waveform, - } - ) - if ( - channel2electrodes[channel]["electrode"] - == units[unit_no]["electrode"] - ): - unit_peak_waveform = { - **units[unit_no], - "peak_electrode_waveform": channel_waveform, - } - yield unit_peak_waveform, unit_electrode_waveforms - - else: - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) - neuropixels_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - elif acq_software == "Open Ephys": - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(key) - ) - openephys_dataset = openephys.OpenEphys(session_dir) - neuropixels_recording = openephys_dataset.probes[probe_serial_number] - - def yield_unit_waveforms(): - for unit_dict in units.values(): - unit_peak_waveform = {} - unit_electrode_waveforms = [] - - spikes = unit_dict["spike_times"] - waveforms = neuropixels_recording.extract_spike_waveforms( - spikes, kilosort_dataset.data["channel_map"] - ) # (sample x channel x spike) - waveforms = waveforms.transpose( - (1, 2, 0) - ) # (channel x spike x sample) - for channel, channel_waveform in zip( - kilosort_dataset.data["channel_map"], waveforms - ): - unit_electrode_waveforms.append( - { - **unit_dict, - **channel2electrodes[channel], - "waveform_mean": channel_waveform.mean(axis=0), - "waveforms": channel_waveform, - } - ) - if ( - channel2electrodes[channel]["electrode"] - == unit_dict["electrode"] - ): - unit_peak_waveform = { - **unit_dict, - "peak_electrode_waveform": channel_waveform.mean( - axis=0 - ), - } - - yield unit_peak_waveform, unit_electrode_waveforms - - # insert waveform on a per-unit basis to mitigate potential memory issue - self.insert1(key) - for unit_peak_waveform, unit_electrode_waveforms in yield_unit_waveforms(): - if unit_peak_waveform: - self.PeakWaveform.insert1(unit_peak_waveform, ignore_extra_fields=True) - if unit_electrode_waveforms: - self.Waveform.insert(unit_electrode_waveforms, ignore_extra_fields=True) - - -@schema -class QualityMetrics(dj.Imported): - """Clustering and waveform quality metrics. - - Attributes: - CuratedClustering (foreign key): CuratedClustering primary key. - """ - - definition = """ - # Clusters and waveforms metrics - -> CuratedClustering - """ - - class Cluster(dj.Part): - """Cluster metrics for a unit. - - Attributes: - QualityMetrics (foreign key): QualityMetrics primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - firing_rate (float): Firing rate of the unit. - snr (float): Signal-to-noise ratio for a unit. - presence_ratio (float): Fraction of time where spikes are present. - isi_violation (float): rate of ISI violation as a fraction of overall rate. - number_violation (int): Total ISI violations. - amplitude_cutoff (float): Estimate of miss rate based on amplitude histogram. - isolation_distance (float): Distance to nearest cluster. - l_ratio (float): Amount of empty space between a cluster and other spikes in dataset. - d_prime (float): Classification accuracy based on LDA. - nn_hit_rate (float): Fraction of neighbors for target cluster that are also in target cluster. - nn_miss_rate (float): Fraction of neighbors outside target cluster that are in the target cluster. - silhouette_core (float): Maximum change in spike depth throughout recording. - cumulative_drift (float): Cumulative change in spike depth throughout recording. - contamination_rate (float): Frequency of spikes in the refractory period. - """ - - definition = """ - # Cluster metrics for a particular unit - -> master - -> CuratedClustering.Unit - --- - firing_rate=null: float # (Hz) firing rate for a unit - snr=null: float # signal-to-noise ratio for a unit - presence_ratio=null: float # fraction of time in which spikes are present - isi_violation=null: float # rate of ISI violation as a fraction of overall rate - number_violation=null: int # total number of ISI violations - amplitude_cutoff=null: float # estimate of miss rate based on amplitude histogram - isolation_distance=null: float # distance to nearest cluster in Mahalanobis space - l_ratio=null: float # - d_prime=null: float # Classification accuracy based on LDA - nn_hit_rate=null: float # Fraction of neighbors for target cluster that are also in target cluster - nn_miss_rate=null: float # Fraction of neighbors outside target cluster that are in target cluster - silhouette_score=null: float # Standard metric for cluster overlap - max_drift=null: float # Maximum change in spike depth throughout recording - cumulative_drift=null: float # Cumulative change in spike depth throughout recording - contamination_rate=null: float # - """ - - class Waveform(dj.Part): - """Waveform metrics for a particular unit. - - Attributes: - QualityMetrics (foreign key): QualityMetrics primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - amplitude (float): Absolute difference between waveform peak and trough in microvolts. - duration (float): Time between waveform peak and trough in milliseconds. - halfwidth (float): Spike width at half max amplitude. - pt_ratio (float): Absolute amplitude of peak divided by absolute amplitude of trough relative to 0. - repolarization_slope (float): Slope of the regression line fit to first 30 microseconds from trough to peak. - recovery_slope (float): Slope of the regression line fit to first 30 microseconds from peak to tail. - spread (float): The range with amplitude over 12-percent of maximum amplitude along the probe. - velocity_above (float): inverse velocity of waveform propagation from soma to the top of the probe. - velocity_below (float): inverse velocity of waveform propagation from soma toward the bottom of the probe. - """ - - definition = """ - # Waveform metrics for a particular unit - -> master - -> CuratedClustering.Unit - --- - amplitude: float # (uV) absolute difference between waveform peak and trough - duration: float # (ms) time between waveform peak and trough - halfwidth=null: float # (ms) spike width at half max amplitude - pt_ratio=null: float # absolute amplitude of peak divided by absolute amplitude of trough relative to 0 - repolarization_slope=null: float # the repolarization slope was defined by fitting a regression line to the first 30us from trough to peak - recovery_slope=null: float # the recovery slope was defined by fitting a regression line to the first 30us from peak to tail - spread=null: float # (um) the range with amplitude above 12-percent of the maximum amplitude along the probe - velocity_above=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the top of the probe - velocity_below=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the bottom of the probe - """ - - def make(self, key): - """Populates tables with quality metrics data.""" - output_dir = (ClusteringTask & key).fetch1("clustering_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - metric_fp = kilosort_dir / "metrics.csv" - rename_dict = { - "isi_viol": "isi_violation", - "num_viol": "number_violation", - "contam_rate": "contamination_rate", - } - - if not metric_fp.exists(): - raise FileNotFoundError(f"QC metrics file not found: {metric_fp}") - - metrics_df = pd.read_csv(metric_fp) - metrics_df.set_index("cluster_id", inplace=True) - metrics_df.replace([np.inf, -np.inf], np.nan, inplace=True) - metrics_df.columns = metrics_df.columns.str.lower() - metrics_df.rename(columns=rename_dict, inplace=True) - metrics_list = [ - dict(metrics_df.loc[unit_key["unit"]], **unit_key) - for unit_key in (CuratedClustering.Unit & key).fetch("KEY") - ] - - self.insert1(key) - self.Cluster.insert(metrics_list, ignore_extra_fields=True) - self.Waveform.insert(metrics_list, ignore_extra_fields=True) - - -# ---------------- HELPER FUNCTIONS ---------------- - - -def get_spikeglx_meta_filepath(ephys_recording_key: dict) -> str: - """Get spikeGLX data filepath.""" - # attempt to retrieve from EphysRecording.EphysFile - spikeglx_meta_filepath = pathlib.Path( - ( - EphysRecording.EphysFile - & ephys_recording_key - & 'file_path LIKE "%.ap.meta"' - ).fetch1("file_path") - ) - - try: - spikeglx_meta_filepath = find_full_path( - get_ephys_root_data_dir(), spikeglx_meta_filepath - ) - except FileNotFoundError: - # if not found, search in session_dir again - if not spikeglx_meta_filepath.exists(): - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(ephys_recording_key) - ) - inserted_probe_serial_number = ( - ProbeInsertion * probe.Probe & ephys_recording_key - ).fetch1("probe") - - spikeglx_meta_filepaths = [fp for fp in session_dir.rglob("*.ap.meta")] - for meta_filepath in spikeglx_meta_filepaths: - spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) - if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: - spikeglx_meta_filepath = meta_filepath - break - else: - raise FileNotFoundError( - "No SpikeGLX data found for probe insertion: {}".format( - ephys_recording_key - ) - ) - - return spikeglx_meta_filepath - - -def get_openephys_probe_data(ephys_recording_key: dict) -> list: - """Get OpenEphys probe data from file.""" - inserted_probe_serial_number = ( - ProbeInsertion * probe.Probe & ephys_recording_key - ).fetch1("probe") - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(ephys_recording_key) - ) - loaded_oe = openephys.OpenEphys(session_dir) - probe_data = loaded_oe.probes[inserted_probe_serial_number] - - # explicitly garbage collect "loaded_oe" - # as these may have large memory footprint and may not be cleared fast enough - del loaded_oe - gc.collect() - - return probe_data - - -def get_neuropixels_channel2electrode_map( - ephys_recording_key: dict, acq_software: str -) -> dict: - """Get the channel map for neuropixels probe.""" - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key) - spikeglx_meta = spikeglx.SpikeGLXMeta(spikeglx_meta_filepath) - electrode_config_key = ( - EphysRecording * probe.ElectrodeConfig & ephys_recording_key - ).fetch1("KEY") - - electrode_query = ( - probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode - & electrode_config_key - ) - - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - - channel2electrode_map = { - recorded_site: probe_electrodes[(shank, shank_col, shank_row)] - for recorded_site, (shank, shank_col, shank_row, _) in enumerate( - spikeglx_meta.shankmap["data"] - ) - } - elif acq_software == "Open Ephys": - probe_dataset = get_openephys_probe_data(ephys_recording_key) - - electrode_query = ( - probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode * EphysRecording - & ephys_recording_key - ) - - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - channel2electrode_map = { - channel_idx: probe_electrodes[channel_idx] - for channel_idx in probe_dataset.ap_meta["channels_indices"] - } - - return channel2electrode_map - - -def generate_electrode_config(probe_type: str, electrode_keys: list) -> dict: - """Generate and insert new ElectrodeConfig - - Args: - probe_type (str): probe type (e.g. neuropixels 2.0 - SS) - electrode_keys (list): list of keys of the probe.ProbeType.Electrode table - - Returns: - dict: representing a key of the probe.ElectrodeConfig table - """ - # compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode) - electrode_config_hash = dict_to_uuid({k["electrode"]: k for k in electrode_keys}) - - electrode_list = sorted([k["electrode"] for k in electrode_keys]) - electrode_gaps = ( - [-1] - + np.where(np.diff(electrode_list) > 1)[0].tolist() - + [len(electrode_list) - 1] - ) - electrode_config_name = "; ".join( - [ - f"{electrode_list[start + 1]}-{electrode_list[end]}" - for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:]) - ] - ) - - electrode_config_key = {"electrode_config_hash": electrode_config_hash} - - # ---- make new ElectrodeConfig if needed ---- - if not probe.ElectrodeConfig & electrode_config_key: - probe.ElectrodeConfig.insert1( - { - **electrode_config_key, - "probe_type": probe_type, - "electrode_config_name": electrode_config_name, - } - ) - probe.ElectrodeConfig.Electrode.insert( - {**electrode_config_key, **electrode} for electrode in electrode_keys - ) - - return electrode_config_key - - -def get_recording_channels_details(ephys_recording_key: dict) -> np.array: - """Get details of recording channels for a given recording.""" - channels_details = {} - - acq_software, sample_rate = (EphysRecording & ephys_recording_key).fetch1( - "acq_software", "sampling_rate" - ) - - probe_type = (ProbeInsertion * probe.Probe & ephys_recording_key).fetch1( - "probe_type" - ) - channels_details["probe_type"] = { - "neuropixels 1.0 - 3A": "3A", - "neuropixels 1.0 - 3B": "NP1", - "neuropixels UHD": "NP1100", - "neuropixels 2.0 - SS": "NP21", - "neuropixels 2.0 - MS": "NP24", - }[probe_type] - - electrode_config_key = ( - probe.ElectrodeConfig * EphysRecording & ephys_recording_key - ).fetch1("KEY") - ( - channels_details["channel_ind"], - channels_details["x_coords"], - channels_details["y_coords"], - channels_details["shank_ind"], - ) = ( - probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode - & electrode_config_key - ).fetch( - "electrode", "x_coord", "y_coord", "shank" - ) - channels_details["sample_rate"] = sample_rate - channels_details["num_channels"] = len(channels_details["channel_ind"]) - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key) - spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - channels_details["uVPerBit"] = spikeglx_recording.get_channel_bit_volts("ap")[0] - channels_details["connected"] = np.array( - [v for *_, v in spikeglx_recording.apmeta.shankmap["data"]] - ) - elif acq_software == "Open Ephys": - oe_probe = get_openephys_probe_data(ephys_recording_key) - channels_details["uVPerBit"] = oe_probe.ap_meta["channels_gains"][0] - channels_details["connected"] = np.array( - [ - int(v == 1) - for c, v in oe_probe.channels_connected.items() - if c in channels_details["channel_ind"] - ] - ) - - return channels_details diff --git a/element_array_ephys/ephys_chronic.py b/element_array_ephys/ephys_chronic.py deleted file mode 100644 index 772e885f..00000000 --- a/element_array_ephys/ephys_chronic.py +++ /dev/null @@ -1,1523 +0,0 @@ -import gc -import importlib -import inspect -import pathlib -from decimal import Decimal - -import datajoint as dj -import numpy as np -import pandas as pd -from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory - -from . import ephys_report, probe -from .readers import kilosort, openephys, spikeglx - -log = dj.logger - -schema = dj.schema() - -_linking_module = None - - -def activate( - ephys_schema_name: str, - probe_schema_name: str = None, - *, - create_schema: bool = True, - create_tables: bool = True, - linking_module: str = None, -): - """Activates the `ephys` and `probe` schemas. - - Args: - ephys_schema_name (str): A string containing the name of the ephys schema. - probe_schema_name (str): A string containing the name of the probe schema. - create_schema (bool): If True, schema will be created in the database. - create_tables (bool): If True, tables related to the schema will be created in the database. - linking_module (str): A string containing the module name or module containing the required dependencies to activate the schema. - - Dependencies: - Upstream tables: - Session: A parent table to ProbeInsertion - Probe: A parent table to EphysRecording. Probe information is required before electrophysiology data is imported. - - Functions: - get_ephys_root_data_dir(): Returns absolute path for root data director(y/ies) with all electrophysiological recording sessions, as a list of string(s). - get_session_direction(session_key: dict): Returns path to electrophysiology data for the a particular session as a list of strings. - get_processed_data_dir(): Optional. Returns absolute path for processed data. Defaults to root directory. - """ - - if isinstance(linking_module, str): - linking_module = importlib.import_module(linking_module) - assert inspect.ismodule( - linking_module - ), "The argument 'dependency' must be a module's name or a module" - - global _linking_module - _linking_module = linking_module - - probe.activate( - probe_schema_name, create_schema=create_schema, create_tables=create_tables - ) - schema.activate( - ephys_schema_name, - create_schema=create_schema, - create_tables=create_tables, - add_objects=_linking_module.__dict__, - ) - ephys_report.activate(f"{ephys_schema_name}_report", ephys_schema_name) - - -# -------------- Functions required by the elements-ephys --------------- - - -def get_ephys_root_data_dir() -> list: - """Fetches absolute data path to ephys data directories. - - The absolute path here is used as a reference for all downstream relative paths used in DataJoint. - - Returns: - A list of the absolute path(s) to ephys data directories. - """ - root_directories = _linking_module.get_ephys_root_data_dir() - if isinstance(root_directories, (str, pathlib.Path)): - root_directories = [root_directories] - - if hasattr(_linking_module, "get_processed_root_data_dir"): - root_directories.append(_linking_module.get_processed_root_data_dir()) - - return root_directories - - -def get_session_directory(session_key: dict) -> str: - """Retrieve the session directory with Neuropixels for the given session. - - Args: - session_key (dict): A dictionary mapping subject to an entry in the subject table, and session_datetime corresponding to a session in the database. - - Returns: - A string for the path to the session directory. - """ - return _linking_module.get_session_directory(session_key) - - -def get_processed_root_data_dir() -> str: - """Retrieve the root directory for all processed data. - - Returns: - A string for the full path to the root directory for processed data. - """ - - if hasattr(_linking_module, "get_processed_root_data_dir"): - return _linking_module.get_processed_root_data_dir() - else: - return get_ephys_root_data_dir()[0] - - -# ----------------------------- Table declarations ---------------------- - - -@schema -class AcquisitionSoftware(dj.Lookup): - """Name of software used for recording electrophysiological data. - - Attributes: - acq_software ( varchar(24) ): Acquisition software, e.g,. SpikeGLX, OpenEphys - """ - - definition = """ # Software used for recording of neuropixels probes - acq_software: varchar(24) - """ - contents = zip(["SpikeGLX", "Open Ephys"]) - - -@schema -class ProbeInsertion(dj.Manual): - """Information about probe insertion across subjects and sessions. - - Attributes: - Session (foreign key): Session primary key. - insertion_number (foreign key, str): Unique insertion number for each probe insertion for a given session. - probe.Probe (str): probe.Probe primary key. - """ - - definition = """ - # Probe insertion chronically implanted into an animal. - -> Subject - insertion_number: tinyint unsigned - --- - -> probe.Probe - insertion_datetime=null: datetime - """ - - -@schema -class InsertionLocation(dj.Manual): - """Stereotaxic location information for each probe insertion. - - Attributes: - ProbeInsertion (foreign key): ProbeInsertion primary key. - SkullReference (dict): SkullReference primary key. - ap_location (decimal (6, 2) ): Anterior-posterior location in micrometers. Reference is 0 with anterior values positive. - ml_location (decimal (6, 2) ): Medial-lateral location in micrometers. Reference is zero with right side values positive. - depth (decimal (6, 2) ): Manipulator depth relative to the surface of the brain at zero. Ventral is negative. - Theta (decimal (5, 2) ): elevation - rotation about the ml-axis in degrees relative to positive z-axis. - phi (decimal (5, 2) ): azimuth - rotation about the dv-axis in degrees relative to the positive x-axis. - """ - - definition = """ - # Brain Location of a given probe insertion. - -> ProbeInsertion - --- - -> SkullReference - ap_location: decimal(6, 2) # (um) anterior-posterior; ref is 0; more anterior is more positive - ml_location: decimal(6, 2) # (um) medial axis; ref is 0 ; more right is more positive - depth: decimal(6, 2) # (um) manipulator depth relative to surface of the brain (0); more ventral is more negative - theta=null: decimal(5, 2) # (deg) - elevation - rotation about the ml-axis [0, 180] - w.r.t the z+ axis - phi=null: decimal(5, 2) # (deg) - azimuth - rotation about the dv-axis [0, 360] - w.r.t the x+ axis - beta=null: decimal(5, 2) # (deg) rotation about the shank of the probe [-180, 180] - clockwise is increasing in degree - 0 is the probe-front facing anterior - """ - - -@schema -class EphysRecording(dj.Imported): - """Automated table with electrophysiology recording information for each probe inserted during an experimental session. - - Attributes: - ProbeInsertion (foreign key): ProbeInsertion primary key. - probe.ElectrodeConfig (dict): probe.ElectrodeConfig primary key. - AcquisitionSoftware (dict): AcquisitionSoftware primary key. - sampling_rate (float): sampling rate of the recording in Hertz (Hz). - recording_datetime (datetime): datetime of the recording from this probe. - recording_duration (float): duration of the entire recording from this probe in seconds. - """ - - definition = """ - # Ephys recording from a probe insertion for a given session. - -> Session - -> ProbeInsertion - --- - -> probe.ElectrodeConfig - -> AcquisitionSoftware - sampling_rate: float # (Hz) - recording_datetime: datetime # datetime of the recording from this probe - recording_duration: float # (seconds) duration of the recording from this probe - """ - - class EphysFile(dj.Part): - """Paths of electrophysiology recording files for each insertion. - - Attributes: - EphysRecording (foreign key): EphysRecording primary key. - file_path (varchar(255) ): relative file path for electrophysiology recording. - """ - - definition = """ - # Paths of files of a given EphysRecording round. - -> master - file_path: varchar(255) # filepath relative to root data directory - """ - - def make(self, key): - """Populates table with electrophysiology recording information.""" - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(key) - ) - - inserted_probe_serial_number = (ProbeInsertion * probe.Probe & key).fetch1( - "probe" - ) - - # search session dir and determine acquisition software - for ephys_pattern, ephys_acq_type in ( - ("*.ap.meta", "SpikeGLX"), - ("*.oebin", "Open Ephys"), - ): - ephys_meta_filepaths = list(session_dir.rglob(ephys_pattern)) - if ephys_meta_filepaths: - acq_software = ephys_acq_type - break - else: - raise FileNotFoundError( - f"Ephys recording data not found!" - f" Neither SpikeGLX nor Open Ephys recording files found" - f" in {session_dir}" - ) - - supported_probe_types = probe.ProbeType.fetch("probe_type") - - if acq_software == "SpikeGLX": - for meta_filepath in ephys_meta_filepaths: - spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) - if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: - break - else: - raise FileNotFoundError( - f"No SpikeGLX data found for probe insertion: {key}" - + " The probe serial number does not match." - ) - - if spikeglx_meta.probe_model in supported_probe_types: - probe_type = spikeglx_meta.probe_model - electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} - - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - - electrode_group_members = [ - probe_electrodes[(shank, shank_col, shank_row)] - for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"] - ] - else: - raise NotImplementedError( - "Processing for neuropixels probe model" - " {} not yet implemented".format(spikeglx_meta.probe_model) - ) - - self.insert1( - { - **key, - **generate_electrode_config(probe_type, electrode_group_members), - "acq_software": acq_software, - "sampling_rate": spikeglx_meta.meta["imSampRate"], - "recording_datetime": spikeglx_meta.recording_time, - "recording_duration": ( - spikeglx_meta.recording_duration - or spikeglx.retrieve_recording_duration(meta_filepath) - ), - } - ) - - root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath) - self.EphysFile.insert1( - {**key, "file_path": meta_filepath.relative_to(root_dir).as_posix()} - ) - elif acq_software == "Open Ephys": - dataset = openephys.OpenEphys(session_dir) - for serial_number, probe_data in dataset.probes.items(): - if str(serial_number) == inserted_probe_serial_number: - break - else: - raise FileNotFoundError( - "No Open Ephys data found for probe insertion: {}".format(key) - ) - - if not probe_data.ap_meta: - raise IOError( - 'No analog signals found - check "structure.oebin" file or "continuous" directory' - ) - - if probe_data.probe_model in supported_probe_types: - probe_type = probe_data.probe_model - electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} - - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - electrode_group_members = [ - probe_electrodes[channel_idx] - for channel_idx in probe_data.ap_meta["channels_indices"] - ] - else: - raise NotImplementedError( - "Processing for neuropixels" - " probe model {} not yet implemented".format(probe_data.probe_model) - ) - - self.insert1( - { - **key, - **generate_electrode_config(probe_type, electrode_group_members), - "acq_software": acq_software, - "sampling_rate": probe_data.ap_meta["sample_rate"], - "recording_datetime": probe_data.recording_info[ - "recording_datetimes" - ][0], - "recording_duration": np.sum( - probe_data.recording_info["recording_durations"] - ), - } - ) - - root_dir = find_root_directory( - get_ephys_root_data_dir(), - probe_data.recording_info["recording_files"][0], - ) - self.EphysFile.insert( - [ - {**key, "file_path": fp.relative_to(root_dir).as_posix()} - for fp in probe_data.recording_info["recording_files"] - ] - ) - # explicitly garbage collect "dataset" - # as these may have large memory footprint and may not be cleared fast enough - del probe_data, dataset - gc.collect() - else: - raise NotImplementedError( - f"Processing ephys files from" - f" acquisition software of type {acq_software} is" - f" not yet implemented" - ) - - -@schema -class LFP(dj.Imported): - """Extracts local field potentials (LFP) from an electrophysiology recording. - - Attributes: - EphysRecording (foreign key): EphysRecording primary key. - lfp_sampling_rate (float): Sampling rate for LFPs in Hz. - lfp_time_stamps (longblob): Time stamps with respect to the start of the recording. - lfp_mean (longblob): Overall mean LFP across electrodes. - """ - - definition = """ - # Acquired local field potential (LFP) from a given Ephys recording. - -> EphysRecording - --- - lfp_sampling_rate: float # (Hz) - lfp_time_stamps: longblob # (s) timestamps with respect to the start of the recording (recording_timestamp) - lfp_mean: longblob # (uV) mean of LFP across electrodes - shape (time,) - """ - - class Electrode(dj.Part): - """Saves local field potential data for each electrode. - - Attributes: - LFP (foreign key): LFP primary key. - probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key. - lfp (longblob): LFP recording at this electrode in microvolts. - """ - - definition = """ - -> master - -> probe.ElectrodeConfig.Electrode - --- - lfp: longblob # (uV) recorded lfp at this electrode - """ - - # Only store LFP for every 9th channel, due to high channel density, - # close-by channels exhibit highly similar LFP - _skip_channel_counts = 9 - - def make(self, key): - """Populates the LFP tables.""" - acq_software = (EphysRecording * ProbeInsertion & key).fetch1("acq_software") - - electrode_keys, lfp = [], [] - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - - lfp_channel_ind = spikeglx_recording.lfmeta.recording_channels[ - -1 :: -self._skip_channel_counts - ] - - # Extract LFP data at specified channels and convert to uV - lfp = spikeglx_recording.lf_timeseries[ - :, lfp_channel_ind - ] # (sample x channel) - lfp = ( - lfp * spikeglx_recording.get_channel_bit_volts("lf")[lfp_channel_ind] - ).T # (channel x sample) - - self.insert1( - dict( - key, - lfp_sampling_rate=spikeglx_recording.lfmeta.meta["imSampRate"], - lfp_time_stamps=( - np.arange(lfp.shape[1]) - / spikeglx_recording.lfmeta.meta["imSampRate"] - ), - lfp_mean=lfp.mean(axis=0), - ) - ) - - electrode_query = ( - probe.ProbeType.Electrode - * probe.ElectrodeConfig.Electrode - * EphysRecording - & key - ) - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - - for recorded_site in lfp_channel_ind: - shank, shank_col, shank_row, _ = spikeglx_recording.apmeta.shankmap[ - "data" - ][recorded_site] - electrode_keys.append(probe_electrodes[(shank, shank_col, shank_row)]) - elif acq_software == "Open Ephys": - oe_probe = get_openephys_probe_data(key) - - lfp_channel_ind = np.r_[ - len(oe_probe.lfp_meta["channels_indices"]) - - 1 : 0 : -self._skip_channel_counts - ] - - # (sample x channel) - lfp = oe_probe.lfp_timeseries[:, lfp_channel_ind] - lfp = ( - lfp * np.array(oe_probe.lfp_meta["channels_gains"])[lfp_channel_ind] - ).T # (channel x sample) - lfp_timestamps = oe_probe.lfp_timestamps - - self.insert1( - dict( - key, - lfp_sampling_rate=oe_probe.lfp_meta["sample_rate"], - lfp_time_stamps=lfp_timestamps, - lfp_mean=lfp.mean(axis=0), - ) - ) - - electrode_query = ( - probe.ProbeType.Electrode - * probe.ElectrodeConfig.Electrode - * EphysRecording - & key - ) - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - electrode_keys.extend( - probe_electrodes[channel_idx] for channel_idx in lfp_channel_ind - ) - else: - raise NotImplementedError( - f"LFP extraction from acquisition software" - f" of type {acq_software} is not yet implemented" - ) - - # single insert in loop to mitigate potential memory issue - for electrode_key, lfp_trace in zip(electrode_keys, lfp): - self.Electrode.insert1({**key, **electrode_key, "lfp": lfp_trace}) - - -# ------------ Clustering -------------- - - -@schema -class ClusteringMethod(dj.Lookup): - """Kilosort clustering method. - - Attributes: - clustering_method (foreign key, varchar(16) ): Kilosort clustering method. - clustering_methods_desc (varchar(1000) ): Additional description of the clustering method. - """ - - definition = """ - # Method for clustering - clustering_method: varchar(16) - --- - clustering_method_desc: varchar(1000) - """ - - contents = [ - ("kilosort2", "kilosort2 clustering method"), - ("kilosort2.5", "kilosort2.5 clustering method"), - ("kilosort3", "kilosort3 clustering method"), - ] - - -@schema -class ClusteringParamSet(dj.Lookup): - """Parameters to be used in clustering procedure for spike sorting. - - Attributes: - paramset_idx (foreign key): Unique ID for the clustering parameter set. - ClusteringMethod (dict): ClusteringMethod primary key. - paramset_desc (varchar(128) ): Description of the clustering parameter set. - param_set_hash (uuid): UUID hash for the parameter set. - params (longblob): Parameters for clustering with Kilosort. - """ - - definition = """ - # Parameter set to be used in a clustering procedure - paramset_idx: smallint - --- - -> ClusteringMethod - paramset_desc: varchar(128) - param_set_hash: uuid - unique index (param_set_hash) - params: longblob # dictionary of all applicable parameters - """ - - @classmethod - def insert_new_params( - cls, - clustering_method: str, - paramset_desc: str, - params: dict, - paramset_idx: int = None, - ): - """Inserts new parameters into the ClusteringParamSet table. - - Args: - clustering_method (str): name of the clustering method. - paramset_desc (str): description of the parameter set - params (dict): clustering parameters - paramset_idx (int, optional): Unique parameter set ID. Defaults to None. - """ - if paramset_idx is None: - paramset_idx = ( - dj.U().aggr(cls, n="max(paramset_idx)").fetch1("n") or 0 - ) + 1 - - param_dict = { - "clustering_method": clustering_method, - "paramset_idx": paramset_idx, - "paramset_desc": paramset_desc, - "params": params, - "param_set_hash": dict_to_uuid( - {**params, "clustering_method": clustering_method} - ), - } - param_query = cls & {"param_set_hash": param_dict["param_set_hash"]} - - if param_query: # If the specified param-set already exists - existing_paramset_idx = param_query.fetch1("paramset_idx") - if ( - existing_paramset_idx == paramset_idx - ): # If the existing set has the same paramset_idx: job done - return - else: # If not same name: human error, trying to add the same paramset with different name - raise dj.DataJointError( - f"The specified param-set already exists" - f" - with paramset_idx: {existing_paramset_idx}" - ) - else: - if {"paramset_idx": paramset_idx} in cls.proj(): - raise dj.DataJointError( - f"The specified paramset_idx {paramset_idx} already exists," - f" please pick a different one." - ) - cls.insert1(param_dict) - - -@schema -class ClusterQualityLabel(dj.Lookup): - """Quality label for each spike sorted cluster. - - Attributes: - cluster_quality_label (foreign key, varchar(100) ): Cluster quality type. - cluster_quality_description (varchar(4000) ): Description of the cluster quality type. - """ - - definition = """ - # Quality - cluster_quality_label: varchar(100) # cluster quality type - e.g. 'good', 'MUA', 'noise', etc. - --- - cluster_quality_description: varchar(4000) - """ - contents = [ - ("good", "single unit"), - ("ok", "probably a single unit, but could be contaminated"), - ("mua", "multi-unit activity"), - ("noise", "bad unit"), - ] - - -@schema -class ClusteringTask(dj.Manual): - """A clustering task to spike sort electrophysiology datasets. - - Attributes: - EphysRecording (foreign key): EphysRecording primary key. - ClusteringParamSet (foreign key): ClusteringParamSet primary key. - clustering_outdir_dir (varchar (255) ): Relative path to output clustering results. - task_mode (enum): `Trigger` computes clustering or and `load` imports existing data. - """ - - definition = """ - # Manual table for defining a clustering task ready to be run - -> EphysRecording - -> ClusteringParamSet - --- - clustering_output_dir='': varchar(255) # clustering output directory relative to the clustering root data directory - task_mode='load': enum('load', 'trigger') # 'load': load computed analysis results, 'trigger': trigger computation - """ - - @classmethod - def infer_output_dir(cls, key, relative=False, mkdir=False) -> pathlib.Path: - """Infer output directory if it is not provided. - - Args: - key (dict): ClusteringTask primary key. - - Returns: - Expected clustering_output_dir based on the following convention: - processed_dir / session_dir / probe_{insertion_number} / {clustering_method}_{paramset_idx} - e.g.: sub4/sess1/probe_2/kilosort2_0 - """ - processed_dir = pathlib.Path(get_processed_root_data_dir()) - sess_dir = find_full_path(get_ephys_root_data_dir(), get_session_directory(key)) - root_dir = find_root_directory(get_ephys_root_data_dir(), sess_dir) - - method = ( - (ClusteringParamSet * ClusteringMethod & key) - .fetch1("clustering_method") - .replace(".", "-") - ) - - output_dir = ( - processed_dir - / sess_dir.relative_to(root_dir) - / f'probe_{key["insertion_number"]}' - / f'{method}_{key["paramset_idx"]}' - ) - - if mkdir: - output_dir.mkdir(parents=True, exist_ok=True) - log.info(f"{output_dir} created!") - - return output_dir.relative_to(processed_dir) if relative else output_dir - - @classmethod - def auto_generate_entries(cls, ephys_recording_key: dict, paramset_idx: int = 0): - """Autogenerate entries based on a particular ephys recording. - - Args: - ephys_recording_key (dict): EphysRecording primary key. - paramset_idx (int, optional): Parameter index to use for clustering task. Defaults to 0. - """ - key = {**ephys_recording_key, "paramset_idx": paramset_idx} - - processed_dir = get_processed_root_data_dir() - output_dir = ClusteringTask.infer_output_dir(key, relative=False, mkdir=True) - - try: - kilosort.Kilosort( - output_dir - ) # check if the directory is a valid Kilosort output - except FileNotFoundError: - task_mode = "trigger" - else: - task_mode = "load" - - cls.insert1( - { - **key, - "clustering_output_dir": output_dir.relative_to( - processed_dir - ).as_posix(), - "task_mode": task_mode, - } - ) - - -@schema -class Clustering(dj.Imported): - """A processing table to handle each clustering task. - - Attributes: - ClusteringTask (foreign key): ClusteringTask primary key. - clustering_time (datetime): Time when clustering results are generated. - package_version (varchar(16) ): Package version used for a clustering analysis. - """ - - definition = """ - # Clustering Procedure - -> ClusteringTask - --- - clustering_time: datetime # time of generation of this set of clustering results - package_version='': varchar(16) - """ - - def make(self, key): - """Triggers or imports clustering analysis.""" - task_mode, output_dir = (ClusteringTask & key).fetch1( - "task_mode", "clustering_output_dir" - ) - - if not output_dir: - output_dir = ClusteringTask.infer_output_dir(key, relative=True, mkdir=True) - # update clustering_output_dir - ClusteringTask.update1( - {**key, "clustering_output_dir": output_dir.as_posix()} - ) - - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - if task_mode == "load": - kilosort.Kilosort( - kilosort_dir - ) # check if the directory is a valid Kilosort output - elif task_mode == "trigger": - acq_software, clustering_method, params = ( - ClusteringTask * EphysRecording * ClusteringParamSet & key - ).fetch1("acq_software", "clustering_method", "params") - - if "kilosort" in clustering_method: - from element_array_ephys.readers import kilosort_triggering - - # add additional probe-recording and channels details into `params` - params = {**params, **get_recording_channels_details(key)} - params["fs"] = params["sample_rate"] - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX( - spikeglx_meta_filepath.parent - ) - spikeglx_recording.validate_file("ap") - run_CatGT = ( - params.pop("run_CatGT", True) - and "_tcat." not in spikeglx_meta_filepath.stem - ) - - if clustering_method.startswith("pykilosort"): - kilosort_triggering.run_pykilosort( - continuous_file=spikeglx_recording.root_dir - / (spikeglx_recording.root_name + ".ap.bin"), - kilosort_output_directory=kilosort_dir, - channel_ind=params.pop("channel_ind"), - x_coords=params.pop("x_coords"), - y_coords=params.pop("y_coords"), - shank_ind=params.pop("shank_ind"), - connected=params.pop("connected"), - sample_rate=params.pop("sample_rate"), - params=params, - ) - else: - run_kilosort = kilosort_triggering.SGLXKilosortPipeline( - npx_input_dir=spikeglx_meta_filepath.parent, - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - run_CatGT=run_CatGT, - ) - run_kilosort.run_modules() - elif acq_software == "Open Ephys": - oe_probe = get_openephys_probe_data(key) - - assert len(oe_probe.recording_info["recording_files"]) == 1 - - # run kilosort - if clustering_method.startswith("pykilosort"): - kilosort_triggering.run_pykilosort( - continuous_file=pathlib.Path( - oe_probe.recording_info["recording_files"][0] - ) - / "continuous.dat", - kilosort_output_directory=kilosort_dir, - channel_ind=params.pop("channel_ind"), - x_coords=params.pop("x_coords"), - y_coords=params.pop("y_coords"), - shank_ind=params.pop("shank_ind"), - connected=params.pop("connected"), - sample_rate=params.pop("sample_rate"), - params=params, - ) - else: - run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( - npx_input_dir=oe_probe.recording_info["recording_files"][0], - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - ) - run_kilosort.run_modules() - else: - raise NotImplementedError( - f"Automatic triggering of {clustering_method}" - f" clustering analysis is not yet supported" - ) - - else: - raise ValueError(f"Unknown task mode: {task_mode}") - - creation_time, _, _ = kilosort.extract_clustering_info(kilosort_dir) - self.insert1({**key, "clustering_time": creation_time, "package_version": ""}) - - -@schema -class Curation(dj.Manual): - """Curation procedure table. - - Attributes: - Clustering (foreign key): Clustering primary key. - curation_id (foreign key, int): Unique curation ID. - curation_time (datetime): Time when curation results are generated. - curation_output_dir (varchar(255) ): Output directory of the curated results. - quality_control (bool): If True, this clustering result has undergone quality control. - manual_curation (bool): If True, manual curation has been performed on this clustering result. - curation_note (varchar(2000) ): Notes about the curation task. - """ - - definition = """ - # Manual curation procedure - -> Clustering - curation_id: int - --- - curation_time: datetime # time of generation of this set of curated clustering results - curation_output_dir: varchar(255) # output directory of the curated results, relative to root data directory - quality_control: bool # has this clustering result undergone quality control? - manual_curation: bool # has manual curation been performed on this clustering result? - curation_note='': varchar(2000) - """ - - def create1_from_clustering_task(self, key, curation_note: str = ""): - """ - A function to create a new corresponding "Curation" for a particular - "ClusteringTask" - """ - if key not in Clustering(): - raise ValueError( - f"No corresponding entry in Clustering available" - f" for: {key}; do `Clustering.populate(key)`" - ) - - task_mode, output_dir = (ClusteringTask & key).fetch1( - "task_mode", "clustering_output_dir" - ) - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - creation_time, is_curated, is_qc = kilosort.extract_clustering_info( - kilosort_dir - ) - # Synthesize curation_id - curation_id = ( - dj.U().aggr(self & key, n="ifnull(max(curation_id)+1,1)").fetch1("n") - ) - self.insert1( - { - **key, - "curation_id": curation_id, - "curation_time": creation_time, - "curation_output_dir": output_dir, - "quality_control": is_qc, - "manual_curation": is_curated, - "curation_note": curation_note, - } - ) - - -@schema -class CuratedClustering(dj.Imported): - """Clustering results after curation. - - Attributes: - Curation (foreign key): Curation primary key. - """ - - definition = """ - # Clustering results of a curation. - -> Curation - """ - - class Unit(dj.Part): - """Single unit properties after clustering and curation. - - Attributes: - CuratedClustering (foreign key): CuratedClustering primary key. - unit (foreign key, int): Unique integer identifying a single unit. - probe.ElectrodeConfig.Electrode (dict): probe.ElectrodeConfig.Electrode primary key. - ClusteringQualityLabel (dict): CLusteringQualityLabel primary key. - spike_count (int): Number of spikes in this recording for this unit. - spike_times (longblob): Spike times of this unit, relative to start time of EphysRecording. - spike_sites (longblob): Array of electrode associated with each spike. - spike_depths (longblob): Array of depths associated with each spike, relative to each spike. - """ - - definition = """ - # Properties of a given unit from a round of clustering (and curation) - -> master - unit: int - --- - -> probe.ElectrodeConfig.Electrode # electrode with highest waveform amplitude for this unit - -> ClusterQualityLabel - spike_count: int # how many spikes in this recording for this unit - spike_times: longblob # (s) spike times of this unit, relative to the start of the EphysRecording - spike_sites : longblob # array of electrode associated with each spike - spike_depths=null : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe - """ - - def make(self, key): - """Automated population of Unit information.""" - output_dir = (Curation & key).fetch1("curation_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - kilosort_dataset = kilosort.Kilosort(kilosort_dir) - acq_software, sample_rate = (EphysRecording & key).fetch1( - "acq_software", "sampling_rate" - ) - - sample_rate = kilosort_dataset.data["params"].get("sample_rate", sample_rate) - - # ---------- Unit ---------- - # -- Remove 0-spike units - withspike_idx = [ - i - for i, u in enumerate(kilosort_dataset.data["cluster_ids"]) - if (kilosort_dataset.data["spike_clusters"] == u).any() - ] - valid_units = kilosort_dataset.data["cluster_ids"][withspike_idx] - valid_unit_labels = kilosort_dataset.data["cluster_groups"][withspike_idx] - # -- Get channel and electrode-site mapping - channel2electrodes = get_neuropixels_channel2electrode_map(key, acq_software) - - # -- Spike-times -- - # spike_times_sec_adj > spike_times_sec > spike_times - spike_time_key = ( - "spike_times_sec_adj" - if "spike_times_sec_adj" in kilosort_dataset.data - else ( - "spike_times_sec" - if "spike_times_sec" in kilosort_dataset.data - else "spike_times" - ) - ) - spike_times = kilosort_dataset.data[spike_time_key] - kilosort_dataset.extract_spike_depths() - - # -- Spike-sites and Spike-depths -- - spike_sites = np.array( - [ - channel2electrodes[s]["electrode"] - for s in kilosort_dataset.data["spike_sites"] - ] - ) - spike_depths = kilosort_dataset.data["spike_depths"] - - # -- Insert unit, label, peak-chn - units = [] - for unit, unit_lbl in zip(valid_units, valid_unit_labels): - if (kilosort_dataset.data["spike_clusters"] == unit).any(): - unit_channel, _ = kilosort_dataset.get_best_channel(unit) - unit_spike_times = ( - spike_times[kilosort_dataset.data["spike_clusters"] == unit] - / sample_rate - ) - spike_count = len(unit_spike_times) - - units.append( - { - "unit": unit, - "cluster_quality_label": unit_lbl, - **channel2electrodes[unit_channel], - "spike_times": unit_spike_times, - "spike_count": spike_count, - "spike_sites": spike_sites[ - kilosort_dataset.data["spike_clusters"] == unit - ], - "spike_depths": ( - spike_depths[ - kilosort_dataset.data["spike_clusters"] == unit - ] - if spike_depths is not None - else None - ), - } - ) - - self.insert1(key) - self.Unit.insert([{**key, **u} for u in units]) - - -@schema -class WaveformSet(dj.Imported): - """A set of spike waveforms for units out of a given CuratedClustering. - - Attributes: - CuratedClustering (foreign key): CuratedClustering primary key. - """ - - definition = """ - # A set of spike waveforms for units out of a given CuratedClustering - -> CuratedClustering - """ - - class PeakWaveform(dj.Part): - """Mean waveform across spikes for a given unit. - - Attributes: - WaveformSet (foreign key): WaveformSet primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - peak_electrode_waveform (longblob): Mean waveform for a given unit at its representative electrode. - """ - - definition = """ - # Mean waveform across spikes for a given unit at its representative electrode - -> master - -> CuratedClustering.Unit - --- - peak_electrode_waveform: longblob # (uV) mean waveform for a given unit at its representative electrode - """ - - class Waveform(dj.Part): - """Spike waveforms for a given unit. - - Attributes: - WaveformSet (foreign key): WaveformSet primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key. - waveform_mean (longblob): mean waveform across spikes of the unit in microvolts. - waveforms (longblob): waveforms of a sampling of spikes at the given electrode and unit. - """ - - definition = """ - # Spike waveforms and their mean across spikes for the given unit - -> master - -> CuratedClustering.Unit - -> probe.ElectrodeConfig.Electrode - --- - waveform_mean: longblob # (uV) mean waveform across spikes of the given unit - waveforms=null: longblob # (uV) (spike x sample) waveforms of a sampling of spikes at the given electrode for the given unit - """ - - def make(self, key): - """Populates waveform tables.""" - output_dir = (Curation & key).fetch1("curation_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - kilosort_dataset = kilosort.Kilosort(kilosort_dir) - - acq_software, probe_serial_number = ( - EphysRecording * ProbeInsertion & key - ).fetch1("acq_software", "probe") - - # -- Get channel and electrode-site mapping - recording_key = (EphysRecording & key).fetch1("KEY") - channel2electrodes = get_neuropixels_channel2electrode_map( - recording_key, acq_software - ) - - is_qc = (Curation & key).fetch1("quality_control") - - # Get all units - units = { - u["unit"]: u - for u in (CuratedClustering.Unit & key).fetch(as_dict=True, order_by="unit") - } - - if is_qc: - unit_waveforms = np.load( - kilosort_dir / "mean_waveforms.npy" - ) # unit x channel x sample - - def yield_unit_waveforms(): - for unit_no, unit_waveform in zip( - kilosort_dataset.data["cluster_ids"], unit_waveforms - ): - unit_peak_waveform = {} - unit_electrode_waveforms = [] - if unit_no in units: - for channel, channel_waveform in zip( - kilosort_dataset.data["channel_map"], unit_waveform - ): - unit_electrode_waveforms.append( - { - **units[unit_no], - **channel2electrodes[channel], - "waveform_mean": channel_waveform, - } - ) - if ( - channel2electrodes[channel]["electrode"] - == units[unit_no]["electrode"] - ): - unit_peak_waveform = { - **units[unit_no], - "peak_electrode_waveform": channel_waveform, - } - yield unit_peak_waveform, unit_electrode_waveforms - - else: - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) - neuropixels_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - elif acq_software == "Open Ephys": - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(key) - ) - openephys_dataset = openephys.OpenEphys(session_dir) - neuropixels_recording = openephys_dataset.probes[probe_serial_number] - - def yield_unit_waveforms(): - for unit_dict in units.values(): - unit_peak_waveform = {} - unit_electrode_waveforms = [] - - spikes = unit_dict["spike_times"] - waveforms = neuropixels_recording.extract_spike_waveforms( - spikes, kilosort_dataset.data["channel_map"] - ) # (sample x channel x spike) - waveforms = waveforms.transpose( - (1, 2, 0) - ) # (channel x spike x sample) - for channel, channel_waveform in zip( - kilosort_dataset.data["channel_map"], waveforms - ): - unit_electrode_waveforms.append( - { - **unit_dict, - **channel2electrodes[channel], - "waveform_mean": channel_waveform.mean(axis=0), - "waveforms": channel_waveform, - } - ) - if ( - channel2electrodes[channel]["electrode"] - == unit_dict["electrode"] - ): - unit_peak_waveform = { - **unit_dict, - "peak_electrode_waveform": channel_waveform.mean( - axis=0 - ), - } - - yield unit_peak_waveform, unit_electrode_waveforms - - # insert waveform on a per-unit basis to mitigate potential memory issue - self.insert1(key) - for unit_peak_waveform, unit_electrode_waveforms in yield_unit_waveforms(): - if unit_peak_waveform: - self.PeakWaveform.insert1(unit_peak_waveform, ignore_extra_fields=True) - if unit_electrode_waveforms: - self.Waveform.insert(unit_electrode_waveforms, ignore_extra_fields=True) - - -@schema -class QualityMetrics(dj.Imported): - """Clustering and waveform quality metrics. - - Attributes: - CuratedClustering (foreign key): CuratedClustering primary key. - """ - - definition = """ - # Clusters and waveforms metrics - -> CuratedClustering - """ - - class Cluster(dj.Part): - """Cluster metrics for a unit. - - Attributes: - QualityMetrics (foreign key): QualityMetrics primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - firing_rate (float): Firing rate of the unit. - snr (float): Signal-to-noise ratio for a unit. - presence_ratio (float): Fraction of time where spikes are present. - isi_violation (float): rate of ISI violation as a fraction of overall rate. - number_violation (int): Total ISI violations. - amplitude_cutoff (float): Estimate of miss rate based on amplitude histogram. - isolation_distance (float): Distance to nearest cluster. - l_ratio (float): Amount of empty space between a cluster and other spikes in dataset. - d_prime (float): Classification accuracy based on LDA. - nn_hit_rate (float): Fraction of neighbors for target cluster that are also in target cluster. - nn_miss_rate (float): Fraction of neighbors outside target cluster that are in the target cluster. - silhouette_core (float): Maximum change in spike depth throughout recording. - cumulative_drift (float): Cumulative change in spike depth throughout recording. - contamination_rate (float): Frequency of spikes in the refractory period. - """ - - definition = """ - # Cluster metrics for a particular unit - -> master - -> CuratedClustering.Unit - --- - firing_rate=null: float # (Hz) firing rate for a unit - snr=null: float # signal-to-noise ratio for a unit - presence_ratio=null: float # fraction of time in which spikes are present - isi_violation=null: float # rate of ISI violation as a fraction of overall rate - number_violation=null: int # total number of ISI violations - amplitude_cutoff=null: float # estimate of miss rate based on amplitude histogram - isolation_distance=null: float # distance to nearest cluster in Mahalanobis space - l_ratio=null: float # - d_prime=null: float # Classification accuracy based on LDA - nn_hit_rate=null: float # Fraction of neighbors for target cluster that are also in target cluster - nn_miss_rate=null: float # Fraction of neighbors outside target cluster that are in target cluster - silhouette_score=null: float # Standard metric for cluster overlap - max_drift=null: float # Maximum change in spike depth throughout recording - cumulative_drift=null: float # Cumulative change in spike depth throughout recording - contamination_rate=null: float # - """ - - class Waveform(dj.Part): - """Waveform metrics for a particular unit. - - Attributes: - QualityMetrics (foreign key): QualityMetrics primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - amplitude (float): Absolute difference between waveform peak and trough in microvolts. - duration (float): Time between waveform peak and trough in milliseconds. - halfwidth (float): Spike width at half max amplitude. - pt_ratio (float): Absolute amplitude of peak divided by absolute amplitude of trough relative to 0. - repolarization_slope (float): Slope of the regression line fit to first 30 microseconds from trough to peak. - recovery_slope (float): Slope of the regression line fit to first 30 microseconds from peak to tail. - spread (float): The range with amplitude over 12-percent of maximum amplitude along the probe. - velocity_above (float): inverse velocity of waveform propagation from soma to the top of the probe. - velocity_below (float): inverse velocity of waveform propagation from soma toward the bottom of the probe. - """ - - definition = """ - # Waveform metrics for a particular unit - -> master - -> CuratedClustering.Unit - --- - amplitude: float # (uV) absolute difference between waveform peak and trough - duration: float # (ms) time between waveform peak and trough - halfwidth=null: float # (ms) spike width at half max amplitude - pt_ratio=null: float # absolute amplitude of peak divided by absolute amplitude of trough relative to 0 - repolarization_slope=null: float # the repolarization slope was defined by fitting a regression line to the first 30us from trough to peak - recovery_slope=null: float # the recovery slope was defined by fitting a regression line to the first 30us from peak to tail - spread=null: float # (um) the range with amplitude above 12-percent of the maximum amplitude along the probe - velocity_above=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the top of the probe - velocity_below=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the bottom of the probe - """ - - def make(self, key): - """Populates tables with quality metrics data.""" - output_dir = (ClusteringTask & key).fetch1("clustering_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - metric_fp = kilosort_dir / "metrics.csv" - rename_dict = { - "isi_viol": "isi_violation", - "num_viol": "number_violation", - "contam_rate": "contamination_rate", - } - - if not metric_fp.exists(): - raise FileNotFoundError(f"QC metrics file not found: {metric_fp}") - - metrics_df = pd.read_csv(metric_fp) - metrics_df.set_index("cluster_id", inplace=True) - metrics_df.replace([np.inf, -np.inf], np.nan, inplace=True) - metrics_df.columns = metrics_df.columns.str.lower() - metrics_df.rename(columns=rename_dict, inplace=True) - metrics_list = [ - dict(metrics_df.loc[unit_key["unit"]], **unit_key) - for unit_key in (CuratedClustering.Unit & key).fetch("KEY") - ] - - self.insert1(key) - self.Cluster.insert(metrics_list, ignore_extra_fields=True) - self.Waveform.insert(metrics_list, ignore_extra_fields=True) - - -# ---------------- HELPER FUNCTIONS ---------------- - - -def get_spikeglx_meta_filepath(ephys_recording_key: dict) -> str: - """Get spikeGLX data filepath.""" - # attempt to retrieve from EphysRecording.EphysFile - spikeglx_meta_filepath = pathlib.Path( - ( - EphysRecording.EphysFile - & ephys_recording_key - & 'file_path LIKE "%.ap.meta"' - ).fetch1("file_path") - ) - - try: - spikeglx_meta_filepath = find_full_path( - get_ephys_root_data_dir(), spikeglx_meta_filepath - ) - except FileNotFoundError: - # if not found, search in session_dir again - if not spikeglx_meta_filepath.exists(): - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(ephys_recording_key) - ) - inserted_probe_serial_number = ( - ProbeInsertion * probe.Probe & ephys_recording_key - ).fetch1("probe") - - spikeglx_meta_filepaths = [fp for fp in session_dir.rglob("*.ap.meta")] - for meta_filepath in spikeglx_meta_filepaths: - spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) - if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: - spikeglx_meta_filepath = meta_filepath - break - else: - raise FileNotFoundError( - "No SpikeGLX data found for probe insertion: {}".format( - ephys_recording_key - ) - ) - - return spikeglx_meta_filepath - - -def get_openephys_probe_data(ephys_recording_key: dict) -> list: - """Get OpenEphys probe data from file.""" - inserted_probe_serial_number = ( - ProbeInsertion * probe.Probe & ephys_recording_key - ).fetch1("probe") - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(ephys_recording_key) - ) - loaded_oe = openephys.OpenEphys(session_dir) - probe_data = loaded_oe.probes[inserted_probe_serial_number] - - # explicitly garbage collect "loaded_oe" - # as these may have large memory footprint and may not be cleared fast enough - del loaded_oe - gc.collect() - - return probe_data - - -def get_neuropixels_channel2electrode_map( - ephys_recording_key: dict, acq_software: str -) -> dict: - """Get the channel map for neuropixels probe.""" - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key) - spikeglx_meta = spikeglx.SpikeGLXMeta(spikeglx_meta_filepath) - electrode_config_key = ( - EphysRecording * probe.ElectrodeConfig & ephys_recording_key - ).fetch1("KEY") - - electrode_query = ( - probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode - & electrode_config_key - ) - - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - - channel2electrode_map = { - recorded_site: probe_electrodes[(shank, shank_col, shank_row)] - for recorded_site, (shank, shank_col, shank_row, _) in enumerate( - spikeglx_meta.shankmap["data"] - ) - } - elif acq_software == "Open Ephys": - probe_dataset = get_openephys_probe_data(ephys_recording_key) - - electrode_query = ( - probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode * EphysRecording - & ephys_recording_key - ) - - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - channel2electrode_map = { - channel_idx: probe_electrodes[channel_idx] - for channel_idx in probe_dataset.ap_meta["channels_indices"] - } - - return channel2electrode_map - - -def generate_electrode_config(probe_type: str, electrode_keys: list) -> dict: - """Generate and insert new ElectrodeConfig - - Args: - probe_type (str): probe type (e.g. neuropixels 2.0 - SS) - electrode_keys (list): list of keys of the probe.ProbeType.Electrode table - - Returns: - dict: representing a key of the probe.ElectrodeConfig table - """ - # compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode) - electrode_config_hash = dict_to_uuid({k["electrode"]: k for k in electrode_keys}) - - electrode_list = sorted([k["electrode"] for k in electrode_keys]) - electrode_gaps = ( - [-1] - + np.where(np.diff(electrode_list) > 1)[0].tolist() - + [len(electrode_list) - 1] - ) - electrode_config_name = "; ".join( - [ - f"{electrode_list[start + 1]}-{electrode_list[end]}" - for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:]) - ] - ) - - electrode_config_key = {"electrode_config_hash": electrode_config_hash} - - # ---- make new ElectrodeConfig if needed ---- - if not probe.ElectrodeConfig & electrode_config_key: - probe.ElectrodeConfig.insert1( - { - **electrode_config_key, - "probe_type": probe_type, - "electrode_config_name": electrode_config_name, - } - ) - probe.ElectrodeConfig.Electrode.insert( - {**electrode_config_key, **electrode} for electrode in electrode_keys - ) - - return electrode_config_key - - -def get_recording_channels_details(ephys_recording_key: dict) -> np.array: - """Get details of recording channels for a given recording.""" - channels_details = {} - - acq_software, sample_rate = (EphysRecording & ephys_recording_key).fetch1( - "acq_software", "sampling_rate" - ) - - probe_type = (ProbeInsertion * probe.Probe & ephys_recording_key).fetch1( - "probe_type" - ) - channels_details["probe_type"] = { - "neuropixels 1.0 - 3A": "3A", - "neuropixels 1.0 - 3B": "NP1", - "neuropixels UHD": "NP1100", - "neuropixels 2.0 - SS": "NP21", - "neuropixels 2.0 - MS": "NP24", - }[probe_type] - - electrode_config_key = ( - probe.ElectrodeConfig * EphysRecording & ephys_recording_key - ).fetch1("KEY") - ( - channels_details["channel_ind"], - channels_details["x_coords"], - channels_details["y_coords"], - channels_details["shank_ind"], - ) = ( - probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode - & electrode_config_key - ).fetch( - "electrode", "x_coord", "y_coord", "shank" - ) - channels_details["sample_rate"] = sample_rate - channels_details["num_channels"] = len(channels_details["channel_ind"]) - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key) - spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - channels_details["uVPerBit"] = spikeglx_recording.get_channel_bit_volts("ap")[0] - channels_details["connected"] = np.array( - [v for *_, v in spikeglx_recording.apmeta.shankmap["data"]] - ) - elif acq_software == "Open Ephys": - oe_probe = get_openephys_probe_data(ephys_recording_key) - channels_details["uVPerBit"] = oe_probe.ap_meta["channels_gains"][0] - channels_details["connected"] = np.array( - [ - int(v == 1) - for c, v in oe_probe.channels_connected.items() - if c in channels_details["channel_ind"] - ] - ) - - return channels_details diff --git a/element_array_ephys/ephys_precluster.py b/element_array_ephys/ephys_precluster.py deleted file mode 100644 index 4d52c610..00000000 --- a/element_array_ephys/ephys_precluster.py +++ /dev/null @@ -1,1435 +0,0 @@ -import importlib -import inspect -import re - -import datajoint as dj -import numpy as np -import pandas as pd -from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory - -from . import ephys_report, probe -from .readers import kilosort, openephys, spikeglx - -schema = dj.schema() - -_linking_module = None - - -def activate( - ephys_schema_name: str, - probe_schema_name: str = None, - *, - create_schema: bool = True, - create_tables: bool = True, - linking_module: str = None, -): - """Activates the `ephys` and `probe` schemas. - - Args: - ephys_schema_name (str): A string containing the name of the ephys schema. - probe_schema_name (str): A string containing the name of the probe schema. - create_schema (bool): If True, schema will be created in the database. - create_tables (bool): If True, tables related to the schema will be created in the database. - linking_module (str): A string containing the module name or module containing the required dependencies to activate the schema. - - Dependencies: - Upstream tables: - Session: A parent table to ProbeInsertion - Probe: A parent table to EphysRecording. Probe information is required before electrophysiology data is imported. - - Functions: - get_ephys_root_data_dir(): Returns absolute path for root data director(y/ies) with all electrophysiological recording sessions, as a list of string(s). - get_session_direction(session_key: dict): Returns path to electrophysiology data for the a particular session as a list of strings. - """ - - if isinstance(linking_module, str): - linking_module = importlib.import_module(linking_module) - assert inspect.ismodule( - linking_module - ), "The argument 'dependency' must be a module's name or a module" - - global _linking_module - _linking_module = linking_module - - probe.activate( - probe_schema_name, create_schema=create_schema, create_tables=create_tables - ) - schema.activate( - ephys_schema_name, - create_schema=create_schema, - create_tables=create_tables, - add_objects=_linking_module.__dict__, - ) - ephys_report.activate(f"{ephys_schema_name}_report", ephys_schema_name) - - -# -------------- Functions required by the elements-ephys --------------- - - -def get_ephys_root_data_dir() -> list: - """Fetches absolute data path to ephys data directories. - - The absolute path here is used as a reference for all downstream relative paths used in DataJoint. - - Returns: - A list of the absolute path(s) to ephys data directories. - """ - return _linking_module.get_ephys_root_data_dir() - - -def get_session_directory(session_key: dict) -> str: - """Retrieve the session directory with Neuropixels for the given session. - - Args: - session_key (dict): A dictionary mapping subject to an entry in the subject table, and session_datetime corresponding to a session in the database. - - Returns: - A string for the path to the session directory. - """ - return _linking_module.get_session_directory(session_key) - - -# ----------------------------- Table declarations ---------------------- - - -@schema -class AcquisitionSoftware(dj.Lookup): - """Name of software used for recording electrophysiological data. - - Attributes: - acq_software ( varchar(24) ): Acquisition software, e.g,. SpikeGLX, OpenEphys - """ - - definition = """ # Name of software used for recording of neuropixels probes - SpikeGLX or Open Ephys - acq_software: varchar(24) - """ - contents = zip(["SpikeGLX", "Open Ephys"]) - - -@schema -class ProbeInsertion(dj.Manual): - """Information about probe insertion across subjects and sessions. - - Attributes: - Session (foreign key): Session primary key. - insertion_number (foreign key, str): Unique insertion number for each probe insertion for a given session. - probe.Probe (str): probe.Probe primary key. - """ - - definition = """ - # Probe insertion implanted into an animal for a given session. - -> Session - insertion_number: tinyint unsigned - --- - -> probe.Probe - """ - - -@schema -class InsertionLocation(dj.Manual): - """Stereotaxic location information for each probe insertion. - - Attributes: - ProbeInsertion (foreign key): ProbeInsertion primary key. - SkullReference (dict): SkullReference primary key. - ap_location (decimal (6, 2) ): Anterior-posterior location in micrometers. Reference is 0 with anterior values positive. - ml_location (decimal (6, 2) ): Medial-lateral location in micrometers. Reference is zero with right side values positive. - depth (decimal (6, 2) ): Manipulator depth relative to the surface of the brain at zero. Ventral is negative. - Theta (decimal (5, 2) ): elevation - rotation about the ml-axis in degrees relative to positive z-axis. - phi (decimal (5, 2) ): azimuth - rotation about the dv-axis in degrees relative to the positive x-axis - - """ - - definition = """ - # Brain Location of a given probe insertion. - -> ProbeInsertion - --- - -> SkullReference - ap_location: decimal(6, 2) # (um) anterior-posterior; ref is 0; more anterior is more positive - ml_location: decimal(6, 2) # (um) medial axis; ref is 0 ; more right is more positive - depth: decimal(6, 2) # (um) manipulator depth relative to surface of the brain (0); more ventral is more negative - theta=null: decimal(5, 2) # (deg) - elevation - rotation about the ml-axis [0, 180] - w.r.t the z+ axis - phi=null: decimal(5, 2) # (deg) - azimuth - rotation about the dv-axis [0, 360] - w.r.t the x+ axis - beta=null: decimal(5, 2) # (deg) rotation about the shank of the probe [-180, 180] - clockwise is increasing in degree - 0 is the probe-front facing anterior - """ - - -@schema -class EphysRecording(dj.Imported): - """Automated table with electrophysiology recording information for each probe inserted during an experimental session. - - Attributes: - ProbeInsertion (foreign key): ProbeInsertion primary key. - probe.ElectrodeConfig (dict): probe.ElectrodeConfig primary key. - AcquisitionSoftware (dict): AcquisitionSoftware primary key. - sampling_rate (float): sampling rate of the recording in Hertz (Hz). - recording_datetime (datetime): datetime of the recording from this probe. - recording_duration (float): duration of the entire recording from this probe in seconds. - """ - - definition = """ - # Ephys recording from a probe insertion for a given session. - -> ProbeInsertion - --- - -> probe.ElectrodeConfig - -> AcquisitionSoftware - sampling_rate: float # (Hz) - recording_datetime: datetime # datetime of the recording from this probe - recording_duration: float # (seconds) duration of the recording from this probe - """ - - class EphysFile(dj.Part): - """Paths of electrophysiology recording files for each insertion. - - Attributes: - EphysRecording (foreign key): EphysRecording primary key. - file_path (varchar(255) ): relative file path for electrophysiology recording. - """ - - definition = """ - # Paths of files of a given EphysRecording round. - -> master - file_path: varchar(255) # filepath relative to root data directory - """ - - def make(self, key): - """Populates table with electrophysiology recording information.""" - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(key) - ) - - inserted_probe_serial_number = (ProbeInsertion * probe.Probe & key).fetch1( - "probe" - ) - - # search session dir and determine acquisition software - for ephys_pattern, ephys_acq_type in ( - ("*.ap.meta", "SpikeGLX"), - ("*.oebin", "Open Ephys"), - ): - ephys_meta_filepaths = [fp for fp in session_dir.rglob(ephys_pattern)] - if ephys_meta_filepaths: - acq_software = ephys_acq_type - break - else: - raise FileNotFoundError( - f"Ephys recording data not found!" - f" Neither SpikeGLX nor Open Ephys recording files found" - f" in {session_dir}" - ) - - if acq_software == "SpikeGLX": - for meta_filepath in ephys_meta_filepaths: - spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) - if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: - break - else: - raise FileNotFoundError( - "No SpikeGLX data found for probe insertion: {}".format(key) - ) - - if re.search("(1.0|2.0)", spikeglx_meta.probe_model): - probe_type = spikeglx_meta.probe_model - electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} - - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - - electrode_group_members = [ - probe_electrodes[(shank, shank_col, shank_row)] - for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"] - ] - else: - raise NotImplementedError( - "Processing for neuropixels probe model" - " {} not yet implemented".format(spikeglx_meta.probe_model) - ) - - self.insert1( - { - **key, - **generate_electrode_config(probe_type, electrode_group_members), - "acq_software": acq_software, - "sampling_rate": spikeglx_meta.meta["imSampRate"], - "recording_datetime": spikeglx_meta.recording_time, - "recording_duration": ( - spikeglx_meta.recording_duration - or spikeglx.retrieve_recording_duration(meta_filepath) - ), - } - ) - - root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath) - self.EphysFile.insert1( - {**key, "file_path": meta_filepath.relative_to(root_dir).as_posix()} - ) - elif acq_software == "Open Ephys": - dataset = openephys.OpenEphys(session_dir) - for serial_number, probe_data in dataset.probes.items(): - if str(serial_number) == inserted_probe_serial_number: - break - else: - raise FileNotFoundError( - "No Open Ephys data found for probe insertion: {}".format(key) - ) - - if re.search("(1.0|2.0)", probe_data.probe_model): - probe_type = probe_data.probe_model - electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} - - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - electrode_group_members = [ - probe_electrodes[channel_idx] - for channel_idx in probe_data.ap_meta["channels_ids"] - ] - else: - raise NotImplementedError( - "Processing for neuropixels" - " probe model {} not yet implemented".format(probe_data.probe_model) - ) - - self.insert1( - { - **key, - **generate_electrode_config(probe_type, electrode_group_members), - "acq_software": acq_software, - "sampling_rate": probe_data.ap_meta["sample_rate"], - "recording_datetime": probe_data.recording_info[ - "recording_datetimes" - ][0], - "recording_duration": np.sum( - probe_data.recording_info["recording_durations"] - ), - } - ) - - root_dir = find_root_directory( - get_ephys_root_data_dir(), - probe_data.recording_info["recording_files"][0], - ) - self.EphysFile.insert( - [ - {**key, "file_path": fp.relative_to(root_dir).as_posix()} - for fp in probe_data.recording_info["recording_files"] - ] - ) - else: - raise NotImplementedError( - f"Processing ephys files from" - f" acquisition software of type {acq_software} is" - f" not yet implemented" - ) - - -@schema -class PreClusterMethod(dj.Lookup): - """Pre-clustering method - - Attributes: - precluster_method (foreign key, varchar(16) ): Pre-clustering method for the dataset. - precluster_method_desc(varchar(1000) ): Pre-clustering method description. - """ - - definition = """ - # Method for pre-clustering - precluster_method: varchar(16) - --- - precluster_method_desc: varchar(1000) - """ - - contents = [("catgt", "Time shift, Common average referencing, Zeroing")] - - -@schema -class PreClusterParamSet(dj.Lookup): - """Parameters for the pre-clustering method. - - Attributes: - paramset_idx (foreign key): Unique parameter set ID. - PreClusterMethod (dict): PreClusterMethod query for this dataset. - paramset_desc (varchar(128) ): Description for the pre-clustering parameter set. - param_set_hash (uuid): Unique hash for parameter set. - params (longblob): All parameters for the pre-clustering method. - """ - - definition = """ - # Parameter set to be used in a clustering procedure - paramset_idx: smallint - --- - -> PreClusterMethod - paramset_desc: varchar(128) - param_set_hash: uuid - unique index (param_set_hash) - params: longblob # dictionary of all applicable parameters - """ - - @classmethod - def insert_new_params( - cls, precluster_method: str, paramset_idx: int, paramset_desc: str, params: dict - ): - param_dict = { - "precluster_method": precluster_method, - "paramset_idx": paramset_idx, - "paramset_desc": paramset_desc, - "params": params, - "param_set_hash": dict_to_uuid(params), - } - param_query = cls & {"param_set_hash": param_dict["param_set_hash"]} - - if param_query: # If the specified param-set already exists - existing_paramset_idx = param_query.fetch1("paramset_idx") - if ( - existing_paramset_idx == paramset_idx - ): # If the existing set has the same paramset_idx: job done - return - else: # If not same name: human error, trying to add the same paramset with different name - raise dj.DataJointError( - "The specified param-set" - " already exists - paramset_idx: {}".format(existing_paramset_idx) - ) - else: - cls.insert1(param_dict) - - -@schema -class PreClusterParamSteps(dj.Manual): - """Ordered list of parameter sets that will be run. - - Attributes: - precluster_param_steps_id (foreign key): Unique ID for the pre-clustering parameter sets to be run. - precluster_param_steps_name (varchar(32) ): User-friendly name for the parameter steps. - precluster_param_steps_desc (varchar(128) ): Description of the parameter steps. - """ - - definition = """ - # Ordered list of paramset_idx that are to be run - # When pre-clustering is not performed, do not create an entry in `Step` Part table - precluster_param_steps_id: smallint - --- - precluster_param_steps_name: varchar(32) - precluster_param_steps_desc: varchar(128) - """ - - class Step(dj.Part): - """Define the order of operations for parameter sets. - - Attributes: - PreClusterParamSteps (foreign key): PreClusterParamSteps primary key. - step_number (foreign key, smallint): Order of operations. - PreClusterParamSet (dict): PreClusterParamSet to be used in pre-clustering. - """ - - definition = """ - -> master - step_number: smallint # Order of operations - --- - -> PreClusterParamSet - """ - - -@schema -class PreClusterTask(dj.Manual): - """Defines a pre-clustering task ready to be run. - - Attributes: - EphysRecording (foreign key): EphysRecording primary key. - PreclusterParamSteps (foreign key): PreClusterParam Steps primary key. - precluster_output_dir (varchar(255) ): relative path to directory for storing results of pre-clustering. - task_mode (enum ): `none` (no pre-clustering), `load` results from file, or `trigger` automated pre-clustering. - """ - - definition = """ - # Manual table for defining a clustering task ready to be run - -> EphysRecording - -> PreClusterParamSteps - --- - precluster_output_dir='': varchar(255) # pre-clustering output directory relative to the root data directory - task_mode='none': enum('none','load', 'trigger') # 'none': no pre-clustering analysis - # 'load': load analysis results - # 'trigger': trigger computation - """ - - -@schema -class PreCluster(dj.Imported): - """ - A processing table to handle each PreClusterTask: - - Attributes: - PreClusterTask (foreign key): PreClusterTask primary key. - precluster_time (datetime): Time of generation of this set of pre-clustering results. - package_version (varchar(16) ): Package version used for performing pre-clustering. - """ - - definition = """ - -> PreClusterTask - --- - precluster_time: datetime # time of generation of this set of pre-clustering results - package_version='': varchar(16) - """ - - def make(self, key): - """Populate pre-clustering tables.""" - task_mode, output_dir = (PreClusterTask & key).fetch1( - "task_mode", "precluster_output_dir" - ) - precluster_output_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - if task_mode == "none": - if len((PreClusterParamSteps.Step & key).fetch()) > 0: - raise ValueError( - "There are entries in the PreClusterParamSteps.Step " - "table and task_mode=none" - ) - creation_time = (EphysRecording & key).fetch1("recording_datetime") - elif task_mode == "load": - acq_software = (EphysRecording & key).fetch1("acq_software") - inserted_probe_serial_number = (ProbeInsertion * probe.Probe & key).fetch1( - "probe" - ) - - if acq_software == "SpikeGLX": - for meta_filepath in precluster_output_dir.rglob("*.ap.meta"): - spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) - - if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: - creation_time = spikeglx_meta.recording_time - break - else: - raise FileNotFoundError( - "No SpikeGLX data found for probe insertion: {}".format(key) - ) - else: - raise NotImplementedError( - f"Pre-clustering analysis of {acq_software}" "is not yet supported." - ) - elif task_mode == "trigger": - raise NotImplementedError( - "Automatic triggering of" - " pre-clustering analysis is not yet supported." - ) - else: - raise ValueError(f"Unknown task mode: {task_mode}") - - self.insert1({**key, "precluster_time": creation_time, "package_version": ""}) - - -@schema -class LFP(dj.Imported): - """Extracts local field potentials (LFP) from an electrophysiology recording. - - Attributes: - EphysRecording (foreign key): EphysRecording primary key. - lfp_sampling_rate (float): Sampling rate for LFPs in Hz. - lfp_time_stamps (longblob): Time stamps with respect to the start of the recording. - lfp_mean (longblob): Overall mean LFP across electrodes. - """ - - definition = """ - # Acquired local field potential (LFP) from a given Ephys recording. - -> PreCluster - --- - lfp_sampling_rate: float # (Hz) - lfp_time_stamps: longblob # (s) timestamps with respect to the start of the recording (recording_timestamp) - lfp_mean: longblob # (uV) mean of LFP across electrodes - shape (time,) - """ - - class Electrode(dj.Part): - """Saves local field potential data for each electrode. - - Attributes: - LFP (foreign key): LFP primary key. - probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key. - lfp (longblob): LFP recording at this electrode in microvolts. - """ - - definition = """ - -> master - -> probe.ElectrodeConfig.Electrode - --- - lfp: longblob # (uV) recorded lfp at this electrode - """ - - # Only store LFP for every 9th channel, due to high channel density, - # close-by channels exhibit highly similar LFP - _skip_channel_counts = 9 - - def make(self, key): - """Populates the LFP tables.""" - acq_software, probe_sn = (EphysRecording * ProbeInsertion & key).fetch1( - "acq_software", "probe" - ) - - electrode_keys, lfp = [], [] - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - - lfp_channel_ind = spikeglx_recording.lfmeta.recording_channels[ - -1 :: -self._skip_channel_counts - ] - - # Extract LFP data at specified channels and convert to uV - lfp = spikeglx_recording.lf_timeseries[ - :, lfp_channel_ind - ] # (sample x channel) - lfp = ( - lfp * spikeglx_recording.get_channel_bit_volts("lf")[lfp_channel_ind] - ).T # (channel x sample) - - self.insert1( - dict( - key, - lfp_sampling_rate=spikeglx_recording.lfmeta.meta["imSampRate"], - lfp_time_stamps=( - np.arange(lfp.shape[1]) - / spikeglx_recording.lfmeta.meta["imSampRate"] - ), - lfp_mean=lfp.mean(axis=0), - ) - ) - - electrode_query = ( - probe.ProbeType.Electrode - * probe.ElectrodeConfig.Electrode - * EphysRecording - & key - ) - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - - for recorded_site in lfp_channel_ind: - shank, shank_col, shank_row, _ = spikeglx_recording.apmeta.shankmap[ - "data" - ][recorded_site] - electrode_keys.append(probe_electrodes[(shank, shank_col, shank_row)]) - elif acq_software == "Open Ephys": - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(key) - ) - - loaded_oe = openephys.OpenEphys(session_dir) - oe_probe = loaded_oe.probes[probe_sn] - - lfp_channel_ind = np.arange(len(oe_probe.lfp_meta["channels_ids"]))[ - -1 :: -self._skip_channel_counts - ] - - lfp = oe_probe.lfp_timeseries[:, lfp_channel_ind] # (sample x channel) - lfp = ( - lfp * np.array(oe_probe.lfp_meta["channels_gains"])[lfp_channel_ind] - ).T # (channel x sample) - lfp_timestamps = oe_probe.lfp_timestamps - - self.insert1( - dict( - key, - lfp_sampling_rate=oe_probe.lfp_meta["sample_rate"], - lfp_time_stamps=lfp_timestamps, - lfp_mean=lfp.mean(axis=0), - ) - ) - - electrode_query = ( - probe.ProbeType.Electrode - * probe.ElectrodeConfig.Electrode - * EphysRecording - & key - ) - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - for channel_idx in np.array(oe_probe.lfp_meta["channels_ids"])[ - lfp_channel_ind - ]: - electrode_keys.append(probe_electrodes[channel_idx]) - else: - raise NotImplementedError( - f"LFP extraction from acquisition software" - f" of type {acq_software} is not yet implemented" - ) - - # single insert in loop to mitigate potential memory issue - for electrode_key, lfp_trace in zip(electrode_keys, lfp): - self.Electrode.insert1({**key, **electrode_key, "lfp": lfp_trace}) - - -# ------------ Clustering -------------- - - -@schema -class ClusteringMethod(dj.Lookup): - """Kilosort clustering method. - - Attributes: - clustering_method (foreign key, varchar(16) ): Kilosort clustering method. - clustering_methods_desc (varchar(1000) ): Additional description of the clustering method. - """ - - definition = """ - # Method for clustering - clustering_method: varchar(16) - --- - clustering_method_desc: varchar(1000) - """ - - contents = [ - ("kilosort", "kilosort clustering method"), - ("kilosort2", "kilosort2 clustering method"), - ] - - -@schema -class ClusteringParamSet(dj.Lookup): - """Parameters to be used in clustering procedure for spike sorting. - - Attributes: - paramset_idx (foreign key): Unique ID for the clustering parameter set. - ClusteringMethod (dict): ClusteringMethod primary key. - paramset_desc (varchar(128) ): Description of the clustering parameter set. - param_set_hash (uuid): UUID hash for the parameter set. - params (longblob): Paramset, dictionary of all applicable parameters. - """ - - definition = """ - # Parameter set to be used in a clustering procedure - paramset_idx: smallint - --- - -> ClusteringMethod - paramset_desc: varchar(128) - param_set_hash: uuid - unique index (param_set_hash) - params: longblob # dictionary of all applicable parameters - """ - - @classmethod - def insert_new_params( - cls, processing_method: str, paramset_idx: int, paramset_desc: str, params: dict - ): - """Inserts new parameters into the ClusteringParamSet table. - - Args: - processing_method (str): name of the clustering method. - paramset_desc (str): description of the parameter set - params (dict): clustering parameters - paramset_idx (int, optional): Unique parameter set ID. Defaults to None. - """ - param_dict = { - "clustering_method": processing_method, - "paramset_idx": paramset_idx, - "paramset_desc": paramset_desc, - "params": params, - "param_set_hash": dict_to_uuid(params), - } - param_query = cls & {"param_set_hash": param_dict["param_set_hash"]} - - if param_query: # If the specified param-set already exists - existing_paramset_idx = param_query.fetch1("paramset_idx") - if ( - existing_paramset_idx == paramset_idx - ): # If the existing set has the same paramset_idx: job done - return - else: # If not same name: human error, trying to add the same paramset with different name - raise dj.DataJointError( - "The specified param-set" - " already exists - paramset_idx: {}".format(existing_paramset_idx) - ) - else: - cls.insert1(param_dict) - - -@schema -class ClusterQualityLabel(dj.Lookup): - """Quality label for each spike sorted cluster. - - Attributes: - cluster_quality_label (foreign key, varchar(100) ): Cluster quality type. - cluster_quality_description (varchar(4000) ): Description of the cluster quality type. - """ - - definition = """ - # Quality - cluster_quality_label: varchar(100) - --- - cluster_quality_description: varchar(4000) - """ - contents = [ - ("good", "single unit"), - ("ok", "probably a single unit, but could be contaminated"), - ("mua", "multi-unit activity"), - ("noise", "bad unit"), - ] - - -@schema -class ClusteringTask(dj.Manual): - """A clustering task to spike sort electrophysiology datasets. - - Attributes: - EphysRecording (foreign key): EphysRecording primary key. - ClusteringParamSet (foreign key): ClusteringParamSet primary key. - clustering_outdir_dir (varchar (255) ): Relative path to output clustering results. - task_mode (enum): `Trigger` computes clustering or and `load` imports existing data. - """ - - definition = """ - # Manual table for defining a clustering task ready to be run - -> PreCluster - -> ClusteringParamSet - --- - clustering_output_dir: varchar(255) # clustering output directory relative to the clustering root data directory - task_mode='load': enum('load', 'trigger') # 'load': load computed analysis results, 'trigger': trigger computation - """ - - -@schema -class Clustering(dj.Imported): - """A processing table to handle each clustering task. - - Attributes: - ClusteringTask (foreign key): ClusteringTask primary key. - clustering_time (datetime): Time when clustering results are generated. - package_version (varchar(16) ): Package version used for a clustering analysis. - """ - - definition = """ - # Clustering Procedure - -> ClusteringTask - --- - clustering_time: datetime # time of generation of this set of clustering results - package_version='': varchar(16) - """ - - def make(self, key): - """Triggers or imports clustering analysis.""" - task_mode, output_dir = (ClusteringTask & key).fetch1( - "task_mode", "clustering_output_dir" - ) - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - if task_mode == "load": - _ = kilosort.Kilosort( - kilosort_dir - ) # check if the directory is a valid Kilosort output - creation_time, _, _ = kilosort.extract_clustering_info(kilosort_dir) - elif task_mode == "trigger": - raise NotImplementedError( - "Automatic triggering of" " clustering analysis is not yet supported" - ) - else: - raise ValueError(f"Unknown task mode: {task_mode}") - - self.insert1({**key, "clustering_time": creation_time, "package_version": ""}) - - -@schema -class Curation(dj.Manual): - """Curation procedure table. - - Attributes: - Clustering (foreign key): Clustering primary key. - curation_id (foreign key, int): Unique curation ID. - curation_time (datetime): Time when curation results are generated. - curation_output_dir (varchar(255) ): Output directory of the curated results. - quality_control (bool): If True, this clustering result has undergone quality control. - manual_curation (bool): If True, manual curation has been performed on this clustering result. - curation_note (varchar(2000) ): Notes about the curation task. - """ - - definition = """ - # Manual curation procedure - -> Clustering - curation_id: int - --- - curation_time: datetime # time of generation of this set of curated clustering results - curation_output_dir: varchar(255) # output directory of the curated results, relative to root data directory - quality_control: bool # has this clustering result undergone quality control? - manual_curation: bool # has manual curation been performed on this clustering result? - curation_note='': varchar(2000) - """ - - def create1_from_clustering_task(self, key, curation_note: str = ""): - """ - A function to create a new corresponding "Curation" for a particular - "ClusteringTask" - """ - if key not in Clustering(): - raise ValueError( - f"No corresponding entry in Clustering available" - f" for: {key}; do `Clustering.populate(key)`" - ) - - task_mode, output_dir = (ClusteringTask & key).fetch1( - "task_mode", "clustering_output_dir" - ) - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - creation_time, is_curated, is_qc = kilosort.extract_clustering_info( - kilosort_dir - ) - # Synthesize curation_id - curation_id = ( - dj.U().aggr(self & key, n="ifnull(max(curation_id)+1,1)").fetch1("n") - ) - self.insert1( - { - **key, - "curation_id": curation_id, - "curation_time": creation_time, - "curation_output_dir": output_dir, - "quality_control": is_qc, - "manual_curation": is_curated, - "curation_note": curation_note, - } - ) - - -@schema -class CuratedClustering(dj.Imported): - """Clustering results after curation. - - Attributes: - Curation (foreign key): Curation primary key. - """ - - definition = """ - # Clustering results of a curation. - -> Curation - """ - - class Unit(dj.Part): - """Single unit properties after clustering and curation. - - Attributes: - CuratedClustering (foreign key): CuratedClustering primary key. - unit (foreign key, int): Unique integer identifying a single unit. - probe.ElectrodeConfig.Electrode (dict): probe.ElectrodeConfig.Electrode primary key. - ClusteringQualityLabel (dict): CLusteringQualityLabel primary key. - spike_count (int): Number of spikes in this recording for this unit. - spike_times (longblob): Spike times of this unit, relative to start time of EphysRecording. - spike_sites (longblob): Array of electrode associated with each spike. - spike_depths (longblob): Array of depths associated with each spike, relative to each spike. - """ - - definition = """ - # Properties of a given unit from a round of clustering (and curation) - -> master - unit: int - --- - -> probe.ElectrodeConfig.Electrode # electrode with highest waveform amplitude for this unit - -> ClusterQualityLabel - spike_count: int # how many spikes in this recording for this unit - spike_times: longblob # (s) spike times of this unit, relative to the start of the EphysRecording - spike_sites : longblob # array of electrode associated with each spike - spike_depths=null : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe - """ - - def make(self, key): - """Automated population of Unit information.""" - output_dir = (Curation & key).fetch1("curation_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - kilosort_dataset = kilosort.Kilosort(kilosort_dir) - acq_software = (EphysRecording & key).fetch1("acq_software") - - # ---------- Unit ---------- - # -- Remove 0-spike units - withspike_idx = [ - i - for i, u in enumerate(kilosort_dataset.data["cluster_ids"]) - if (kilosort_dataset.data["spike_clusters"] == u).any() - ] - valid_units = kilosort_dataset.data["cluster_ids"][withspike_idx] - valid_unit_labels = kilosort_dataset.data["cluster_groups"][withspike_idx] - # -- Get channel and electrode-site mapping - channel2electrodes = get_neuropixels_channel2electrode_map(key, acq_software) - - # -- Spike-times -- - # spike_times_sec_adj > spike_times_sec > spike_times - spike_time_key = ( - "spike_times_sec_adj" - if "spike_times_sec_adj" in kilosort_dataset.data - else ( - "spike_times_sec" - if "spike_times_sec" in kilosort_dataset.data - else "spike_times" - ) - ) - spike_times = kilosort_dataset.data[spike_time_key] - kilosort_dataset.extract_spike_depths() - - # -- Spike-sites and Spike-depths -- - spike_sites = np.array( - [ - channel2electrodes[s]["electrode"] - for s in kilosort_dataset.data["spike_sites"] - ] - ) - spike_depths = kilosort_dataset.data["spike_depths"] - - # -- Insert unit, label, peak-chn - units = [] - for unit, unit_lbl in zip(valid_units, valid_unit_labels): - if (kilosort_dataset.data["spike_clusters"] == unit).any(): - unit_channel, _ = kilosort_dataset.get_best_channel(unit) - unit_spike_times = ( - spike_times[kilosort_dataset.data["spike_clusters"] == unit] - / kilosort_dataset.data["params"]["sample_rate"] - ) - spike_count = len(unit_spike_times) - - units.append( - { - "unit": unit, - "cluster_quality_label": unit_lbl, - **channel2electrodes[unit_channel], - "spike_times": unit_spike_times, - "spike_count": spike_count, - "spike_sites": spike_sites[ - kilosort_dataset.data["spike_clusters"] == unit - ], - "spike_depths": ( - spike_depths[ - kilosort_dataset.data["spike_clusters"] == unit - ] - if spike_depths is not None - else None - ), - } - ) - - self.insert1(key) - self.Unit.insert([{**key, **u} for u in units]) - - -@schema -class WaveformSet(dj.Imported): - """A set of spike waveforms for units out of a given CuratedClustering. - - Attributes: - CuratedClustering (foreign key): CuratedClustering primary key. - """ - - definition = """ - # A set of spike waveforms for units out of a given CuratedClustering - -> CuratedClustering - """ - - class PeakWaveform(dj.Part): - """Mean waveform across spikes for a given unit. - - Attributes: - WaveformSet (foreign key): WaveformSet primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - peak_electrode_waveform (longblob): Mean waveform for a given unit at its representative electrode. - """ - - definition = """ - # Mean waveform across spikes for a given unit at its representative electrode - -> master - -> CuratedClustering.Unit - --- - peak_electrode_waveform: longblob # (uV) mean waveform for a given unit at its representative electrode - """ - - class Waveform(dj.Part): - """Spike waveforms for a given unit. - - Attributes: - WaveformSet (foreign key): WaveformSet primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key. - waveform_mean (longblob): mean waveform across spikes of the unit in microvolts. - waveforms (longblob): waveforms of a sampling of spikes at the given electrode and unit. - """ - - definition = """ - # Spike waveforms and their mean across spikes for the given unit - -> master - -> CuratedClustering.Unit - -> probe.ElectrodeConfig.Electrode - --- - waveform_mean: longblob # (uV) mean waveform across spikes of the given unit - waveforms=null: longblob # (uV) (spike x sample) waveforms of a sampling of spikes at the given electrode for the given unit - """ - - def make(self, key): - """Populates waveform tables.""" - output_dir = (Curation & key).fetch1("curation_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - kilosort_dataset = kilosort.Kilosort(kilosort_dir) - - acq_software, probe_serial_number = ( - EphysRecording * ProbeInsertion & key - ).fetch1("acq_software", "probe") - - # -- Get channel and electrode-site mapping - recording_key = (EphysRecording & key).fetch1("KEY") - channel2electrodes = get_neuropixels_channel2electrode_map( - recording_key, acq_software - ) - - is_qc = (Curation & key).fetch1("quality_control") - - # Get all units - units = { - u["unit"]: u - for u in (CuratedClustering.Unit & key).fetch(as_dict=True, order_by="unit") - } - - if is_qc: - unit_waveforms = np.load( - kilosort_dir / "mean_waveforms.npy" - ) # unit x channel x sample - - def yield_unit_waveforms(): - for unit_no, unit_waveform in zip( - kilosort_dataset.data["cluster_ids"], unit_waveforms - ): - unit_peak_waveform = {} - unit_electrode_waveforms = [] - if unit_no in units: - for channel, channel_waveform in zip( - kilosort_dataset.data["channel_map"], unit_waveform - ): - unit_electrode_waveforms.append( - { - **units[unit_no], - **channel2electrodes[channel], - "waveform_mean": channel_waveform, - } - ) - if ( - channel2electrodes[channel]["electrode"] - == units[unit_no]["electrode"] - ): - unit_peak_waveform = { - **units[unit_no], - "peak_electrode_waveform": channel_waveform, - } - yield unit_peak_waveform, unit_electrode_waveforms - - else: - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) - neuropixels_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - elif acq_software == "Open Ephys": - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(key) - ) - openephys_dataset = openephys.OpenEphys(session_dir) - neuropixels_recording = openephys_dataset.probes[probe_serial_number] - - def yield_unit_waveforms(): - for unit_dict in units.values(): - unit_peak_waveform = {} - unit_electrode_waveforms = [] - - spikes = unit_dict["spike_times"] - waveforms = neuropixels_recording.extract_spike_waveforms( - spikes, kilosort_dataset.data["channel_map"] - ) # (sample x channel x spike) - waveforms = waveforms.transpose( - (1, 2, 0) - ) # (channel x spike x sample) - for channel, channel_waveform in zip( - kilosort_dataset.data["channel_map"], waveforms - ): - unit_electrode_waveforms.append( - { - **unit_dict, - **channel2electrodes[channel], - "waveform_mean": channel_waveform.mean(axis=0), - "waveforms": channel_waveform, - } - ) - if ( - channel2electrodes[channel]["electrode"] - == unit_dict["electrode"] - ): - unit_peak_waveform = { - **unit_dict, - "peak_electrode_waveform": channel_waveform.mean( - axis=0 - ), - } - - yield unit_peak_waveform, unit_electrode_waveforms - - # insert waveform on a per-unit basis to mitigate potential memory issue - self.insert1(key) - for unit_peak_waveform, unit_electrode_waveforms in yield_unit_waveforms(): - self.PeakWaveform.insert1(unit_peak_waveform, ignore_extra_fields=True) - self.Waveform.insert(unit_electrode_waveforms, ignore_extra_fields=True) - - -@schema -class QualityMetrics(dj.Imported): - """Clustering and waveform quality metrics. - - Attributes: - CuratedClustering (foreign key): CuratedClustering primary key. - """ - - definition = """ - # Clusters and waveforms metrics - -> CuratedClustering - """ - - class Cluster(dj.Part): - """Cluster metrics for a unit. - - Attributes: - QualityMetrics (foreign key): QualityMetrics primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - firing_rate (float): Firing rate of the unit. - snr (float): Signal-to-noise ratio for a unit. - presence_ratio (float): Fraction of time where spikes are present. - isi_violation (float): rate of ISI violation as a fraction of overall rate. - number_violation (int): Total ISI violations. - amplitude_cutoff (float): Estimate of miss rate based on amplitude histogram. - isolation_distance (float): Distance to nearest cluster. - l_ratio (float): Amount of empty space between a cluster and other spikes in dataset. - d_prime (float): Classification accuracy based on LDA. - nn_hit_rate (float): Fraction of neighbors for target cluster that are also in target cluster. - nn_miss_rate (float): Fraction of neighbors outside target cluster that are in the target cluster. - silhouette_core (float): Maximum change in spike depth throughout recording. - cumulative_drift (float): Cumulative change in spike depth throughout recording. - contamination_rate (float): Frequency of spikes in the refractory period. - """ - - definition = """ - # Cluster metrics for a particular unit - -> master - -> CuratedClustering.Unit - --- - firing_rate=null: float # (Hz) firing rate for a unit - snr=null: float # signal-to-noise ratio for a unit - presence_ratio=null: float # fraction of time in which spikes are present - isi_violation=null: float # rate of ISI violation as a fraction of overall rate - number_violation=null: int # total number of ISI violations - amplitude_cutoff=null: float # estimate of miss rate based on amplitude histogram - isolation_distance=null: float # distance to nearest cluster in Mahalanobis space - l_ratio=null: float # - d_prime=null: float # Classification accuracy based on LDA - nn_hit_rate=null: float # Fraction of neighbors for target cluster that are also in target cluster - nn_miss_rate=null: float # Fraction of neighbors outside target cluster that are in target cluster - silhouette_score=null: float # Standard metric for cluster overlap - max_drift=null: float # Maximum change in spike depth throughout recording - cumulative_drift=null: float # Cumulative change in spike depth throughout recording - contamination_rate=null: float # - """ - - class Waveform(dj.Part): - """Waveform metrics for a particular unit. - - Attributes: - QualityMetrics (foreign key): QualityMetrics primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - amplitude (float): Absolute difference between waveform peak and trough in microvolts. - duration (float): Time between waveform peak and trough in milliseconds. - halfwidth (float): Spike width at half max amplitude. - pt_ratio (float): Absolute amplitude of peak divided by absolute amplitude of trough relative to 0. - repolarization_slope (float): Slope of the regression line fit to first 30 microseconds from trough to peak. - recovery_slope (float): Slope of the regression line fit to first 30 microseconds from peak to tail. - spread (float): The range with amplitude over 12-percent of maximum amplitude along the probe. - velocity_above (float): inverse velocity of waveform propagation from soma to the top of the probe. - velocity_below (float): inverse velocity of waveform propagation from soma toward the bottom of the probe. - """ - - definition = """ - # Waveform metrics for a particular unit - -> master - -> CuratedClustering.Unit - --- - amplitude: float # (uV) absolute difference between waveform peak and trough - duration: float # (ms) time between waveform peak and trough - halfwidth=null: float # (ms) spike width at half max amplitude - pt_ratio=null: float # absolute amplitude of peak divided by absolute amplitude of trough relative to 0 - repolarization_slope=null: float # the repolarization slope was defined by fitting a regression line to the first 30us from trough to peak - recovery_slope=null: float # the recovery slope was defined by fitting a regression line to the first 30us from peak to tail - spread=null: float # (um) the range with amplitude above 12-percent of the maximum amplitude along the probe - velocity_above=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the top of the probe - velocity_below=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the bottom of the probe - """ - - def make(self, key): - """Populates tables with quality metrics data.""" - output_dir = (ClusteringTask & key).fetch1("clustering_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - metric_fp = kilosort_dir / "metrics.csv" - rename_dict = { - "isi_viol": "isi_violation", - "num_viol": "number_violation", - "contam_rate": "contamination_rate", - } - - if not metric_fp.exists(): - raise FileNotFoundError(f"QC metrics file not found: {metric_fp}") - - metrics_df = pd.read_csv(metric_fp) - metrics_df.set_index("cluster_id", inplace=True) - metrics_df.replace([np.inf, -np.inf], np.nan, inplace=True) - metrics_df.columns = metrics_df.columns.str.lower() - metrics_df.rename(columns=rename_dict, inplace=True) - metrics_list = [ - dict(metrics_df.loc[unit_key["unit"]], **unit_key) - for unit_key in (CuratedClustering.Unit & key).fetch("KEY") - ] - - self.insert1(key) - self.Cluster.insert(metrics_list, ignore_extra_fields=True) - self.Waveform.insert(metrics_list, ignore_extra_fields=True) - - -# ---------------- HELPER FUNCTIONS ---------------- - - -def get_spikeglx_meta_filepath(ephys_recording_key: dict) -> str: - """Get spikeGLX data filepath.""" - # attempt to retrieve from EphysRecording.EphysFile - spikeglx_meta_filepath = ( - EphysRecording.EphysFile & ephys_recording_key & 'file_path LIKE "%.ap.meta"' - ).fetch1("file_path") - - try: - spikeglx_meta_filepath = find_full_path( - get_ephys_root_data_dir(), spikeglx_meta_filepath - ) - except FileNotFoundError: - # if not found, search in session_dir again - if not spikeglx_meta_filepath.exists(): - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(ephys_recording_key) - ) - inserted_probe_serial_number = ( - ProbeInsertion * probe.Probe & ephys_recording_key - ).fetch1("probe") - - spikeglx_meta_filepaths = [fp for fp in session_dir.rglob("*.ap.meta")] - for meta_filepath in spikeglx_meta_filepaths: - spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) - if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: - spikeglx_meta_filepath = meta_filepath - break - else: - raise FileNotFoundError( - "No SpikeGLX data found for probe insertion: {}".format( - ephys_recording_key - ) - ) - - return spikeglx_meta_filepath - - -def get_neuropixels_channel2electrode_map( - ephys_recording_key: dict, acq_software: str -) -> dict: - """Get the channel map for neuropixels probe.""" - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key) - spikeglx_meta = spikeglx.SpikeGLXMeta(spikeglx_meta_filepath) - electrode_config_key = ( - EphysRecording * probe.ElectrodeConfig & ephys_recording_key - ).fetch1("KEY") - - electrode_query = ( - probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode - & electrode_config_key - ) - - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - - channel2electrode_map = { - recorded_site: probe_electrodes[(shank, shank_col, shank_row)] - for recorded_site, (shank, shank_col, shank_row, _) in enumerate( - spikeglx_meta.shankmap["data"] - ) - } - elif acq_software == "Open Ephys": - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(ephys_recording_key) - ) - openephys_dataset = openephys.OpenEphys(session_dir) - probe_serial_number = (ProbeInsertion & ephys_recording_key).fetch1("probe") - probe_dataset = openephys_dataset.probes[probe_serial_number] - - electrode_query = ( - probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode * EphysRecording - & ephys_recording_key - ) - - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - channel2electrode_map = { - channel_idx: probe_electrodes[channel_idx] - for channel_idx in probe_dataset.ap_meta["channels_ids"] - } - - return channel2electrode_map - - -def generate_electrode_config(probe_type: str, electrode_keys: list) -> dict: - """Generate and insert new ElectrodeConfig - - Args: - probe_type (str): probe type (e.g. neuropixels 2.0 - SS) - electrode_keys (list): list of keys of the probe.ProbeType.Electrode table - - Returns: - dict: representing a key of the probe.ElectrodeConfig table - """ - # compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode) - electrode_config_hash = dict_to_uuid({k["electrode"]: k for k in electrode_keys}) - - electrode_list = sorted([k["electrode"] for k in electrode_keys]) - electrode_gaps = ( - [-1] - + np.where(np.diff(electrode_list) > 1)[0].tolist() - + [len(electrode_list) - 1] - ) - electrode_config_name = "; ".join( - [ - f"{electrode_list[start + 1]}-{electrode_list[end]}" - for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:]) - ] - ) - - electrode_config_key = {"electrode_config_hash": electrode_config_hash} - - # ---- make new ElectrodeConfig if needed ---- - if not probe.ElectrodeConfig & electrode_config_key: - probe.ElectrodeConfig.insert1( - { - **electrode_config_key, - "probe_type": probe_type, - "electrode_config_name": electrode_config_name, - } - ) - probe.ElectrodeConfig.Electrode.insert( - {**electrode_config_key, **electrode} for electrode in electrode_keys - ) - - return electrode_config_key diff --git a/element_array_ephys/ephys_report.py b/element_array_ephys/ephys_report.py index 48bcf613..0c6836a0 100644 --- a/element_array_ephys/ephys_report.py +++ b/element_array_ephys/ephys_report.py @@ -2,31 +2,30 @@ import datetime import pathlib +import tempfile from uuid import UUID import datajoint as dj from element_interface.utils import dict_to_uuid -from . import probe +from . import probe, ephys schema = dj.schema() -ephys = None - -def activate(schema_name, ephys_schema_name, *, create_schema=True, create_tables=True): +def activate(schema_name, *, create_schema=True, create_tables=True): """Activate the current schema. Args: schema_name (str): schema name on the database server to activate the `ephys_report` schema. - ephys_schema_name (str): schema name of the activated ephys element for which - this ephys_report schema will be downstream from. create_schema (bool, optional): If True (default), create schema in the database if it does not yet exist. create_tables (bool, optional): If True (default), create tables in the database if they do not yet exist. """ + if not probe.schema.is_activated(): + raise RuntimeError("Please activate the `probe` schema first.") + if not ephys.schema.is_activated(): + raise RuntimeError("Please activate the `ephys` schema first.") - global ephys - ephys = dj.create_virtual_module("ephys", ephys_schema_name) schema.activate( schema_name, create_schema=create_schema, @@ -55,7 +54,7 @@ class ProbeLevelReport(dj.Computed): def make(self, key): from .plotting.probe_level import plot_driftmap - save_dir = _make_save_dir() + save_dir = tempfile.TemporaryDirectory() units = ephys.CuratedClustering.Unit & key & "cluster_quality_label='good'" @@ -90,13 +89,15 @@ def make(self, key): fig_dict = _save_figs( figs=(fig,), fig_names=("drift_map_plot",), - save_dir=save_dir, + save_dir=save_dir.name, fig_prefix=fig_prefix, extension=".png", ) self.insert1({**key, **fig_dict, "shank": shank_no}) + save_dir.cleanup() + @schema class UnitLevelReport(dj.Computed): @@ -268,17 +269,10 @@ def make(self, key): ) -def _make_save_dir(root_dir: pathlib.Path = None) -> pathlib.Path: - if root_dir is None: - root_dir = pathlib.Path().absolute() - save_dir = root_dir / "temp_ephys_figures" - save_dir.mkdir(parents=True, exist_ok=True) - return save_dir - - def _save_figs( figs, fig_names, save_dir, fig_prefix, extension=".png" ) -> dict[str, pathlib.Path]: + save_dir = pathlib.Path(save_dir) fig_dict = {} for fig, fig_name in zip(figs, fig_names): fig_filepath = save_dir / (fig_prefix + "_" + fig_name + extension) diff --git a/element_array_ephys/export/nwb/nwb.py b/element_array_ephys/export/nwb/nwb.py index a45eb754..8d7da8f5 100644 --- a/element_array_ephys/export/nwb/nwb.py +++ b/element_array_ephys/export/nwb/nwb.py @@ -17,14 +17,7 @@ from spikeinterface import extractors from tqdm import tqdm -from ... import ephys_no_curation as ephys -from ... import probe - -ephys_mode = os.getenv("EPHYS_MODE", dj.config["custom"].get("ephys_mode", "acute")) -if ephys_mode != "no-curation": - raise NotImplementedError( - "This export function is designed for the no_curation " + "schema" - ) +from ... import probe, ephys class DecimalEncoder(json.JSONEncoder): diff --git a/element_array_ephys/readers/kilosort.py b/element_array_ephys/readers/kilosort.py index 4b50619d..4f8530d8 100644 --- a/element_array_ephys/readers/kilosort.py +++ b/element_array_ephys/readers/kilosort.py @@ -1,12 +1,10 @@ -import logging -import pathlib -import re -from datetime import datetime from os import path - -import numpy as np +from datetime import datetime +import pathlib import pandas as pd - +import numpy as np +import re +import logging from .utils import convert_to_number log = logging.getLogger(__name__) @@ -117,7 +115,8 @@ def _load(self): # Read the Cluster Groups for cluster_pattern, cluster_col_name in zip( - ["cluster_group.*", "cluster_KSLabel.*"], ["group", "KSLabel"] + ["cluster_group.*", "cluster_KSLabel.*", "cluster_group.*"], + ["group", "KSLabel", "KSLabel"], ): try: cluster_file = next(self._kilosort_dir.glob(cluster_pattern)) @@ -126,22 +125,26 @@ def _load(self): else: cluster_file_suffix = cluster_file.suffix assert cluster_file_suffix in (".tsv", ".xlsx") - break + + if cluster_file_suffix == ".tsv": + df = pd.read_csv(cluster_file, sep="\t", header=0) + elif cluster_file_suffix == ".xlsx": + df = pd.read_excel(cluster_file, engine="openpyxl") + else: + df = pd.read_csv(cluster_file, delimiter="\t") + + try: + self._data["cluster_groups"] = np.array(df[cluster_col_name].values) + self._data["cluster_ids"] = np.array(df["cluster_id"].values) + except KeyError: + continue + else: + break else: raise FileNotFoundError( 'Neither "cluster_groups" nor "cluster_KSLabel" file found!' ) - if cluster_file_suffix == ".tsv": - df = pd.read_csv(cluster_file, sep="\t", header=0) - elif cluster_file_suffix == ".xlsx": - df = pd.read_excel(cluster_file, engine="openpyxl") - else: - df = pd.read_csv(cluster_file, delimiter="\t") - - self._data["cluster_groups"] = np.array(df[cluster_col_name].values) - self._data["cluster_ids"] = np.array(df["cluster_id"].values) - def get_best_channel(self, unit): template_idx = self.data["spike_templates"][ np.where(self.data["spike_clusters"] == unit)[0][0] diff --git a/element_array_ephys/readers/probe_geometry.py b/element_array_ephys/readers/probe_geometry.py index b6fbc09e..f0d50a1c 100644 --- a/element_array_ephys/readers/probe_geometry.py +++ b/element_array_ephys/readers/probe_geometry.py @@ -140,8 +140,8 @@ def build_npx_probe( return elec_pos_df -def to_probeinterface(electrodes_df): - from probeinterface import Probe +def to_probeinterface(electrodes_df, **kwargs): + import probeinterface as pi probe_df = electrodes_df.copy() probe_df.rename( @@ -153,10 +153,22 @@ def to_probeinterface(electrodes_df): }, inplace=True, ) - probe_df["contact_shapes"] = "square" - probe_df["width"] = 12 - - return Probe.from_dataframe(probe_df) + # Get the contact shapes. By default, it's set to circle with a radius of 10. + contact_shapes = kwargs.get("contact_shapes", "circle") + assert ( + contact_shapes in pi.probe._possible_contact_shapes + ), f"contacts shape should be in {pi.probe._possible_contact_shapes}" + + probe_df["contact_shapes"] = contact_shapes + if contact_shapes == "circle": + probe_df["radius"] = kwargs.get("radius", 10) + elif contact_shapes == "square": + probe_df["width"] = kwargs.get("width", 10) + elif contact_shapes == "rect": + probe_df["width"] = kwargs.get("width") + probe_df["height"] = kwargs.get("height") + + return pi.Probe.from_dataframe(probe_df) def build_electrode_layouts( diff --git a/element_array_ephys/spike_sorting/__init__.py b/element_array_ephys/spike_sorting/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/element_array_ephys/readers/kilosort_triggering.py b/element_array_ephys/spike_sorting/kilosort_triggering.py similarity index 100% rename from element_array_ephys/readers/kilosort_triggering.py rename to element_array_ephys/spike_sorting/kilosort_triggering.py diff --git a/element_array_ephys/spike_sorting/si_preprocessing.py b/element_array_ephys/spike_sorting/si_preprocessing.py new file mode 100644 index 00000000..22adbdca --- /dev/null +++ b/element_array_ephys/spike_sorting/si_preprocessing.py @@ -0,0 +1,37 @@ +import spikeinterface as si +from spikeinterface import preprocessing + + +def CatGT(recording): + recording = si.preprocessing.phase_shift(recording) + recording = si.preprocessing.common_reference( + recording, operator="median", reference="global" + ) + return recording + + +def IBLdestriping(recording): + # From International Brain Laboratory. “Spike sorting pipeline for the International Brain Laboratory”. 4 May 2022. 9 Jun 2022. + recording = si.preprocessing.highpass_filter(recording, freq_min=400.0) + bad_channel_ids, channel_labels = si.preprocessing.detect_bad_channels(recording) + # For IBL destriping interpolate bad channels + recording = si.preprocessing.interpolate_bad_channels(bad_channel_ids) + recording = si.preprocessing.phase_shift(recording) + # For IBL destriping use highpass_spatial_filter used instead of common reference + recording = si.preprocessing.highpass_spatial_filter( + recording, operator="median", reference="global" + ) + return recording + + +def IBLdestriping_modified(recording): + # From SpikeInterface Implementation (https://spikeinterface.readthedocs.io/en/latest/how_to/analyse_neuropixels.html) + recording = si.preprocessing.highpass_filter(recording, freq_min=400.0) + bad_channel_ids, channel_labels = si.preprocessing.detect_bad_channels(recording) + # For IBL destriping interpolate bad channels + recording = recording.remove_channels(bad_channel_ids) + recording = si.preprocessing.phase_shift(recording) + recording = si.preprocessing.common_reference( + recording, operator="median", reference="global" + ) + return recording diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py new file mode 100644 index 00000000..e2f011e1 --- /dev/null +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -0,0 +1,403 @@ +""" +The following DataJoint pipeline implements the sequence of steps in the spike-sorting routine featured in the "spikeinterface" pipeline. +Spikeinterface was developed by Alessio Buccino, Samuel Garcia, Cole Hurwitz, Jeremy Magland, and Matthias Hennig (https://github.com/SpikeInterface) +If you use this pipeline, please cite SpikeInterface and the relevant sorter(s) used in your publication (see https://github.com/SpikeInterface for additional details for citation). +""" + +from datetime import datetime + +import datajoint as dj +import pandas as pd +import spikeinterface as si +from element_array_ephys import probe, ephys, readers +from element_interface.utils import find_full_path, memoized_result +from spikeinterface import exporters, extractors, sorters + +from . import si_preprocessing + +log = dj.logger + +schema = dj.schema() + + +def activate( + schema_name, + *, + create_schema=True, + create_tables=True, +): + """Activate the current schema. + + Args: + schema_name (str): schema name on the database server to activate the `si_spike_sorting` schema. + create_schema (bool, optional): If True (default), create schema in the database if it does not yet exist. + create_tables (bool, optional): If True (default), create tables in the database if they do not yet exist. + """ + if not probe.schema.is_activated(): + raise RuntimeError("Please activate the `probe` schema first.") + if not ephys.schema.is_activated(): + raise RuntimeError("Please activate the `ephys` schema first.") + + schema.activate( + schema_name, + create_schema=create_schema, + create_tables=create_tables, + add_objects=ephys.__dict__, + ) + ephys.Clustering.key_source -= PreProcessing.key_source.proj() + + +SI_SORTERS = [s.replace("_", ".") for s in si.sorters.sorter_dict.keys()] + + +@schema +class PreProcessing(dj.Imported): + """A table to handle preprocessing of each clustering task. The output will be serialized and stored as a si_recording.pkl in the output directory.""" + + definition = """ + -> ephys.ClusteringTask + --- + execution_time: datetime # datetime of the start of this step + execution_duration: float # execution duration in hours + """ + + @property + def key_source(self): + return ( + ephys.ClusteringTask * ephys.ClusteringParamSet + & {"task_mode": "trigger"} + & f"clustering_method in {tuple(SI_SORTERS)}" + ) - ephys.Clustering + + def make(self, key): + """Triggers or imports clustering analysis.""" + execution_time = datetime.utcnow() + + # Set the output directory + clustering_method, acq_software, output_dir, params = ( + ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key + ).fetch1("clustering_method", "acq_software", "clustering_output_dir", "params") + + # Get sorter method and create output directory. + sorter_name = clustering_method.replace(".", "_") + + for required_key in ( + "SI_PREPROCESSING_METHOD", + "SI_SORTING_PARAMS", + "SI_POSTPROCESSING_PARAMS", + ): + if required_key not in params: + raise ValueError( + f"{required_key} must be defined in ClusteringParamSet for SpikeInterface execution" + ) + + # Set directory to store recording file. + if not output_dir: + output_dir = ephys.ClusteringTask.infer_output_dir( + key, relative=True, mkdir=True + ) + # update clustering_output_dir + ephys.ClusteringTask.update1( + {**key, "clustering_output_dir": output_dir.as_posix()} + ) + output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + recording_dir = output_dir / sorter_name / "recording" + recording_dir.mkdir(parents=True, exist_ok=True) + recording_file = recording_dir / "si_recording.pkl" + + # Create SI recording extractor object + if acq_software == "SpikeGLX": + spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) + spikeglx_recording = readers.spikeglx.SpikeGLX( + spikeglx_meta_filepath.parent + ) + spikeglx_recording.validate_file("ap") + data_dir = spikeglx_meta_filepath.parent + + si_extractor = ( + si.extractors.neoextractors.spikeglx.SpikeGLXRecordingExtractor + ) + stream_names, stream_ids = si.extractors.get_neo_streams( + "spikeglx", folder_path=data_dir + ) + si_recording: si.BaseRecording = si_extractor( + folder_path=data_dir, stream_name=stream_names[0] + ) + elif acq_software == "Open Ephys": + oe_probe = ephys.get_openephys_probe_data(key) + assert len(oe_probe.recording_info["recording_files"]) == 1 + data_dir = oe_probe.recording_info["recording_files"][0] + si_extractor = ( + si.extractors.neoextractors.openephys.OpenEphysBinaryRecordingExtractor + ) + + stream_names, stream_ids = si.extractors.get_neo_streams( + "openephysbinary", folder_path=data_dir + ) + si_recording: si.BaseRecording = si_extractor( + folder_path=data_dir, stream_name=stream_names[0] + ) + else: + raise NotImplementedError( + f"SpikeInterface processing for {acq_software} not yet implemented." + ) + + # Add probe information to recording object + electrodes_df = ( + ( + ephys.EphysRecording.Channel + * probe.ElectrodeConfig.Electrode + * probe.ProbeType.Electrode + & key + ) + .fetch(format="frame") + .reset_index() + ) + + # Create SI probe object + si_probe = readers.probe_geometry.to_probeinterface( + electrodes_df[["electrode", "x_coord", "y_coord", "shank"]] + ) + si_probe.set_device_channel_indices(electrodes_df["channel_idx"].values) + si_recording.set_probe(probe=si_probe, in_place=True) + + # Run preprocessing and save results to output folder + si_preproc_func = getattr(si_preprocessing, params["SI_PREPROCESSING_METHOD"]) + si_recording = si_preproc_func(si_recording) + si_recording.dump_to_pickle(file_path=recording_file, relative_to=output_dir) + + self.insert1( + { + **key, + "execution_time": execution_time, + "execution_duration": ( + datetime.utcnow() - execution_time + ).total_seconds() + / 3600, + } + ) + + +@schema +class SIClustering(dj.Imported): + """A processing table to handle each clustering task.""" + + definition = """ + -> PreProcessing + --- + execution_time: datetime # datetime of the start of this step + execution_duration: float # execution duration in hours + """ + + def make(self, key): + execution_time = datetime.utcnow() + + # Load recording object. + clustering_method, output_dir, params = ( + ephys.ClusteringTask * ephys.ClusteringParamSet & key + ).fetch1("clustering_method", "clustering_output_dir", "params") + output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + sorter_name = clustering_method.replace(".", "_") + recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl" + si_recording: si.BaseRecording = si.load_extractor( + recording_file, base_folder=output_dir + ) + + sorting_params = params["SI_SORTING_PARAMS"] + sorting_output_dir = output_dir / sorter_name / "spike_sorting" + + # Run sorting + @memoized_result( + uniqueness_dict=sorting_params, + output_directory=sorting_output_dir, + ) + def _run_sorter(): + # Sorting performed in a dedicated docker environment if the sorter is not built in the spikeinterface package. + si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter( + sorter_name=sorter_name, + recording=si_recording, + folder=sorting_output_dir, + remove_existing_folder=True, + verbose=True, + docker_image=sorter_name not in si.sorters.installed_sorters(), + **sorting_params, + ) + + # Save sorting object + sorting_save_path = sorting_output_dir / "si_sorting.pkl" + si_sorting.dump_to_pickle(sorting_save_path, relative_to=output_dir) + + _run_sorter() + + self.insert1( + { + **key, + "execution_time": execution_time, + "execution_duration": ( + datetime.utcnow() - execution_time + ).total_seconds() + / 3600, + } + ) + + +@schema +class PostProcessing(dj.Imported): + """A processing table to handle each clustering task.""" + + definition = """ + -> SIClustering + --- + execution_time: datetime # datetime of the start of this step + execution_duration: float # execution duration in hours + do_si_export=0: bool # whether to export to phy + """ + + def make(self, key): + execution_time = datetime.utcnow() + + # Load recording & sorting object. + clustering_method, output_dir, params = ( + ephys.ClusteringTask * ephys.ClusteringParamSet & key + ).fetch1("clustering_method", "clustering_output_dir", "params") + output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + sorter_name = clustering_method.replace(".", "_") + + recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl" + sorting_file = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" + + si_recording: si.BaseRecording = si.load_extractor( + recording_file, base_folder=output_dir + ) + si_sorting: si.sorters.BaseSorter = si.load_extractor( + sorting_file, base_folder=output_dir + ) + + postprocessing_params = params["SI_POSTPROCESSING_PARAMS"] + + job_kwargs = postprocessing_params.get( + "job_kwargs", {"n_jobs": -1, "chunk_duration": "1s"} + ) + + analyzer_output_dir = output_dir / sorter_name / "sorting_analyzer" + + @memoized_result( + uniqueness_dict=postprocessing_params, + output_directory=analyzer_output_dir, + ) + def _sorting_analyzer_compute(): + # Sorting Analyzer + sorting_analyzer = si.create_sorting_analyzer( + sorting=si_sorting, + recording=si_recording, + format="binary_folder", + folder=analyzer_output_dir, + sparse=True, + overwrite=True, + ) + + # The order of extension computation is drawn from sorting_analyzer.get_computable_extensions() + # each extension is parameterized by params specified in extensions_params dictionary (skip if not specified) + extensions_params = postprocessing_params.get("extensions", {}) + extensions_to_compute = { + ext_name: extensions_params[ext_name] + for ext_name in sorting_analyzer.get_computable_extensions() + if ext_name in extensions_params + } + + sorting_analyzer.compute(extensions_to_compute, **job_kwargs) + + _sorting_analyzer_compute() + + self.insert1( + { + **key, + "execution_time": execution_time, + "execution_duration": ( + datetime.utcnow() - execution_time + ).total_seconds() + / 3600, + "do_si_export": postprocessing_params.get("export_to_phy", False) + or postprocessing_params.get("export_report", False), + } + ) + + # Once finished, insert this `key` into ephys.Clustering + ephys.Clustering.insert1( + {**key, "clustering_time": datetime.utcnow()}, allow_direct_insert=True + ) + + +@schema +class SIExport(dj.Computed): + """A SpikeInterface export report and to Phy""" + + definition = """ + -> PostProcessing + --- + execution_time: datetime + execution_duration: float + """ + + @property + def key_source(self): + return PostProcessing & "do_si_export = 1" + + def make(self, key): + execution_time = datetime.utcnow() + + clustering_method, output_dir, params = ( + ephys.ClusteringTask * ephys.ClusteringParamSet & key + ).fetch1("clustering_method", "clustering_output_dir", "params") + output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + sorter_name = clustering_method.replace(".", "_") + + postprocessing_params = params["SI_POSTPROCESSING_PARAMS"] + + job_kwargs = postprocessing_params.get( + "job_kwargs", {"n_jobs": -1, "chunk_duration": "1s"} + ) + + analyzer_output_dir = output_dir / sorter_name / "sorting_analyzer" + sorting_analyzer = si.load_sorting_analyzer(folder=analyzer_output_dir) + + @memoized_result( + uniqueness_dict=postprocessing_params, + output_directory=analyzer_output_dir / "phy", + ) + def _export_to_phy(): + # Save to phy format + si.exporters.export_to_phy( + sorting_analyzer=sorting_analyzer, + output_folder=analyzer_output_dir / "phy", + use_relative_path=True, + **job_kwargs, + ) + + @memoized_result( + uniqueness_dict=postprocessing_params, + output_directory=analyzer_output_dir / "spikeinterface_report", + ) + def _export_report(): + # Generate spike interface report + si.exporters.export_report( + sorting_analyzer=sorting_analyzer, + output_folder=analyzer_output_dir / "spikeinterface_report", + **job_kwargs, + ) + + if postprocessing_params.get("export_report", False): + _export_report() + if postprocessing_params.get("export_to_phy", False): + _export_to_phy() + + self.insert1( + { + **key, + "execution_time": execution_time, + "execution_duration": ( + datetime.utcnow() - execution_time + ).total_seconds() + / 3600, + } + ) diff --git a/element_array_ephys/version.py b/element_array_ephys/version.py index 6b5406e8..19ba4c76 100644 --- a/element_array_ephys/version.py +++ b/element_array_ephys/version.py @@ -1,3 +1,3 @@ """Package metadata.""" -__version__ = "0.3.5" +__version__ = "1.0.0" diff --git a/env.yml b/env.yml new file mode 100644 index 00000000..e9b3ce13 --- /dev/null +++ b/env.yml @@ -0,0 +1,7 @@ +channels: + - conda-forge + - defaults +dependencies: + - pip + - python>=3.7,<3.11 +name: element_array_ephys diff --git a/images/attached_array_ephys_element_no_curation.svg b/images/attached_array_ephys_element.svg similarity index 100% rename from images/attached_array_ephys_element_no_curation.svg rename to images/attached_array_ephys_element.svg diff --git a/images/attached_array_ephys_element_acute.svg b/images/attached_array_ephys_element_acute.svg deleted file mode 100644 index 5b2bc265..00000000 --- a/images/attached_array_ephys_element_acute.svg +++ /dev/null @@ -1,451 +0,0 @@ - - - - - -ephys.ProbeInsertion - - -ephys.ProbeInsertion - - - - - -ephys.InsertionLocation - - -ephys.InsertionLocation - - - - - -ephys.ProbeInsertion->ephys.InsertionLocation - - - - -ephys.EphysRecording - - -ephys.EphysRecording - - - - - -ephys.ProbeInsertion->ephys.EphysRecording - - - - -ephys.QualityMetrics - - -ephys.QualityMetrics - - - - - -ephys.QualityMetrics.Cluster - - -ephys.QualityMetrics.Cluster - - - - - -ephys.QualityMetrics->ephys.QualityMetrics.Cluster - - - - -ephys.QualityMetrics.Waveform - - -ephys.QualityMetrics.Waveform - - - - - -ephys.QualityMetrics->ephys.QualityMetrics.Waveform - - - - -probe.ElectrodeConfig - - -probe.ElectrodeConfig - - - - - -probe.ElectrodeConfig.Electrode - - -probe.ElectrodeConfig.Electrode - - - - - -probe.ElectrodeConfig->probe.ElectrodeConfig.Electrode - - - - -probe.ElectrodeConfig->ephys.EphysRecording - - - - -ephys.AcquisitionSoftware - - -ephys.AcquisitionSoftware - - - - - -ephys.AcquisitionSoftware->ephys.EphysRecording - - - - -SkullReference - - -SkullReference - - - - - -SkullReference->ephys.InsertionLocation - - - - -ephys.ClusteringParamSet - - -ephys.ClusteringParamSet - - - - - -ephys.ClusteringTask - - -ephys.ClusteringTask - - - - - -ephys.ClusteringParamSet->ephys.ClusteringTask - - - - -ephys.LFP.Electrode - - -ephys.LFP.Electrode - - - - - -ephys.ClusterQualityLabel - - -ephys.ClusterQualityLabel - - - - - -ephys.CuratedClustering.Unit - - -ephys.CuratedClustering.Unit - - - - - -ephys.ClusterQualityLabel->ephys.CuratedClustering.Unit - - - - -ephys.WaveformSet.Waveform - - -ephys.WaveformSet.Waveform - - - - - -ephys.Clustering - - -ephys.Clustering - - - - - -ephys.ClusteringTask->ephys.Clustering - - - - -probe.ProbeType - - -probe.ProbeType - - - - - -probe.ProbeType->probe.ElectrodeConfig - - - - -probe.Probe - - -probe.Probe - - - - - -probe.ProbeType->probe.Probe - - - - -probe.ProbeType.Electrode - - -probe.ProbeType.Electrode - - - - - -probe.ProbeType->probe.ProbeType.Electrode - - - - -ephys.Curation - - -ephys.Curation - - - - - -ephys.Clustering->ephys.Curation - - - - -ephys.LFP - - -ephys.LFP - - - - - -ephys.LFP->ephys.LFP.Electrode - - - - -probe.Probe->ephys.ProbeInsertion - - - - -ephys.CuratedClustering - - -ephys.CuratedClustering - - - - - -ephys.CuratedClustering->ephys.QualityMetrics - - - - -ephys.WaveformSet - - -ephys.WaveformSet - - - - - -ephys.CuratedClustering->ephys.WaveformSet - - - - -ephys.CuratedClustering->ephys.CuratedClustering.Unit - - - - -subject.Subject - - -subject.Subject - - - - - -session.Session - - -session.Session - - - - - -subject.Subject->session.Session - - - - -probe.ElectrodeConfig.Electrode->ephys.LFP.Electrode - - - - -probe.ElectrodeConfig.Electrode->ephys.WaveformSet.Waveform - - - - -probe.ElectrodeConfig.Electrode->ephys.CuratedClustering.Unit - - - - -ephys.Curation->ephys.CuratedClustering - - - - -ephys.ClusteringMethod - - -ephys.ClusteringMethod - - - - - -ephys.ClusteringMethod->ephys.ClusteringParamSet - - - - -ephys.WaveformSet.PeakWaveform - - -ephys.WaveformSet.PeakWaveform - - - - - -session.Session->ephys.ProbeInsertion - - - - -ephys.EphysRecording.EphysFile - - -ephys.EphysRecording.EphysFile - - - - - -ephys.WaveformSet->ephys.WaveformSet.Waveform - - - - -ephys.WaveformSet->ephys.WaveformSet.PeakWaveform - - - - -ephys.CuratedClustering.Unit->ephys.WaveformSet.Waveform - - - - -ephys.CuratedClustering.Unit->ephys.QualityMetrics.Cluster - - - - -ephys.CuratedClustering.Unit->ephys.QualityMetrics.Waveform - - - - -ephys.CuratedClustering.Unit->ephys.WaveformSet.PeakWaveform - - - - -ephys.EphysRecording->ephys.ClusteringTask - - - - -ephys.EphysRecording->ephys.LFP - - - - -ephys.EphysRecording->ephys.EphysRecording.EphysFile - - - - -probe.ProbeType.Electrode->probe.ElectrodeConfig.Electrode - - - - \ No newline at end of file diff --git a/images/attached_array_ephys_element_chronic.svg b/images/attached_array_ephys_element_chronic.svg deleted file mode 100644 index 808a2f17..00000000 --- a/images/attached_array_ephys_element_chronic.svg +++ /dev/null @@ -1,456 +0,0 @@ - - - - - -ephys.Curation - - -ephys.Curation - - - - - -ephys.CuratedClustering - - -ephys.CuratedClustering - - - - - -ephys.Curation->ephys.CuratedClustering - - - - -ephys.AcquisitionSoftware - - -ephys.AcquisitionSoftware - - - - - -ephys.EphysRecording - - -ephys.EphysRecording - - - - - -ephys.AcquisitionSoftware->ephys.EphysRecording - - - - -ephys.ProbeInsertion - - -ephys.ProbeInsertion - - - - - -ephys.ProbeInsertion->ephys.EphysRecording - - - - -ephys.InsertionLocation - - -ephys.InsertionLocation - - - - - -ephys.ProbeInsertion->ephys.InsertionLocation - - - - -subject.Subject - - -subject.Subject - - - - - -subject.Subject->ephys.ProbeInsertion - - - - -session.Session - - -session.Session - - - - - -subject.Subject->session.Session - - - - -ephys.WaveformSet.PeakWaveform - - -ephys.WaveformSet.PeakWaveform - - - - - -ephys.EphysRecording.EphysFile - - -ephys.EphysRecording.EphysFile - - - - - -ephys.EphysRecording->ephys.EphysRecording.EphysFile - - - - -ephys.ClusteringTask - - -ephys.ClusteringTask - - - - - -ephys.EphysRecording->ephys.ClusteringTask - - - - -ephys.LFP - - -ephys.LFP - - - - - -ephys.EphysRecording->ephys.LFP - - - - -probe.Probe - - -probe.Probe - - - - - -probe.Probe->ephys.ProbeInsertion - - - - -ephys.QualityMetrics - - -ephys.QualityMetrics - - - - - -ephys.QualityMetrics.Waveform - - -ephys.QualityMetrics.Waveform - - - - - -ephys.QualityMetrics->ephys.QualityMetrics.Waveform - - - - -ephys.QualityMetrics.Cluster - - -ephys.QualityMetrics.Cluster - - - - - -ephys.QualityMetrics->ephys.QualityMetrics.Cluster - - - - -ephys.ClusteringParamSet - - -ephys.ClusteringParamSet - - - - - -ephys.ClusteringParamSet->ephys.ClusteringTask - - - - -ephys.WaveformSet.Waveform - - -ephys.WaveformSet.Waveform - - - - - -probe.ProbeType - - -probe.ProbeType - - - - - -probe.ProbeType->probe.Probe - - - - -probe.ElectrodeConfig - - -probe.ElectrodeConfig - - - - - -probe.ProbeType->probe.ElectrodeConfig - - - - -probe.ProbeType.Electrode - - -probe.ProbeType.Electrode - - - - - -probe.ProbeType->probe.ProbeType.Electrode - - - - -ephys.Clustering - - -ephys.Clustering - - - - - -ephys.ClusteringTask->ephys.Clustering - - - - -ephys.LFP.Electrode - - -ephys.LFP.Electrode - - - - - -ephys.LFP->ephys.LFP.Electrode - - - - -session.Session->ephys.EphysRecording - - - - -ephys.Clustering->ephys.Curation - - - - -probe.ElectrodeConfig.Electrode - - -probe.ElectrodeConfig.Electrode - - - - - -probe.ElectrodeConfig.Electrode->ephys.WaveformSet.Waveform - - - - -probe.ElectrodeConfig.Electrode->ephys.LFP.Electrode - - - - -ephys.CuratedClustering.Unit - - -ephys.CuratedClustering.Unit - - - - - -probe.ElectrodeConfig.Electrode->ephys.CuratedClustering.Unit - - - - -ephys.WaveformSet - - -ephys.WaveformSet - - - - - -ephys.WaveformSet->ephys.WaveformSet.PeakWaveform - - - - -ephys.WaveformSet->ephys.WaveformSet.Waveform - - - - -probe.ElectrodeConfig->ephys.EphysRecording - - - - -probe.ElectrodeConfig->probe.ElectrodeConfig.Electrode - - - - -probe.ProbeType.Electrode->probe.ElectrodeConfig.Electrode - - - - -ephys.CuratedClustering.Unit->ephys.WaveformSet.PeakWaveform - - - - -ephys.CuratedClustering.Unit->ephys.WaveformSet.Waveform - - - - -ephys.CuratedClustering.Unit->ephys.QualityMetrics.Waveform - - - - -ephys.CuratedClustering.Unit->ephys.QualityMetrics.Cluster - - - - -ephys.ClusteringMethod - - -ephys.ClusteringMethod - - - - - -ephys.ClusteringMethod->ephys.ClusteringParamSet - - - - -ephys.CuratedClustering->ephys.QualityMetrics - - - - -ephys.CuratedClustering->ephys.WaveformSet - - - - -ephys.CuratedClustering->ephys.CuratedClustering.Unit - - - - -ephys.ClusterQualityLabel - - -ephys.ClusterQualityLabel - - - - - -ephys.ClusterQualityLabel->ephys.CuratedClustering.Unit - - - - -SkullReference - - -SkullReference - - - - - -SkullReference->ephys.InsertionLocation - - - - \ No newline at end of file diff --git a/images/attached_array_ephys_element_precluster.svg b/images/attached_array_ephys_element_precluster.svg deleted file mode 100644 index 7d854d2e..00000000 --- a/images/attached_array_ephys_element_precluster.svg +++ /dev/null @@ -1,535 +0,0 @@ - - - - - -ephys.AcquisitionSoftware - - -ephys.AcquisitionSoftware - - - - - -ephys.EphysRecording - - -ephys.EphysRecording - - - - - -ephys.AcquisitionSoftware->ephys.EphysRecording - - - - -ephys.QualityMetrics.Waveform - - -ephys.QualityMetrics.Waveform - - - - - -ephys.PreClusterTask - - -ephys.PreClusterTask - - - - - -ephys.EphysRecording->ephys.PreClusterTask - - - - -ephys.EphysRecording.EphysFile - - -ephys.EphysRecording.EphysFile - - - - - -ephys.EphysRecording->ephys.EphysRecording.EphysFile - - - - -ephys.PreCluster - - -ephys.PreCluster - - - - - -ephys.PreClusterTask->ephys.PreCluster - - - - -probe.ProbeType.Electrode - - -probe.ProbeType.Electrode - - - - - -probe.ElectrodeConfig.Electrode - - -probe.ElectrodeConfig.Electrode - - - - - -probe.ProbeType.Electrode->probe.ElectrodeConfig.Electrode - - - - -ephys.LFP - - -ephys.LFP - - - - - -ephys.PreCluster->ephys.LFP - - - - -ephys.ClusteringTask - - -ephys.ClusteringTask - - - - - -ephys.PreCluster->ephys.ClusteringTask - - - - -ephys.LFP.Electrode - - -ephys.LFP.Electrode - - - - - -probe.ElectrodeConfig.Electrode->ephys.LFP.Electrode - - - - -ephys.CuratedClustering.Unit - - -ephys.CuratedClustering.Unit - - - - - -probe.ElectrodeConfig.Electrode->ephys.CuratedClustering.Unit - - - - -ephys.WaveformSet.Waveform - - -ephys.WaveformSet.Waveform - - - - - -probe.ElectrodeConfig.Electrode->ephys.WaveformSet.Waveform - - - - -ephys.Curation - - -ephys.Curation - - - - - -ephys.CuratedClustering - - -ephys.CuratedClustering - - - - - -ephys.Curation->ephys.CuratedClustering - - - - -probe.ElectrodeConfig - - -probe.ElectrodeConfig - - - - - -probe.ElectrodeConfig->ephys.EphysRecording - - - - -probe.ElectrodeConfig->probe.ElectrodeConfig.Electrode - - - - -ephys.QualityMetrics - - -ephys.QualityMetrics - - - - - -ephys.CuratedClustering->ephys.QualityMetrics - - - - -ephys.WaveformSet - - -ephys.WaveformSet - - - - - -ephys.CuratedClustering->ephys.WaveformSet - - - - -ephys.CuratedClustering->ephys.CuratedClustering.Unit - - - - -ephys.InsertionLocation - - -ephys.InsertionLocation - - - - - -SkullReference - - -SkullReference - - - - - -SkullReference->ephys.InsertionLocation - - - - -ephys.QualityMetrics->ephys.QualityMetrics.Waveform - - - - -ephys.QualityMetrics.Cluster - - -ephys.QualityMetrics.Cluster - - - - - -ephys.QualityMetrics->ephys.QualityMetrics.Cluster - - - - -ephys.PreClusterParamSteps.Step - - -ephys.PreClusterParamSteps.Step - - - - - -ephys.ClusterQualityLabel - - -ephys.ClusterQualityLabel - - - - - -ephys.ClusterQualityLabel->ephys.CuratedClustering.Unit - - - - -session.Session - - -session.Session - - - - - -ephys.ProbeInsertion - - -ephys.ProbeInsertion - - - - - -session.Session->ephys.ProbeInsertion - - - - -ephys.ClusteringMethod - - -ephys.ClusteringMethod - - - - - -ephys.ClusteringParamSet - - -ephys.ClusteringParamSet - - - - - -ephys.ClusteringMethod->ephys.ClusteringParamSet - - - - -ephys.WaveformSet.PeakWaveform - - -ephys.WaveformSet.PeakWaveform - - - - - -ephys.WaveformSet->ephys.WaveformSet.PeakWaveform - - - - -ephys.WaveformSet->ephys.WaveformSet.Waveform - - - - -subject.Subject - - -subject.Subject - - - - - -subject.Subject->session.Session - - - - -ephys.LFP->ephys.LFP.Electrode - - - - -ephys.CuratedClustering.Unit->ephys.QualityMetrics.Waveform - - - - -ephys.CuratedClustering.Unit->ephys.QualityMetrics.Cluster - - - - -ephys.CuratedClustering.Unit->ephys.WaveformSet.PeakWaveform - - - - -ephys.CuratedClustering.Unit->ephys.WaveformSet.Waveform - - - - -ephys.Clustering - - -ephys.Clustering - - - - - -ephys.ClusteringTask->ephys.Clustering - - - - -probe.Probe - - -probe.Probe - - - - - -probe.Probe->ephys.ProbeInsertion - - - - -ephys.PreClusterMethod - - -ephys.PreClusterMethod - - - - - -ephys.PreClusterParamSet - - -ephys.PreClusterParamSet - - - - - -ephys.PreClusterMethod->ephys.PreClusterParamSet - - - - -ephys.ClusteringParamSet->ephys.ClusteringTask - - - - -probe.ProbeType - - -probe.ProbeType - - - - - -probe.ProbeType->probe.ProbeType.Electrode - - - - -probe.ProbeType->probe.ElectrodeConfig - - - - -probe.ProbeType->probe.Probe - - - - -ephys.ProbeInsertion->ephys.EphysRecording - - - - -ephys.ProbeInsertion->ephys.InsertionLocation - - - - -ephys.PreClusterParamSteps - - -ephys.PreClusterParamSteps - - - - - -ephys.PreClusterParamSteps->ephys.PreClusterTask - - - - -ephys.PreClusterParamSteps->ephys.PreClusterParamSteps.Step - - - - -ephys.Clustering->ephys.Curation - - - - -ephys.PreClusterParamSet->ephys.PreClusterParamSteps.Step - - - - \ No newline at end of file diff --git a/notebooks/demo_prepare.ipynb b/notebooks/demo_prepare.ipynb index 74057ba4..85ee1be2 100644 --- a/notebooks/demo_prepare.ipynb +++ b/notebooks/demo_prepare.ipynb @@ -213,7 +213,6 @@ "pygments_lexer": "ipython3", "version": "3.9.17" }, - "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" diff --git a/notebooks/demo_run.ipynb b/notebooks/demo_run.ipynb index 348a3c43..70fbb746 100644 --- a/notebooks/demo_run.ipynb +++ b/notebooks/demo_run.ipynb @@ -96,7 +96,6 @@ "pygments_lexer": "ipython3", "version": "3.9.17" }, - "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "ff52d424e56dd643d8b2ec122f40a2e279e94970100b4e6430cb9025a65ba4cf" diff --git a/setup.py b/setup.py index 19a4a5ae..38b08c29 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from os import path -from setuptools import find_packages, setup +from setuptools import find_packages, setup pkg_name = "element_array_ephys" here = path.abspath(path.dirname(__file__)) @@ -16,6 +16,7 @@ setup( name=pkg_name.replace("_", "-"), + python_requires=">=3.7, <3.11", version=__version__, # noqa F821 description="Extracellular Array Electrophysiology DataJoint Element", long_description=long_description, @@ -34,20 +35,22 @@ "openpyxl", "plotly", "seaborn", - "spikeinterface", + "spikeinterface @ git+https://github.com/SpikeInterface/spikeinterface.git", "scikit-image>=0.20", "nbformat>=4.2.0", "pyopenephys>=1.1.6", + "element-interface @ git+https://github.com/datajoint/element-interface.git", + "numba", ], extras_require={ "elements": [ "element-animal @ git+https://github.com/datajoint/element-animal.git", "element-event @ git+https://github.com/datajoint/element-event.git", - "element-interface @ git+https://github.com/datajoint/element-interface.git", "element-lab @ git+https://github.com/datajoint/element-lab.git", "element-session @ git+https://github.com/datajoint/element-session.git", ], "nwb": ["dandi", "neuroconv[ecephys]", "pynwb"], "tests": ["pre-commit", "pytest", "pytest-cov"], + "spikingcircus": ["hdbscan"], }, ) diff --git a/tests/tutorial_pipeline.py b/tests/tutorial_pipeline.py index 74b27ddc..1b27027d 100644 --- a/tests/tutorial_pipeline.py +++ b/tests/tutorial_pipeline.py @@ -3,7 +3,7 @@ import datajoint as dj from element_animal import subject from element_animal.subject import Subject -from element_array_ephys import probe, ephys_no_curation as ephys, ephys_report +from element_array_ephys import probe, ephys, ephys_report from element_lab import lab from element_lab.lab import Lab, Location, Project, Protocol, Source, User from element_lab.lab import Device as Equipment @@ -62,7 +62,9 @@ def get_session_directory(session_key): return pathlib.Path(session_directory) -ephys.activate(db_prefix + "ephys", db_prefix + "probe", linking_module=__name__) +probe.activate(db_prefix + "probe") +ephys.activate(db_prefix + "ephys", linking_module=__name__) +ephys_report.activate(db_prefix + "ephys_report") probe.create_neuropixels_probe_types()