diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml
index e5e6a07a..19c91e36 100644
--- a/.github/workflows/test.yaml
+++ b/.github/workflows/test.yaml
@@ -1,7 +1,11 @@
name: Test
on:
push:
+ branches:
+ - main
pull_request:
+ branches:
+ - main
workflow_dispatch:
jobs:
devcontainer-build:
@@ -31,4 +35,3 @@ jobs:
run: |
python_version=${{matrix.py_ver}}
black element_array_ephys --check --verbose --target-version py${python_version//.}
-
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 0d513df7..6d28ef11 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -3,7 +3,7 @@ exclude: (^.github/|^docs/|^images/)
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
- rev: v4.4.0
+ rev: v4.5.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
@@ -16,7 +16,7 @@ repos:
# black
- repo: https://github.com/psf/black
- rev: 22.12.0
+ rev: 24.2.0
hooks:
- id: black
- id: black-jupyter
@@ -25,7 +25,7 @@ repos:
# isort
- repo: https://github.com/pycqa/isort
- rev: 5.11.2
+ rev: 5.13.2
hooks:
- id: isort
args: ["--profile", "black"]
@@ -33,7 +33,7 @@ repos:
# flake8
- repo: https://github.com/pycqa/flake8
- rev: 4.0.1
+ rev: 7.0.0
hooks:
- id: flake8
args: # arguments to configure flake8
diff --git a/CHANGELOG.md b/CHANGELOG.md
index d2e48eaa..34d1a2e4 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -3,6 +3,21 @@
Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and
[Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention.
+
+## [1.0.0] - 2024-09-10
+
++ Update - No longer support multiple variation of ephys module, keep only `ephys_no_curation` module, renamed to `ephys`
++ Update - Remove other ephys modules (e.g. `ephys_acute`, `ephys_chronic`) (moved to different branches)
++ Update - Add support for `SpikeInterface`
++ Update - Remove support for `ecephys_spike_sorting` (moved to a different branch)
++ Update - Simplify the "activate" mechanism
+
+## [0.4.0] - 2024-08-16
+
++ Add - support for SpikeInterface version >= 0.101.0 (updated API)
++ Add - feature for memoization of spike sorting results (prevent duplicated runs)
+
+
## [0.3.5] - 2024-08-16
+ Fix - Improve `spikeglx` loader in extracting neuropixels probe type from the meta file
diff --git a/element_array_ephys/__init__.py b/element_array_ephys/__init__.py
index 1c0c7285..079950b4 100644
--- a/element_array_ephys/__init__.py
+++ b/element_array_ephys/__init__.py
@@ -1 +1,3 @@
-from . import ephys_acute as ephys
+from . import ephys
+
+ephys_no_curation = ephys # alias for backward compatibility
diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys.py
similarity index 66%
rename from element_array_ephys/ephys_no_curation.py
rename to element_array_ephys/ephys.py
index cd0909c9..ad9bb8d7 100644
--- a/element_array_ephys/ephys_no_curation.py
+++ b/element_array_ephys/ephys.py
@@ -10,10 +10,10 @@
import pandas as pd
from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory
-from . import ephys_report, probe
+from . import probe
from .readers import kilosort, openephys, spikeglx
-log = dj.logger
+logger = dj.logger
schema = dj.schema()
@@ -22,7 +22,6 @@
def activate(
ephys_schema_name: str,
- probe_schema_name: str = None,
*,
create_schema: bool = True,
create_tables: bool = True,
@@ -32,7 +31,6 @@ def activate(
Args:
ephys_schema_name (str): A string containing the name of the ephys schema.
- probe_schema_name (str): A string containing the name of the probe schema.
create_schema (bool): If True, schema will be created in the database.
create_tables (bool): If True, tables related to the schema will be created in the database.
linking_module (str): A string containing the module name or module containing the required dependencies to activate the schema.
@@ -46,7 +44,6 @@ def activate(
get_ephys_root_data_dir(): Returns absolute path for root data director(y/ies) with all electrophysiological recording sessions, as a list of string(s).
get_session_direction(session_key: dict): Returns path to electrophysiology data for the a particular session as a list of strings.
get_processed_data_dir(): Optional. Returns absolute path for processed data. Defaults to root directory.
-
"""
if isinstance(linking_module, str):
@@ -58,17 +55,15 @@ def activate(
global _linking_module
_linking_module = linking_module
- # activate
- probe.activate(
- probe_schema_name, create_schema=create_schema, create_tables=create_tables
- )
+ if not probe.schema.is_activated():
+ raise RuntimeError("Please activate the `probe` schema first.")
+
schema.activate(
ephys_schema_name,
create_schema=create_schema,
create_tables=create_tables,
add_objects=_linking_module.__dict__,
)
- ephys_report.activate(f"{ephys_schema_name}_report", ephys_schema_name)
# -------------- Functions required by the elements-ephys ---------------
@@ -129,7 +124,7 @@ class AcquisitionSoftware(dj.Lookup):
"""
definition = """ # Name of software used for recording of neuropixels probes - SpikeGLX or Open Ephys
- acq_software: varchar(24)
+ acq_software: varchar(24)
"""
contents = zip(["SpikeGLX", "Open Ephys"])
@@ -272,15 +267,24 @@ class EphysRecording(dj.Imported):
definition = """
# Ephys recording from a probe insertion for a given session.
- -> ProbeInsertion
+ -> ProbeInsertion
---
-> probe.ElectrodeConfig
-> AcquisitionSoftware
- sampling_rate: float # (Hz)
+ sampling_rate: float # (Hz)
recording_datetime: datetime # datetime of the recording from this probe
recording_duration: float # (seconds) duration of the recording from this probe
"""
+ class Channel(dj.Part):
+ definition = """
+ -> master
+ channel_idx: int # channel index (index of the raw data)
+ ---
+ -> probe.ElectrodeConfig.Electrode
+ channel_name="": varchar(64) # alias of the channel
+ """
+
class EphysFile(dj.Part):
"""Paths of electrophysiology recording files for each insertion.
@@ -304,7 +308,7 @@ def make(self, key):
"probe"
)
- # search session dir and determine acquisition software
+ # Search session dir and determine acquisition software
for ephys_pattern, ephys_acq_type in (
("*.ap.meta", "SpikeGLX"),
("*.oebin", "Open Ephys"),
@@ -315,8 +319,13 @@ def make(self, key):
break
else:
raise FileNotFoundError(
- "Ephys recording data not found!"
- " Neither SpikeGLX nor Open Ephys recording files found"
+ f"Ephys recording data not found in {session_dir}."
+ "Neither SpikeGLX nor Open Ephys recording files found"
+ )
+
+ if acq_software not in AcquisitionSoftware.fetch("acq_software"):
+ raise NotImplementedError(
+ f"Processing ephys files from acquisition software of type {acq_software} is not yet implemented."
)
supported_probe_types = probe.ProbeType.fetch("probe_type")
@@ -325,51 +334,79 @@ def make(self, key):
for meta_filepath in ephys_meta_filepaths:
spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath)
if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number:
+ spikeglx_meta_filepath = meta_filepath
break
else:
raise FileNotFoundError(
"No SpikeGLX data found for probe insertion: {}".format(key)
)
- if spikeglx_meta.probe_model in supported_probe_types:
- probe_type = spikeglx_meta.probe_model
- electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type}
+ if spikeglx_meta.probe_model not in supported_probe_types:
+ raise NotImplementedError(
+ f"Processing for neuropixels probe model {spikeglx_meta.probe_model} not yet implemented."
+ )
- probe_electrodes = {
- (shank, shank_col, shank_row): key
- for key, shank, shank_col, shank_row in zip(
- *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row")
- )
- }
+ probe_type = spikeglx_meta.probe_model
+ electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type}
- electrode_group_members = [
- probe_electrodes[(shank, shank_col, shank_row)]
- for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"]
- ]
- else:
- raise NotImplementedError(
- "Processing for neuropixels probe model"
- " {} not yet implemented".format(spikeglx_meta.probe_model)
+ probe_electrodes = {
+ (shank, shank_col, shank_row): key
+ for key, shank, shank_col, shank_row in zip(
+ *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row")
)
+ } # electrode configuration
+ electrode_group_members = [
+ probe_electrodes[(shank, shank_col, shank_row)]
+ for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"]
+ ] # recording session-specific electrode configuration
+
+ econfig_entry, econfig_electrodes = generate_electrode_config_entry(
+ probe_type, electrode_group_members
+ )
- self.insert1(
+ ephys_recording_entry = {
+ **key,
+ "electrode_config_hash": econfig_entry["electrode_config_hash"],
+ "acq_software": acq_software,
+ "sampling_rate": spikeglx_meta.meta["imSampRate"],
+ "recording_datetime": spikeglx_meta.recording_time,
+ "recording_duration": (
+ spikeglx_meta.recording_duration
+ or spikeglx.retrieve_recording_duration(spikeglx_meta_filepath)
+ ),
+ }
+
+ root_dir = find_root_directory(
+ get_ephys_root_data_dir(), spikeglx_meta_filepath
+ )
+
+ ephys_file_entries = [
{
**key,
- **generate_electrode_config(probe_type, electrode_group_members),
- "acq_software": acq_software,
- "sampling_rate": spikeglx_meta.meta["imSampRate"],
- "recording_datetime": spikeglx_meta.recording_time,
- "recording_duration": (
- spikeglx_meta.recording_duration
- or spikeglx.retrieve_recording_duration(meta_filepath)
- ),
+ "file_path": spikeglx_meta_filepath.relative_to(
+ root_dir
+ ).as_posix(),
}
- )
+ ]
- root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath)
- self.EphysFile.insert1(
- {**key, "file_path": meta_filepath.relative_to(root_dir).as_posix()}
- )
+ # Insert channel information
+ # Get channel and electrode-site mapping
+ channel2electrode_map = {
+ recorded_site: probe_electrodes[(shank, shank_col, shank_row)]
+ for recorded_site, (shank, shank_col, shank_row, _) in enumerate(
+ spikeglx_meta.shankmap["data"]
+ )
+ }
+
+ ephys_channel_entries = [
+ {
+ **key,
+ "electrode_config_hash": econfig_entry["electrode_config_hash"],
+ "channel_idx": channel_idx,
+ **channel_info,
+ }
+ for channel_idx, channel_info in channel2electrode_map.items()
+ ]
elif acq_software == "Open Ephys":
dataset = openephys.OpenEphys(session_dir)
for serial_number, probe_data in dataset.probes.items():
@@ -385,60 +422,84 @@ def make(self, key):
'No analog signals found - check "structure.oebin" file or "continuous" directory'
)
- if probe_data.probe_model in supported_probe_types:
- probe_type = probe_data.probe_model
- electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type}
-
- probe_electrodes = {
- key["electrode"]: key for key in electrode_query.fetch("KEY")
- }
-
- electrode_group_members = [
- probe_electrodes[channel_idx]
- for channel_idx in probe_data.ap_meta["channels_indices"]
- ]
- else:
+ if probe_data.probe_model not in supported_probe_types:
raise NotImplementedError(
- "Processing for neuropixels"
- " probe model {} not yet implemented".format(probe_data.probe_model)
+ f"Processing for neuropixels probe model {probe_data.probe_model} not yet implemented."
)
- self.insert1(
- {
- **key,
- **generate_electrode_config(probe_type, electrode_group_members),
- "acq_software": acq_software,
- "sampling_rate": probe_data.ap_meta["sample_rate"],
- "recording_datetime": probe_data.recording_info[
- "recording_datetimes"
- ][0],
- "recording_duration": np.sum(
- probe_data.recording_info["recording_durations"]
- ),
- }
+ probe_type = probe_data.probe_model
+ electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type}
+
+ probe_electrodes = {
+ key["electrode"]: key for key in electrode_query.fetch("KEY")
+ } # electrode configuration
+
+ electrode_group_members = [
+ probe_electrodes[channel_idx]
+ for channel_idx in probe_data.ap_meta["channels_indices"]
+ ] # recording session-specific electrode configuration
+
+ econfig_entry, econfig_electrodes = generate_electrode_config_entry(
+ probe_type, electrode_group_members
)
+ ephys_recording_entry = {
+ **key,
+ "electrode_config_hash": econfig_entry["electrode_config_hash"],
+ "acq_software": acq_software,
+ "sampling_rate": probe_data.ap_meta["sample_rate"],
+ "recording_datetime": probe_data.recording_info["recording_datetimes"][
+ 0
+ ],
+ "recording_duration": np.sum(
+ probe_data.recording_info["recording_durations"]
+ ),
+ }
+
root_dir = find_root_directory(
get_ephys_root_data_dir(),
probe_data.recording_info["recording_files"][0],
)
- self.EphysFile.insert(
- [
- {**key, "file_path": fp.relative_to(root_dir).as_posix()}
- for fp in probe_data.recording_info["recording_files"]
- ]
- )
- # explicitly garbage collect "dataset"
- # as these may have large memory footprint and may not be cleared fast enough
+
+ ephys_file_entries = [
+ {**key, "file_path": fp.relative_to(root_dir).as_posix()}
+ for fp in probe_data.recording_info["recording_files"]
+ ]
+
+ channel2electrode_map = {
+ channel_idx: probe_electrodes[channel_idx]
+ for channel_idx in probe_data.ap_meta["channels_indices"]
+ }
+
+ ephys_channel_entries = [
+ {
+ **key,
+ "electrode_config_hash": econfig_entry["electrode_config_hash"],
+ "channel_idx": channel_idx,
+ **channel_info,
+ }
+ for channel_idx, channel_info in channel2electrode_map.items()
+ ]
+
+ # Explicitly garbage collect "dataset" as these may have large memory footprint and may not be cleared fast enough
del probe_data, dataset
gc.collect()
else:
raise NotImplementedError(
- f"Processing ephys files from"
- f" acquisition software of type {acq_software} is"
- f" not yet implemented"
+ f"Processing ephys files from acquisition software of type {acq_software} is not yet implemented."
)
+ # Insert into probe.ElectrodeConfig (recording configuration)
+ if not probe.ElectrodeConfig & {
+ "electrode_config_hash": econfig_entry["electrode_config_hash"]
+ }:
+ probe.ElectrodeConfig.insert1(econfig_entry)
+ probe.ElectrodeConfig.Electrode.insert(econfig_electrodes)
+
+ self.insert1(ephys_recording_entry)
+ self.EphysFile.insert(ephys_file_entries)
+ self.Channel.insert(ephys_channel_entries)
+
@schema
class LFP(dj.Imported):
@@ -471,9 +532,9 @@ class Electrode(dj.Part):
definition = """
-> master
- -> probe.ElectrodeConfig.Electrode
+ -> probe.ElectrodeConfig.Electrode
---
- lfp: longblob # (uV) recorded lfp at this electrode
+ lfp: longblob # (uV) recorded lfp at this electrode
"""
# Only store LFP for every 9th channel, due to high channel density,
@@ -614,14 +675,14 @@ class ClusteringParamSet(dj.Lookup):
ClusteringMethod (dict): ClusteringMethod primary key.
paramset_desc (varchar(128) ): Description of the clustering parameter set.
param_set_hash (uuid): UUID hash for the parameter set.
- params (longblob): Set of clustering parameters
+ params (longblob): Set of clustering parameters.
"""
definition = """
# Parameter set to be used in a clustering procedure
paramset_idx: smallint
---
- -> ClusteringMethod
+ -> ClusteringMethod
paramset_desc: varchar(128)
param_set_hash: uuid
unique index (param_set_hash)
@@ -700,6 +761,7 @@ class ClusterQualityLabel(dj.Lookup):
("ok", "probably a single unit, but could be contaminated"),
("mua", "multi-unit activity"),
("noise", "bad unit"),
+ ("n.a.", "not available"),
]
@@ -724,18 +786,15 @@ class ClusteringTask(dj.Manual):
"""
@classmethod
- def infer_output_dir(
- cls, key, relative: bool = False, mkdir: bool = False
- ) -> pathlib.Path:
+ def infer_output_dir(cls, key, relative: bool = False, mkdir: bool = False):
"""Infer output directory if it is not provided.
Args:
key (dict): ClusteringTask primary key.
Returns:
- Expected clustering_output_dir based on the following convention:
- processed_dir / session_dir / probe_{insertion_number} / {clustering_method}_{paramset_idx}
- e.g.: sub4/sess1/probe_2/kilosort2_0
+ Pathlib.Path: Expected clustering_output_dir based on the following convention: processed_dir / session_dir / probe_{insertion_number} / {clustering_method}_{paramset_idx}
+ e.g.: sub4/sess1/probe_2/kilosort2_0
"""
processed_dir = pathlib.Path(get_processed_root_data_dir())
session_dir = find_full_path(
@@ -758,7 +817,7 @@ def infer_output_dir(
if mkdir:
output_dir.mkdir(parents=True, exist_ok=True)
- log.info(f"{output_dir} created!")
+ logger.info(f"{output_dir} created!")
return output_dir.relative_to(processed_dir) if relative else output_dir
@@ -809,7 +868,7 @@ class Clustering(dj.Imported):
# Clustering Procedure
-> ClusteringTask
---
- clustering_time: datetime # time of generation of this set of clustering results
+ clustering_time: datetime # time of generation of this set of clustering results
package_version='': varchar(16)
"""
@@ -838,7 +897,7 @@ def make(self, key):
).fetch1("acq_software", "clustering_method", "params")
if "kilosort" in clustering_method:
- from element_array_ephys.readers import kilosort_triggering
+ from .spike_sorting import kilosort_triggering
# add additional probe-recording and channels details into `params`
params = {**params, **get_recording_channels_details(key)}
@@ -850,10 +909,6 @@ def make(self, key):
spikeglx_meta_filepath.parent
)
spikeglx_recording.validate_file("ap")
- run_CatGT = (
- params.pop("run_CatGT", True)
- and "_tcat." not in spikeglx_meta_filepath.stem
- )
if clustering_method.startswith("pykilosort"):
kilosort_triggering.run_pykilosort(
@@ -874,7 +929,7 @@ def make(self, key):
ks_output_dir=kilosort_dir,
params=params,
KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}',
- run_CatGT=run_CatGT,
+ run_CatGT=True,
)
run_kilosort.run_modules()
elif acq_software == "Open Ephys":
@@ -929,7 +984,7 @@ class CuratedClustering(dj.Imported):
definition = """
# Clustering results of the spike sorting step.
- -> Clustering
+ -> Clustering
"""
class Unit(dj.Part):
@@ -946,7 +1001,7 @@ class Unit(dj.Part):
spike_depths (longblob): Array of depths associated with each spike, relative to each spike.
"""
- definition = """
+ definition = """
# Properties of a given unit from a round of clustering (and curation)
-> master
unit: int
@@ -956,85 +1011,175 @@ class Unit(dj.Part):
spike_count: int # how many spikes in this recording for this unit
spike_times: longblob # (s) spike times of this unit, relative to the start of the EphysRecording
spike_sites : longblob # array of electrode associated with each spike
- spike_depths=null : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe
+ spike_depths=null : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe
"""
def make(self, key):
"""Automated population of Unit information."""
- output_dir = (ClusteringTask & key).fetch1("clustering_output_dir")
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
+ clustering_method, output_dir = (
+ ClusteringTask * ClusteringParamSet & key
+ ).fetch1("clustering_method", "clustering_output_dir")
+ output_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
+
+ # Get channel and electrode-site mapping
+ electrode_query = (EphysRecording.Channel & key).proj(..., "-channel_name")
+ channel2electrode_map: dict[int, dict] = {
+ chn.pop("channel_idx"): chn for chn in electrode_query.fetch(as_dict=True)
+ }
- kilosort_dataset = kilosort.Kilosort(kilosort_dir)
- acq_software, sample_rate = (EphysRecording & key).fetch1(
- "acq_software", "sampling_rate"
- )
+ # Get sorter method and create output directory.
+ sorter_name = clustering_method.replace(".", "_")
+ si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer"
- sample_rate = kilosort_dataset.data["params"].get("sample_rate", sample_rate)
+ if si_sorting_analyzer_dir.exists(): # Read from spikeinterface outputs
+ import spikeinterface as si
- # ---------- Unit ----------
- # -- Remove 0-spike units
- withspike_idx = [
- i
- for i, u in enumerate(kilosort_dataset.data["cluster_ids"])
- if (kilosort_dataset.data["spike_clusters"] == u).any()
- ]
- valid_units = kilosort_dataset.data["cluster_ids"][withspike_idx]
- valid_unit_labels = kilosort_dataset.data["cluster_groups"][withspike_idx]
- # -- Get channel and electrode-site mapping
- channel2electrodes = get_neuropixels_channel2electrode_map(key, acq_software)
-
- # -- Spike-times --
- # spike_times_sec_adj > spike_times_sec > spike_times
- spike_time_key = (
- "spike_times_sec_adj"
- if "spike_times_sec_adj" in kilosort_dataset.data
- else (
- "spike_times_sec"
- if "spike_times_sec" in kilosort_dataset.data
- else "spike_times"
+ sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir)
+ si_sorting = sorting_analyzer.sorting
+
+ # Find representative channel for each unit
+ unit_peak_channel: dict[int, np.ndarray] = (
+ si.ChannelSparsity.from_best_channels(
+ sorting_analyzer,
+ 1,
+ ).unit_id_to_channel_indices
)
- )
- spike_times = kilosort_dataset.data[spike_time_key]
- kilosort_dataset.extract_spike_depths()
+ unit_peak_channel: dict[int, int] = {
+ u: chn[0] for u, chn in unit_peak_channel.items()
+ }
- # -- Spike-sites and Spike-depths --
- spike_sites = np.array(
- [
- channel2electrodes[s]["electrode"]
- for s in kilosort_dataset.data["spike_sites"]
- ]
- )
- spike_depths = kilosort_dataset.data["spike_depths"]
-
- # -- Insert unit, label, peak-chn
- units = []
- for unit, unit_lbl in zip(valid_units, valid_unit_labels):
- if (kilosort_dataset.data["spike_clusters"] == unit).any():
- unit_channel, _ = kilosort_dataset.get_best_channel(unit)
- unit_spike_times = (
- spike_times[kilosort_dataset.data["spike_clusters"] == unit]
- / sample_rate
+ spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit()
+ # {unit: spike_count}
+
+ # update channel2electrode_map to match with probe's channel index
+ channel2electrode_map = {
+ idx: channel2electrode_map[int(chn_idx)]
+ for idx, chn_idx in enumerate(sorting_analyzer.get_probe().contact_ids)
+ }
+
+ # Get unit id to quality label mapping
+ cluster_quality_label_map = {
+ int(unit_id): (
+ si_sorting.get_unit_property(unit_id, "KSLabel")
+ if "KSLabel" in si_sorting.get_property_keys()
+ else "n.a."
)
- spike_count = len(unit_spike_times)
+ for unit_id in si_sorting.unit_ids
+ }
+
+ spike_locations = sorting_analyzer.get_extension("spike_locations")
+ extremum_channel_inds = si.template_tools.get_template_extremum_channel(
+ sorting_analyzer, outputs="index"
+ )
+ spikes_df = pd.DataFrame(
+ sorting_analyzer.sorting.to_spike_vector(
+ extremum_channel_inds=extremum_channel_inds
+ )
+ )
+
+ units = []
+ for unit_idx, unit_id in enumerate(si_sorting.unit_ids):
+ unit_id = int(unit_id)
+ unit_spikes_df = spikes_df[spikes_df.unit_index == unit_idx]
+ spike_sites = np.array(
+ [
+ channel2electrode_map[chn_idx]["electrode"]
+ for chn_idx in unit_spikes_df.channel_index
+ ]
+ )
+ unit_spikes_loc = spike_locations.get_data()[unit_spikes_df.index]
+ _, spike_depths = zip(*unit_spikes_loc) # x-coordinates, y-coordinates
+ spike_times = si_sorting.get_unit_spike_train(
+ unit_id, return_times=True
+ )
+
+ assert len(spike_times) == len(spike_sites) == len(spike_depths)
units.append(
{
- "unit": unit,
- "cluster_quality_label": unit_lbl,
- **channel2electrodes[unit_channel],
- "spike_times": unit_spike_times,
- "spike_count": spike_count,
- "spike_sites": spike_sites[
- kilosort_dataset.data["spike_clusters"] == unit
- ],
- "spike_depths": spike_depths[
- kilosort_dataset.data["spike_clusters"] == unit
- ],
+ **key,
+ **channel2electrode_map[unit_peak_channel[unit_id]],
+ "unit": unit_id,
+ "cluster_quality_label": cluster_quality_label_map[unit_id],
+ "spike_times": spike_times,
+ "spike_count": spike_count_dict[unit_id],
+ "spike_sites": spike_sites,
+ "spike_depths": spike_depths,
}
)
+ else: # read from kilosort outputs
+ kilosort_dataset = kilosort.Kilosort(output_dir)
+ acq_software, sample_rate = (EphysRecording & key).fetch1(
+ "acq_software", "sampling_rate"
+ )
+
+ sample_rate = kilosort_dataset.data["params"].get(
+ "sample_rate", sample_rate
+ )
+
+ # ---------- Unit ----------
+ # -- Remove 0-spike units
+ withspike_idx = [
+ i
+ for i, u in enumerate(kilosort_dataset.data["cluster_ids"])
+ if (kilosort_dataset.data["spike_clusters"] == u).any()
+ ]
+ valid_units = kilosort_dataset.data["cluster_ids"][withspike_idx]
+ valid_unit_labels = kilosort_dataset.data["cluster_groups"][withspike_idx]
+
+ # -- Spike-times --
+ # spike_times_sec_adj > spike_times_sec > spike_times
+ spike_time_key = (
+ "spike_times_sec_adj"
+ if "spike_times_sec_adj" in kilosort_dataset.data
+ else (
+ "spike_times_sec"
+ if "spike_times_sec" in kilosort_dataset.data
+ else "spike_times"
+ )
+ )
+ spike_times = kilosort_dataset.data[spike_time_key]
+ kilosort_dataset.extract_spike_depths()
+
+ # -- Spike-sites and Spike-depths --
+ spike_sites = np.array(
+ [
+ channel2electrode_map[s]["electrode"]
+ for s in kilosort_dataset.data["spike_sites"]
+ ]
+ )
+ spike_depths = kilosort_dataset.data["spike_depths"]
+
+ # -- Insert unit, label, peak-chn
+ units = []
+ for unit, unit_lbl in zip(valid_units, valid_unit_labels):
+ if (kilosort_dataset.data["spike_clusters"] == unit).any():
+ unit_channel, _ = kilosort_dataset.get_best_channel(unit)
+ unit_spike_times = (
+ spike_times[kilosort_dataset.data["spike_clusters"] == unit]
+ / sample_rate
+ )
+ spike_count = len(unit_spike_times)
+
+ units.append(
+ {
+ **key,
+ "unit": unit,
+ "cluster_quality_label": unit_lbl,
+ **channel2electrode_map[unit_channel],
+ "spike_times": unit_spike_times,
+ "spike_count": spike_count,
+ "spike_sites": spike_sites[
+ kilosort_dataset.data["spike_clusters"] == unit
+ ],
+ "spike_depths": spike_depths[
+ kilosort_dataset.data["spike_clusters"] == unit
+ ],
+ }
+ )
self.insert1(key)
- self.Unit.insert([{**key, **u} for u in units])
+ self.Unit.insert(units, ignore_extra_fields=True)
@schema
@@ -1082,113 +1227,171 @@ class Waveform(dj.Part):
# Spike waveforms and their mean across spikes for the given unit
-> master
-> CuratedClustering.Unit
- -> probe.ElectrodeConfig.Electrode
- ---
+ -> probe.ElectrodeConfig.Electrode
+ ---
waveform_mean: longblob # (uV) mean waveform across spikes of the given unit
waveforms=null: longblob # (uV) (spike x sample) waveforms of a sampling of spikes at the given electrode for the given unit
"""
def make(self, key):
"""Populates waveform tables."""
- output_dir = (ClusteringTask & key).fetch1("clustering_output_dir")
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
+ clustering_method, output_dir = (
+ ClusteringTask * ClusteringParamSet & key
+ ).fetch1("clustering_method", "clustering_output_dir")
+ output_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
+ sorter_name = clustering_method.replace(".", "_")
+
+ # Get channel and electrode-site mapping
+ electrode_query = (EphysRecording.Channel & key).proj(..., "-channel_name")
+ channel2electrode_map: dict[int, dict] = {
+ chn.pop("channel_idx"): chn for chn in electrode_query.fetch(as_dict=True)
+ }
- kilosort_dataset = kilosort.Kilosort(kilosort_dir)
+ si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer"
+ if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs
+ import spikeinterface as si
- acq_software, probe_serial_number = (
- EphysRecording * ProbeInsertion & key
- ).fetch1("acq_software", "probe")
+ sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir)
- # -- Get channel and electrode-site mapping
- recording_key = (EphysRecording & key).fetch1("KEY")
- channel2electrodes = get_neuropixels_channel2electrode_map(
- recording_key, acq_software
- )
+ # Find representative channel for each unit
+ unit_peak_channel: dict[int, np.ndarray] = (
+ si.ChannelSparsity.from_best_channels(
+ sorting_analyzer, 1
+ ).unit_id_to_channel_indices
+ ) # {unit: peak_channel_index}
+ unit_peak_channel = {u: chn[0] for u, chn in unit_peak_channel.items()}
- # Get all units
- units = {
- u["unit"]: u
- for u in (CuratedClustering.Unit & key).fetch(as_dict=True, order_by="unit")
- }
+ # update channel2electrode_map to match with probe's channel index
+ channel2electrode_map = {
+ idx: channel2electrode_map[int(chn_idx)]
+ for idx, chn_idx in enumerate(sorting_analyzer.get_probe().contact_ids)
+ }
- if (kilosort_dir / "mean_waveforms.npy").exists():
- unit_waveforms = np.load(
- kilosort_dir / "mean_waveforms.npy"
- ) # unit x channel x sample
+ templates = sorting_analyzer.get_extension("templates")
def yield_unit_waveforms():
- for unit_no, unit_waveform in zip(
- kilosort_dataset.data["cluster_ids"], unit_waveforms
+ for unit in (CuratedClustering.Unit & key).fetch(
+ "KEY", order_by="unit"
):
- unit_peak_waveform = {}
- unit_electrode_waveforms = []
- if unit_no in units:
+ # Get mean waveform for this unit from all channels - (sample x channel)
+ unit_waveforms = templates.get_unit_template(
+ unit_id=unit["unit"], operator="average"
+ )
+ unit_peak_waveform = {
+ **unit,
+ "peak_electrode_waveform": unit_waveforms[
+ :, unit_peak_channel[unit["unit"]]
+ ],
+ }
+
+ unit_electrode_waveforms = [
+ {
+ **unit,
+ **channel2electrode_map[chn_idx],
+ "waveform_mean": unit_waveforms[:, chn_idx],
+ }
+ for chn_idx in channel2electrode_map
+ ]
+
+ yield unit_peak_waveform, unit_electrode_waveforms
+
+ else: # read from kilosort outputs (ecephys pipeline)
+ kilosort_dataset = kilosort.Kilosort(output_dir)
+
+ acq_software, probe_serial_number = (
+ EphysRecording * ProbeInsertion & key
+ ).fetch1("acq_software", "probe")
+
+ # Get all units
+ units = {
+ u["unit"]: u
+ for u in (CuratedClustering.Unit & key).fetch(
+ as_dict=True, order_by="unit"
+ )
+ }
+
+ if (output_dir / "mean_waveforms.npy").exists():
+ unit_waveforms = np.load(
+ output_dir / "mean_waveforms.npy"
+ ) # unit x channel x sample
+
+ def yield_unit_waveforms():
+ for unit_no, unit_waveform in zip(
+ kilosort_dataset.data["cluster_ids"], unit_waveforms
+ ):
+ unit_peak_waveform = {}
+ unit_electrode_waveforms = []
+ if unit_no in units:
+ for channel, channel_waveform in zip(
+ kilosort_dataset.data["channel_map"], unit_waveform
+ ):
+ unit_electrode_waveforms.append(
+ {
+ **units[unit_no],
+ **channel2electrode_map[channel],
+ "waveform_mean": channel_waveform,
+ }
+ )
+ if (
+ channel2electrode_map[channel]["electrode"]
+ == units[unit_no]["electrode"]
+ ):
+ unit_peak_waveform = {
+ **units[unit_no],
+ "peak_electrode_waveform": channel_waveform,
+ }
+ yield unit_peak_waveform, unit_electrode_waveforms
+
+ else:
+ if acq_software == "SpikeGLX":
+ spikeglx_meta_filepath = get_spikeglx_meta_filepath(key)
+ neuropixels_recording = spikeglx.SpikeGLX(
+ spikeglx_meta_filepath.parent
+ )
+ elif acq_software == "Open Ephys":
+ session_dir = find_full_path(
+ get_ephys_root_data_dir(), get_session_directory(key)
+ )
+ openephys_dataset = openephys.OpenEphys(session_dir)
+ neuropixels_recording = openephys_dataset.probes[
+ probe_serial_number
+ ]
+
+ def yield_unit_waveforms():
+ for unit_dict in units.values():
+ unit_peak_waveform = {}
+ unit_electrode_waveforms = []
+
+ spikes = unit_dict["spike_times"]
+ waveforms = neuropixels_recording.extract_spike_waveforms(
+ spikes, kilosort_dataset.data["channel_map"]
+ ) # (sample x channel x spike)
+ waveforms = waveforms.transpose(
+ (1, 2, 0)
+ ) # (channel x spike x sample)
for channel, channel_waveform in zip(
- kilosort_dataset.data["channel_map"], unit_waveform
+ kilosort_dataset.data["channel_map"], waveforms
):
unit_electrode_waveforms.append(
{
- **units[unit_no],
- **channel2electrodes[channel],
- "waveform_mean": channel_waveform,
+ **unit_dict,
+ **channel2electrode_map[channel],
+ "waveform_mean": channel_waveform.mean(axis=0),
+ "waveforms": channel_waveform,
}
)
if (
- channel2electrodes[channel]["electrode"]
- == units[unit_no]["electrode"]
+ channel2electrode_map[channel]["electrode"]
+ == unit_dict["electrode"]
):
unit_peak_waveform = {
- **units[unit_no],
- "peak_electrode_waveform": channel_waveform,
+ **unit_dict,
+ "peak_electrode_waveform": channel_waveform.mean(
+ axis=0
+ ),
}
- yield unit_peak_waveform, unit_electrode_waveforms
-
- else:
- if acq_software == "SpikeGLX":
- spikeglx_meta_filepath = get_spikeglx_meta_filepath(key)
- neuropixels_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent)
- elif acq_software == "Open Ephys":
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(key)
- )
- openephys_dataset = openephys.OpenEphys(session_dir)
- neuropixels_recording = openephys_dataset.probes[probe_serial_number]
-
- def yield_unit_waveforms():
- for unit_dict in units.values():
- unit_peak_waveform = {}
- unit_electrode_waveforms = []
-
- spikes = unit_dict["spike_times"]
- waveforms = neuropixels_recording.extract_spike_waveforms(
- spikes, kilosort_dataset.data["channel_map"]
- ) # (sample x channel x spike)
- waveforms = waveforms.transpose(
- (1, 2, 0)
- ) # (channel x spike x sample)
- for channel, channel_waveform in zip(
- kilosort_dataset.data["channel_map"], waveforms
- ):
- unit_electrode_waveforms.append(
- {
- **unit_dict,
- **channel2electrodes[channel],
- "waveform_mean": channel_waveform.mean(axis=0),
- "waveforms": channel_waveform,
- }
- )
- if (
- channel2electrodes[channel]["electrode"]
- == unit_dict["electrode"]
- ):
- unit_peak_waveform = {
- **unit_dict,
- "peak_electrode_waveform": channel_waveform.mean(
- axis=0
- ),
- }
- yield unit_peak_waveform, unit_electrode_waveforms
+ yield unit_peak_waveform, unit_electrode_waveforms
# insert waveform on a per-unit basis to mitigate potential memory issue
self.insert1(key)
@@ -1209,7 +1412,7 @@ class QualityMetrics(dj.Imported):
definition = """
# Clusters and waveforms metrics
- -> CuratedClustering
+ -> CuratedClustering
"""
class Cluster(dj.Part):
@@ -1234,26 +1437,26 @@ class Cluster(dj.Part):
contamination_rate (float): Frequency of spikes in the refractory period.
"""
- definition = """
+ definition = """
# Cluster metrics for a particular unit
-> master
-> CuratedClustering.Unit
---
- firing_rate=null: float # (Hz) firing rate for a unit
+ firing_rate=null: float # (Hz) firing rate for a unit
snr=null: float # signal-to-noise ratio for a unit
presence_ratio=null: float # fraction of time in which spikes are present
isi_violation=null: float # rate of ISI violation as a fraction of overall rate
number_violation=null: int # total number of ISI violations
amplitude_cutoff=null: float # estimate of miss rate based on amplitude histogram
isolation_distance=null: float # distance to nearest cluster in Mahalanobis space
- l_ratio=null: float #
+ l_ratio=null: float #
d_prime=null: float # Classification accuracy based on LDA
nn_hit_rate=null: float # Fraction of neighbors for target cluster that are also in target cluster
nn_miss_rate=null: float # Fraction of neighbors outside target cluster that are in target cluster
silhouette_score=null: float # Standard metric for cluster overlap
max_drift=null: float # Maximum change in spike depth throughout recording
- cumulative_drift=null: float # Cumulative change in spike depth throughout recording
- contamination_rate=null: float #
+ cumulative_drift=null: float # Cumulative change in spike depth throughout recording
+ contamination_rate=null: float #
"""
class Waveform(dj.Part):
@@ -1273,13 +1476,13 @@ class Waveform(dj.Part):
velocity_below (float): inverse velocity of waveform propagation from soma toward the bottom of the probe.
"""
- definition = """
+ definition = """
# Waveform metrics for a particular unit
-> master
-> CuratedClustering.Unit
---
- amplitude: float # (uV) absolute difference between waveform peak and trough
- duration: float # (ms) time between waveform peak and trough
+ amplitude=null: float # (uV) absolute difference between waveform peak and trough
+ duration=null: float # (ms) time between waveform peak and trough
halfwidth=null: float # (ms) spike width at half max amplitude
pt_ratio=null: float # absolute amplitude of peak divided by absolute amplitude of trough relative to 0
repolarization_slope=null: float # the repolarization slope was defined by fitting a regression line to the first 30us from trough to peak
@@ -1291,24 +1494,63 @@ class Waveform(dj.Part):
def make(self, key):
"""Populates tables with quality metrics data."""
- output_dir = (ClusteringTask & key).fetch1("clustering_output_dir")
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
+ # Load metrics.csv
+ clustering_method, output_dir = (
+ ClusteringTask * ClusteringParamSet & key
+ ).fetch1("clustering_method", "clustering_output_dir")
+ output_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
+ sorter_name = clustering_method.replace(".", "_")
+
+ si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer"
+ if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs
+ import spikeinterface as si
+
+ sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir)
+ qc_metrics = sorting_analyzer.get_extension("quality_metrics").get_data()
+ template_metrics = sorting_analyzer.get_extension(
+ "template_metrics"
+ ).get_data()
+ metrics_df = pd.concat([qc_metrics, template_metrics], axis=1)
+
+ metrics_df.rename(
+ columns={
+ "amplitude_median": "amplitude",
+ "isi_violations_ratio": "isi_violation",
+ "isi_violations_count": "number_violation",
+ "silhouette": "silhouette_score",
+ "rp_contamination": "contamination_rate",
+ "drift_ptp": "max_drift",
+ "drift_mad": "cumulative_drift",
+ "half_width": "halfwidth",
+ "peak_trough_ratio": "pt_ratio",
+ "peak_to_valley": "duration",
+ },
+ inplace=True,
+ )
+ else: # read from kilosort outputs (ecephys pipeline)
+ # find metric_fp
+ for metric_fp in [
+ output_dir / "metrics.csv",
+ ]:
+ if metric_fp.exists():
+ break
+ else:
+ raise FileNotFoundError(f"QC metrics file not found in: {output_dir}")
- metric_fp = kilosort_dir / "metrics.csv"
- rename_dict = {
- "isi_viol": "isi_violation",
- "num_viol": "number_violation",
- "contam_rate": "contamination_rate",
- }
+ metrics_df = pd.read_csv(metric_fp)
- if not metric_fp.exists():
- raise FileNotFoundError(f"QC metrics file not found: {metric_fp}")
+ # Conform the dataframe to match the table definition
+ if "cluster_id" in metrics_df.columns:
+ metrics_df.set_index("cluster_id", inplace=True)
+ else:
+ metrics_df.rename(
+ columns={metrics_df.columns[0]: "cluster_id"}, inplace=True
+ )
+ metrics_df.set_index("cluster_id", inplace=True)
+
+ metrics_df.columns = metrics_df.columns.str.lower()
- metrics_df = pd.read_csv(metric_fp)
- metrics_df.set_index("cluster_id", inplace=True)
metrics_df.replace([np.inf, -np.inf], np.nan, inplace=True)
- metrics_df.columns = metrics_df.columns.str.lower()
- metrics_df.rename(columns=rename_dict, inplace=True)
metrics_list = [
dict(metrics_df.loc[unit_key["unit"]], **unit_key)
for unit_key in (CuratedClustering.Unit & key).fetch("KEY")
@@ -1382,99 +1624,6 @@ def get_openephys_probe_data(ephys_recording_key: dict) -> list:
return probe_data
-def get_neuropixels_channel2electrode_map(
- ephys_recording_key: dict, acq_software: str
-) -> dict:
- """Get the channel map for neuropixels probe."""
- if acq_software == "SpikeGLX":
- spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key)
- spikeglx_meta = spikeglx.SpikeGLXMeta(spikeglx_meta_filepath)
- electrode_config_key = (
- EphysRecording * probe.ElectrodeConfig & ephys_recording_key
- ).fetch1("KEY")
-
- electrode_query = (
- probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode
- & electrode_config_key
- )
-
- probe_electrodes = {
- (shank, shank_col, shank_row): key
- for key, shank, shank_col, shank_row in zip(
- *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row")
- )
- }
-
- channel2electrode_map = {
- recorded_site: probe_electrodes[(shank, shank_col, shank_row)]
- for recorded_site, (shank, shank_col, shank_row, _) in enumerate(
- spikeglx_meta.shankmap["data"]
- )
- }
- elif acq_software == "Open Ephys":
- probe_dataset = get_openephys_probe_data(ephys_recording_key)
-
- electrode_query = (
- probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode * EphysRecording
- & ephys_recording_key
- )
-
- probe_electrodes = {
- key["electrode"]: key for key in electrode_query.fetch("KEY")
- }
-
- channel2electrode_map = {
- channel_idx: probe_electrodes[channel_idx]
- for channel_idx in probe_dataset.ap_meta["channels_indices"]
- }
-
- return channel2electrode_map
-
-
-def generate_electrode_config(probe_type: str, electrode_keys: list) -> dict:
- """Generate and insert new ElectrodeConfig
-
- Args:
- probe_type (str): probe type (e.g. neuropixels 2.0 - SS)
- electrode_keys (list): list of keys of the probe.ProbeType.Electrode table
-
- Returns:
- dict: representing a key of the probe.ElectrodeConfig table
- """
- # compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode)
- electrode_config_hash = dict_to_uuid({k["electrode"]: k for k in electrode_keys})
-
- electrode_list = sorted([k["electrode"] for k in electrode_keys])
- electrode_gaps = (
- [-1]
- + np.where(np.diff(electrode_list) > 1)[0].tolist()
- + [len(electrode_list) - 1]
- )
- electrode_config_name = "; ".join(
- [
- f"{electrode_list[start + 1]}-{electrode_list[end]}"
- for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:])
- ]
- )
-
- electrode_config_key = {"electrode_config_hash": electrode_config_hash}
-
- # ---- make new ElectrodeConfig if needed ----
- if not probe.ElectrodeConfig & electrode_config_key:
- probe.ElectrodeConfig.insert1(
- {
- **electrode_config_key,
- "probe_type": probe_type,
- "electrode_config_name": electrode_config_name,
- }
- )
- probe.ElectrodeConfig.Electrode.insert(
- {**electrode_config_key, **electrode} for electrode in electrode_keys
- )
-
- return electrode_config_key
-
-
def get_recording_channels_details(ephys_recording_key: dict) -> np.array:
"""Get details of recording channels for a given recording."""
channels_details = {}
@@ -1530,3 +1679,41 @@ def get_recording_channels_details(ephys_recording_key: dict) -> np.array:
)
return channels_details
+
+
+def generate_electrode_config_entry(probe_type: str, electrode_keys: list) -> dict:
+ """Generate and insert new ElectrodeConfig
+
+ Args:
+ probe_type (str): probe type (e.g. neuropixels 2.0 - SS)
+ electrode_keys (list): list of keys of the probe.ProbeType.Electrode table
+
+ Returns:
+ dict: representing a key of the probe.ElectrodeConfig table
+ """
+ # compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode)
+ electrode_config_hash = dict_to_uuid({k["electrode"]: k for k in electrode_keys})
+
+ electrode_list = sorted([k["electrode"] for k in electrode_keys])
+ electrode_gaps = (
+ [-1]
+ + np.where(np.diff(electrode_list) > 1)[0].tolist()
+ + [len(electrode_list) - 1]
+ )
+ electrode_config_name = "; ".join(
+ [
+ f"{electrode_list[start + 1]}-{electrode_list[end]}"
+ for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:])
+ ]
+ )
+ electrode_config_key = {"electrode_config_hash": electrode_config_hash}
+ econfig_entry = {
+ **electrode_config_key,
+ "probe_type": probe_type,
+ "electrode_config_name": electrode_config_name,
+ }
+ econfig_electrodes = [
+ {**electrode, **electrode_config_key} for electrode in electrode_keys
+ ]
+
+ return econfig_entry, econfig_electrodes
diff --git a/element_array_ephys/ephys_acute.py b/element_array_ephys/ephys_acute.py
deleted file mode 100644
index 50371104..00000000
--- a/element_array_ephys/ephys_acute.py
+++ /dev/null
@@ -1,1594 +0,0 @@
-import gc
-import importlib
-import inspect
-import pathlib
-import re
-from decimal import Decimal
-
-import datajoint as dj
-import numpy as np
-import pandas as pd
-from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory
-
-from . import ephys_report, probe
-from .readers import kilosort, openephys, spikeglx
-
-log = dj.logger
-
-schema = dj.schema()
-
-_linking_module = None
-
-
-def activate(
- ephys_schema_name: str,
- probe_schema_name: str = None,
- *,
- create_schema: bool = True,
- create_tables: bool = True,
- linking_module: str = None,
-):
- """Activates the `ephys` and `probe` schemas.
-
- Args:
- ephys_schema_name (str): A string containing the name of the ephys schema.
- probe_schema_name (str): A string containing the name of the probe schema.
- create_schema (bool): If True, schema will be created in the database.
- create_tables (bool): If True, tables related to the schema will be created in the database.
- linking_module (str): A string containing the module name or module containing the required dependencies to activate the schema.
-
- Dependencies:
- Upstream tables:
- Session: A parent table to ProbeInsertion
- Probe: A parent table to EphysRecording. Probe information is required before electrophysiology data is imported.
-
- Functions:
- get_ephys_root_data_dir(): Returns absolute path for root data director(y/ies) with all electrophysiological recording sessions, as a list of string(s).
- get_session_direction(session_key: dict): Returns path to electrophysiology data for the a particular session as a list of strings.
- get_processed_data_dir(): Optional. Returns absolute path for processed data. Defaults to root directory.
- """
-
- if isinstance(linking_module, str):
- linking_module = importlib.import_module(linking_module)
- assert inspect.ismodule(
- linking_module
- ), "The argument 'dependency' must be a module's name or a module"
-
- global _linking_module
- _linking_module = linking_module
-
- probe.activate(
- probe_schema_name, create_schema=create_schema, create_tables=create_tables
- )
- schema.activate(
- ephys_schema_name,
- create_schema=create_schema,
- create_tables=create_tables,
- add_objects=_linking_module.__dict__,
- )
- ephys_report.activate(f"{ephys_schema_name}_report", ephys_schema_name)
-
-
-# -------------- Functions required by the elements-ephys ---------------
-
-
-def get_ephys_root_data_dir() -> list:
- """Fetches absolute data path to ephys data directories.
-
- The absolute path here is used as a reference for all downstream relative paths used in DataJoint.
-
- Returns:
- A list of the absolute path(s) to ephys data directories.
- """
- root_directories = _linking_module.get_ephys_root_data_dir()
- if isinstance(root_directories, (str, pathlib.Path)):
- root_directories = [root_directories]
-
- if hasattr(_linking_module, "get_processed_root_data_dir"):
- root_directories.append(_linking_module.get_processed_root_data_dir())
-
- return root_directories
-
-
-def get_session_directory(session_key: dict) -> str:
- """Retrieve the session directory with Neuropixels for the given session.
-
- Args:
- session_key (dict): A dictionary mapping subject to an entry in the subject table, and session_datetime corresponding to a session in the database.
-
- Returns:
- A string for the path to the session directory.
- """
- return _linking_module.get_session_directory(session_key)
-
-
-def get_processed_root_data_dir() -> str:
- """Retrieve the root directory for all processed data.
-
- Returns:
- A string for the full path to the root directory for processed data.
- """
-
- if hasattr(_linking_module, "get_processed_root_data_dir"):
- return _linking_module.get_processed_root_data_dir()
- else:
- return get_ephys_root_data_dir()[0]
-
-
-# ----------------------------- Table declarations ----------------------
-
-
-@schema
-class AcquisitionSoftware(dj.Lookup):
- """Name of software used for recording electrophysiological data.
-
- Attributes:
- acq_software ( varchar(24) ): Acquisition software, e.g,. SpikeGLX, OpenEphys
- """
-
- definition = """ # Software used for recording of neuropixels probes
- acq_software: varchar(24)
- """
- contents = zip(["SpikeGLX", "Open Ephys"])
-
-
-@schema
-class ProbeInsertion(dj.Manual):
- """Information about probe insertion across subjects and sessions.
-
- Attributes:
- Session (foreign key): Session primary key.
- insertion_number (foreign key, str): Unique insertion number for each probe insertion for a given session.
- probe.Probe (str): probe.Probe primary key.
- """
-
- definition = """
- # Probe insertion implanted into an animal for a given session.
- -> Session
- insertion_number: tinyint unsigned
- ---
- -> probe.Probe
- """
-
- @classmethod
- def auto_generate_entries(cls, session_key):
- """Automatically populate entries in ProbeInsertion table for a session."""
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(session_key)
- )
- # search session dir and determine acquisition software
- for ephys_pattern, ephys_acq_type in (
- ("*.ap.meta", "SpikeGLX"),
- ("*.oebin", "Open Ephys"),
- ):
- ephys_meta_filepaths = list(session_dir.rglob(ephys_pattern))
- if ephys_meta_filepaths:
- acq_software = ephys_acq_type
- break
- else:
- raise FileNotFoundError(
- f"Ephys recording data not found!"
- f" Neither SpikeGLX nor Open Ephys recording files found in: {session_dir}"
- )
-
- probe_list, probe_insertion_list = [], []
- if acq_software == "SpikeGLX":
- for meta_fp_idx, meta_filepath in enumerate(ephys_meta_filepaths):
- spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath)
-
- probe_key = {
- "probe_type": spikeglx_meta.probe_model,
- "probe": spikeglx_meta.probe_SN,
- }
- if probe_key["probe"] not in [p["probe"] for p in probe_list]:
- probe_list.append(probe_key)
-
- probe_dir = meta_filepath.parent
- try:
- probe_number = re.search(r"(imec)?\d{1}$", probe_dir.name).group()
- probe_number = int(probe_number.replace("imec", ""))
- except AttributeError:
- probe_number = meta_fp_idx
-
- probe_insertion_list.append(
- {
- **session_key,
- "probe": spikeglx_meta.probe_SN,
- "insertion_number": int(probe_number),
- }
- )
- elif acq_software == "Open Ephys":
- loaded_oe = openephys.OpenEphys(session_dir)
- for probe_idx, oe_probe in enumerate(loaded_oe.probes.values()):
- probe_key = {
- "probe_type": oe_probe.probe_model,
- "probe": oe_probe.probe_SN,
- }
- if probe_key["probe"] not in [p["probe"] for p in probe_list]:
- probe_list.append(probe_key)
- probe_insertion_list.append(
- {
- **session_key,
- "probe": oe_probe.probe_SN,
- "insertion_number": probe_idx,
- }
- )
- else:
- raise NotImplementedError(f"Unknown acquisition software: {acq_software}")
-
- probe.Probe.insert(probe_list, skip_duplicates=True)
- cls.insert(probe_insertion_list, skip_duplicates=True)
-
-
-@schema
-class InsertionLocation(dj.Manual):
- """Stereotaxic location information for each probe insertion.
-
- Attributes:
- ProbeInsertion (foreign key): ProbeInsertion primary key.
- SkullReference (dict): SkullReference primary key.
- ap_location (decimal (6, 2) ): Anterior-posterior location in micrometers. Reference is 0 with anterior values positive.
- ml_location (decimal (6, 2) ): Medial-lateral location in micrometers. Reference is zero with right side values positive.
- depth (decimal (6, 2) ): Manipulator depth relative to the surface of the brain at zero. Ventral is negative.
- Theta (decimal (5, 2) ): elevation - rotation about the ml-axis in degrees relative to positive z-axis.
- phi (decimal (5, 2) ): azimuth - rotation about the dv-axis in degrees relative to the positive x-axis.
- """
-
- definition = """
- # Brain Location of a given probe insertion.
- -> ProbeInsertion
- ---
- -> SkullReference
- ap_location: decimal(6, 2) # (um) anterior-posterior; ref is 0; more anterior is more positive
- ml_location: decimal(6, 2) # (um) medial axis; ref is 0 ; more right is more positive
- depth: decimal(6, 2) # (um) manipulator depth relative to surface of the brain (0); more ventral is more negative
- theta=null: decimal(5, 2) # (deg) - elevation - rotation about the ml-axis [0, 180] - w.r.t the z+ axis
- phi=null: decimal(5, 2) # (deg) - azimuth - rotation about the dv-axis [0, 360] - w.r.t the x+ axis
- beta=null: decimal(5, 2) # (deg) rotation about the shank of the probe [-180, 180] - clockwise is increasing in degree - 0 is the probe-front facing anterior
- """
-
-
-@schema
-class EphysRecording(dj.Imported):
- """Automated table with electrophysiology recording information for each probe inserted during an experimental session.
-
- Attributes:
- ProbeInsertion (foreign key): ProbeInsertion primary key.
- probe.ElectrodeConfig (dict): probe.ElectrodeConfig primary key.
- AcquisitionSoftware (dict): AcquisitionSoftware primary key.
- sampling_rate (float): sampling rate of the recording in Hertz (Hz).
- recording_datetime (datetime): datetime of the recording from this probe.
- recording_duration (float): duration of the entire recording from this probe in seconds.
- """
-
- definition = """
- # Ephys recording from a probe insertion for a given session.
- -> ProbeInsertion
- ---
- -> probe.ElectrodeConfig
- -> AcquisitionSoftware
- sampling_rate: float # (Hz)
- recording_datetime: datetime # datetime of the recording from this probe
- recording_duration: float # (seconds) duration of the recording from this probe
- """
-
- class EphysFile(dj.Part):
- """Paths of electrophysiology recording files for each insertion.
-
- Attributes:
- EphysRecording (foreign key): EphysRecording primary key.
- file_path (varchar(255) ): relative file path for electrophysiology recording.
- """
-
- definition = """
- # Paths of files of a given EphysRecording round.
- -> master
- file_path: varchar(255) # filepath relative to root data directory
- """
-
- def make(self, key):
- """Populates table with electrophysiology recording information."""
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(key)
- )
-
- inserted_probe_serial_number = (ProbeInsertion * probe.Probe & key).fetch1(
- "probe"
- )
-
- # search session dir and determine acquisition software
- for ephys_pattern, ephys_acq_type in (
- ("*.ap.meta", "SpikeGLX"),
- ("*.oebin", "Open Ephys"),
- ):
- ephys_meta_filepaths = list(session_dir.rglob(ephys_pattern))
- if ephys_meta_filepaths:
- acq_software = ephys_acq_type
- break
- else:
- raise FileNotFoundError(
- f"Ephys recording data not found!"
- f" Neither SpikeGLX nor Open Ephys recording files found"
- f" in {session_dir}"
- )
-
- supported_probe_types = probe.ProbeType.fetch("probe_type")
-
- if acq_software == "SpikeGLX":
- for meta_filepath in ephys_meta_filepaths:
- spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath)
- if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number:
- break
- else:
- raise FileNotFoundError(
- "No SpikeGLX data found for probe insertion: {}".format(key)
- )
-
- if spikeglx_meta.probe_model in supported_probe_types:
- probe_type = spikeglx_meta.probe_model
- electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type}
-
- probe_electrodes = {
- (shank, shank_col, shank_row): key
- for key, shank, shank_col, shank_row in zip(
- *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row")
- )
- }
-
- electrode_group_members = [
- probe_electrodes[(shank, shank_col, shank_row)]
- for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"]
- ]
- else:
- raise NotImplementedError(
- "Processing for neuropixels probe model"
- " {} not yet implemented".format(spikeglx_meta.probe_model)
- )
-
- self.insert1(
- {
- **key,
- **generate_electrode_config(probe_type, electrode_group_members),
- "acq_software": acq_software,
- "sampling_rate": spikeglx_meta.meta["imSampRate"],
- "recording_datetime": spikeglx_meta.recording_time,
- "recording_duration": (
- spikeglx_meta.recording_duration
- or spikeglx.retrieve_recording_duration(meta_filepath)
- ),
- }
- )
-
- root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath)
- self.EphysFile.insert1(
- {**key, "file_path": meta_filepath.relative_to(root_dir).as_posix()}
- )
- elif acq_software == "Open Ephys":
- dataset = openephys.OpenEphys(session_dir)
- for serial_number, probe_data in dataset.probes.items():
- if str(serial_number) == inserted_probe_serial_number:
- break
- else:
- raise FileNotFoundError(
- "No Open Ephys data found for probe insertion: {}".format(key)
- )
-
- if not probe_data.ap_meta:
- raise IOError(
- 'No analog signals found - check "structure.oebin" file or "continuous" directory'
- )
-
- if probe_data.probe_model in supported_probe_types:
- probe_type = probe_data.probe_model
- electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type}
-
- probe_electrodes = {
- key["electrode"]: key for key in electrode_query.fetch("KEY")
- }
-
- electrode_group_members = [
- probe_electrodes[channel_idx]
- for channel_idx in probe_data.ap_meta["channels_indices"]
- ]
- else:
- raise NotImplementedError(
- "Processing for neuropixels"
- " probe model {} not yet implemented".format(probe_data.probe_model)
- )
-
- self.insert1(
- {
- **key,
- **generate_electrode_config(probe_type, electrode_group_members),
- "acq_software": acq_software,
- "sampling_rate": probe_data.ap_meta["sample_rate"],
- "recording_datetime": probe_data.recording_info[
- "recording_datetimes"
- ][0],
- "recording_duration": np.sum(
- probe_data.recording_info["recording_durations"]
- ),
- }
- )
-
- root_dir = find_root_directory(
- get_ephys_root_data_dir(),
- probe_data.recording_info["recording_files"][0],
- )
- self.EphysFile.insert(
- [
- {**key, "file_path": fp.relative_to(root_dir).as_posix()}
- for fp in probe_data.recording_info["recording_files"]
- ]
- )
- # explicitly garbage collect "dataset"
- # as these may have large memory footprint and may not be cleared fast enough
- del probe_data, dataset
- gc.collect()
- else:
- raise NotImplementedError(
- f"Processing ephys files from"
- f" acquisition software of type {acq_software} is"
- f" not yet implemented"
- )
-
-
-@schema
-class LFP(dj.Imported):
- """Extracts local field potentials (LFP) from an electrophysiology recording.
-
- Attributes:
- EphysRecording (foreign key): EphysRecording primary key.
- lfp_sampling_rate (float): Sampling rate for LFPs in Hz.
- lfp_time_stamps (longblob): Time stamps with respect to the start of the recording.
- lfp_mean (longblob): Overall mean LFP across electrodes.
- """
-
- definition = """
- # Acquired local field potential (LFP) from a given Ephys recording.
- -> EphysRecording
- ---
- lfp_sampling_rate: float # (Hz)
- lfp_time_stamps: longblob # (s) timestamps with respect to the start of the recording (recording_timestamp)
- lfp_mean: longblob # (uV) mean of LFP across electrodes - shape (time,)
- """
-
- class Electrode(dj.Part):
- """Saves local field potential data for each electrode.
-
- Attributes:
- LFP (foreign key): LFP primary key.
- probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key.
- lfp (longblob): LFP recording at this electrode in microvolts.
- """
-
- definition = """
- -> master
- -> probe.ElectrodeConfig.Electrode
- ---
- lfp: longblob # (uV) recorded lfp at this electrode
- """
-
- # Only store LFP for every 9th channel, due to high channel density,
- # close-by channels exhibit highly similar LFP
- _skip_channel_counts = 9
-
- def make(self, key):
- """Populates the LFP tables."""
- acq_software = (EphysRecording * ProbeInsertion & key).fetch1("acq_software")
-
- electrode_keys, lfp = [], []
-
- if acq_software == "SpikeGLX":
- spikeglx_meta_filepath = get_spikeglx_meta_filepath(key)
- spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent)
-
- lfp_channel_ind = spikeglx_recording.lfmeta.recording_channels[
- -1 :: -self._skip_channel_counts
- ]
-
- # Extract LFP data at specified channels and convert to uV
- lfp = spikeglx_recording.lf_timeseries[
- :, lfp_channel_ind
- ] # (sample x channel)
- lfp = (
- lfp * spikeglx_recording.get_channel_bit_volts("lf")[lfp_channel_ind]
- ).T # (channel x sample)
-
- self.insert1(
- dict(
- key,
- lfp_sampling_rate=spikeglx_recording.lfmeta.meta["imSampRate"],
- lfp_time_stamps=(
- np.arange(lfp.shape[1])
- / spikeglx_recording.lfmeta.meta["imSampRate"]
- ),
- lfp_mean=lfp.mean(axis=0),
- )
- )
-
- electrode_query = (
- probe.ProbeType.Electrode
- * probe.ElectrodeConfig.Electrode
- * EphysRecording
- & key
- )
- probe_electrodes = {
- (shank, shank_col, shank_row): key
- for key, shank, shank_col, shank_row in zip(
- *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row")
- )
- }
-
- for recorded_site in lfp_channel_ind:
- shank, shank_col, shank_row, _ = spikeglx_recording.apmeta.shankmap[
- "data"
- ][recorded_site]
- electrode_keys.append(probe_electrodes[(shank, shank_col, shank_row)])
- elif acq_software == "Open Ephys":
- oe_probe = get_openephys_probe_data(key)
-
- lfp_channel_ind = np.r_[
- len(oe_probe.lfp_meta["channels_indices"])
- - 1 : 0 : -self._skip_channel_counts
- ]
-
- # (sample x channel)
- lfp = oe_probe.lfp_timeseries[:, lfp_channel_ind]
- lfp = (
- lfp * np.array(oe_probe.lfp_meta["channels_gains"])[lfp_channel_ind]
- ).T # (channel x sample)
- lfp_timestamps = oe_probe.lfp_timestamps
-
- self.insert1(
- dict(
- key,
- lfp_sampling_rate=oe_probe.lfp_meta["sample_rate"],
- lfp_time_stamps=lfp_timestamps,
- lfp_mean=lfp.mean(axis=0),
- )
- )
-
- electrode_query = (
- probe.ProbeType.Electrode
- * probe.ElectrodeConfig.Electrode
- * EphysRecording
- & key
- )
- probe_electrodes = {
- key["electrode"]: key for key in electrode_query.fetch("KEY")
- }
-
- electrode_keys.extend(
- probe_electrodes[channel_idx] for channel_idx in lfp_channel_ind
- )
- else:
- raise NotImplementedError(
- f"LFP extraction from acquisition software"
- f" of type {acq_software} is not yet implemented"
- )
-
- # single insert in loop to mitigate potential memory issue
- for electrode_key, lfp_trace in zip(electrode_keys, lfp):
- self.Electrode.insert1({**key, **electrode_key, "lfp": lfp_trace})
-
-
-# ------------ Clustering --------------
-
-
-@schema
-class ClusteringMethod(dj.Lookup):
- """Kilosort clustering method.
-
- Attributes:
- clustering_method (foreign key, varchar(16) ): Kilosort clustering method.
- clustering_methods_desc (varchar(1000) ): Additional description of the clustering method.
- """
-
- definition = """
- # Method for clustering
- clustering_method: varchar(16)
- ---
- clustering_method_desc: varchar(1000)
- """
-
- contents = [
- ("kilosort2", "kilosort2 clustering method"),
- ("kilosort2.5", "kilosort2.5 clustering method"),
- ("kilosort3", "kilosort3 clustering method"),
- ]
-
-
-@schema
-class ClusteringParamSet(dj.Lookup):
- """Parameters to be used in clustering procedure for spike sorting.
-
- Attributes:
- paramset_idx (foreign key): Unique ID for the clustering parameter set.
- ClusteringMethod (dict): ClusteringMethod primary key.
- paramset_desc (varchar(128) ): Description of the clustering parameter set.
- param_set_hash (uuid): UUID hash for the parameter set.
- params (longblob): Parameters for clustering with Kilosort.
- """
-
- definition = """
- # Parameter set to be used in a clustering procedure
- paramset_idx: smallint
- ---
- -> ClusteringMethod
- paramset_desc: varchar(128)
- param_set_hash: uuid
- unique index (param_set_hash)
- params: longblob # dictionary of all applicable parameters
- """
-
- @classmethod
- def insert_new_params(
- cls,
- clustering_method: str,
- paramset_desc: str,
- params: dict,
- paramset_idx: int = None,
- ):
- """Inserts new parameters into the ClusteringParamSet table.
-
- Args:
- clustering_method (str): name of the clustering method.
- paramset_desc (str): description of the parameter set
- params (dict): clustering parameters
- paramset_idx (int, optional): Unique parameter set ID. Defaults to None.
- """
- if paramset_idx is None:
- paramset_idx = (
- dj.U().aggr(cls, n="max(paramset_idx)").fetch1("n") or 0
- ) + 1
-
- param_dict = {
- "clustering_method": clustering_method,
- "paramset_idx": paramset_idx,
- "paramset_desc": paramset_desc,
- "params": params,
- "param_set_hash": dict_to_uuid(
- {**params, "clustering_method": clustering_method}
- ),
- }
- param_query = cls & {"param_set_hash": param_dict["param_set_hash"]}
-
- if param_query: # If the specified param-set already exists
- existing_paramset_idx = param_query.fetch1("paramset_idx")
- if (
- existing_paramset_idx == paramset_idx
- ): # If the existing set has the same paramset_idx: job done
- return
- else: # If not same name: human error, trying to add the same paramset with different name
- raise dj.DataJointError(
- f"The specified param-set already exists"
- f" - with paramset_idx: {existing_paramset_idx}"
- )
- else:
- if {"paramset_idx": paramset_idx} in cls.proj():
- raise dj.DataJointError(
- f"The specified paramset_idx {paramset_idx} already exists,"
- f" please pick a different one."
- )
- cls.insert1(param_dict)
-
-
-@schema
-class ClusterQualityLabel(dj.Lookup):
- """Quality label for each spike sorted cluster.
-
- Attributes:
- cluster_quality_label (foreign key, varchar(100) ): Cluster quality type.
- cluster_quality_description ( varchar(4000) ): Description of the cluster quality type.
- """
-
- definition = """
- # Quality
- cluster_quality_label: varchar(100) # cluster quality type - e.g. 'good', 'MUA', 'noise', etc.
- ---
- cluster_quality_description: varchar(4000)
- """
- contents = [
- ("good", "single unit"),
- ("ok", "probably a single unit, but could be contaminated"),
- ("mua", "multi-unit activity"),
- ("noise", "bad unit"),
- ]
-
-
-@schema
-class ClusteringTask(dj.Manual):
- """A clustering task to spike sort electrophysiology datasets.
-
- Attributes:
- EphysRecording (foreign key): EphysRecording primary key.
- ClusteringParamSet (foreign key): ClusteringParamSet primary key.
- clustering_output_dir ( varchar (255) ): Relative path to output clustering results.
- task_mode (enum): `Trigger` computes clustering or and `load` imports existing data.
- """
-
- definition = """
- # Manual table for defining a clustering task ready to be run
- -> EphysRecording
- -> ClusteringParamSet
- ---
- clustering_output_dir='': varchar(255) # clustering output directory relative to the clustering root data directory
- task_mode='load': enum('load', 'trigger') # 'load': load computed analysis results, 'trigger': trigger computation
- """
-
- @classmethod
- def infer_output_dir(
- cls, key: dict, relative: bool = False, mkdir: bool = False
- ) -> pathlib.Path:
- """Infer output directory if it is not provided.
-
- Args:
- key (dict): ClusteringTask primary key.
-
- Returns:
- Expected clustering_output_dir based on the following convention:
- processed_dir / session_dir / probe_{insertion_number} / {clustering_method}_{paramset_idx}
- e.g.: sub4/sess1/probe_2/kilosort2_0
- """
- processed_dir = pathlib.Path(get_processed_root_data_dir())
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(key)
- )
- root_dir = find_root_directory(get_ephys_root_data_dir(), session_dir)
-
- method = (
- (ClusteringParamSet * ClusteringMethod & key)
- .fetch1("clustering_method")
- .replace(".", "-")
- )
-
- output_dir = (
- processed_dir
- / session_dir.relative_to(root_dir)
- / f'probe_{key["insertion_number"]}'
- / f'{method}_{key["paramset_idx"]}'
- )
-
- if mkdir:
- output_dir.mkdir(parents=True, exist_ok=True)
- log.info(f"{output_dir} created!")
-
- return output_dir.relative_to(processed_dir) if relative else output_dir
-
- @classmethod
- def auto_generate_entries(cls, ephys_recording_key: dict, paramset_idx: int = 0):
- """Autogenerate entries based on a particular ephys recording.
-
- Args:
- ephys_recording_key (dict): EphysRecording primary key.
- paramset_idx (int, optional): Parameter index to use for clustering task. Defaults to 0.
- """
- key = {**ephys_recording_key, "paramset_idx": paramset_idx}
-
- processed_dir = get_processed_root_data_dir()
- output_dir = ClusteringTask.infer_output_dir(key, relative=False, mkdir=True)
-
- try:
- kilosort.Kilosort(
- output_dir
- ) # check if the directory is a valid Kilosort output
- except FileNotFoundError:
- task_mode = "trigger"
- else:
- task_mode = "load"
-
- cls.insert1(
- {
- **key,
- "clustering_output_dir": output_dir.relative_to(
- processed_dir
- ).as_posix(),
- "task_mode": task_mode,
- }
- )
-
-
-@schema
-class Clustering(dj.Imported):
- """A processing table to handle each clustering task.
-
- Attributes:
- ClusteringTask (foreign key): ClusteringTask primary key.
- clustering_time (datetime): Time when clustering results are generated.
- package_version ( varchar(16) ): Package version used for a clustering analysis.
- """
-
- definition = """
- # Clustering Procedure
- -> ClusteringTask
- ---
- clustering_time: datetime # time of generation of this set of clustering results
- package_version='': varchar(16)
- """
-
- def make(self, key):
- """Triggers or imports clustering analysis."""
- task_mode, output_dir = (ClusteringTask & key).fetch1(
- "task_mode", "clustering_output_dir"
- )
-
- if not output_dir:
- output_dir = ClusteringTask.infer_output_dir(key, relative=True, mkdir=True)
- # update clustering_output_dir
- ClusteringTask.update1(
- {**key, "clustering_output_dir": output_dir.as_posix()}
- )
-
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- if task_mode == "load":
- kilosort.Kilosort(
- kilosort_dir
- ) # check if the directory is a valid Kilosort output
- elif task_mode == "trigger":
- acq_software, clustering_method, params = (
- ClusteringTask * EphysRecording * ClusteringParamSet & key
- ).fetch1("acq_software", "clustering_method", "params")
-
- if "kilosort" in clustering_method:
- from element_array_ephys.readers import kilosort_triggering
-
- # add additional probe-recording and channels details into `params`
- params = {**params, **get_recording_channels_details(key)}
- params["fs"] = params["sample_rate"]
-
- if acq_software == "SpikeGLX":
- spikeglx_meta_filepath = get_spikeglx_meta_filepath(key)
- spikeglx_recording = spikeglx.SpikeGLX(
- spikeglx_meta_filepath.parent
- )
- spikeglx_recording.validate_file("ap")
- run_CatGT = (
- params.pop("run_CatGT", True)
- and "_tcat." not in spikeglx_meta_filepath.stem
- )
-
- if clustering_method.startswith("pykilosort"):
- kilosort_triggering.run_pykilosort(
- continuous_file=spikeglx_recording.root_dir
- / (spikeglx_recording.root_name + ".ap.bin"),
- kilosort_output_directory=kilosort_dir,
- channel_ind=params.pop("channel_ind"),
- x_coords=params.pop("x_coords"),
- y_coords=params.pop("y_coords"),
- shank_ind=params.pop("shank_ind"),
- connected=params.pop("connected"),
- sample_rate=params.pop("sample_rate"),
- params=params,
- )
- else:
- run_kilosort = kilosort_triggering.SGLXKilosortPipeline(
- npx_input_dir=spikeglx_meta_filepath.parent,
- ks_output_dir=kilosort_dir,
- params=params,
- KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}',
- run_CatGT=run_CatGT,
- )
- run_kilosort.run_modules()
- elif acq_software == "Open Ephys":
- oe_probe = get_openephys_probe_data(key)
-
- assert len(oe_probe.recording_info["recording_files"]) == 1
-
- # run kilosort
- if clustering_method.startswith("pykilosort"):
- kilosort_triggering.run_pykilosort(
- continuous_file=pathlib.Path(
- oe_probe.recording_info["recording_files"][0]
- )
- / "continuous.dat",
- kilosort_output_directory=kilosort_dir,
- channel_ind=params.pop("channel_ind"),
- x_coords=params.pop("x_coords"),
- y_coords=params.pop("y_coords"),
- shank_ind=params.pop("shank_ind"),
- connected=params.pop("connected"),
- sample_rate=params.pop("sample_rate"),
- params=params,
- )
- else:
- run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline(
- npx_input_dir=oe_probe.recording_info["recording_files"][0],
- ks_output_dir=kilosort_dir,
- params=params,
- KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}',
- )
- run_kilosort.run_modules()
- else:
- raise NotImplementedError(
- f"Automatic triggering of {clustering_method}"
- f" clustering analysis is not yet supported"
- )
-
- else:
- raise ValueError(f"Unknown task mode: {task_mode}")
-
- creation_time, _, _ = kilosort.extract_clustering_info(kilosort_dir)
- self.insert1({**key, "clustering_time": creation_time, "package_version": ""})
-
-
-@schema
-class Curation(dj.Manual):
- """Curation procedure table.
-
- Attributes:
- Clustering (foreign key): Clustering primary key.
- curation_id (foreign key, int): Unique curation ID.
- curation_time (datetime): Time when curation results are generated.
- curation_output_dir ( varchar(255) ): Output directory of the curated results.
- quality_control (bool): If True, this clustering result has undergone quality control.
- manual_curation (bool): If True, manual curation has been performed on this clustering result.
- curation_note ( varchar(2000) ): Notes about the curation task.
- """
-
- definition = """
- # Manual curation procedure
- -> Clustering
- curation_id: int
- ---
- curation_time: datetime # time of generation of this set of curated clustering results
- curation_output_dir: varchar(255) # output directory of the curated results, relative to root data directory
- quality_control: bool # has this clustering result undergone quality control?
- manual_curation: bool # has manual curation been performed on this clustering result?
- curation_note='': varchar(2000)
- """
-
- def create1_from_clustering_task(self, key, curation_note=""):
- """
- A function to create a new corresponding "Curation" for a particular
- "ClusteringTask"
- """
- if key not in Clustering():
- raise ValueError(
- f"No corresponding entry in Clustering available"
- f" for: {key}; do `Clustering.populate(key)`"
- )
-
- task_mode, output_dir = (ClusteringTask & key).fetch1(
- "task_mode", "clustering_output_dir"
- )
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- creation_time, is_curated, is_qc = kilosort.extract_clustering_info(
- kilosort_dir
- )
- # Synthesize curation_id
- curation_id = (
- dj.U().aggr(self & key, n="ifnull(max(curation_id)+1,1)").fetch1("n")
- )
- self.insert1(
- {
- **key,
- "curation_id": curation_id,
- "curation_time": creation_time,
- "curation_output_dir": output_dir,
- "quality_control": is_qc,
- "manual_curation": is_curated,
- "curation_note": curation_note,
- }
- )
-
-
-@schema
-class CuratedClustering(dj.Imported):
- """Clustering results after curation.
-
- Attributes:
- Curation (foreign key): Curation primary key.
- """
-
- definition = """
- # Clustering results of a curation.
- -> Curation
- """
-
- class Unit(dj.Part):
- """Single unit properties after clustering and curation.
-
- Attributes:
- CuratedClustering (foreign key): CuratedClustering primary key.
- unit (foreign key, int): Unique integer identifying a single unit.
- probe.ElectrodeConfig.Electrode (dict): probe.ElectrodeConfig.Electrode primary key.
- ClusteringQualityLabel (dict): CLusteringQualityLabel primary key.
- spike_count (int): Number of spikes in this recording for this unit.
- spike_times (longblob): Spike times of this unit, relative to start time of EphysRecording.
- spike_sites (longblob): Array of electrode associated with each spike.
- spike_depths (longblob): Array of depths associated with each spike, relative to each spike.
- """
-
- definition = """
- # Properties of a given unit from a round of clustering (and curation)
- -> master
- unit: int
- ---
- -> probe.ElectrodeConfig.Electrode # electrode with highest waveform amplitude for this unit
- -> ClusterQualityLabel
- spike_count: int # how many spikes in this recording for this unit
- spike_times: longblob # (s) spike times of this unit, relative to the start of the EphysRecording
- spike_sites : longblob # array of electrode associated with each spike
- spike_depths=null : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe
- """
-
- def make(self, key):
- """Automated population of Unit information."""
- output_dir = (Curation & key).fetch1("curation_output_dir")
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- kilosort_dataset = kilosort.Kilosort(kilosort_dir)
- acq_software, sample_rate = (EphysRecording & key).fetch1(
- "acq_software", "sampling_rate"
- )
-
- sample_rate = kilosort_dataset.data["params"].get("sample_rate", sample_rate)
-
- # ---------- Unit ----------
- # -- Remove 0-spike units
- withspike_idx = [
- i
- for i, u in enumerate(kilosort_dataset.data["cluster_ids"])
- if (kilosort_dataset.data["spike_clusters"] == u).any()
- ]
- valid_units = kilosort_dataset.data["cluster_ids"][withspike_idx]
- valid_unit_labels = kilosort_dataset.data["cluster_groups"][withspike_idx]
- # -- Get channel and electrode-site mapping
- channel2electrodes = get_neuropixels_channel2electrode_map(key, acq_software)
-
- # -- Spike-times --
- # spike_times_sec_adj > spike_times_sec > spike_times
- spike_time_key = (
- "spike_times_sec_adj"
- if "spike_times_sec_adj" in kilosort_dataset.data
- else (
- "spike_times_sec"
- if "spike_times_sec" in kilosort_dataset.data
- else "spike_times"
- )
- )
- spike_times = kilosort_dataset.data[spike_time_key]
- kilosort_dataset.extract_spike_depths()
-
- # -- Spike-sites and Spike-depths --
- spike_sites = np.array(
- [
- channel2electrodes[s]["electrode"]
- for s in kilosort_dataset.data["spike_sites"]
- ]
- )
- spike_depths = kilosort_dataset.data["spike_depths"]
-
- # -- Insert unit, label, peak-chn
- units = []
- for unit, unit_lbl in zip(valid_units, valid_unit_labels):
- if (kilosort_dataset.data["spike_clusters"] == unit).any():
- unit_channel, _ = kilosort_dataset.get_best_channel(unit)
- unit_spike_times = (
- spike_times[kilosort_dataset.data["spike_clusters"] == unit]
- / sample_rate
- )
- spike_count = len(unit_spike_times)
-
- units.append(
- {
- "unit": unit,
- "cluster_quality_label": unit_lbl,
- **channel2electrodes[unit_channel],
- "spike_times": unit_spike_times,
- "spike_count": spike_count,
- "spike_sites": spike_sites[
- kilosort_dataset.data["spike_clusters"] == unit
- ],
- "spike_depths": (
- spike_depths[
- kilosort_dataset.data["spike_clusters"] == unit
- ]
- if spike_depths is not None
- else None
- ),
- }
- )
-
- self.insert1(key)
- self.Unit.insert([{**key, **u} for u in units])
-
-
-@schema
-class WaveformSet(dj.Imported):
- """A set of spike waveforms for units out of a given CuratedClustering.
-
- Attributes:
- CuratedClustering (foreign key): CuratedClustering primary key.
- """
-
- definition = """
- # A set of spike waveforms for units out of a given CuratedClustering
- -> CuratedClustering
- """
-
- class PeakWaveform(dj.Part):
- """Mean waveform across spikes for a given unit.
-
- Attributes:
- WaveformSet (foreign key): WaveformSet primary key.
- CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key.
- peak_electrode_waveform (longblob): Mean waveform for a given unit at its representative electrode.
- """
-
- definition = """
- # Mean waveform across spikes for a given unit at its representative electrode
- -> master
- -> CuratedClustering.Unit
- ---
- peak_electrode_waveform: longblob # (uV) mean waveform for a given unit at its representative electrode
- """
-
- class Waveform(dj.Part):
- """Spike waveforms for a given unit.
-
- Attributes:
- WaveformSet (foreign key): WaveformSet primary key.
- CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key.
- probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key.
- waveform_mean (longblob): mean waveform across spikes of the unit in microvolts.
- waveforms (longblob): waveforms of a sampling of spikes at the given electrode and unit.
- """
-
- definition = """
- # Spike waveforms and their mean across spikes for the given unit
- -> master
- -> CuratedClustering.Unit
- -> probe.ElectrodeConfig.Electrode
- ---
- waveform_mean: longblob # (uV) mean waveform across spikes of the given unit
- waveforms=null: longblob # (uV) (spike x sample) waveforms of a sampling of spikes at the given electrode for the given unit
- """
-
- def make(self, key):
- """Populates waveform tables."""
- output_dir = (Curation & key).fetch1("curation_output_dir")
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- kilosort_dataset = kilosort.Kilosort(kilosort_dir)
-
- acq_software, probe_serial_number = (
- EphysRecording * ProbeInsertion & key
- ).fetch1("acq_software", "probe")
-
- # -- Get channel and electrode-site mapping
- recording_key = (EphysRecording & key).fetch1("KEY")
- channel2electrodes = get_neuropixels_channel2electrode_map(
- recording_key, acq_software
- )
-
- is_qc = (Curation & key).fetch1("quality_control")
-
- # Get all units
- units = {
- u["unit"]: u
- for u in (CuratedClustering.Unit & key).fetch(as_dict=True, order_by="unit")
- }
-
- if is_qc:
- unit_waveforms = np.load(
- kilosort_dir / "mean_waveforms.npy"
- ) # unit x channel x sample
-
- def yield_unit_waveforms():
- for unit_no, unit_waveform in zip(
- kilosort_dataset.data["cluster_ids"], unit_waveforms
- ):
- unit_peak_waveform = {}
- unit_electrode_waveforms = []
- if unit_no in units:
- for channel, channel_waveform in zip(
- kilosort_dataset.data["channel_map"], unit_waveform
- ):
- unit_electrode_waveforms.append(
- {
- **units[unit_no],
- **channel2electrodes[channel],
- "waveform_mean": channel_waveform,
- }
- )
- if (
- channel2electrodes[channel]["electrode"]
- == units[unit_no]["electrode"]
- ):
- unit_peak_waveform = {
- **units[unit_no],
- "peak_electrode_waveform": channel_waveform,
- }
- yield unit_peak_waveform, unit_electrode_waveforms
-
- else:
- if acq_software == "SpikeGLX":
- spikeglx_meta_filepath = get_spikeglx_meta_filepath(key)
- neuropixels_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent)
- elif acq_software == "Open Ephys":
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(key)
- )
- openephys_dataset = openephys.OpenEphys(session_dir)
- neuropixels_recording = openephys_dataset.probes[probe_serial_number]
-
- def yield_unit_waveforms():
- for unit_dict in units.values():
- unit_peak_waveform = {}
- unit_electrode_waveforms = []
-
- spikes = unit_dict["spike_times"]
- waveforms = neuropixels_recording.extract_spike_waveforms(
- spikes, kilosort_dataset.data["channel_map"]
- ) # (sample x channel x spike)
- waveforms = waveforms.transpose(
- (1, 2, 0)
- ) # (channel x spike x sample)
- for channel, channel_waveform in zip(
- kilosort_dataset.data["channel_map"], waveforms
- ):
- unit_electrode_waveforms.append(
- {
- **unit_dict,
- **channel2electrodes[channel],
- "waveform_mean": channel_waveform.mean(axis=0),
- "waveforms": channel_waveform,
- }
- )
- if (
- channel2electrodes[channel]["electrode"]
- == unit_dict["electrode"]
- ):
- unit_peak_waveform = {
- **unit_dict,
- "peak_electrode_waveform": channel_waveform.mean(
- axis=0
- ),
- }
-
- yield unit_peak_waveform, unit_electrode_waveforms
-
- # insert waveform on a per-unit basis to mitigate potential memory issue
- self.insert1(key)
- for unit_peak_waveform, unit_electrode_waveforms in yield_unit_waveforms():
- if unit_peak_waveform:
- self.PeakWaveform.insert1(unit_peak_waveform, ignore_extra_fields=True)
- if unit_electrode_waveforms:
- self.Waveform.insert(unit_electrode_waveforms, ignore_extra_fields=True)
-
-
-@schema
-class QualityMetrics(dj.Imported):
- """Clustering and waveform quality metrics.
-
- Attributes:
- CuratedClustering (foreign key): CuratedClustering primary key.
- """
-
- definition = """
- # Clusters and waveforms metrics
- -> CuratedClustering
- """
-
- class Cluster(dj.Part):
- """Cluster metrics for a unit.
-
- Attributes:
- QualityMetrics (foreign key): QualityMetrics primary key.
- CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key.
- firing_rate (float): Firing rate of the unit.
- snr (float): Signal-to-noise ratio for a unit.
- presence_ratio (float): Fraction of time where spikes are present.
- isi_violation (float): rate of ISI violation as a fraction of overall rate.
- number_violation (int): Total ISI violations.
- amplitude_cutoff (float): Estimate of miss rate based on amplitude histogram.
- isolation_distance (float): Distance to nearest cluster.
- l_ratio (float): Amount of empty space between a cluster and other spikes in dataset.
- d_prime (float): Classification accuracy based on LDA.
- nn_hit_rate (float): Fraction of neighbors for target cluster that are also in target cluster.
- nn_miss_rate (float): Fraction of neighbors outside target cluster that are in the target cluster.
- silhouette_core (float): Maximum change in spike depth throughout recording.
- cumulative_drift (float): Cumulative change in spike depth throughout recording.
- contamination_rate (float): Frequency of spikes in the refractory period.
- """
-
- definition = """
- # Cluster metrics for a particular unit
- -> master
- -> CuratedClustering.Unit
- ---
- firing_rate=null: float # (Hz) firing rate for a unit
- snr=null: float # signal-to-noise ratio for a unit
- presence_ratio=null: float # fraction of time in which spikes are present
- isi_violation=null: float # rate of ISI violation as a fraction of overall rate
- number_violation=null: int # total number of ISI violations
- amplitude_cutoff=null: float # estimate of miss rate based on amplitude histogram
- isolation_distance=null: float # distance to nearest cluster in Mahalanobis space
- l_ratio=null: float #
- d_prime=null: float # Classification accuracy based on LDA
- nn_hit_rate=null: float # Fraction of neighbors for target cluster that are also in target cluster
- nn_miss_rate=null: float # Fraction of neighbors outside target cluster that are in target cluster
- silhouette_score=null: float # Standard metric for cluster overlap
- max_drift=null: float # Maximum change in spike depth throughout recording
- cumulative_drift=null: float # Cumulative change in spike depth throughout recording
- contamination_rate=null: float #
- """
-
- class Waveform(dj.Part):
- """Waveform metrics for a particular unit.
-
- Attributes:
- QualityMetrics (foreign key): QualityMetrics primary key.
- CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key.
- amplitude (float): Absolute difference between waveform peak and trough in microvolts.
- duration (float): Time between waveform peak and trough in milliseconds.
- halfwidth (float): Spike width at half max amplitude.
- pt_ratio (float): Absolute amplitude of peak divided by absolute amplitude of trough relative to 0.
- repolarization_slope (float): Slope of the regression line fit to first 30 microseconds from trough to peak.
- recovery_slope (float): Slope of the regression line fit to first 30 microseconds from peak to tail.
- spread (float): The range with amplitude over 12-percent of maximum amplitude along the probe.
- velocity_above (float): inverse velocity of waveform propagation from soma to the top of the probe.
- velocity_below (float): inverse velocity of waveform propagation from soma toward the bottom of the probe.
- """
-
- definition = """
- # Waveform metrics for a particular unit
- -> master
- -> CuratedClustering.Unit
- ---
- amplitude: float # (uV) absolute difference between waveform peak and trough
- duration: float # (ms) time between waveform peak and trough
- halfwidth=null: float # (ms) spike width at half max amplitude
- pt_ratio=null: float # absolute amplitude of peak divided by absolute amplitude of trough relative to 0
- repolarization_slope=null: float # the repolarization slope was defined by fitting a regression line to the first 30us from trough to peak
- recovery_slope=null: float # the recovery slope was defined by fitting a regression line to the first 30us from peak to tail
- spread=null: float # (um) the range with amplitude above 12-percent of the maximum amplitude along the probe
- velocity_above=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the top of the probe
- velocity_below=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the bottom of the probe
- """
-
- def make(self, key):
- """Populates tables with quality metrics data."""
- output_dir = (ClusteringTask & key).fetch1("clustering_output_dir")
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- metric_fp = kilosort_dir / "metrics.csv"
- rename_dict = {
- "isi_viol": "isi_violation",
- "num_viol": "number_violation",
- "contam_rate": "contamination_rate",
- }
-
- if not metric_fp.exists():
- raise FileNotFoundError(f"QC metrics file not found: {metric_fp}")
-
- metrics_df = pd.read_csv(metric_fp)
- metrics_df.set_index("cluster_id", inplace=True)
- metrics_df.replace([np.inf, -np.inf], np.nan, inplace=True)
- metrics_df.columns = metrics_df.columns.str.lower()
- metrics_df.rename(columns=rename_dict, inplace=True)
- metrics_list = [
- dict(metrics_df.loc[unit_key["unit"]], **unit_key)
- for unit_key in (CuratedClustering.Unit & key).fetch("KEY")
- ]
-
- self.insert1(key)
- self.Cluster.insert(metrics_list, ignore_extra_fields=True)
- self.Waveform.insert(metrics_list, ignore_extra_fields=True)
-
-
-# ---------------- HELPER FUNCTIONS ----------------
-
-
-def get_spikeglx_meta_filepath(ephys_recording_key: dict) -> str:
- """Get spikeGLX data filepath."""
- # attempt to retrieve from EphysRecording.EphysFile
- spikeglx_meta_filepath = pathlib.Path(
- (
- EphysRecording.EphysFile
- & ephys_recording_key
- & 'file_path LIKE "%.ap.meta"'
- ).fetch1("file_path")
- )
-
- try:
- spikeglx_meta_filepath = find_full_path(
- get_ephys_root_data_dir(), spikeglx_meta_filepath
- )
- except FileNotFoundError:
- # if not found, search in session_dir again
- if not spikeglx_meta_filepath.exists():
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(ephys_recording_key)
- )
- inserted_probe_serial_number = (
- ProbeInsertion * probe.Probe & ephys_recording_key
- ).fetch1("probe")
-
- spikeglx_meta_filepaths = [fp for fp in session_dir.rglob("*.ap.meta")]
- for meta_filepath in spikeglx_meta_filepaths:
- spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath)
- if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number:
- spikeglx_meta_filepath = meta_filepath
- break
- else:
- raise FileNotFoundError(
- "No SpikeGLX data found for probe insertion: {}".format(
- ephys_recording_key
- )
- )
-
- return spikeglx_meta_filepath
-
-
-def get_openephys_probe_data(ephys_recording_key: dict) -> list:
- """Get OpenEphys probe data from file."""
- inserted_probe_serial_number = (
- ProbeInsertion * probe.Probe & ephys_recording_key
- ).fetch1("probe")
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(ephys_recording_key)
- )
- loaded_oe = openephys.OpenEphys(session_dir)
- probe_data = loaded_oe.probes[inserted_probe_serial_number]
-
- # explicitly garbage collect "loaded_oe"
- # as these may have large memory footprint and may not be cleared fast enough
- del loaded_oe
- gc.collect()
-
- return probe_data
-
-
-def get_neuropixels_channel2electrode_map(
- ephys_recording_key: dict, acq_software: str
-) -> dict:
- """Get the channel map for neuropixels probe."""
- if acq_software == "SpikeGLX":
- spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key)
- spikeglx_meta = spikeglx.SpikeGLXMeta(spikeglx_meta_filepath)
- electrode_config_key = (
- EphysRecording * probe.ElectrodeConfig & ephys_recording_key
- ).fetch1("KEY")
-
- electrode_query = (
- probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode
- & electrode_config_key
- )
-
- probe_electrodes = {
- (shank, shank_col, shank_row): key
- for key, shank, shank_col, shank_row in zip(
- *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row")
- )
- }
-
- channel2electrode_map = {
- recorded_site: probe_electrodes[(shank, shank_col, shank_row)]
- for recorded_site, (shank, shank_col, shank_row, _) in enumerate(
- spikeglx_meta.shankmap["data"]
- )
- }
- elif acq_software == "Open Ephys":
- probe_dataset = get_openephys_probe_data(ephys_recording_key)
-
- electrode_query = (
- probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode * EphysRecording
- & ephys_recording_key
- )
-
- probe_electrodes = {
- key["electrode"]: key for key in electrode_query.fetch("KEY")
- }
-
- channel2electrode_map = {
- channel_idx: probe_electrodes[channel_idx]
- for channel_idx in probe_dataset.ap_meta["channels_indices"]
- }
-
- return channel2electrode_map
-
-
-def generate_electrode_config(probe_type: str, electrode_keys: list) -> dict:
- """Generate and insert new ElectrodeConfig
-
- Args:
- probe_type (str): probe type (e.g. neuropixels 2.0 - SS)
- electrode_keys (list): list of keys of the probe.ProbeType.Electrode table
-
- Returns:
- dict: representing a key of the probe.ElectrodeConfig table
- """
- # compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode)
- electrode_config_hash = dict_to_uuid({k["electrode"]: k for k in electrode_keys})
-
- electrode_list = sorted([k["electrode"] for k in electrode_keys])
- electrode_gaps = (
- [-1]
- + np.where(np.diff(electrode_list) > 1)[0].tolist()
- + [len(electrode_list) - 1]
- )
- electrode_config_name = "; ".join(
- [
- f"{electrode_list[start + 1]}-{electrode_list[end]}"
- for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:])
- ]
- )
-
- electrode_config_key = {"electrode_config_hash": electrode_config_hash}
-
- # ---- make new ElectrodeConfig if needed ----
- if not probe.ElectrodeConfig & electrode_config_key:
- probe.ElectrodeConfig.insert1(
- {
- **electrode_config_key,
- "probe_type": probe_type,
- "electrode_config_name": electrode_config_name,
- }
- )
- probe.ElectrodeConfig.Electrode.insert(
- {**electrode_config_key, **electrode} for electrode in electrode_keys
- )
-
- return electrode_config_key
-
-
-def get_recording_channels_details(ephys_recording_key: dict) -> np.array:
- """Get details of recording channels for a given recording."""
- channels_details = {}
-
- acq_software, sample_rate = (EphysRecording & ephys_recording_key).fetch1(
- "acq_software", "sampling_rate"
- )
-
- probe_type = (ProbeInsertion * probe.Probe & ephys_recording_key).fetch1(
- "probe_type"
- )
- channels_details["probe_type"] = {
- "neuropixels 1.0 - 3A": "3A",
- "neuropixels 1.0 - 3B": "NP1",
- "neuropixels UHD": "NP1100",
- "neuropixels 2.0 - SS": "NP21",
- "neuropixels 2.0 - MS": "NP24",
- }[probe_type]
-
- electrode_config_key = (
- probe.ElectrodeConfig * EphysRecording & ephys_recording_key
- ).fetch1("KEY")
- (
- channels_details["channel_ind"],
- channels_details["x_coords"],
- channels_details["y_coords"],
- channels_details["shank_ind"],
- ) = (
- probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode
- & electrode_config_key
- ).fetch(
- "electrode", "x_coord", "y_coord", "shank"
- )
- channels_details["sample_rate"] = sample_rate
- channels_details["num_channels"] = len(channels_details["channel_ind"])
-
- if acq_software == "SpikeGLX":
- spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key)
- spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent)
- channels_details["uVPerBit"] = spikeglx_recording.get_channel_bit_volts("ap")[0]
- channels_details["connected"] = np.array(
- [v for *_, v in spikeglx_recording.apmeta.shankmap["data"]]
- )
- elif acq_software == "Open Ephys":
- oe_probe = get_openephys_probe_data(ephys_recording_key)
- channels_details["uVPerBit"] = oe_probe.ap_meta["channels_gains"][0]
- channels_details["connected"] = np.array(
- [
- int(v == 1)
- for c, v in oe_probe.channels_connected.items()
- if c in channels_details["channel_ind"]
- ]
- )
-
- return channels_details
diff --git a/element_array_ephys/ephys_chronic.py b/element_array_ephys/ephys_chronic.py
deleted file mode 100644
index 772e885f..00000000
--- a/element_array_ephys/ephys_chronic.py
+++ /dev/null
@@ -1,1523 +0,0 @@
-import gc
-import importlib
-import inspect
-import pathlib
-from decimal import Decimal
-
-import datajoint as dj
-import numpy as np
-import pandas as pd
-from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory
-
-from . import ephys_report, probe
-from .readers import kilosort, openephys, spikeglx
-
-log = dj.logger
-
-schema = dj.schema()
-
-_linking_module = None
-
-
-def activate(
- ephys_schema_name: str,
- probe_schema_name: str = None,
- *,
- create_schema: bool = True,
- create_tables: bool = True,
- linking_module: str = None,
-):
- """Activates the `ephys` and `probe` schemas.
-
- Args:
- ephys_schema_name (str): A string containing the name of the ephys schema.
- probe_schema_name (str): A string containing the name of the probe schema.
- create_schema (bool): If True, schema will be created in the database.
- create_tables (bool): If True, tables related to the schema will be created in the database.
- linking_module (str): A string containing the module name or module containing the required dependencies to activate the schema.
-
- Dependencies:
- Upstream tables:
- Session: A parent table to ProbeInsertion
- Probe: A parent table to EphysRecording. Probe information is required before electrophysiology data is imported.
-
- Functions:
- get_ephys_root_data_dir(): Returns absolute path for root data director(y/ies) with all electrophysiological recording sessions, as a list of string(s).
- get_session_direction(session_key: dict): Returns path to electrophysiology data for the a particular session as a list of strings.
- get_processed_data_dir(): Optional. Returns absolute path for processed data. Defaults to root directory.
- """
-
- if isinstance(linking_module, str):
- linking_module = importlib.import_module(linking_module)
- assert inspect.ismodule(
- linking_module
- ), "The argument 'dependency' must be a module's name or a module"
-
- global _linking_module
- _linking_module = linking_module
-
- probe.activate(
- probe_schema_name, create_schema=create_schema, create_tables=create_tables
- )
- schema.activate(
- ephys_schema_name,
- create_schema=create_schema,
- create_tables=create_tables,
- add_objects=_linking_module.__dict__,
- )
- ephys_report.activate(f"{ephys_schema_name}_report", ephys_schema_name)
-
-
-# -------------- Functions required by the elements-ephys ---------------
-
-
-def get_ephys_root_data_dir() -> list:
- """Fetches absolute data path to ephys data directories.
-
- The absolute path here is used as a reference for all downstream relative paths used in DataJoint.
-
- Returns:
- A list of the absolute path(s) to ephys data directories.
- """
- root_directories = _linking_module.get_ephys_root_data_dir()
- if isinstance(root_directories, (str, pathlib.Path)):
- root_directories = [root_directories]
-
- if hasattr(_linking_module, "get_processed_root_data_dir"):
- root_directories.append(_linking_module.get_processed_root_data_dir())
-
- return root_directories
-
-
-def get_session_directory(session_key: dict) -> str:
- """Retrieve the session directory with Neuropixels for the given session.
-
- Args:
- session_key (dict): A dictionary mapping subject to an entry in the subject table, and session_datetime corresponding to a session in the database.
-
- Returns:
- A string for the path to the session directory.
- """
- return _linking_module.get_session_directory(session_key)
-
-
-def get_processed_root_data_dir() -> str:
- """Retrieve the root directory for all processed data.
-
- Returns:
- A string for the full path to the root directory for processed data.
- """
-
- if hasattr(_linking_module, "get_processed_root_data_dir"):
- return _linking_module.get_processed_root_data_dir()
- else:
- return get_ephys_root_data_dir()[0]
-
-
-# ----------------------------- Table declarations ----------------------
-
-
-@schema
-class AcquisitionSoftware(dj.Lookup):
- """Name of software used for recording electrophysiological data.
-
- Attributes:
- acq_software ( varchar(24) ): Acquisition software, e.g,. SpikeGLX, OpenEphys
- """
-
- definition = """ # Software used for recording of neuropixels probes
- acq_software: varchar(24)
- """
- contents = zip(["SpikeGLX", "Open Ephys"])
-
-
-@schema
-class ProbeInsertion(dj.Manual):
- """Information about probe insertion across subjects and sessions.
-
- Attributes:
- Session (foreign key): Session primary key.
- insertion_number (foreign key, str): Unique insertion number for each probe insertion for a given session.
- probe.Probe (str): probe.Probe primary key.
- """
-
- definition = """
- # Probe insertion chronically implanted into an animal.
- -> Subject
- insertion_number: tinyint unsigned
- ---
- -> probe.Probe
- insertion_datetime=null: datetime
- """
-
-
-@schema
-class InsertionLocation(dj.Manual):
- """Stereotaxic location information for each probe insertion.
-
- Attributes:
- ProbeInsertion (foreign key): ProbeInsertion primary key.
- SkullReference (dict): SkullReference primary key.
- ap_location (decimal (6, 2) ): Anterior-posterior location in micrometers. Reference is 0 with anterior values positive.
- ml_location (decimal (6, 2) ): Medial-lateral location in micrometers. Reference is zero with right side values positive.
- depth (decimal (6, 2) ): Manipulator depth relative to the surface of the brain at zero. Ventral is negative.
- Theta (decimal (5, 2) ): elevation - rotation about the ml-axis in degrees relative to positive z-axis.
- phi (decimal (5, 2) ): azimuth - rotation about the dv-axis in degrees relative to the positive x-axis.
- """
-
- definition = """
- # Brain Location of a given probe insertion.
- -> ProbeInsertion
- ---
- -> SkullReference
- ap_location: decimal(6, 2) # (um) anterior-posterior; ref is 0; more anterior is more positive
- ml_location: decimal(6, 2) # (um) medial axis; ref is 0 ; more right is more positive
- depth: decimal(6, 2) # (um) manipulator depth relative to surface of the brain (0); more ventral is more negative
- theta=null: decimal(5, 2) # (deg) - elevation - rotation about the ml-axis [0, 180] - w.r.t the z+ axis
- phi=null: decimal(5, 2) # (deg) - azimuth - rotation about the dv-axis [0, 360] - w.r.t the x+ axis
- beta=null: decimal(5, 2) # (deg) rotation about the shank of the probe [-180, 180] - clockwise is increasing in degree - 0 is the probe-front facing anterior
- """
-
-
-@schema
-class EphysRecording(dj.Imported):
- """Automated table with electrophysiology recording information for each probe inserted during an experimental session.
-
- Attributes:
- ProbeInsertion (foreign key): ProbeInsertion primary key.
- probe.ElectrodeConfig (dict): probe.ElectrodeConfig primary key.
- AcquisitionSoftware (dict): AcquisitionSoftware primary key.
- sampling_rate (float): sampling rate of the recording in Hertz (Hz).
- recording_datetime (datetime): datetime of the recording from this probe.
- recording_duration (float): duration of the entire recording from this probe in seconds.
- """
-
- definition = """
- # Ephys recording from a probe insertion for a given session.
- -> Session
- -> ProbeInsertion
- ---
- -> probe.ElectrodeConfig
- -> AcquisitionSoftware
- sampling_rate: float # (Hz)
- recording_datetime: datetime # datetime of the recording from this probe
- recording_duration: float # (seconds) duration of the recording from this probe
- """
-
- class EphysFile(dj.Part):
- """Paths of electrophysiology recording files for each insertion.
-
- Attributes:
- EphysRecording (foreign key): EphysRecording primary key.
- file_path (varchar(255) ): relative file path for electrophysiology recording.
- """
-
- definition = """
- # Paths of files of a given EphysRecording round.
- -> master
- file_path: varchar(255) # filepath relative to root data directory
- """
-
- def make(self, key):
- """Populates table with electrophysiology recording information."""
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(key)
- )
-
- inserted_probe_serial_number = (ProbeInsertion * probe.Probe & key).fetch1(
- "probe"
- )
-
- # search session dir and determine acquisition software
- for ephys_pattern, ephys_acq_type in (
- ("*.ap.meta", "SpikeGLX"),
- ("*.oebin", "Open Ephys"),
- ):
- ephys_meta_filepaths = list(session_dir.rglob(ephys_pattern))
- if ephys_meta_filepaths:
- acq_software = ephys_acq_type
- break
- else:
- raise FileNotFoundError(
- f"Ephys recording data not found!"
- f" Neither SpikeGLX nor Open Ephys recording files found"
- f" in {session_dir}"
- )
-
- supported_probe_types = probe.ProbeType.fetch("probe_type")
-
- if acq_software == "SpikeGLX":
- for meta_filepath in ephys_meta_filepaths:
- spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath)
- if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number:
- break
- else:
- raise FileNotFoundError(
- f"No SpikeGLX data found for probe insertion: {key}"
- + " The probe serial number does not match."
- )
-
- if spikeglx_meta.probe_model in supported_probe_types:
- probe_type = spikeglx_meta.probe_model
- electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type}
-
- probe_electrodes = {
- (shank, shank_col, shank_row): key
- for key, shank, shank_col, shank_row in zip(
- *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row")
- )
- }
-
- electrode_group_members = [
- probe_electrodes[(shank, shank_col, shank_row)]
- for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"]
- ]
- else:
- raise NotImplementedError(
- "Processing for neuropixels probe model"
- " {} not yet implemented".format(spikeglx_meta.probe_model)
- )
-
- self.insert1(
- {
- **key,
- **generate_electrode_config(probe_type, electrode_group_members),
- "acq_software": acq_software,
- "sampling_rate": spikeglx_meta.meta["imSampRate"],
- "recording_datetime": spikeglx_meta.recording_time,
- "recording_duration": (
- spikeglx_meta.recording_duration
- or spikeglx.retrieve_recording_duration(meta_filepath)
- ),
- }
- )
-
- root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath)
- self.EphysFile.insert1(
- {**key, "file_path": meta_filepath.relative_to(root_dir).as_posix()}
- )
- elif acq_software == "Open Ephys":
- dataset = openephys.OpenEphys(session_dir)
- for serial_number, probe_data in dataset.probes.items():
- if str(serial_number) == inserted_probe_serial_number:
- break
- else:
- raise FileNotFoundError(
- "No Open Ephys data found for probe insertion: {}".format(key)
- )
-
- if not probe_data.ap_meta:
- raise IOError(
- 'No analog signals found - check "structure.oebin" file or "continuous" directory'
- )
-
- if probe_data.probe_model in supported_probe_types:
- probe_type = probe_data.probe_model
- electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type}
-
- probe_electrodes = {
- key["electrode"]: key for key in electrode_query.fetch("KEY")
- }
-
- electrode_group_members = [
- probe_electrodes[channel_idx]
- for channel_idx in probe_data.ap_meta["channels_indices"]
- ]
- else:
- raise NotImplementedError(
- "Processing for neuropixels"
- " probe model {} not yet implemented".format(probe_data.probe_model)
- )
-
- self.insert1(
- {
- **key,
- **generate_electrode_config(probe_type, electrode_group_members),
- "acq_software": acq_software,
- "sampling_rate": probe_data.ap_meta["sample_rate"],
- "recording_datetime": probe_data.recording_info[
- "recording_datetimes"
- ][0],
- "recording_duration": np.sum(
- probe_data.recording_info["recording_durations"]
- ),
- }
- )
-
- root_dir = find_root_directory(
- get_ephys_root_data_dir(),
- probe_data.recording_info["recording_files"][0],
- )
- self.EphysFile.insert(
- [
- {**key, "file_path": fp.relative_to(root_dir).as_posix()}
- for fp in probe_data.recording_info["recording_files"]
- ]
- )
- # explicitly garbage collect "dataset"
- # as these may have large memory footprint and may not be cleared fast enough
- del probe_data, dataset
- gc.collect()
- else:
- raise NotImplementedError(
- f"Processing ephys files from"
- f" acquisition software of type {acq_software} is"
- f" not yet implemented"
- )
-
-
-@schema
-class LFP(dj.Imported):
- """Extracts local field potentials (LFP) from an electrophysiology recording.
-
- Attributes:
- EphysRecording (foreign key): EphysRecording primary key.
- lfp_sampling_rate (float): Sampling rate for LFPs in Hz.
- lfp_time_stamps (longblob): Time stamps with respect to the start of the recording.
- lfp_mean (longblob): Overall mean LFP across electrodes.
- """
-
- definition = """
- # Acquired local field potential (LFP) from a given Ephys recording.
- -> EphysRecording
- ---
- lfp_sampling_rate: float # (Hz)
- lfp_time_stamps: longblob # (s) timestamps with respect to the start of the recording (recording_timestamp)
- lfp_mean: longblob # (uV) mean of LFP across electrodes - shape (time,)
- """
-
- class Electrode(dj.Part):
- """Saves local field potential data for each electrode.
-
- Attributes:
- LFP (foreign key): LFP primary key.
- probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key.
- lfp (longblob): LFP recording at this electrode in microvolts.
- """
-
- definition = """
- -> master
- -> probe.ElectrodeConfig.Electrode
- ---
- lfp: longblob # (uV) recorded lfp at this electrode
- """
-
- # Only store LFP for every 9th channel, due to high channel density,
- # close-by channels exhibit highly similar LFP
- _skip_channel_counts = 9
-
- def make(self, key):
- """Populates the LFP tables."""
- acq_software = (EphysRecording * ProbeInsertion & key).fetch1("acq_software")
-
- electrode_keys, lfp = [], []
-
- if acq_software == "SpikeGLX":
- spikeglx_meta_filepath = get_spikeglx_meta_filepath(key)
- spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent)
-
- lfp_channel_ind = spikeglx_recording.lfmeta.recording_channels[
- -1 :: -self._skip_channel_counts
- ]
-
- # Extract LFP data at specified channels and convert to uV
- lfp = spikeglx_recording.lf_timeseries[
- :, lfp_channel_ind
- ] # (sample x channel)
- lfp = (
- lfp * spikeglx_recording.get_channel_bit_volts("lf")[lfp_channel_ind]
- ).T # (channel x sample)
-
- self.insert1(
- dict(
- key,
- lfp_sampling_rate=spikeglx_recording.lfmeta.meta["imSampRate"],
- lfp_time_stamps=(
- np.arange(lfp.shape[1])
- / spikeglx_recording.lfmeta.meta["imSampRate"]
- ),
- lfp_mean=lfp.mean(axis=0),
- )
- )
-
- electrode_query = (
- probe.ProbeType.Electrode
- * probe.ElectrodeConfig.Electrode
- * EphysRecording
- & key
- )
- probe_electrodes = {
- (shank, shank_col, shank_row): key
- for key, shank, shank_col, shank_row in zip(
- *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row")
- )
- }
-
- for recorded_site in lfp_channel_ind:
- shank, shank_col, shank_row, _ = spikeglx_recording.apmeta.shankmap[
- "data"
- ][recorded_site]
- electrode_keys.append(probe_electrodes[(shank, shank_col, shank_row)])
- elif acq_software == "Open Ephys":
- oe_probe = get_openephys_probe_data(key)
-
- lfp_channel_ind = np.r_[
- len(oe_probe.lfp_meta["channels_indices"])
- - 1 : 0 : -self._skip_channel_counts
- ]
-
- # (sample x channel)
- lfp = oe_probe.lfp_timeseries[:, lfp_channel_ind]
- lfp = (
- lfp * np.array(oe_probe.lfp_meta["channels_gains"])[lfp_channel_ind]
- ).T # (channel x sample)
- lfp_timestamps = oe_probe.lfp_timestamps
-
- self.insert1(
- dict(
- key,
- lfp_sampling_rate=oe_probe.lfp_meta["sample_rate"],
- lfp_time_stamps=lfp_timestamps,
- lfp_mean=lfp.mean(axis=0),
- )
- )
-
- electrode_query = (
- probe.ProbeType.Electrode
- * probe.ElectrodeConfig.Electrode
- * EphysRecording
- & key
- )
- probe_electrodes = {
- key["electrode"]: key for key in electrode_query.fetch("KEY")
- }
-
- electrode_keys.extend(
- probe_electrodes[channel_idx] for channel_idx in lfp_channel_ind
- )
- else:
- raise NotImplementedError(
- f"LFP extraction from acquisition software"
- f" of type {acq_software} is not yet implemented"
- )
-
- # single insert in loop to mitigate potential memory issue
- for electrode_key, lfp_trace in zip(electrode_keys, lfp):
- self.Electrode.insert1({**key, **electrode_key, "lfp": lfp_trace})
-
-
-# ------------ Clustering --------------
-
-
-@schema
-class ClusteringMethod(dj.Lookup):
- """Kilosort clustering method.
-
- Attributes:
- clustering_method (foreign key, varchar(16) ): Kilosort clustering method.
- clustering_methods_desc (varchar(1000) ): Additional description of the clustering method.
- """
-
- definition = """
- # Method for clustering
- clustering_method: varchar(16)
- ---
- clustering_method_desc: varchar(1000)
- """
-
- contents = [
- ("kilosort2", "kilosort2 clustering method"),
- ("kilosort2.5", "kilosort2.5 clustering method"),
- ("kilosort3", "kilosort3 clustering method"),
- ]
-
-
-@schema
-class ClusteringParamSet(dj.Lookup):
- """Parameters to be used in clustering procedure for spike sorting.
-
- Attributes:
- paramset_idx (foreign key): Unique ID for the clustering parameter set.
- ClusteringMethod (dict): ClusteringMethod primary key.
- paramset_desc (varchar(128) ): Description of the clustering parameter set.
- param_set_hash (uuid): UUID hash for the parameter set.
- params (longblob): Parameters for clustering with Kilosort.
- """
-
- definition = """
- # Parameter set to be used in a clustering procedure
- paramset_idx: smallint
- ---
- -> ClusteringMethod
- paramset_desc: varchar(128)
- param_set_hash: uuid
- unique index (param_set_hash)
- params: longblob # dictionary of all applicable parameters
- """
-
- @classmethod
- def insert_new_params(
- cls,
- clustering_method: str,
- paramset_desc: str,
- params: dict,
- paramset_idx: int = None,
- ):
- """Inserts new parameters into the ClusteringParamSet table.
-
- Args:
- clustering_method (str): name of the clustering method.
- paramset_desc (str): description of the parameter set
- params (dict): clustering parameters
- paramset_idx (int, optional): Unique parameter set ID. Defaults to None.
- """
- if paramset_idx is None:
- paramset_idx = (
- dj.U().aggr(cls, n="max(paramset_idx)").fetch1("n") or 0
- ) + 1
-
- param_dict = {
- "clustering_method": clustering_method,
- "paramset_idx": paramset_idx,
- "paramset_desc": paramset_desc,
- "params": params,
- "param_set_hash": dict_to_uuid(
- {**params, "clustering_method": clustering_method}
- ),
- }
- param_query = cls & {"param_set_hash": param_dict["param_set_hash"]}
-
- if param_query: # If the specified param-set already exists
- existing_paramset_idx = param_query.fetch1("paramset_idx")
- if (
- existing_paramset_idx == paramset_idx
- ): # If the existing set has the same paramset_idx: job done
- return
- else: # If not same name: human error, trying to add the same paramset with different name
- raise dj.DataJointError(
- f"The specified param-set already exists"
- f" - with paramset_idx: {existing_paramset_idx}"
- )
- else:
- if {"paramset_idx": paramset_idx} in cls.proj():
- raise dj.DataJointError(
- f"The specified paramset_idx {paramset_idx} already exists,"
- f" please pick a different one."
- )
- cls.insert1(param_dict)
-
-
-@schema
-class ClusterQualityLabel(dj.Lookup):
- """Quality label for each spike sorted cluster.
-
- Attributes:
- cluster_quality_label (foreign key, varchar(100) ): Cluster quality type.
- cluster_quality_description (varchar(4000) ): Description of the cluster quality type.
- """
-
- definition = """
- # Quality
- cluster_quality_label: varchar(100) # cluster quality type - e.g. 'good', 'MUA', 'noise', etc.
- ---
- cluster_quality_description: varchar(4000)
- """
- contents = [
- ("good", "single unit"),
- ("ok", "probably a single unit, but could be contaminated"),
- ("mua", "multi-unit activity"),
- ("noise", "bad unit"),
- ]
-
-
-@schema
-class ClusteringTask(dj.Manual):
- """A clustering task to spike sort electrophysiology datasets.
-
- Attributes:
- EphysRecording (foreign key): EphysRecording primary key.
- ClusteringParamSet (foreign key): ClusteringParamSet primary key.
- clustering_outdir_dir (varchar (255) ): Relative path to output clustering results.
- task_mode (enum): `Trigger` computes clustering or and `load` imports existing data.
- """
-
- definition = """
- # Manual table for defining a clustering task ready to be run
- -> EphysRecording
- -> ClusteringParamSet
- ---
- clustering_output_dir='': varchar(255) # clustering output directory relative to the clustering root data directory
- task_mode='load': enum('load', 'trigger') # 'load': load computed analysis results, 'trigger': trigger computation
- """
-
- @classmethod
- def infer_output_dir(cls, key, relative=False, mkdir=False) -> pathlib.Path:
- """Infer output directory if it is not provided.
-
- Args:
- key (dict): ClusteringTask primary key.
-
- Returns:
- Expected clustering_output_dir based on the following convention:
- processed_dir / session_dir / probe_{insertion_number} / {clustering_method}_{paramset_idx}
- e.g.: sub4/sess1/probe_2/kilosort2_0
- """
- processed_dir = pathlib.Path(get_processed_root_data_dir())
- sess_dir = find_full_path(get_ephys_root_data_dir(), get_session_directory(key))
- root_dir = find_root_directory(get_ephys_root_data_dir(), sess_dir)
-
- method = (
- (ClusteringParamSet * ClusteringMethod & key)
- .fetch1("clustering_method")
- .replace(".", "-")
- )
-
- output_dir = (
- processed_dir
- / sess_dir.relative_to(root_dir)
- / f'probe_{key["insertion_number"]}'
- / f'{method}_{key["paramset_idx"]}'
- )
-
- if mkdir:
- output_dir.mkdir(parents=True, exist_ok=True)
- log.info(f"{output_dir} created!")
-
- return output_dir.relative_to(processed_dir) if relative else output_dir
-
- @classmethod
- def auto_generate_entries(cls, ephys_recording_key: dict, paramset_idx: int = 0):
- """Autogenerate entries based on a particular ephys recording.
-
- Args:
- ephys_recording_key (dict): EphysRecording primary key.
- paramset_idx (int, optional): Parameter index to use for clustering task. Defaults to 0.
- """
- key = {**ephys_recording_key, "paramset_idx": paramset_idx}
-
- processed_dir = get_processed_root_data_dir()
- output_dir = ClusteringTask.infer_output_dir(key, relative=False, mkdir=True)
-
- try:
- kilosort.Kilosort(
- output_dir
- ) # check if the directory is a valid Kilosort output
- except FileNotFoundError:
- task_mode = "trigger"
- else:
- task_mode = "load"
-
- cls.insert1(
- {
- **key,
- "clustering_output_dir": output_dir.relative_to(
- processed_dir
- ).as_posix(),
- "task_mode": task_mode,
- }
- )
-
-
-@schema
-class Clustering(dj.Imported):
- """A processing table to handle each clustering task.
-
- Attributes:
- ClusteringTask (foreign key): ClusteringTask primary key.
- clustering_time (datetime): Time when clustering results are generated.
- package_version (varchar(16) ): Package version used for a clustering analysis.
- """
-
- definition = """
- # Clustering Procedure
- -> ClusteringTask
- ---
- clustering_time: datetime # time of generation of this set of clustering results
- package_version='': varchar(16)
- """
-
- def make(self, key):
- """Triggers or imports clustering analysis."""
- task_mode, output_dir = (ClusteringTask & key).fetch1(
- "task_mode", "clustering_output_dir"
- )
-
- if not output_dir:
- output_dir = ClusteringTask.infer_output_dir(key, relative=True, mkdir=True)
- # update clustering_output_dir
- ClusteringTask.update1(
- {**key, "clustering_output_dir": output_dir.as_posix()}
- )
-
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- if task_mode == "load":
- kilosort.Kilosort(
- kilosort_dir
- ) # check if the directory is a valid Kilosort output
- elif task_mode == "trigger":
- acq_software, clustering_method, params = (
- ClusteringTask * EphysRecording * ClusteringParamSet & key
- ).fetch1("acq_software", "clustering_method", "params")
-
- if "kilosort" in clustering_method:
- from element_array_ephys.readers import kilosort_triggering
-
- # add additional probe-recording and channels details into `params`
- params = {**params, **get_recording_channels_details(key)}
- params["fs"] = params["sample_rate"]
-
- if acq_software == "SpikeGLX":
- spikeglx_meta_filepath = get_spikeglx_meta_filepath(key)
- spikeglx_recording = spikeglx.SpikeGLX(
- spikeglx_meta_filepath.parent
- )
- spikeglx_recording.validate_file("ap")
- run_CatGT = (
- params.pop("run_CatGT", True)
- and "_tcat." not in spikeglx_meta_filepath.stem
- )
-
- if clustering_method.startswith("pykilosort"):
- kilosort_triggering.run_pykilosort(
- continuous_file=spikeglx_recording.root_dir
- / (spikeglx_recording.root_name + ".ap.bin"),
- kilosort_output_directory=kilosort_dir,
- channel_ind=params.pop("channel_ind"),
- x_coords=params.pop("x_coords"),
- y_coords=params.pop("y_coords"),
- shank_ind=params.pop("shank_ind"),
- connected=params.pop("connected"),
- sample_rate=params.pop("sample_rate"),
- params=params,
- )
- else:
- run_kilosort = kilosort_triggering.SGLXKilosortPipeline(
- npx_input_dir=spikeglx_meta_filepath.parent,
- ks_output_dir=kilosort_dir,
- params=params,
- KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}',
- run_CatGT=run_CatGT,
- )
- run_kilosort.run_modules()
- elif acq_software == "Open Ephys":
- oe_probe = get_openephys_probe_data(key)
-
- assert len(oe_probe.recording_info["recording_files"]) == 1
-
- # run kilosort
- if clustering_method.startswith("pykilosort"):
- kilosort_triggering.run_pykilosort(
- continuous_file=pathlib.Path(
- oe_probe.recording_info["recording_files"][0]
- )
- / "continuous.dat",
- kilosort_output_directory=kilosort_dir,
- channel_ind=params.pop("channel_ind"),
- x_coords=params.pop("x_coords"),
- y_coords=params.pop("y_coords"),
- shank_ind=params.pop("shank_ind"),
- connected=params.pop("connected"),
- sample_rate=params.pop("sample_rate"),
- params=params,
- )
- else:
- run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline(
- npx_input_dir=oe_probe.recording_info["recording_files"][0],
- ks_output_dir=kilosort_dir,
- params=params,
- KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}',
- )
- run_kilosort.run_modules()
- else:
- raise NotImplementedError(
- f"Automatic triggering of {clustering_method}"
- f" clustering analysis is not yet supported"
- )
-
- else:
- raise ValueError(f"Unknown task mode: {task_mode}")
-
- creation_time, _, _ = kilosort.extract_clustering_info(kilosort_dir)
- self.insert1({**key, "clustering_time": creation_time, "package_version": ""})
-
-
-@schema
-class Curation(dj.Manual):
- """Curation procedure table.
-
- Attributes:
- Clustering (foreign key): Clustering primary key.
- curation_id (foreign key, int): Unique curation ID.
- curation_time (datetime): Time when curation results are generated.
- curation_output_dir (varchar(255) ): Output directory of the curated results.
- quality_control (bool): If True, this clustering result has undergone quality control.
- manual_curation (bool): If True, manual curation has been performed on this clustering result.
- curation_note (varchar(2000) ): Notes about the curation task.
- """
-
- definition = """
- # Manual curation procedure
- -> Clustering
- curation_id: int
- ---
- curation_time: datetime # time of generation of this set of curated clustering results
- curation_output_dir: varchar(255) # output directory of the curated results, relative to root data directory
- quality_control: bool # has this clustering result undergone quality control?
- manual_curation: bool # has manual curation been performed on this clustering result?
- curation_note='': varchar(2000)
- """
-
- def create1_from_clustering_task(self, key, curation_note: str = ""):
- """
- A function to create a new corresponding "Curation" for a particular
- "ClusteringTask"
- """
- if key not in Clustering():
- raise ValueError(
- f"No corresponding entry in Clustering available"
- f" for: {key}; do `Clustering.populate(key)`"
- )
-
- task_mode, output_dir = (ClusteringTask & key).fetch1(
- "task_mode", "clustering_output_dir"
- )
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- creation_time, is_curated, is_qc = kilosort.extract_clustering_info(
- kilosort_dir
- )
- # Synthesize curation_id
- curation_id = (
- dj.U().aggr(self & key, n="ifnull(max(curation_id)+1,1)").fetch1("n")
- )
- self.insert1(
- {
- **key,
- "curation_id": curation_id,
- "curation_time": creation_time,
- "curation_output_dir": output_dir,
- "quality_control": is_qc,
- "manual_curation": is_curated,
- "curation_note": curation_note,
- }
- )
-
-
-@schema
-class CuratedClustering(dj.Imported):
- """Clustering results after curation.
-
- Attributes:
- Curation (foreign key): Curation primary key.
- """
-
- definition = """
- # Clustering results of a curation.
- -> Curation
- """
-
- class Unit(dj.Part):
- """Single unit properties after clustering and curation.
-
- Attributes:
- CuratedClustering (foreign key): CuratedClustering primary key.
- unit (foreign key, int): Unique integer identifying a single unit.
- probe.ElectrodeConfig.Electrode (dict): probe.ElectrodeConfig.Electrode primary key.
- ClusteringQualityLabel (dict): CLusteringQualityLabel primary key.
- spike_count (int): Number of spikes in this recording for this unit.
- spike_times (longblob): Spike times of this unit, relative to start time of EphysRecording.
- spike_sites (longblob): Array of electrode associated with each spike.
- spike_depths (longblob): Array of depths associated with each spike, relative to each spike.
- """
-
- definition = """
- # Properties of a given unit from a round of clustering (and curation)
- -> master
- unit: int
- ---
- -> probe.ElectrodeConfig.Electrode # electrode with highest waveform amplitude for this unit
- -> ClusterQualityLabel
- spike_count: int # how many spikes in this recording for this unit
- spike_times: longblob # (s) spike times of this unit, relative to the start of the EphysRecording
- spike_sites : longblob # array of electrode associated with each spike
- spike_depths=null : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe
- """
-
- def make(self, key):
- """Automated population of Unit information."""
- output_dir = (Curation & key).fetch1("curation_output_dir")
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- kilosort_dataset = kilosort.Kilosort(kilosort_dir)
- acq_software, sample_rate = (EphysRecording & key).fetch1(
- "acq_software", "sampling_rate"
- )
-
- sample_rate = kilosort_dataset.data["params"].get("sample_rate", sample_rate)
-
- # ---------- Unit ----------
- # -- Remove 0-spike units
- withspike_idx = [
- i
- for i, u in enumerate(kilosort_dataset.data["cluster_ids"])
- if (kilosort_dataset.data["spike_clusters"] == u).any()
- ]
- valid_units = kilosort_dataset.data["cluster_ids"][withspike_idx]
- valid_unit_labels = kilosort_dataset.data["cluster_groups"][withspike_idx]
- # -- Get channel and electrode-site mapping
- channel2electrodes = get_neuropixels_channel2electrode_map(key, acq_software)
-
- # -- Spike-times --
- # spike_times_sec_adj > spike_times_sec > spike_times
- spike_time_key = (
- "spike_times_sec_adj"
- if "spike_times_sec_adj" in kilosort_dataset.data
- else (
- "spike_times_sec"
- if "spike_times_sec" in kilosort_dataset.data
- else "spike_times"
- )
- )
- spike_times = kilosort_dataset.data[spike_time_key]
- kilosort_dataset.extract_spike_depths()
-
- # -- Spike-sites and Spike-depths --
- spike_sites = np.array(
- [
- channel2electrodes[s]["electrode"]
- for s in kilosort_dataset.data["spike_sites"]
- ]
- )
- spike_depths = kilosort_dataset.data["spike_depths"]
-
- # -- Insert unit, label, peak-chn
- units = []
- for unit, unit_lbl in zip(valid_units, valid_unit_labels):
- if (kilosort_dataset.data["spike_clusters"] == unit).any():
- unit_channel, _ = kilosort_dataset.get_best_channel(unit)
- unit_spike_times = (
- spike_times[kilosort_dataset.data["spike_clusters"] == unit]
- / sample_rate
- )
- spike_count = len(unit_spike_times)
-
- units.append(
- {
- "unit": unit,
- "cluster_quality_label": unit_lbl,
- **channel2electrodes[unit_channel],
- "spike_times": unit_spike_times,
- "spike_count": spike_count,
- "spike_sites": spike_sites[
- kilosort_dataset.data["spike_clusters"] == unit
- ],
- "spike_depths": (
- spike_depths[
- kilosort_dataset.data["spike_clusters"] == unit
- ]
- if spike_depths is not None
- else None
- ),
- }
- )
-
- self.insert1(key)
- self.Unit.insert([{**key, **u} for u in units])
-
-
-@schema
-class WaveformSet(dj.Imported):
- """A set of spike waveforms for units out of a given CuratedClustering.
-
- Attributes:
- CuratedClustering (foreign key): CuratedClustering primary key.
- """
-
- definition = """
- # A set of spike waveforms for units out of a given CuratedClustering
- -> CuratedClustering
- """
-
- class PeakWaveform(dj.Part):
- """Mean waveform across spikes for a given unit.
-
- Attributes:
- WaveformSet (foreign key): WaveformSet primary key.
- CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key.
- peak_electrode_waveform (longblob): Mean waveform for a given unit at its representative electrode.
- """
-
- definition = """
- # Mean waveform across spikes for a given unit at its representative electrode
- -> master
- -> CuratedClustering.Unit
- ---
- peak_electrode_waveform: longblob # (uV) mean waveform for a given unit at its representative electrode
- """
-
- class Waveform(dj.Part):
- """Spike waveforms for a given unit.
-
- Attributes:
- WaveformSet (foreign key): WaveformSet primary key.
- CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key.
- probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key.
- waveform_mean (longblob): mean waveform across spikes of the unit in microvolts.
- waveforms (longblob): waveforms of a sampling of spikes at the given electrode and unit.
- """
-
- definition = """
- # Spike waveforms and their mean across spikes for the given unit
- -> master
- -> CuratedClustering.Unit
- -> probe.ElectrodeConfig.Electrode
- ---
- waveform_mean: longblob # (uV) mean waveform across spikes of the given unit
- waveforms=null: longblob # (uV) (spike x sample) waveforms of a sampling of spikes at the given electrode for the given unit
- """
-
- def make(self, key):
- """Populates waveform tables."""
- output_dir = (Curation & key).fetch1("curation_output_dir")
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- kilosort_dataset = kilosort.Kilosort(kilosort_dir)
-
- acq_software, probe_serial_number = (
- EphysRecording * ProbeInsertion & key
- ).fetch1("acq_software", "probe")
-
- # -- Get channel and electrode-site mapping
- recording_key = (EphysRecording & key).fetch1("KEY")
- channel2electrodes = get_neuropixels_channel2electrode_map(
- recording_key, acq_software
- )
-
- is_qc = (Curation & key).fetch1("quality_control")
-
- # Get all units
- units = {
- u["unit"]: u
- for u in (CuratedClustering.Unit & key).fetch(as_dict=True, order_by="unit")
- }
-
- if is_qc:
- unit_waveforms = np.load(
- kilosort_dir / "mean_waveforms.npy"
- ) # unit x channel x sample
-
- def yield_unit_waveforms():
- for unit_no, unit_waveform in zip(
- kilosort_dataset.data["cluster_ids"], unit_waveforms
- ):
- unit_peak_waveform = {}
- unit_electrode_waveforms = []
- if unit_no in units:
- for channel, channel_waveform in zip(
- kilosort_dataset.data["channel_map"], unit_waveform
- ):
- unit_electrode_waveforms.append(
- {
- **units[unit_no],
- **channel2electrodes[channel],
- "waveform_mean": channel_waveform,
- }
- )
- if (
- channel2electrodes[channel]["electrode"]
- == units[unit_no]["electrode"]
- ):
- unit_peak_waveform = {
- **units[unit_no],
- "peak_electrode_waveform": channel_waveform,
- }
- yield unit_peak_waveform, unit_electrode_waveforms
-
- else:
- if acq_software == "SpikeGLX":
- spikeglx_meta_filepath = get_spikeglx_meta_filepath(key)
- neuropixels_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent)
- elif acq_software == "Open Ephys":
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(key)
- )
- openephys_dataset = openephys.OpenEphys(session_dir)
- neuropixels_recording = openephys_dataset.probes[probe_serial_number]
-
- def yield_unit_waveforms():
- for unit_dict in units.values():
- unit_peak_waveform = {}
- unit_electrode_waveforms = []
-
- spikes = unit_dict["spike_times"]
- waveforms = neuropixels_recording.extract_spike_waveforms(
- spikes, kilosort_dataset.data["channel_map"]
- ) # (sample x channel x spike)
- waveforms = waveforms.transpose(
- (1, 2, 0)
- ) # (channel x spike x sample)
- for channel, channel_waveform in zip(
- kilosort_dataset.data["channel_map"], waveforms
- ):
- unit_electrode_waveforms.append(
- {
- **unit_dict,
- **channel2electrodes[channel],
- "waveform_mean": channel_waveform.mean(axis=0),
- "waveforms": channel_waveform,
- }
- )
- if (
- channel2electrodes[channel]["electrode"]
- == unit_dict["electrode"]
- ):
- unit_peak_waveform = {
- **unit_dict,
- "peak_electrode_waveform": channel_waveform.mean(
- axis=0
- ),
- }
-
- yield unit_peak_waveform, unit_electrode_waveforms
-
- # insert waveform on a per-unit basis to mitigate potential memory issue
- self.insert1(key)
- for unit_peak_waveform, unit_electrode_waveforms in yield_unit_waveforms():
- if unit_peak_waveform:
- self.PeakWaveform.insert1(unit_peak_waveform, ignore_extra_fields=True)
- if unit_electrode_waveforms:
- self.Waveform.insert(unit_electrode_waveforms, ignore_extra_fields=True)
-
-
-@schema
-class QualityMetrics(dj.Imported):
- """Clustering and waveform quality metrics.
-
- Attributes:
- CuratedClustering (foreign key): CuratedClustering primary key.
- """
-
- definition = """
- # Clusters and waveforms metrics
- -> CuratedClustering
- """
-
- class Cluster(dj.Part):
- """Cluster metrics for a unit.
-
- Attributes:
- QualityMetrics (foreign key): QualityMetrics primary key.
- CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key.
- firing_rate (float): Firing rate of the unit.
- snr (float): Signal-to-noise ratio for a unit.
- presence_ratio (float): Fraction of time where spikes are present.
- isi_violation (float): rate of ISI violation as a fraction of overall rate.
- number_violation (int): Total ISI violations.
- amplitude_cutoff (float): Estimate of miss rate based on amplitude histogram.
- isolation_distance (float): Distance to nearest cluster.
- l_ratio (float): Amount of empty space between a cluster and other spikes in dataset.
- d_prime (float): Classification accuracy based on LDA.
- nn_hit_rate (float): Fraction of neighbors for target cluster that are also in target cluster.
- nn_miss_rate (float): Fraction of neighbors outside target cluster that are in the target cluster.
- silhouette_core (float): Maximum change in spike depth throughout recording.
- cumulative_drift (float): Cumulative change in spike depth throughout recording.
- contamination_rate (float): Frequency of spikes in the refractory period.
- """
-
- definition = """
- # Cluster metrics for a particular unit
- -> master
- -> CuratedClustering.Unit
- ---
- firing_rate=null: float # (Hz) firing rate for a unit
- snr=null: float # signal-to-noise ratio for a unit
- presence_ratio=null: float # fraction of time in which spikes are present
- isi_violation=null: float # rate of ISI violation as a fraction of overall rate
- number_violation=null: int # total number of ISI violations
- amplitude_cutoff=null: float # estimate of miss rate based on amplitude histogram
- isolation_distance=null: float # distance to nearest cluster in Mahalanobis space
- l_ratio=null: float #
- d_prime=null: float # Classification accuracy based on LDA
- nn_hit_rate=null: float # Fraction of neighbors for target cluster that are also in target cluster
- nn_miss_rate=null: float # Fraction of neighbors outside target cluster that are in target cluster
- silhouette_score=null: float # Standard metric for cluster overlap
- max_drift=null: float # Maximum change in spike depth throughout recording
- cumulative_drift=null: float # Cumulative change in spike depth throughout recording
- contamination_rate=null: float #
- """
-
- class Waveform(dj.Part):
- """Waveform metrics for a particular unit.
-
- Attributes:
- QualityMetrics (foreign key): QualityMetrics primary key.
- CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key.
- amplitude (float): Absolute difference between waveform peak and trough in microvolts.
- duration (float): Time between waveform peak and trough in milliseconds.
- halfwidth (float): Spike width at half max amplitude.
- pt_ratio (float): Absolute amplitude of peak divided by absolute amplitude of trough relative to 0.
- repolarization_slope (float): Slope of the regression line fit to first 30 microseconds from trough to peak.
- recovery_slope (float): Slope of the regression line fit to first 30 microseconds from peak to tail.
- spread (float): The range with amplitude over 12-percent of maximum amplitude along the probe.
- velocity_above (float): inverse velocity of waveform propagation from soma to the top of the probe.
- velocity_below (float): inverse velocity of waveform propagation from soma toward the bottom of the probe.
- """
-
- definition = """
- # Waveform metrics for a particular unit
- -> master
- -> CuratedClustering.Unit
- ---
- amplitude: float # (uV) absolute difference between waveform peak and trough
- duration: float # (ms) time between waveform peak and trough
- halfwidth=null: float # (ms) spike width at half max amplitude
- pt_ratio=null: float # absolute amplitude of peak divided by absolute amplitude of trough relative to 0
- repolarization_slope=null: float # the repolarization slope was defined by fitting a regression line to the first 30us from trough to peak
- recovery_slope=null: float # the recovery slope was defined by fitting a regression line to the first 30us from peak to tail
- spread=null: float # (um) the range with amplitude above 12-percent of the maximum amplitude along the probe
- velocity_above=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the top of the probe
- velocity_below=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the bottom of the probe
- """
-
- def make(self, key):
- """Populates tables with quality metrics data."""
- output_dir = (ClusteringTask & key).fetch1("clustering_output_dir")
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- metric_fp = kilosort_dir / "metrics.csv"
- rename_dict = {
- "isi_viol": "isi_violation",
- "num_viol": "number_violation",
- "contam_rate": "contamination_rate",
- }
-
- if not metric_fp.exists():
- raise FileNotFoundError(f"QC metrics file not found: {metric_fp}")
-
- metrics_df = pd.read_csv(metric_fp)
- metrics_df.set_index("cluster_id", inplace=True)
- metrics_df.replace([np.inf, -np.inf], np.nan, inplace=True)
- metrics_df.columns = metrics_df.columns.str.lower()
- metrics_df.rename(columns=rename_dict, inplace=True)
- metrics_list = [
- dict(metrics_df.loc[unit_key["unit"]], **unit_key)
- for unit_key in (CuratedClustering.Unit & key).fetch("KEY")
- ]
-
- self.insert1(key)
- self.Cluster.insert(metrics_list, ignore_extra_fields=True)
- self.Waveform.insert(metrics_list, ignore_extra_fields=True)
-
-
-# ---------------- HELPER FUNCTIONS ----------------
-
-
-def get_spikeglx_meta_filepath(ephys_recording_key: dict) -> str:
- """Get spikeGLX data filepath."""
- # attempt to retrieve from EphysRecording.EphysFile
- spikeglx_meta_filepath = pathlib.Path(
- (
- EphysRecording.EphysFile
- & ephys_recording_key
- & 'file_path LIKE "%.ap.meta"'
- ).fetch1("file_path")
- )
-
- try:
- spikeglx_meta_filepath = find_full_path(
- get_ephys_root_data_dir(), spikeglx_meta_filepath
- )
- except FileNotFoundError:
- # if not found, search in session_dir again
- if not spikeglx_meta_filepath.exists():
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(ephys_recording_key)
- )
- inserted_probe_serial_number = (
- ProbeInsertion * probe.Probe & ephys_recording_key
- ).fetch1("probe")
-
- spikeglx_meta_filepaths = [fp for fp in session_dir.rglob("*.ap.meta")]
- for meta_filepath in spikeglx_meta_filepaths:
- spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath)
- if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number:
- spikeglx_meta_filepath = meta_filepath
- break
- else:
- raise FileNotFoundError(
- "No SpikeGLX data found for probe insertion: {}".format(
- ephys_recording_key
- )
- )
-
- return spikeglx_meta_filepath
-
-
-def get_openephys_probe_data(ephys_recording_key: dict) -> list:
- """Get OpenEphys probe data from file."""
- inserted_probe_serial_number = (
- ProbeInsertion * probe.Probe & ephys_recording_key
- ).fetch1("probe")
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(ephys_recording_key)
- )
- loaded_oe = openephys.OpenEphys(session_dir)
- probe_data = loaded_oe.probes[inserted_probe_serial_number]
-
- # explicitly garbage collect "loaded_oe"
- # as these may have large memory footprint and may not be cleared fast enough
- del loaded_oe
- gc.collect()
-
- return probe_data
-
-
-def get_neuropixels_channel2electrode_map(
- ephys_recording_key: dict, acq_software: str
-) -> dict:
- """Get the channel map for neuropixels probe."""
- if acq_software == "SpikeGLX":
- spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key)
- spikeglx_meta = spikeglx.SpikeGLXMeta(spikeglx_meta_filepath)
- electrode_config_key = (
- EphysRecording * probe.ElectrodeConfig & ephys_recording_key
- ).fetch1("KEY")
-
- electrode_query = (
- probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode
- & electrode_config_key
- )
-
- probe_electrodes = {
- (shank, shank_col, shank_row): key
- for key, shank, shank_col, shank_row in zip(
- *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row")
- )
- }
-
- channel2electrode_map = {
- recorded_site: probe_electrodes[(shank, shank_col, shank_row)]
- for recorded_site, (shank, shank_col, shank_row, _) in enumerate(
- spikeglx_meta.shankmap["data"]
- )
- }
- elif acq_software == "Open Ephys":
- probe_dataset = get_openephys_probe_data(ephys_recording_key)
-
- electrode_query = (
- probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode * EphysRecording
- & ephys_recording_key
- )
-
- probe_electrodes = {
- key["electrode"]: key for key in electrode_query.fetch("KEY")
- }
-
- channel2electrode_map = {
- channel_idx: probe_electrodes[channel_idx]
- for channel_idx in probe_dataset.ap_meta["channels_indices"]
- }
-
- return channel2electrode_map
-
-
-def generate_electrode_config(probe_type: str, electrode_keys: list) -> dict:
- """Generate and insert new ElectrodeConfig
-
- Args:
- probe_type (str): probe type (e.g. neuropixels 2.0 - SS)
- electrode_keys (list): list of keys of the probe.ProbeType.Electrode table
-
- Returns:
- dict: representing a key of the probe.ElectrodeConfig table
- """
- # compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode)
- electrode_config_hash = dict_to_uuid({k["electrode"]: k for k in electrode_keys})
-
- electrode_list = sorted([k["electrode"] for k in electrode_keys])
- electrode_gaps = (
- [-1]
- + np.where(np.diff(electrode_list) > 1)[0].tolist()
- + [len(electrode_list) - 1]
- )
- electrode_config_name = "; ".join(
- [
- f"{electrode_list[start + 1]}-{electrode_list[end]}"
- for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:])
- ]
- )
-
- electrode_config_key = {"electrode_config_hash": electrode_config_hash}
-
- # ---- make new ElectrodeConfig if needed ----
- if not probe.ElectrodeConfig & electrode_config_key:
- probe.ElectrodeConfig.insert1(
- {
- **electrode_config_key,
- "probe_type": probe_type,
- "electrode_config_name": electrode_config_name,
- }
- )
- probe.ElectrodeConfig.Electrode.insert(
- {**electrode_config_key, **electrode} for electrode in electrode_keys
- )
-
- return electrode_config_key
-
-
-def get_recording_channels_details(ephys_recording_key: dict) -> np.array:
- """Get details of recording channels for a given recording."""
- channels_details = {}
-
- acq_software, sample_rate = (EphysRecording & ephys_recording_key).fetch1(
- "acq_software", "sampling_rate"
- )
-
- probe_type = (ProbeInsertion * probe.Probe & ephys_recording_key).fetch1(
- "probe_type"
- )
- channels_details["probe_type"] = {
- "neuropixels 1.0 - 3A": "3A",
- "neuropixels 1.0 - 3B": "NP1",
- "neuropixels UHD": "NP1100",
- "neuropixels 2.0 - SS": "NP21",
- "neuropixels 2.0 - MS": "NP24",
- }[probe_type]
-
- electrode_config_key = (
- probe.ElectrodeConfig * EphysRecording & ephys_recording_key
- ).fetch1("KEY")
- (
- channels_details["channel_ind"],
- channels_details["x_coords"],
- channels_details["y_coords"],
- channels_details["shank_ind"],
- ) = (
- probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode
- & electrode_config_key
- ).fetch(
- "electrode", "x_coord", "y_coord", "shank"
- )
- channels_details["sample_rate"] = sample_rate
- channels_details["num_channels"] = len(channels_details["channel_ind"])
-
- if acq_software == "SpikeGLX":
- spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key)
- spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent)
- channels_details["uVPerBit"] = spikeglx_recording.get_channel_bit_volts("ap")[0]
- channels_details["connected"] = np.array(
- [v for *_, v in spikeglx_recording.apmeta.shankmap["data"]]
- )
- elif acq_software == "Open Ephys":
- oe_probe = get_openephys_probe_data(ephys_recording_key)
- channels_details["uVPerBit"] = oe_probe.ap_meta["channels_gains"][0]
- channels_details["connected"] = np.array(
- [
- int(v == 1)
- for c, v in oe_probe.channels_connected.items()
- if c in channels_details["channel_ind"]
- ]
- )
-
- return channels_details
diff --git a/element_array_ephys/ephys_precluster.py b/element_array_ephys/ephys_precluster.py
deleted file mode 100644
index 4d52c610..00000000
--- a/element_array_ephys/ephys_precluster.py
+++ /dev/null
@@ -1,1435 +0,0 @@
-import importlib
-import inspect
-import re
-
-import datajoint as dj
-import numpy as np
-import pandas as pd
-from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory
-
-from . import ephys_report, probe
-from .readers import kilosort, openephys, spikeglx
-
-schema = dj.schema()
-
-_linking_module = None
-
-
-def activate(
- ephys_schema_name: str,
- probe_schema_name: str = None,
- *,
- create_schema: bool = True,
- create_tables: bool = True,
- linking_module: str = None,
-):
- """Activates the `ephys` and `probe` schemas.
-
- Args:
- ephys_schema_name (str): A string containing the name of the ephys schema.
- probe_schema_name (str): A string containing the name of the probe schema.
- create_schema (bool): If True, schema will be created in the database.
- create_tables (bool): If True, tables related to the schema will be created in the database.
- linking_module (str): A string containing the module name or module containing the required dependencies to activate the schema.
-
- Dependencies:
- Upstream tables:
- Session: A parent table to ProbeInsertion
- Probe: A parent table to EphysRecording. Probe information is required before electrophysiology data is imported.
-
- Functions:
- get_ephys_root_data_dir(): Returns absolute path for root data director(y/ies) with all electrophysiological recording sessions, as a list of string(s).
- get_session_direction(session_key: dict): Returns path to electrophysiology data for the a particular session as a list of strings.
- """
-
- if isinstance(linking_module, str):
- linking_module = importlib.import_module(linking_module)
- assert inspect.ismodule(
- linking_module
- ), "The argument 'dependency' must be a module's name or a module"
-
- global _linking_module
- _linking_module = linking_module
-
- probe.activate(
- probe_schema_name, create_schema=create_schema, create_tables=create_tables
- )
- schema.activate(
- ephys_schema_name,
- create_schema=create_schema,
- create_tables=create_tables,
- add_objects=_linking_module.__dict__,
- )
- ephys_report.activate(f"{ephys_schema_name}_report", ephys_schema_name)
-
-
-# -------------- Functions required by the elements-ephys ---------------
-
-
-def get_ephys_root_data_dir() -> list:
- """Fetches absolute data path to ephys data directories.
-
- The absolute path here is used as a reference for all downstream relative paths used in DataJoint.
-
- Returns:
- A list of the absolute path(s) to ephys data directories.
- """
- return _linking_module.get_ephys_root_data_dir()
-
-
-def get_session_directory(session_key: dict) -> str:
- """Retrieve the session directory with Neuropixels for the given session.
-
- Args:
- session_key (dict): A dictionary mapping subject to an entry in the subject table, and session_datetime corresponding to a session in the database.
-
- Returns:
- A string for the path to the session directory.
- """
- return _linking_module.get_session_directory(session_key)
-
-
-# ----------------------------- Table declarations ----------------------
-
-
-@schema
-class AcquisitionSoftware(dj.Lookup):
- """Name of software used for recording electrophysiological data.
-
- Attributes:
- acq_software ( varchar(24) ): Acquisition software, e.g,. SpikeGLX, OpenEphys
- """
-
- definition = """ # Name of software used for recording of neuropixels probes - SpikeGLX or Open Ephys
- acq_software: varchar(24)
- """
- contents = zip(["SpikeGLX", "Open Ephys"])
-
-
-@schema
-class ProbeInsertion(dj.Manual):
- """Information about probe insertion across subjects and sessions.
-
- Attributes:
- Session (foreign key): Session primary key.
- insertion_number (foreign key, str): Unique insertion number for each probe insertion for a given session.
- probe.Probe (str): probe.Probe primary key.
- """
-
- definition = """
- # Probe insertion implanted into an animal for a given session.
- -> Session
- insertion_number: tinyint unsigned
- ---
- -> probe.Probe
- """
-
-
-@schema
-class InsertionLocation(dj.Manual):
- """Stereotaxic location information for each probe insertion.
-
- Attributes:
- ProbeInsertion (foreign key): ProbeInsertion primary key.
- SkullReference (dict): SkullReference primary key.
- ap_location (decimal (6, 2) ): Anterior-posterior location in micrometers. Reference is 0 with anterior values positive.
- ml_location (decimal (6, 2) ): Medial-lateral location in micrometers. Reference is zero with right side values positive.
- depth (decimal (6, 2) ): Manipulator depth relative to the surface of the brain at zero. Ventral is negative.
- Theta (decimal (5, 2) ): elevation - rotation about the ml-axis in degrees relative to positive z-axis.
- phi (decimal (5, 2) ): azimuth - rotation about the dv-axis in degrees relative to the positive x-axis
-
- """
-
- definition = """
- # Brain Location of a given probe insertion.
- -> ProbeInsertion
- ---
- -> SkullReference
- ap_location: decimal(6, 2) # (um) anterior-posterior; ref is 0; more anterior is more positive
- ml_location: decimal(6, 2) # (um) medial axis; ref is 0 ; more right is more positive
- depth: decimal(6, 2) # (um) manipulator depth relative to surface of the brain (0); more ventral is more negative
- theta=null: decimal(5, 2) # (deg) - elevation - rotation about the ml-axis [0, 180] - w.r.t the z+ axis
- phi=null: decimal(5, 2) # (deg) - azimuth - rotation about the dv-axis [0, 360] - w.r.t the x+ axis
- beta=null: decimal(5, 2) # (deg) rotation about the shank of the probe [-180, 180] - clockwise is increasing in degree - 0 is the probe-front facing anterior
- """
-
-
-@schema
-class EphysRecording(dj.Imported):
- """Automated table with electrophysiology recording information for each probe inserted during an experimental session.
-
- Attributes:
- ProbeInsertion (foreign key): ProbeInsertion primary key.
- probe.ElectrodeConfig (dict): probe.ElectrodeConfig primary key.
- AcquisitionSoftware (dict): AcquisitionSoftware primary key.
- sampling_rate (float): sampling rate of the recording in Hertz (Hz).
- recording_datetime (datetime): datetime of the recording from this probe.
- recording_duration (float): duration of the entire recording from this probe in seconds.
- """
-
- definition = """
- # Ephys recording from a probe insertion for a given session.
- -> ProbeInsertion
- ---
- -> probe.ElectrodeConfig
- -> AcquisitionSoftware
- sampling_rate: float # (Hz)
- recording_datetime: datetime # datetime of the recording from this probe
- recording_duration: float # (seconds) duration of the recording from this probe
- """
-
- class EphysFile(dj.Part):
- """Paths of electrophysiology recording files for each insertion.
-
- Attributes:
- EphysRecording (foreign key): EphysRecording primary key.
- file_path (varchar(255) ): relative file path for electrophysiology recording.
- """
-
- definition = """
- # Paths of files of a given EphysRecording round.
- -> master
- file_path: varchar(255) # filepath relative to root data directory
- """
-
- def make(self, key):
- """Populates table with electrophysiology recording information."""
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(key)
- )
-
- inserted_probe_serial_number = (ProbeInsertion * probe.Probe & key).fetch1(
- "probe"
- )
-
- # search session dir and determine acquisition software
- for ephys_pattern, ephys_acq_type in (
- ("*.ap.meta", "SpikeGLX"),
- ("*.oebin", "Open Ephys"),
- ):
- ephys_meta_filepaths = [fp for fp in session_dir.rglob(ephys_pattern)]
- if ephys_meta_filepaths:
- acq_software = ephys_acq_type
- break
- else:
- raise FileNotFoundError(
- f"Ephys recording data not found!"
- f" Neither SpikeGLX nor Open Ephys recording files found"
- f" in {session_dir}"
- )
-
- if acq_software == "SpikeGLX":
- for meta_filepath in ephys_meta_filepaths:
- spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath)
- if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number:
- break
- else:
- raise FileNotFoundError(
- "No SpikeGLX data found for probe insertion: {}".format(key)
- )
-
- if re.search("(1.0|2.0)", spikeglx_meta.probe_model):
- probe_type = spikeglx_meta.probe_model
- electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type}
-
- probe_electrodes = {
- (shank, shank_col, shank_row): key
- for key, shank, shank_col, shank_row in zip(
- *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row")
- )
- }
-
- electrode_group_members = [
- probe_electrodes[(shank, shank_col, shank_row)]
- for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"]
- ]
- else:
- raise NotImplementedError(
- "Processing for neuropixels probe model"
- " {} not yet implemented".format(spikeglx_meta.probe_model)
- )
-
- self.insert1(
- {
- **key,
- **generate_electrode_config(probe_type, electrode_group_members),
- "acq_software": acq_software,
- "sampling_rate": spikeglx_meta.meta["imSampRate"],
- "recording_datetime": spikeglx_meta.recording_time,
- "recording_duration": (
- spikeglx_meta.recording_duration
- or spikeglx.retrieve_recording_duration(meta_filepath)
- ),
- }
- )
-
- root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath)
- self.EphysFile.insert1(
- {**key, "file_path": meta_filepath.relative_to(root_dir).as_posix()}
- )
- elif acq_software == "Open Ephys":
- dataset = openephys.OpenEphys(session_dir)
- for serial_number, probe_data in dataset.probes.items():
- if str(serial_number) == inserted_probe_serial_number:
- break
- else:
- raise FileNotFoundError(
- "No Open Ephys data found for probe insertion: {}".format(key)
- )
-
- if re.search("(1.0|2.0)", probe_data.probe_model):
- probe_type = probe_data.probe_model
- electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type}
-
- probe_electrodes = {
- key["electrode"]: key for key in electrode_query.fetch("KEY")
- }
-
- electrode_group_members = [
- probe_electrodes[channel_idx]
- for channel_idx in probe_data.ap_meta["channels_ids"]
- ]
- else:
- raise NotImplementedError(
- "Processing for neuropixels"
- " probe model {} not yet implemented".format(probe_data.probe_model)
- )
-
- self.insert1(
- {
- **key,
- **generate_electrode_config(probe_type, electrode_group_members),
- "acq_software": acq_software,
- "sampling_rate": probe_data.ap_meta["sample_rate"],
- "recording_datetime": probe_data.recording_info[
- "recording_datetimes"
- ][0],
- "recording_duration": np.sum(
- probe_data.recording_info["recording_durations"]
- ),
- }
- )
-
- root_dir = find_root_directory(
- get_ephys_root_data_dir(),
- probe_data.recording_info["recording_files"][0],
- )
- self.EphysFile.insert(
- [
- {**key, "file_path": fp.relative_to(root_dir).as_posix()}
- for fp in probe_data.recording_info["recording_files"]
- ]
- )
- else:
- raise NotImplementedError(
- f"Processing ephys files from"
- f" acquisition software of type {acq_software} is"
- f" not yet implemented"
- )
-
-
-@schema
-class PreClusterMethod(dj.Lookup):
- """Pre-clustering method
-
- Attributes:
- precluster_method (foreign key, varchar(16) ): Pre-clustering method for the dataset.
- precluster_method_desc(varchar(1000) ): Pre-clustering method description.
- """
-
- definition = """
- # Method for pre-clustering
- precluster_method: varchar(16)
- ---
- precluster_method_desc: varchar(1000)
- """
-
- contents = [("catgt", "Time shift, Common average referencing, Zeroing")]
-
-
-@schema
-class PreClusterParamSet(dj.Lookup):
- """Parameters for the pre-clustering method.
-
- Attributes:
- paramset_idx (foreign key): Unique parameter set ID.
- PreClusterMethod (dict): PreClusterMethod query for this dataset.
- paramset_desc (varchar(128) ): Description for the pre-clustering parameter set.
- param_set_hash (uuid): Unique hash for parameter set.
- params (longblob): All parameters for the pre-clustering method.
- """
-
- definition = """
- # Parameter set to be used in a clustering procedure
- paramset_idx: smallint
- ---
- -> PreClusterMethod
- paramset_desc: varchar(128)
- param_set_hash: uuid
- unique index (param_set_hash)
- params: longblob # dictionary of all applicable parameters
- """
-
- @classmethod
- def insert_new_params(
- cls, precluster_method: str, paramset_idx: int, paramset_desc: str, params: dict
- ):
- param_dict = {
- "precluster_method": precluster_method,
- "paramset_idx": paramset_idx,
- "paramset_desc": paramset_desc,
- "params": params,
- "param_set_hash": dict_to_uuid(params),
- }
- param_query = cls & {"param_set_hash": param_dict["param_set_hash"]}
-
- if param_query: # If the specified param-set already exists
- existing_paramset_idx = param_query.fetch1("paramset_idx")
- if (
- existing_paramset_idx == paramset_idx
- ): # If the existing set has the same paramset_idx: job done
- return
- else: # If not same name: human error, trying to add the same paramset with different name
- raise dj.DataJointError(
- "The specified param-set"
- " already exists - paramset_idx: {}".format(existing_paramset_idx)
- )
- else:
- cls.insert1(param_dict)
-
-
-@schema
-class PreClusterParamSteps(dj.Manual):
- """Ordered list of parameter sets that will be run.
-
- Attributes:
- precluster_param_steps_id (foreign key): Unique ID for the pre-clustering parameter sets to be run.
- precluster_param_steps_name (varchar(32) ): User-friendly name for the parameter steps.
- precluster_param_steps_desc (varchar(128) ): Description of the parameter steps.
- """
-
- definition = """
- # Ordered list of paramset_idx that are to be run
- # When pre-clustering is not performed, do not create an entry in `Step` Part table
- precluster_param_steps_id: smallint
- ---
- precluster_param_steps_name: varchar(32)
- precluster_param_steps_desc: varchar(128)
- """
-
- class Step(dj.Part):
- """Define the order of operations for parameter sets.
-
- Attributes:
- PreClusterParamSteps (foreign key): PreClusterParamSteps primary key.
- step_number (foreign key, smallint): Order of operations.
- PreClusterParamSet (dict): PreClusterParamSet to be used in pre-clustering.
- """
-
- definition = """
- -> master
- step_number: smallint # Order of operations
- ---
- -> PreClusterParamSet
- """
-
-
-@schema
-class PreClusterTask(dj.Manual):
- """Defines a pre-clustering task ready to be run.
-
- Attributes:
- EphysRecording (foreign key): EphysRecording primary key.
- PreclusterParamSteps (foreign key): PreClusterParam Steps primary key.
- precluster_output_dir (varchar(255) ): relative path to directory for storing results of pre-clustering.
- task_mode (enum ): `none` (no pre-clustering), `load` results from file, or `trigger` automated pre-clustering.
- """
-
- definition = """
- # Manual table for defining a clustering task ready to be run
- -> EphysRecording
- -> PreClusterParamSteps
- ---
- precluster_output_dir='': varchar(255) # pre-clustering output directory relative to the root data directory
- task_mode='none': enum('none','load', 'trigger') # 'none': no pre-clustering analysis
- # 'load': load analysis results
- # 'trigger': trigger computation
- """
-
-
-@schema
-class PreCluster(dj.Imported):
- """
- A processing table to handle each PreClusterTask:
-
- Attributes:
- PreClusterTask (foreign key): PreClusterTask primary key.
- precluster_time (datetime): Time of generation of this set of pre-clustering results.
- package_version (varchar(16) ): Package version used for performing pre-clustering.
- """
-
- definition = """
- -> PreClusterTask
- ---
- precluster_time: datetime # time of generation of this set of pre-clustering results
- package_version='': varchar(16)
- """
-
- def make(self, key):
- """Populate pre-clustering tables."""
- task_mode, output_dir = (PreClusterTask & key).fetch1(
- "task_mode", "precluster_output_dir"
- )
- precluster_output_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- if task_mode == "none":
- if len((PreClusterParamSteps.Step & key).fetch()) > 0:
- raise ValueError(
- "There are entries in the PreClusterParamSteps.Step "
- "table and task_mode=none"
- )
- creation_time = (EphysRecording & key).fetch1("recording_datetime")
- elif task_mode == "load":
- acq_software = (EphysRecording & key).fetch1("acq_software")
- inserted_probe_serial_number = (ProbeInsertion * probe.Probe & key).fetch1(
- "probe"
- )
-
- if acq_software == "SpikeGLX":
- for meta_filepath in precluster_output_dir.rglob("*.ap.meta"):
- spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath)
-
- if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number:
- creation_time = spikeglx_meta.recording_time
- break
- else:
- raise FileNotFoundError(
- "No SpikeGLX data found for probe insertion: {}".format(key)
- )
- else:
- raise NotImplementedError(
- f"Pre-clustering analysis of {acq_software}" "is not yet supported."
- )
- elif task_mode == "trigger":
- raise NotImplementedError(
- "Automatic triggering of"
- " pre-clustering analysis is not yet supported."
- )
- else:
- raise ValueError(f"Unknown task mode: {task_mode}")
-
- self.insert1({**key, "precluster_time": creation_time, "package_version": ""})
-
-
-@schema
-class LFP(dj.Imported):
- """Extracts local field potentials (LFP) from an electrophysiology recording.
-
- Attributes:
- EphysRecording (foreign key): EphysRecording primary key.
- lfp_sampling_rate (float): Sampling rate for LFPs in Hz.
- lfp_time_stamps (longblob): Time stamps with respect to the start of the recording.
- lfp_mean (longblob): Overall mean LFP across electrodes.
- """
-
- definition = """
- # Acquired local field potential (LFP) from a given Ephys recording.
- -> PreCluster
- ---
- lfp_sampling_rate: float # (Hz)
- lfp_time_stamps: longblob # (s) timestamps with respect to the start of the recording (recording_timestamp)
- lfp_mean: longblob # (uV) mean of LFP across electrodes - shape (time,)
- """
-
- class Electrode(dj.Part):
- """Saves local field potential data for each electrode.
-
- Attributes:
- LFP (foreign key): LFP primary key.
- probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key.
- lfp (longblob): LFP recording at this electrode in microvolts.
- """
-
- definition = """
- -> master
- -> probe.ElectrodeConfig.Electrode
- ---
- lfp: longblob # (uV) recorded lfp at this electrode
- """
-
- # Only store LFP for every 9th channel, due to high channel density,
- # close-by channels exhibit highly similar LFP
- _skip_channel_counts = 9
-
- def make(self, key):
- """Populates the LFP tables."""
- acq_software, probe_sn = (EphysRecording * ProbeInsertion & key).fetch1(
- "acq_software", "probe"
- )
-
- electrode_keys, lfp = [], []
-
- if acq_software == "SpikeGLX":
- spikeglx_meta_filepath = get_spikeglx_meta_filepath(key)
- spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent)
-
- lfp_channel_ind = spikeglx_recording.lfmeta.recording_channels[
- -1 :: -self._skip_channel_counts
- ]
-
- # Extract LFP data at specified channels and convert to uV
- lfp = spikeglx_recording.lf_timeseries[
- :, lfp_channel_ind
- ] # (sample x channel)
- lfp = (
- lfp * spikeglx_recording.get_channel_bit_volts("lf")[lfp_channel_ind]
- ).T # (channel x sample)
-
- self.insert1(
- dict(
- key,
- lfp_sampling_rate=spikeglx_recording.lfmeta.meta["imSampRate"],
- lfp_time_stamps=(
- np.arange(lfp.shape[1])
- / spikeglx_recording.lfmeta.meta["imSampRate"]
- ),
- lfp_mean=lfp.mean(axis=0),
- )
- )
-
- electrode_query = (
- probe.ProbeType.Electrode
- * probe.ElectrodeConfig.Electrode
- * EphysRecording
- & key
- )
- probe_electrodes = {
- (shank, shank_col, shank_row): key
- for key, shank, shank_col, shank_row in zip(
- *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row")
- )
- }
-
- for recorded_site in lfp_channel_ind:
- shank, shank_col, shank_row, _ = spikeglx_recording.apmeta.shankmap[
- "data"
- ][recorded_site]
- electrode_keys.append(probe_electrodes[(shank, shank_col, shank_row)])
- elif acq_software == "Open Ephys":
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(key)
- )
-
- loaded_oe = openephys.OpenEphys(session_dir)
- oe_probe = loaded_oe.probes[probe_sn]
-
- lfp_channel_ind = np.arange(len(oe_probe.lfp_meta["channels_ids"]))[
- -1 :: -self._skip_channel_counts
- ]
-
- lfp = oe_probe.lfp_timeseries[:, lfp_channel_ind] # (sample x channel)
- lfp = (
- lfp * np.array(oe_probe.lfp_meta["channels_gains"])[lfp_channel_ind]
- ).T # (channel x sample)
- lfp_timestamps = oe_probe.lfp_timestamps
-
- self.insert1(
- dict(
- key,
- lfp_sampling_rate=oe_probe.lfp_meta["sample_rate"],
- lfp_time_stamps=lfp_timestamps,
- lfp_mean=lfp.mean(axis=0),
- )
- )
-
- electrode_query = (
- probe.ProbeType.Electrode
- * probe.ElectrodeConfig.Electrode
- * EphysRecording
- & key
- )
- probe_electrodes = {
- key["electrode"]: key for key in electrode_query.fetch("KEY")
- }
-
- for channel_idx in np.array(oe_probe.lfp_meta["channels_ids"])[
- lfp_channel_ind
- ]:
- electrode_keys.append(probe_electrodes[channel_idx])
- else:
- raise NotImplementedError(
- f"LFP extraction from acquisition software"
- f" of type {acq_software} is not yet implemented"
- )
-
- # single insert in loop to mitigate potential memory issue
- for electrode_key, lfp_trace in zip(electrode_keys, lfp):
- self.Electrode.insert1({**key, **electrode_key, "lfp": lfp_trace})
-
-
-# ------------ Clustering --------------
-
-
-@schema
-class ClusteringMethod(dj.Lookup):
- """Kilosort clustering method.
-
- Attributes:
- clustering_method (foreign key, varchar(16) ): Kilosort clustering method.
- clustering_methods_desc (varchar(1000) ): Additional description of the clustering method.
- """
-
- definition = """
- # Method for clustering
- clustering_method: varchar(16)
- ---
- clustering_method_desc: varchar(1000)
- """
-
- contents = [
- ("kilosort", "kilosort clustering method"),
- ("kilosort2", "kilosort2 clustering method"),
- ]
-
-
-@schema
-class ClusteringParamSet(dj.Lookup):
- """Parameters to be used in clustering procedure for spike sorting.
-
- Attributes:
- paramset_idx (foreign key): Unique ID for the clustering parameter set.
- ClusteringMethod (dict): ClusteringMethod primary key.
- paramset_desc (varchar(128) ): Description of the clustering parameter set.
- param_set_hash (uuid): UUID hash for the parameter set.
- params (longblob): Paramset, dictionary of all applicable parameters.
- """
-
- definition = """
- # Parameter set to be used in a clustering procedure
- paramset_idx: smallint
- ---
- -> ClusteringMethod
- paramset_desc: varchar(128)
- param_set_hash: uuid
- unique index (param_set_hash)
- params: longblob # dictionary of all applicable parameters
- """
-
- @classmethod
- def insert_new_params(
- cls, processing_method: str, paramset_idx: int, paramset_desc: str, params: dict
- ):
- """Inserts new parameters into the ClusteringParamSet table.
-
- Args:
- processing_method (str): name of the clustering method.
- paramset_desc (str): description of the parameter set
- params (dict): clustering parameters
- paramset_idx (int, optional): Unique parameter set ID. Defaults to None.
- """
- param_dict = {
- "clustering_method": processing_method,
- "paramset_idx": paramset_idx,
- "paramset_desc": paramset_desc,
- "params": params,
- "param_set_hash": dict_to_uuid(params),
- }
- param_query = cls & {"param_set_hash": param_dict["param_set_hash"]}
-
- if param_query: # If the specified param-set already exists
- existing_paramset_idx = param_query.fetch1("paramset_idx")
- if (
- existing_paramset_idx == paramset_idx
- ): # If the existing set has the same paramset_idx: job done
- return
- else: # If not same name: human error, trying to add the same paramset with different name
- raise dj.DataJointError(
- "The specified param-set"
- " already exists - paramset_idx: {}".format(existing_paramset_idx)
- )
- else:
- cls.insert1(param_dict)
-
-
-@schema
-class ClusterQualityLabel(dj.Lookup):
- """Quality label for each spike sorted cluster.
-
- Attributes:
- cluster_quality_label (foreign key, varchar(100) ): Cluster quality type.
- cluster_quality_description (varchar(4000) ): Description of the cluster quality type.
- """
-
- definition = """
- # Quality
- cluster_quality_label: varchar(100)
- ---
- cluster_quality_description: varchar(4000)
- """
- contents = [
- ("good", "single unit"),
- ("ok", "probably a single unit, but could be contaminated"),
- ("mua", "multi-unit activity"),
- ("noise", "bad unit"),
- ]
-
-
-@schema
-class ClusteringTask(dj.Manual):
- """A clustering task to spike sort electrophysiology datasets.
-
- Attributes:
- EphysRecording (foreign key): EphysRecording primary key.
- ClusteringParamSet (foreign key): ClusteringParamSet primary key.
- clustering_outdir_dir (varchar (255) ): Relative path to output clustering results.
- task_mode (enum): `Trigger` computes clustering or and `load` imports existing data.
- """
-
- definition = """
- # Manual table for defining a clustering task ready to be run
- -> PreCluster
- -> ClusteringParamSet
- ---
- clustering_output_dir: varchar(255) # clustering output directory relative to the clustering root data directory
- task_mode='load': enum('load', 'trigger') # 'load': load computed analysis results, 'trigger': trigger computation
- """
-
-
-@schema
-class Clustering(dj.Imported):
- """A processing table to handle each clustering task.
-
- Attributes:
- ClusteringTask (foreign key): ClusteringTask primary key.
- clustering_time (datetime): Time when clustering results are generated.
- package_version (varchar(16) ): Package version used for a clustering analysis.
- """
-
- definition = """
- # Clustering Procedure
- -> ClusteringTask
- ---
- clustering_time: datetime # time of generation of this set of clustering results
- package_version='': varchar(16)
- """
-
- def make(self, key):
- """Triggers or imports clustering analysis."""
- task_mode, output_dir = (ClusteringTask & key).fetch1(
- "task_mode", "clustering_output_dir"
- )
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- if task_mode == "load":
- _ = kilosort.Kilosort(
- kilosort_dir
- ) # check if the directory is a valid Kilosort output
- creation_time, _, _ = kilosort.extract_clustering_info(kilosort_dir)
- elif task_mode == "trigger":
- raise NotImplementedError(
- "Automatic triggering of" " clustering analysis is not yet supported"
- )
- else:
- raise ValueError(f"Unknown task mode: {task_mode}")
-
- self.insert1({**key, "clustering_time": creation_time, "package_version": ""})
-
-
-@schema
-class Curation(dj.Manual):
- """Curation procedure table.
-
- Attributes:
- Clustering (foreign key): Clustering primary key.
- curation_id (foreign key, int): Unique curation ID.
- curation_time (datetime): Time when curation results are generated.
- curation_output_dir (varchar(255) ): Output directory of the curated results.
- quality_control (bool): If True, this clustering result has undergone quality control.
- manual_curation (bool): If True, manual curation has been performed on this clustering result.
- curation_note (varchar(2000) ): Notes about the curation task.
- """
-
- definition = """
- # Manual curation procedure
- -> Clustering
- curation_id: int
- ---
- curation_time: datetime # time of generation of this set of curated clustering results
- curation_output_dir: varchar(255) # output directory of the curated results, relative to root data directory
- quality_control: bool # has this clustering result undergone quality control?
- manual_curation: bool # has manual curation been performed on this clustering result?
- curation_note='': varchar(2000)
- """
-
- def create1_from_clustering_task(self, key, curation_note: str = ""):
- """
- A function to create a new corresponding "Curation" for a particular
- "ClusteringTask"
- """
- if key not in Clustering():
- raise ValueError(
- f"No corresponding entry in Clustering available"
- f" for: {key}; do `Clustering.populate(key)`"
- )
-
- task_mode, output_dir = (ClusteringTask & key).fetch1(
- "task_mode", "clustering_output_dir"
- )
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- creation_time, is_curated, is_qc = kilosort.extract_clustering_info(
- kilosort_dir
- )
- # Synthesize curation_id
- curation_id = (
- dj.U().aggr(self & key, n="ifnull(max(curation_id)+1,1)").fetch1("n")
- )
- self.insert1(
- {
- **key,
- "curation_id": curation_id,
- "curation_time": creation_time,
- "curation_output_dir": output_dir,
- "quality_control": is_qc,
- "manual_curation": is_curated,
- "curation_note": curation_note,
- }
- )
-
-
-@schema
-class CuratedClustering(dj.Imported):
- """Clustering results after curation.
-
- Attributes:
- Curation (foreign key): Curation primary key.
- """
-
- definition = """
- # Clustering results of a curation.
- -> Curation
- """
-
- class Unit(dj.Part):
- """Single unit properties after clustering and curation.
-
- Attributes:
- CuratedClustering (foreign key): CuratedClustering primary key.
- unit (foreign key, int): Unique integer identifying a single unit.
- probe.ElectrodeConfig.Electrode (dict): probe.ElectrodeConfig.Electrode primary key.
- ClusteringQualityLabel (dict): CLusteringQualityLabel primary key.
- spike_count (int): Number of spikes in this recording for this unit.
- spike_times (longblob): Spike times of this unit, relative to start time of EphysRecording.
- spike_sites (longblob): Array of electrode associated with each spike.
- spike_depths (longblob): Array of depths associated with each spike, relative to each spike.
- """
-
- definition = """
- # Properties of a given unit from a round of clustering (and curation)
- -> master
- unit: int
- ---
- -> probe.ElectrodeConfig.Electrode # electrode with highest waveform amplitude for this unit
- -> ClusterQualityLabel
- spike_count: int # how many spikes in this recording for this unit
- spike_times: longblob # (s) spike times of this unit, relative to the start of the EphysRecording
- spike_sites : longblob # array of electrode associated with each spike
- spike_depths=null : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe
- """
-
- def make(self, key):
- """Automated population of Unit information."""
- output_dir = (Curation & key).fetch1("curation_output_dir")
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- kilosort_dataset = kilosort.Kilosort(kilosort_dir)
- acq_software = (EphysRecording & key).fetch1("acq_software")
-
- # ---------- Unit ----------
- # -- Remove 0-spike units
- withspike_idx = [
- i
- for i, u in enumerate(kilosort_dataset.data["cluster_ids"])
- if (kilosort_dataset.data["spike_clusters"] == u).any()
- ]
- valid_units = kilosort_dataset.data["cluster_ids"][withspike_idx]
- valid_unit_labels = kilosort_dataset.data["cluster_groups"][withspike_idx]
- # -- Get channel and electrode-site mapping
- channel2electrodes = get_neuropixels_channel2electrode_map(key, acq_software)
-
- # -- Spike-times --
- # spike_times_sec_adj > spike_times_sec > spike_times
- spike_time_key = (
- "spike_times_sec_adj"
- if "spike_times_sec_adj" in kilosort_dataset.data
- else (
- "spike_times_sec"
- if "spike_times_sec" in kilosort_dataset.data
- else "spike_times"
- )
- )
- spike_times = kilosort_dataset.data[spike_time_key]
- kilosort_dataset.extract_spike_depths()
-
- # -- Spike-sites and Spike-depths --
- spike_sites = np.array(
- [
- channel2electrodes[s]["electrode"]
- for s in kilosort_dataset.data["spike_sites"]
- ]
- )
- spike_depths = kilosort_dataset.data["spike_depths"]
-
- # -- Insert unit, label, peak-chn
- units = []
- for unit, unit_lbl in zip(valid_units, valid_unit_labels):
- if (kilosort_dataset.data["spike_clusters"] == unit).any():
- unit_channel, _ = kilosort_dataset.get_best_channel(unit)
- unit_spike_times = (
- spike_times[kilosort_dataset.data["spike_clusters"] == unit]
- / kilosort_dataset.data["params"]["sample_rate"]
- )
- spike_count = len(unit_spike_times)
-
- units.append(
- {
- "unit": unit,
- "cluster_quality_label": unit_lbl,
- **channel2electrodes[unit_channel],
- "spike_times": unit_spike_times,
- "spike_count": spike_count,
- "spike_sites": spike_sites[
- kilosort_dataset.data["spike_clusters"] == unit
- ],
- "spike_depths": (
- spike_depths[
- kilosort_dataset.data["spike_clusters"] == unit
- ]
- if spike_depths is not None
- else None
- ),
- }
- )
-
- self.insert1(key)
- self.Unit.insert([{**key, **u} for u in units])
-
-
-@schema
-class WaveformSet(dj.Imported):
- """A set of spike waveforms for units out of a given CuratedClustering.
-
- Attributes:
- CuratedClustering (foreign key): CuratedClustering primary key.
- """
-
- definition = """
- # A set of spike waveforms for units out of a given CuratedClustering
- -> CuratedClustering
- """
-
- class PeakWaveform(dj.Part):
- """Mean waveform across spikes for a given unit.
-
- Attributes:
- WaveformSet (foreign key): WaveformSet primary key.
- CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key.
- peak_electrode_waveform (longblob): Mean waveform for a given unit at its representative electrode.
- """
-
- definition = """
- # Mean waveform across spikes for a given unit at its representative electrode
- -> master
- -> CuratedClustering.Unit
- ---
- peak_electrode_waveform: longblob # (uV) mean waveform for a given unit at its representative electrode
- """
-
- class Waveform(dj.Part):
- """Spike waveforms for a given unit.
-
- Attributes:
- WaveformSet (foreign key): WaveformSet primary key.
- CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key.
- probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key.
- waveform_mean (longblob): mean waveform across spikes of the unit in microvolts.
- waveforms (longblob): waveforms of a sampling of spikes at the given electrode and unit.
- """
-
- definition = """
- # Spike waveforms and their mean across spikes for the given unit
- -> master
- -> CuratedClustering.Unit
- -> probe.ElectrodeConfig.Electrode
- ---
- waveform_mean: longblob # (uV) mean waveform across spikes of the given unit
- waveforms=null: longblob # (uV) (spike x sample) waveforms of a sampling of spikes at the given electrode for the given unit
- """
-
- def make(self, key):
- """Populates waveform tables."""
- output_dir = (Curation & key).fetch1("curation_output_dir")
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- kilosort_dataset = kilosort.Kilosort(kilosort_dir)
-
- acq_software, probe_serial_number = (
- EphysRecording * ProbeInsertion & key
- ).fetch1("acq_software", "probe")
-
- # -- Get channel and electrode-site mapping
- recording_key = (EphysRecording & key).fetch1("KEY")
- channel2electrodes = get_neuropixels_channel2electrode_map(
- recording_key, acq_software
- )
-
- is_qc = (Curation & key).fetch1("quality_control")
-
- # Get all units
- units = {
- u["unit"]: u
- for u in (CuratedClustering.Unit & key).fetch(as_dict=True, order_by="unit")
- }
-
- if is_qc:
- unit_waveforms = np.load(
- kilosort_dir / "mean_waveforms.npy"
- ) # unit x channel x sample
-
- def yield_unit_waveforms():
- for unit_no, unit_waveform in zip(
- kilosort_dataset.data["cluster_ids"], unit_waveforms
- ):
- unit_peak_waveform = {}
- unit_electrode_waveforms = []
- if unit_no in units:
- for channel, channel_waveform in zip(
- kilosort_dataset.data["channel_map"], unit_waveform
- ):
- unit_electrode_waveforms.append(
- {
- **units[unit_no],
- **channel2electrodes[channel],
- "waveform_mean": channel_waveform,
- }
- )
- if (
- channel2electrodes[channel]["electrode"]
- == units[unit_no]["electrode"]
- ):
- unit_peak_waveform = {
- **units[unit_no],
- "peak_electrode_waveform": channel_waveform,
- }
- yield unit_peak_waveform, unit_electrode_waveforms
-
- else:
- if acq_software == "SpikeGLX":
- spikeglx_meta_filepath = get_spikeglx_meta_filepath(key)
- neuropixels_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent)
- elif acq_software == "Open Ephys":
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(key)
- )
- openephys_dataset = openephys.OpenEphys(session_dir)
- neuropixels_recording = openephys_dataset.probes[probe_serial_number]
-
- def yield_unit_waveforms():
- for unit_dict in units.values():
- unit_peak_waveform = {}
- unit_electrode_waveforms = []
-
- spikes = unit_dict["spike_times"]
- waveforms = neuropixels_recording.extract_spike_waveforms(
- spikes, kilosort_dataset.data["channel_map"]
- ) # (sample x channel x spike)
- waveforms = waveforms.transpose(
- (1, 2, 0)
- ) # (channel x spike x sample)
- for channel, channel_waveform in zip(
- kilosort_dataset.data["channel_map"], waveforms
- ):
- unit_electrode_waveforms.append(
- {
- **unit_dict,
- **channel2electrodes[channel],
- "waveform_mean": channel_waveform.mean(axis=0),
- "waveforms": channel_waveform,
- }
- )
- if (
- channel2electrodes[channel]["electrode"]
- == unit_dict["electrode"]
- ):
- unit_peak_waveform = {
- **unit_dict,
- "peak_electrode_waveform": channel_waveform.mean(
- axis=0
- ),
- }
-
- yield unit_peak_waveform, unit_electrode_waveforms
-
- # insert waveform on a per-unit basis to mitigate potential memory issue
- self.insert1(key)
- for unit_peak_waveform, unit_electrode_waveforms in yield_unit_waveforms():
- self.PeakWaveform.insert1(unit_peak_waveform, ignore_extra_fields=True)
- self.Waveform.insert(unit_electrode_waveforms, ignore_extra_fields=True)
-
-
-@schema
-class QualityMetrics(dj.Imported):
- """Clustering and waveform quality metrics.
-
- Attributes:
- CuratedClustering (foreign key): CuratedClustering primary key.
- """
-
- definition = """
- # Clusters and waveforms metrics
- -> CuratedClustering
- """
-
- class Cluster(dj.Part):
- """Cluster metrics for a unit.
-
- Attributes:
- QualityMetrics (foreign key): QualityMetrics primary key.
- CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key.
- firing_rate (float): Firing rate of the unit.
- snr (float): Signal-to-noise ratio for a unit.
- presence_ratio (float): Fraction of time where spikes are present.
- isi_violation (float): rate of ISI violation as a fraction of overall rate.
- number_violation (int): Total ISI violations.
- amplitude_cutoff (float): Estimate of miss rate based on amplitude histogram.
- isolation_distance (float): Distance to nearest cluster.
- l_ratio (float): Amount of empty space between a cluster and other spikes in dataset.
- d_prime (float): Classification accuracy based on LDA.
- nn_hit_rate (float): Fraction of neighbors for target cluster that are also in target cluster.
- nn_miss_rate (float): Fraction of neighbors outside target cluster that are in the target cluster.
- silhouette_core (float): Maximum change in spike depth throughout recording.
- cumulative_drift (float): Cumulative change in spike depth throughout recording.
- contamination_rate (float): Frequency of spikes in the refractory period.
- """
-
- definition = """
- # Cluster metrics for a particular unit
- -> master
- -> CuratedClustering.Unit
- ---
- firing_rate=null: float # (Hz) firing rate for a unit
- snr=null: float # signal-to-noise ratio for a unit
- presence_ratio=null: float # fraction of time in which spikes are present
- isi_violation=null: float # rate of ISI violation as a fraction of overall rate
- number_violation=null: int # total number of ISI violations
- amplitude_cutoff=null: float # estimate of miss rate based on amplitude histogram
- isolation_distance=null: float # distance to nearest cluster in Mahalanobis space
- l_ratio=null: float #
- d_prime=null: float # Classification accuracy based on LDA
- nn_hit_rate=null: float # Fraction of neighbors for target cluster that are also in target cluster
- nn_miss_rate=null: float # Fraction of neighbors outside target cluster that are in target cluster
- silhouette_score=null: float # Standard metric for cluster overlap
- max_drift=null: float # Maximum change in spike depth throughout recording
- cumulative_drift=null: float # Cumulative change in spike depth throughout recording
- contamination_rate=null: float #
- """
-
- class Waveform(dj.Part):
- """Waveform metrics for a particular unit.
-
- Attributes:
- QualityMetrics (foreign key): QualityMetrics primary key.
- CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key.
- amplitude (float): Absolute difference between waveform peak and trough in microvolts.
- duration (float): Time between waveform peak and trough in milliseconds.
- halfwidth (float): Spike width at half max amplitude.
- pt_ratio (float): Absolute amplitude of peak divided by absolute amplitude of trough relative to 0.
- repolarization_slope (float): Slope of the regression line fit to first 30 microseconds from trough to peak.
- recovery_slope (float): Slope of the regression line fit to first 30 microseconds from peak to tail.
- spread (float): The range with amplitude over 12-percent of maximum amplitude along the probe.
- velocity_above (float): inverse velocity of waveform propagation from soma to the top of the probe.
- velocity_below (float): inverse velocity of waveform propagation from soma toward the bottom of the probe.
- """
-
- definition = """
- # Waveform metrics for a particular unit
- -> master
- -> CuratedClustering.Unit
- ---
- amplitude: float # (uV) absolute difference between waveform peak and trough
- duration: float # (ms) time between waveform peak and trough
- halfwidth=null: float # (ms) spike width at half max amplitude
- pt_ratio=null: float # absolute amplitude of peak divided by absolute amplitude of trough relative to 0
- repolarization_slope=null: float # the repolarization slope was defined by fitting a regression line to the first 30us from trough to peak
- recovery_slope=null: float # the recovery slope was defined by fitting a regression line to the first 30us from peak to tail
- spread=null: float # (um) the range with amplitude above 12-percent of the maximum amplitude along the probe
- velocity_above=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the top of the probe
- velocity_below=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the bottom of the probe
- """
-
- def make(self, key):
- """Populates tables with quality metrics data."""
- output_dir = (ClusteringTask & key).fetch1("clustering_output_dir")
- kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir)
-
- metric_fp = kilosort_dir / "metrics.csv"
- rename_dict = {
- "isi_viol": "isi_violation",
- "num_viol": "number_violation",
- "contam_rate": "contamination_rate",
- }
-
- if not metric_fp.exists():
- raise FileNotFoundError(f"QC metrics file not found: {metric_fp}")
-
- metrics_df = pd.read_csv(metric_fp)
- metrics_df.set_index("cluster_id", inplace=True)
- metrics_df.replace([np.inf, -np.inf], np.nan, inplace=True)
- metrics_df.columns = metrics_df.columns.str.lower()
- metrics_df.rename(columns=rename_dict, inplace=True)
- metrics_list = [
- dict(metrics_df.loc[unit_key["unit"]], **unit_key)
- for unit_key in (CuratedClustering.Unit & key).fetch("KEY")
- ]
-
- self.insert1(key)
- self.Cluster.insert(metrics_list, ignore_extra_fields=True)
- self.Waveform.insert(metrics_list, ignore_extra_fields=True)
-
-
-# ---------------- HELPER FUNCTIONS ----------------
-
-
-def get_spikeglx_meta_filepath(ephys_recording_key: dict) -> str:
- """Get spikeGLX data filepath."""
- # attempt to retrieve from EphysRecording.EphysFile
- spikeglx_meta_filepath = (
- EphysRecording.EphysFile & ephys_recording_key & 'file_path LIKE "%.ap.meta"'
- ).fetch1("file_path")
-
- try:
- spikeglx_meta_filepath = find_full_path(
- get_ephys_root_data_dir(), spikeglx_meta_filepath
- )
- except FileNotFoundError:
- # if not found, search in session_dir again
- if not spikeglx_meta_filepath.exists():
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(ephys_recording_key)
- )
- inserted_probe_serial_number = (
- ProbeInsertion * probe.Probe & ephys_recording_key
- ).fetch1("probe")
-
- spikeglx_meta_filepaths = [fp for fp in session_dir.rglob("*.ap.meta")]
- for meta_filepath in spikeglx_meta_filepaths:
- spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath)
- if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number:
- spikeglx_meta_filepath = meta_filepath
- break
- else:
- raise FileNotFoundError(
- "No SpikeGLX data found for probe insertion: {}".format(
- ephys_recording_key
- )
- )
-
- return spikeglx_meta_filepath
-
-
-def get_neuropixels_channel2electrode_map(
- ephys_recording_key: dict, acq_software: str
-) -> dict:
- """Get the channel map for neuropixels probe."""
- if acq_software == "SpikeGLX":
- spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key)
- spikeglx_meta = spikeglx.SpikeGLXMeta(spikeglx_meta_filepath)
- electrode_config_key = (
- EphysRecording * probe.ElectrodeConfig & ephys_recording_key
- ).fetch1("KEY")
-
- electrode_query = (
- probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode
- & electrode_config_key
- )
-
- probe_electrodes = {
- (shank, shank_col, shank_row): key
- for key, shank, shank_col, shank_row in zip(
- *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row")
- )
- }
-
- channel2electrode_map = {
- recorded_site: probe_electrodes[(shank, shank_col, shank_row)]
- for recorded_site, (shank, shank_col, shank_row, _) in enumerate(
- spikeglx_meta.shankmap["data"]
- )
- }
- elif acq_software == "Open Ephys":
- session_dir = find_full_path(
- get_ephys_root_data_dir(), get_session_directory(ephys_recording_key)
- )
- openephys_dataset = openephys.OpenEphys(session_dir)
- probe_serial_number = (ProbeInsertion & ephys_recording_key).fetch1("probe")
- probe_dataset = openephys_dataset.probes[probe_serial_number]
-
- electrode_query = (
- probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode * EphysRecording
- & ephys_recording_key
- )
-
- probe_electrodes = {
- key["electrode"]: key for key in electrode_query.fetch("KEY")
- }
-
- channel2electrode_map = {
- channel_idx: probe_electrodes[channel_idx]
- for channel_idx in probe_dataset.ap_meta["channels_ids"]
- }
-
- return channel2electrode_map
-
-
-def generate_electrode_config(probe_type: str, electrode_keys: list) -> dict:
- """Generate and insert new ElectrodeConfig
-
- Args:
- probe_type (str): probe type (e.g. neuropixels 2.0 - SS)
- electrode_keys (list): list of keys of the probe.ProbeType.Electrode table
-
- Returns:
- dict: representing a key of the probe.ElectrodeConfig table
- """
- # compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode)
- electrode_config_hash = dict_to_uuid({k["electrode"]: k for k in electrode_keys})
-
- electrode_list = sorted([k["electrode"] for k in electrode_keys])
- electrode_gaps = (
- [-1]
- + np.where(np.diff(electrode_list) > 1)[0].tolist()
- + [len(electrode_list) - 1]
- )
- electrode_config_name = "; ".join(
- [
- f"{electrode_list[start + 1]}-{electrode_list[end]}"
- for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:])
- ]
- )
-
- electrode_config_key = {"electrode_config_hash": electrode_config_hash}
-
- # ---- make new ElectrodeConfig if needed ----
- if not probe.ElectrodeConfig & electrode_config_key:
- probe.ElectrodeConfig.insert1(
- {
- **electrode_config_key,
- "probe_type": probe_type,
- "electrode_config_name": electrode_config_name,
- }
- )
- probe.ElectrodeConfig.Electrode.insert(
- {**electrode_config_key, **electrode} for electrode in electrode_keys
- )
-
- return electrode_config_key
diff --git a/element_array_ephys/ephys_report.py b/element_array_ephys/ephys_report.py
index 48bcf613..0c6836a0 100644
--- a/element_array_ephys/ephys_report.py
+++ b/element_array_ephys/ephys_report.py
@@ -2,31 +2,30 @@
import datetime
import pathlib
+import tempfile
from uuid import UUID
import datajoint as dj
from element_interface.utils import dict_to_uuid
-from . import probe
+from . import probe, ephys
schema = dj.schema()
-ephys = None
-
-def activate(schema_name, ephys_schema_name, *, create_schema=True, create_tables=True):
+def activate(schema_name, *, create_schema=True, create_tables=True):
"""Activate the current schema.
Args:
schema_name (str): schema name on the database server to activate the `ephys_report` schema.
- ephys_schema_name (str): schema name of the activated ephys element for which
- this ephys_report schema will be downstream from.
create_schema (bool, optional): If True (default), create schema in the database if it does not yet exist.
create_tables (bool, optional): If True (default), create tables in the database if they do not yet exist.
"""
+ if not probe.schema.is_activated():
+ raise RuntimeError("Please activate the `probe` schema first.")
+ if not ephys.schema.is_activated():
+ raise RuntimeError("Please activate the `ephys` schema first.")
- global ephys
- ephys = dj.create_virtual_module("ephys", ephys_schema_name)
schema.activate(
schema_name,
create_schema=create_schema,
@@ -55,7 +54,7 @@ class ProbeLevelReport(dj.Computed):
def make(self, key):
from .plotting.probe_level import plot_driftmap
- save_dir = _make_save_dir()
+ save_dir = tempfile.TemporaryDirectory()
units = ephys.CuratedClustering.Unit & key & "cluster_quality_label='good'"
@@ -90,13 +89,15 @@ def make(self, key):
fig_dict = _save_figs(
figs=(fig,),
fig_names=("drift_map_plot",),
- save_dir=save_dir,
+ save_dir=save_dir.name,
fig_prefix=fig_prefix,
extension=".png",
)
self.insert1({**key, **fig_dict, "shank": shank_no})
+ save_dir.cleanup()
+
@schema
class UnitLevelReport(dj.Computed):
@@ -268,17 +269,10 @@ def make(self, key):
)
-def _make_save_dir(root_dir: pathlib.Path = None) -> pathlib.Path:
- if root_dir is None:
- root_dir = pathlib.Path().absolute()
- save_dir = root_dir / "temp_ephys_figures"
- save_dir.mkdir(parents=True, exist_ok=True)
- return save_dir
-
-
def _save_figs(
figs, fig_names, save_dir, fig_prefix, extension=".png"
) -> dict[str, pathlib.Path]:
+ save_dir = pathlib.Path(save_dir)
fig_dict = {}
for fig, fig_name in zip(figs, fig_names):
fig_filepath = save_dir / (fig_prefix + "_" + fig_name + extension)
diff --git a/element_array_ephys/export/nwb/nwb.py b/element_array_ephys/export/nwb/nwb.py
index a45eb754..8d7da8f5 100644
--- a/element_array_ephys/export/nwb/nwb.py
+++ b/element_array_ephys/export/nwb/nwb.py
@@ -17,14 +17,7 @@
from spikeinterface import extractors
from tqdm import tqdm
-from ... import ephys_no_curation as ephys
-from ... import probe
-
-ephys_mode = os.getenv("EPHYS_MODE", dj.config["custom"].get("ephys_mode", "acute"))
-if ephys_mode != "no-curation":
- raise NotImplementedError(
- "This export function is designed for the no_curation " + "schema"
- )
+from ... import probe, ephys
class DecimalEncoder(json.JSONEncoder):
diff --git a/element_array_ephys/readers/kilosort.py b/element_array_ephys/readers/kilosort.py
index 4b50619d..4f8530d8 100644
--- a/element_array_ephys/readers/kilosort.py
+++ b/element_array_ephys/readers/kilosort.py
@@ -1,12 +1,10 @@
-import logging
-import pathlib
-import re
-from datetime import datetime
from os import path
-
-import numpy as np
+from datetime import datetime
+import pathlib
import pandas as pd
-
+import numpy as np
+import re
+import logging
from .utils import convert_to_number
log = logging.getLogger(__name__)
@@ -117,7 +115,8 @@ def _load(self):
# Read the Cluster Groups
for cluster_pattern, cluster_col_name in zip(
- ["cluster_group.*", "cluster_KSLabel.*"], ["group", "KSLabel"]
+ ["cluster_group.*", "cluster_KSLabel.*", "cluster_group.*"],
+ ["group", "KSLabel", "KSLabel"],
):
try:
cluster_file = next(self._kilosort_dir.glob(cluster_pattern))
@@ -126,22 +125,26 @@ def _load(self):
else:
cluster_file_suffix = cluster_file.suffix
assert cluster_file_suffix in (".tsv", ".xlsx")
- break
+
+ if cluster_file_suffix == ".tsv":
+ df = pd.read_csv(cluster_file, sep="\t", header=0)
+ elif cluster_file_suffix == ".xlsx":
+ df = pd.read_excel(cluster_file, engine="openpyxl")
+ else:
+ df = pd.read_csv(cluster_file, delimiter="\t")
+
+ try:
+ self._data["cluster_groups"] = np.array(df[cluster_col_name].values)
+ self._data["cluster_ids"] = np.array(df["cluster_id"].values)
+ except KeyError:
+ continue
+ else:
+ break
else:
raise FileNotFoundError(
'Neither "cluster_groups" nor "cluster_KSLabel" file found!'
)
- if cluster_file_suffix == ".tsv":
- df = pd.read_csv(cluster_file, sep="\t", header=0)
- elif cluster_file_suffix == ".xlsx":
- df = pd.read_excel(cluster_file, engine="openpyxl")
- else:
- df = pd.read_csv(cluster_file, delimiter="\t")
-
- self._data["cluster_groups"] = np.array(df[cluster_col_name].values)
- self._data["cluster_ids"] = np.array(df["cluster_id"].values)
-
def get_best_channel(self, unit):
template_idx = self.data["spike_templates"][
np.where(self.data["spike_clusters"] == unit)[0][0]
diff --git a/element_array_ephys/readers/probe_geometry.py b/element_array_ephys/readers/probe_geometry.py
index b6fbc09e..f0d50a1c 100644
--- a/element_array_ephys/readers/probe_geometry.py
+++ b/element_array_ephys/readers/probe_geometry.py
@@ -140,8 +140,8 @@ def build_npx_probe(
return elec_pos_df
-def to_probeinterface(electrodes_df):
- from probeinterface import Probe
+def to_probeinterface(electrodes_df, **kwargs):
+ import probeinterface as pi
probe_df = electrodes_df.copy()
probe_df.rename(
@@ -153,10 +153,22 @@ def to_probeinterface(electrodes_df):
},
inplace=True,
)
- probe_df["contact_shapes"] = "square"
- probe_df["width"] = 12
-
- return Probe.from_dataframe(probe_df)
+ # Get the contact shapes. By default, it's set to circle with a radius of 10.
+ contact_shapes = kwargs.get("contact_shapes", "circle")
+ assert (
+ contact_shapes in pi.probe._possible_contact_shapes
+ ), f"contacts shape should be in {pi.probe._possible_contact_shapes}"
+
+ probe_df["contact_shapes"] = contact_shapes
+ if contact_shapes == "circle":
+ probe_df["radius"] = kwargs.get("radius", 10)
+ elif contact_shapes == "square":
+ probe_df["width"] = kwargs.get("width", 10)
+ elif contact_shapes == "rect":
+ probe_df["width"] = kwargs.get("width")
+ probe_df["height"] = kwargs.get("height")
+
+ return pi.Probe.from_dataframe(probe_df)
def build_electrode_layouts(
diff --git a/element_array_ephys/spike_sorting/__init__.py b/element_array_ephys/spike_sorting/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/element_array_ephys/readers/kilosort_triggering.py b/element_array_ephys/spike_sorting/kilosort_triggering.py
similarity index 100%
rename from element_array_ephys/readers/kilosort_triggering.py
rename to element_array_ephys/spike_sorting/kilosort_triggering.py
diff --git a/element_array_ephys/spike_sorting/si_preprocessing.py b/element_array_ephys/spike_sorting/si_preprocessing.py
new file mode 100644
index 00000000..22adbdca
--- /dev/null
+++ b/element_array_ephys/spike_sorting/si_preprocessing.py
@@ -0,0 +1,37 @@
+import spikeinterface as si
+from spikeinterface import preprocessing
+
+
+def CatGT(recording):
+ recording = si.preprocessing.phase_shift(recording)
+ recording = si.preprocessing.common_reference(
+ recording, operator="median", reference="global"
+ )
+ return recording
+
+
+def IBLdestriping(recording):
+ # From International Brain Laboratory. “Spike sorting pipeline for the International Brain Laboratory”. 4 May 2022. 9 Jun 2022.
+ recording = si.preprocessing.highpass_filter(recording, freq_min=400.0)
+ bad_channel_ids, channel_labels = si.preprocessing.detect_bad_channels(recording)
+ # For IBL destriping interpolate bad channels
+ recording = si.preprocessing.interpolate_bad_channels(bad_channel_ids)
+ recording = si.preprocessing.phase_shift(recording)
+ # For IBL destriping use highpass_spatial_filter used instead of common reference
+ recording = si.preprocessing.highpass_spatial_filter(
+ recording, operator="median", reference="global"
+ )
+ return recording
+
+
+def IBLdestriping_modified(recording):
+ # From SpikeInterface Implementation (https://spikeinterface.readthedocs.io/en/latest/how_to/analyse_neuropixels.html)
+ recording = si.preprocessing.highpass_filter(recording, freq_min=400.0)
+ bad_channel_ids, channel_labels = si.preprocessing.detect_bad_channels(recording)
+ # For IBL destriping interpolate bad channels
+ recording = recording.remove_channels(bad_channel_ids)
+ recording = si.preprocessing.phase_shift(recording)
+ recording = si.preprocessing.common_reference(
+ recording, operator="median", reference="global"
+ )
+ return recording
diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py
new file mode 100644
index 00000000..e2f011e1
--- /dev/null
+++ b/element_array_ephys/spike_sorting/si_spike_sorting.py
@@ -0,0 +1,403 @@
+"""
+The following DataJoint pipeline implements the sequence of steps in the spike-sorting routine featured in the "spikeinterface" pipeline.
+Spikeinterface was developed by Alessio Buccino, Samuel Garcia, Cole Hurwitz, Jeremy Magland, and Matthias Hennig (https://github.com/SpikeInterface)
+If you use this pipeline, please cite SpikeInterface and the relevant sorter(s) used in your publication (see https://github.com/SpikeInterface for additional details for citation).
+"""
+
+from datetime import datetime
+
+import datajoint as dj
+import pandas as pd
+import spikeinterface as si
+from element_array_ephys import probe, ephys, readers
+from element_interface.utils import find_full_path, memoized_result
+from spikeinterface import exporters, extractors, sorters
+
+from . import si_preprocessing
+
+log = dj.logger
+
+schema = dj.schema()
+
+
+def activate(
+ schema_name,
+ *,
+ create_schema=True,
+ create_tables=True,
+):
+ """Activate the current schema.
+
+ Args:
+ schema_name (str): schema name on the database server to activate the `si_spike_sorting` schema.
+ create_schema (bool, optional): If True (default), create schema in the database if it does not yet exist.
+ create_tables (bool, optional): If True (default), create tables in the database if they do not yet exist.
+ """
+ if not probe.schema.is_activated():
+ raise RuntimeError("Please activate the `probe` schema first.")
+ if not ephys.schema.is_activated():
+ raise RuntimeError("Please activate the `ephys` schema first.")
+
+ schema.activate(
+ schema_name,
+ create_schema=create_schema,
+ create_tables=create_tables,
+ add_objects=ephys.__dict__,
+ )
+ ephys.Clustering.key_source -= PreProcessing.key_source.proj()
+
+
+SI_SORTERS = [s.replace("_", ".") for s in si.sorters.sorter_dict.keys()]
+
+
+@schema
+class PreProcessing(dj.Imported):
+ """A table to handle preprocessing of each clustering task. The output will be serialized and stored as a si_recording.pkl in the output directory."""
+
+ definition = """
+ -> ephys.ClusteringTask
+ ---
+ execution_time: datetime # datetime of the start of this step
+ execution_duration: float # execution duration in hours
+ """
+
+ @property
+ def key_source(self):
+ return (
+ ephys.ClusteringTask * ephys.ClusteringParamSet
+ & {"task_mode": "trigger"}
+ & f"clustering_method in {tuple(SI_SORTERS)}"
+ ) - ephys.Clustering
+
+ def make(self, key):
+ """Triggers or imports clustering analysis."""
+ execution_time = datetime.utcnow()
+
+ # Set the output directory
+ clustering_method, acq_software, output_dir, params = (
+ ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key
+ ).fetch1("clustering_method", "acq_software", "clustering_output_dir", "params")
+
+ # Get sorter method and create output directory.
+ sorter_name = clustering_method.replace(".", "_")
+
+ for required_key in (
+ "SI_PREPROCESSING_METHOD",
+ "SI_SORTING_PARAMS",
+ "SI_POSTPROCESSING_PARAMS",
+ ):
+ if required_key not in params:
+ raise ValueError(
+ f"{required_key} must be defined in ClusteringParamSet for SpikeInterface execution"
+ )
+
+ # Set directory to store recording file.
+ if not output_dir:
+ output_dir = ephys.ClusteringTask.infer_output_dir(
+ key, relative=True, mkdir=True
+ )
+ # update clustering_output_dir
+ ephys.ClusteringTask.update1(
+ {**key, "clustering_output_dir": output_dir.as_posix()}
+ )
+ output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir)
+ recording_dir = output_dir / sorter_name / "recording"
+ recording_dir.mkdir(parents=True, exist_ok=True)
+ recording_file = recording_dir / "si_recording.pkl"
+
+ # Create SI recording extractor object
+ if acq_software == "SpikeGLX":
+ spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key)
+ spikeglx_recording = readers.spikeglx.SpikeGLX(
+ spikeglx_meta_filepath.parent
+ )
+ spikeglx_recording.validate_file("ap")
+ data_dir = spikeglx_meta_filepath.parent
+
+ si_extractor = (
+ si.extractors.neoextractors.spikeglx.SpikeGLXRecordingExtractor
+ )
+ stream_names, stream_ids = si.extractors.get_neo_streams(
+ "spikeglx", folder_path=data_dir
+ )
+ si_recording: si.BaseRecording = si_extractor(
+ folder_path=data_dir, stream_name=stream_names[0]
+ )
+ elif acq_software == "Open Ephys":
+ oe_probe = ephys.get_openephys_probe_data(key)
+ assert len(oe_probe.recording_info["recording_files"]) == 1
+ data_dir = oe_probe.recording_info["recording_files"][0]
+ si_extractor = (
+ si.extractors.neoextractors.openephys.OpenEphysBinaryRecordingExtractor
+ )
+
+ stream_names, stream_ids = si.extractors.get_neo_streams(
+ "openephysbinary", folder_path=data_dir
+ )
+ si_recording: si.BaseRecording = si_extractor(
+ folder_path=data_dir, stream_name=stream_names[0]
+ )
+ else:
+ raise NotImplementedError(
+ f"SpikeInterface processing for {acq_software} not yet implemented."
+ )
+
+ # Add probe information to recording object
+ electrodes_df = (
+ (
+ ephys.EphysRecording.Channel
+ * probe.ElectrodeConfig.Electrode
+ * probe.ProbeType.Electrode
+ & key
+ )
+ .fetch(format="frame")
+ .reset_index()
+ )
+
+ # Create SI probe object
+ si_probe = readers.probe_geometry.to_probeinterface(
+ electrodes_df[["electrode", "x_coord", "y_coord", "shank"]]
+ )
+ si_probe.set_device_channel_indices(electrodes_df["channel_idx"].values)
+ si_recording.set_probe(probe=si_probe, in_place=True)
+
+ # Run preprocessing and save results to output folder
+ si_preproc_func = getattr(si_preprocessing, params["SI_PREPROCESSING_METHOD"])
+ si_recording = si_preproc_func(si_recording)
+ si_recording.dump_to_pickle(file_path=recording_file, relative_to=output_dir)
+
+ self.insert1(
+ {
+ **key,
+ "execution_time": execution_time,
+ "execution_duration": (
+ datetime.utcnow() - execution_time
+ ).total_seconds()
+ / 3600,
+ }
+ )
+
+
+@schema
+class SIClustering(dj.Imported):
+ """A processing table to handle each clustering task."""
+
+ definition = """
+ -> PreProcessing
+ ---
+ execution_time: datetime # datetime of the start of this step
+ execution_duration: float # execution duration in hours
+ """
+
+ def make(self, key):
+ execution_time = datetime.utcnow()
+
+ # Load recording object.
+ clustering_method, output_dir, params = (
+ ephys.ClusteringTask * ephys.ClusteringParamSet & key
+ ).fetch1("clustering_method", "clustering_output_dir", "params")
+ output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir)
+ sorter_name = clustering_method.replace(".", "_")
+ recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl"
+ si_recording: si.BaseRecording = si.load_extractor(
+ recording_file, base_folder=output_dir
+ )
+
+ sorting_params = params["SI_SORTING_PARAMS"]
+ sorting_output_dir = output_dir / sorter_name / "spike_sorting"
+
+ # Run sorting
+ @memoized_result(
+ uniqueness_dict=sorting_params,
+ output_directory=sorting_output_dir,
+ )
+ def _run_sorter():
+ # Sorting performed in a dedicated docker environment if the sorter is not built in the spikeinterface package.
+ si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter(
+ sorter_name=sorter_name,
+ recording=si_recording,
+ folder=sorting_output_dir,
+ remove_existing_folder=True,
+ verbose=True,
+ docker_image=sorter_name not in si.sorters.installed_sorters(),
+ **sorting_params,
+ )
+
+ # Save sorting object
+ sorting_save_path = sorting_output_dir / "si_sorting.pkl"
+ si_sorting.dump_to_pickle(sorting_save_path, relative_to=output_dir)
+
+ _run_sorter()
+
+ self.insert1(
+ {
+ **key,
+ "execution_time": execution_time,
+ "execution_duration": (
+ datetime.utcnow() - execution_time
+ ).total_seconds()
+ / 3600,
+ }
+ )
+
+
+@schema
+class PostProcessing(dj.Imported):
+ """A processing table to handle each clustering task."""
+
+ definition = """
+ -> SIClustering
+ ---
+ execution_time: datetime # datetime of the start of this step
+ execution_duration: float # execution duration in hours
+ do_si_export=0: bool # whether to export to phy
+ """
+
+ def make(self, key):
+ execution_time = datetime.utcnow()
+
+ # Load recording & sorting object.
+ clustering_method, output_dir, params = (
+ ephys.ClusteringTask * ephys.ClusteringParamSet & key
+ ).fetch1("clustering_method", "clustering_output_dir", "params")
+ output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir)
+ sorter_name = clustering_method.replace(".", "_")
+
+ recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl"
+ sorting_file = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl"
+
+ si_recording: si.BaseRecording = si.load_extractor(
+ recording_file, base_folder=output_dir
+ )
+ si_sorting: si.sorters.BaseSorter = si.load_extractor(
+ sorting_file, base_folder=output_dir
+ )
+
+ postprocessing_params = params["SI_POSTPROCESSING_PARAMS"]
+
+ job_kwargs = postprocessing_params.get(
+ "job_kwargs", {"n_jobs": -1, "chunk_duration": "1s"}
+ )
+
+ analyzer_output_dir = output_dir / sorter_name / "sorting_analyzer"
+
+ @memoized_result(
+ uniqueness_dict=postprocessing_params,
+ output_directory=analyzer_output_dir,
+ )
+ def _sorting_analyzer_compute():
+ # Sorting Analyzer
+ sorting_analyzer = si.create_sorting_analyzer(
+ sorting=si_sorting,
+ recording=si_recording,
+ format="binary_folder",
+ folder=analyzer_output_dir,
+ sparse=True,
+ overwrite=True,
+ )
+
+ # The order of extension computation is drawn from sorting_analyzer.get_computable_extensions()
+ # each extension is parameterized by params specified in extensions_params dictionary (skip if not specified)
+ extensions_params = postprocessing_params.get("extensions", {})
+ extensions_to_compute = {
+ ext_name: extensions_params[ext_name]
+ for ext_name in sorting_analyzer.get_computable_extensions()
+ if ext_name in extensions_params
+ }
+
+ sorting_analyzer.compute(extensions_to_compute, **job_kwargs)
+
+ _sorting_analyzer_compute()
+
+ self.insert1(
+ {
+ **key,
+ "execution_time": execution_time,
+ "execution_duration": (
+ datetime.utcnow() - execution_time
+ ).total_seconds()
+ / 3600,
+ "do_si_export": postprocessing_params.get("export_to_phy", False)
+ or postprocessing_params.get("export_report", False),
+ }
+ )
+
+ # Once finished, insert this `key` into ephys.Clustering
+ ephys.Clustering.insert1(
+ {**key, "clustering_time": datetime.utcnow()}, allow_direct_insert=True
+ )
+
+
+@schema
+class SIExport(dj.Computed):
+ """A SpikeInterface export report and to Phy"""
+
+ definition = """
+ -> PostProcessing
+ ---
+ execution_time: datetime
+ execution_duration: float
+ """
+
+ @property
+ def key_source(self):
+ return PostProcessing & "do_si_export = 1"
+
+ def make(self, key):
+ execution_time = datetime.utcnow()
+
+ clustering_method, output_dir, params = (
+ ephys.ClusteringTask * ephys.ClusteringParamSet & key
+ ).fetch1("clustering_method", "clustering_output_dir", "params")
+ output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir)
+ sorter_name = clustering_method.replace(".", "_")
+
+ postprocessing_params = params["SI_POSTPROCESSING_PARAMS"]
+
+ job_kwargs = postprocessing_params.get(
+ "job_kwargs", {"n_jobs": -1, "chunk_duration": "1s"}
+ )
+
+ analyzer_output_dir = output_dir / sorter_name / "sorting_analyzer"
+ sorting_analyzer = si.load_sorting_analyzer(folder=analyzer_output_dir)
+
+ @memoized_result(
+ uniqueness_dict=postprocessing_params,
+ output_directory=analyzer_output_dir / "phy",
+ )
+ def _export_to_phy():
+ # Save to phy format
+ si.exporters.export_to_phy(
+ sorting_analyzer=sorting_analyzer,
+ output_folder=analyzer_output_dir / "phy",
+ use_relative_path=True,
+ **job_kwargs,
+ )
+
+ @memoized_result(
+ uniqueness_dict=postprocessing_params,
+ output_directory=analyzer_output_dir / "spikeinterface_report",
+ )
+ def _export_report():
+ # Generate spike interface report
+ si.exporters.export_report(
+ sorting_analyzer=sorting_analyzer,
+ output_folder=analyzer_output_dir / "spikeinterface_report",
+ **job_kwargs,
+ )
+
+ if postprocessing_params.get("export_report", False):
+ _export_report()
+ if postprocessing_params.get("export_to_phy", False):
+ _export_to_phy()
+
+ self.insert1(
+ {
+ **key,
+ "execution_time": execution_time,
+ "execution_duration": (
+ datetime.utcnow() - execution_time
+ ).total_seconds()
+ / 3600,
+ }
+ )
diff --git a/element_array_ephys/version.py b/element_array_ephys/version.py
index 6b5406e8..19ba4c76 100644
--- a/element_array_ephys/version.py
+++ b/element_array_ephys/version.py
@@ -1,3 +1,3 @@
"""Package metadata."""
-__version__ = "0.3.5"
+__version__ = "1.0.0"
diff --git a/env.yml b/env.yml
new file mode 100644
index 00000000..e9b3ce13
--- /dev/null
+++ b/env.yml
@@ -0,0 +1,7 @@
+channels:
+ - conda-forge
+ - defaults
+dependencies:
+ - pip
+ - python>=3.7,<3.11
+name: element_array_ephys
diff --git a/images/attached_array_ephys_element_no_curation.svg b/images/attached_array_ephys_element.svg
similarity index 100%
rename from images/attached_array_ephys_element_no_curation.svg
rename to images/attached_array_ephys_element.svg
diff --git a/images/attached_array_ephys_element_acute.svg b/images/attached_array_ephys_element_acute.svg
deleted file mode 100644
index 5b2bc265..00000000
--- a/images/attached_array_ephys_element_acute.svg
+++ /dev/null
@@ -1,451 +0,0 @@
-
\ No newline at end of file
diff --git a/images/attached_array_ephys_element_chronic.svg b/images/attached_array_ephys_element_chronic.svg
deleted file mode 100644
index 808a2f17..00000000
--- a/images/attached_array_ephys_element_chronic.svg
+++ /dev/null
@@ -1,456 +0,0 @@
-
\ No newline at end of file
diff --git a/images/attached_array_ephys_element_precluster.svg b/images/attached_array_ephys_element_precluster.svg
deleted file mode 100644
index 7d854d2e..00000000
--- a/images/attached_array_ephys_element_precluster.svg
+++ /dev/null
@@ -1,535 +0,0 @@
-
\ No newline at end of file
diff --git a/notebooks/demo_prepare.ipynb b/notebooks/demo_prepare.ipynb
index 74057ba4..85ee1be2 100644
--- a/notebooks/demo_prepare.ipynb
+++ b/notebooks/demo_prepare.ipynb
@@ -213,7 +213,6 @@
"pygments_lexer": "ipython3",
"version": "3.9.17"
},
- "orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
diff --git a/notebooks/demo_run.ipynb b/notebooks/demo_run.ipynb
index 348a3c43..70fbb746 100644
--- a/notebooks/demo_run.ipynb
+++ b/notebooks/demo_run.ipynb
@@ -96,7 +96,6 @@
"pygments_lexer": "ipython3",
"version": "3.9.17"
},
- "orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "ff52d424e56dd643d8b2ec122f40a2e279e94970100b4e6430cb9025a65ba4cf"
diff --git a/setup.py b/setup.py
index 19a4a5ae..38b08c29 100644
--- a/setup.py
+++ b/setup.py
@@ -1,6 +1,6 @@
from os import path
-from setuptools import find_packages, setup
+from setuptools import find_packages, setup
pkg_name = "element_array_ephys"
here = path.abspath(path.dirname(__file__))
@@ -16,6 +16,7 @@
setup(
name=pkg_name.replace("_", "-"),
+ python_requires=">=3.7, <3.11",
version=__version__, # noqa F821
description="Extracellular Array Electrophysiology DataJoint Element",
long_description=long_description,
@@ -34,20 +35,22 @@
"openpyxl",
"plotly",
"seaborn",
- "spikeinterface",
+ "spikeinterface @ git+https://github.com/SpikeInterface/spikeinterface.git",
"scikit-image>=0.20",
"nbformat>=4.2.0",
"pyopenephys>=1.1.6",
+ "element-interface @ git+https://github.com/datajoint/element-interface.git",
+ "numba",
],
extras_require={
"elements": [
"element-animal @ git+https://github.com/datajoint/element-animal.git",
"element-event @ git+https://github.com/datajoint/element-event.git",
- "element-interface @ git+https://github.com/datajoint/element-interface.git",
"element-lab @ git+https://github.com/datajoint/element-lab.git",
"element-session @ git+https://github.com/datajoint/element-session.git",
],
"nwb": ["dandi", "neuroconv[ecephys]", "pynwb"],
"tests": ["pre-commit", "pytest", "pytest-cov"],
+ "spikingcircus": ["hdbscan"],
},
)
diff --git a/tests/tutorial_pipeline.py b/tests/tutorial_pipeline.py
index 74b27ddc..1b27027d 100644
--- a/tests/tutorial_pipeline.py
+++ b/tests/tutorial_pipeline.py
@@ -3,7 +3,7 @@
import datajoint as dj
from element_animal import subject
from element_animal.subject import Subject
-from element_array_ephys import probe, ephys_no_curation as ephys, ephys_report
+from element_array_ephys import probe, ephys, ephys_report
from element_lab import lab
from element_lab.lab import Lab, Location, Project, Protocol, Source, User
from element_lab.lab import Device as Equipment
@@ -62,7 +62,9 @@ def get_session_directory(session_key):
return pathlib.Path(session_directory)
-ephys.activate(db_prefix + "ephys", db_prefix + "probe", linking_module=__name__)
+probe.activate(db_prefix + "probe")
+ephys.activate(db_prefix + "ephys", linking_module=__name__)
+ephys_report.activate(db_prefix + "ephys_report")
probe.create_neuropixels_probe_types()