From ec9fcf3350b8e5809a9968504d6e55ae0ab10fbd Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Wed, 30 Nov 2022 15:46:21 -0600 Subject: [PATCH 001/146] first prototype in separating `Clustering` into multiple steps --- element_array_ephys/spike_sorting/__init__.py | 0 .../spike_sorting/ecephys_spike_sorting.py | 250 ++++++++++++++++++ 2 files changed, 250 insertions(+) create mode 100644 element_array_ephys/spike_sorting/__init__.py create mode 100644 element_array_ephys/spike_sorting/ecephys_spike_sorting.py diff --git a/element_array_ephys/spike_sorting/__init__.py b/element_array_ephys/spike_sorting/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py new file mode 100644 index 00000000..1dc71e7a --- /dev/null +++ b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py @@ -0,0 +1,250 @@ +import datajoint as dj +from element_array_ephys import get_logger +from decimal import Decimal +import json +from datetime import datetime, timedelta + +from element_interface.utils import find_full_path +from element_array_ephys.readers import spikeglx, kilosort, openephys, kilosort_triggering + +log = get_logger(__name__) + +schema = dj.schema() + +ephys = None + + +def activate(schema_name, ephys_schema_name, *, create_schema=True, create_tables=True): + """ + activate(schema_name, *, create_schema=True, create_tables=True, activated_ephys=None) + :param schema_name: schema name on the database server to activate the `spike_sorting` schema + :param ephys_schema_name: schema name of the activated ephys element for which this ephys_report schema will be downstream from + :param create_schema: when True (default), create schema in the database if it does not yet exist. + :param create_tables: when True (default), create tables in the database if they do not yet exist. + (The "activation" of this ephys_report module should be evoked by one of the ephys modules only) + """ + global ephys + ephys = dj.create_virtual_module("ephys", ephys_schema_name) + schema.activate( + schema_name, + create_schema=create_schema, + create_tables=create_tables, + add_objects=ephys.__dict__, + ) + + +@schema +class KilosortPreProcessing(dj.Imported): + """A processing table to handle each clustering task. + """ + definition = """ + -> ephys.ClusteringTask + --- + params: longblob # finalized parameterset for this run + execution_time: datetime # datetime of the start of this step + execution_duration: float # (hour) execution duration + """ + + @property + def key_source(self): + return (ephys.ClusteringTask * ephys.ClusteringParamSet + & {'task_mode': 'trigger'} + & 'clustering_method in ("kilosort2", "kilosort2.5", "kilosort3")') + + def make(self, key): + """Triggers or imports clustering analysis.""" + execution_time = datetime.utcnow() + + task_mode, output_dir = (ephys.ClusteringTask & key).fetch1( + "task_mode", "clustering_output_dir" + ) + + assert task_mode == "trigger", 'Supporting "trigger" task_mode only' + + if not output_dir: + output_dir = ephys.ClusteringTask.infer_output_dir(key, relative=True, mkdir=True) + # update clustering_output_dir + ephys.ClusteringTask.update1( + {**key, "clustering_output_dir": output_dir.as_posix()} + ) + + kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + + acq_software, clustering_method, params = ( + ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key + ).fetch1("acq_software", "clustering_method", "params") + + assert clustering_method in ("kilosort2", "kilosort2.5", "kilosort3"), 'Supporting "kilosort" clustering_method only' + + # add additional probe-recording and channels details into `params` + params = {**params, **ephys.get_recording_channels_details(key)} + params["fs"] = params["sample_rate"] + + if acq_software == "SpikeGLX": + spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) + spikeglx_recording = spikeglx.SpikeGLX( + spikeglx_meta_filepath.parent + ) + spikeglx_recording.validate_file("ap") + + run_kilosort = kilosort_triggering.SGLXKilosortPipeline( + npx_input_dir=spikeglx_meta_filepath.parent, + ks_output_dir=kilosort_dir, + params=params, + KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', + run_CatGT=True, + ) + run_kilosort.run_CatGT() + elif acq_software == "Open Ephys": + oe_probe = ephys.get_openephys_probe_data(key) + + assert len(oe_probe.recording_info["recording_files"]) == 1 + + # run kilosort + run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( + npx_input_dir=oe_probe.recording_info["recording_files"][0], + ks_output_dir=kilosort_dir, + params=params, + KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', + ) + run_kilosort._modules = ['depth_estimation', 'median_subtraction'] + run_kilosort.run_modules() + + self.insert1({**key, + "params": params, + "execution_time": execution_time, + "execution_duration": (datetime.utcnow() - execution_time).total_seconds() / 3600}) + + +@schema +class KilosortClustering(dj.Imported): + """A processing table to handle each clustering task. + """ + definition = """ + -> KilosortPreProcessing + --- + execution_time: datetime # datetime of the start of this step + execution_duration: float # (hour) execution duration + """ + + def make(self, key): + execution_time = datetime.utcnow() + + output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") + kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + + acq_software, clustering_method = ( + ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key + ).fetch1("acq_software", "clustering_method") + assert clustering_method in ("kilosort2", "kilosort2.5", "kilosort3"), 'Supporting "kilosort" clustering_method only' + + params = (KilosortPreProcessing & key).fetch1('params') + + if acq_software == "SpikeGLX": + spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) + spikeglx_recording = spikeglx.SpikeGLX( + spikeglx_meta_filepath.parent + ) + spikeglx_recording.validate_file("ap") + + run_kilosort = kilosort_triggering.SGLXKilosortPipeline( + npx_input_dir=spikeglx_meta_filepath.parent, + ks_output_dir=kilosort_dir, + params=params, + KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', + run_CatGT=True, + ) + run_kilosort._modules = ['kilosort_helper'] + run_kilosort._CatGT_finished = True + run_kilosort.run_modules() + elif acq_software == "Open Ephys": + oe_probe = ephys.get_openephys_probe_data(key) + + assert len(oe_probe.recording_info["recording_files"]) == 1 + + # run kilosort + run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( + npx_input_dir=oe_probe.recording_info["recording_files"][0], + ks_output_dir=kilosort_dir, + params=params, + KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', + ) + run_kilosort._modules = ['kilosort_helper'] + run_kilosort.run_modules() + + self.insert1({**key, + "execution_time": execution_time, + "execution_duration": (datetime.utcnow() - execution_time).total_seconds() / 3600}) + + +@schema +class KilosortPostProcessing(dj.Imported): + """A processing table to handle each clustering task. + """ + definition = """ + -> KilosortClustering + --- + modules_status: longblob # dictionary of summary status for all modules + execution_time: datetime # datetime of the start of this step + execution_duration: float # (hour) execution duration + """ + + def make(self, key): + execution_time = datetime.utcnow() + + output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") + kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + + acq_software, clustering_method = ( + ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key + ).fetch1("acq_software", "clustering_method") + assert clustering_method in ( + "kilosort2", "kilosort2.5", "kilosort3"), 'Supporting "kilosort" clustering_method only' + + params = (KilosortPreProcessing & key).fetch1('params') + + if acq_software == "SpikeGLX": + spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) + spikeglx_recording = spikeglx.SpikeGLX( + spikeglx_meta_filepath.parent + ) + spikeglx_recording.validate_file("ap") + + run_kilosort = kilosort_triggering.SGLXKilosortPipeline( + npx_input_dir=spikeglx_meta_filepath.parent, + ks_output_dir=kilosort_dir, + params=params, + KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', + run_CatGT=True, + ) + run_kilosort._modules = ['kilosort_postprocessing', + 'noise_templates', + 'mean_waveforms', + 'quality_metrics'] + run_kilosort._CatGT_finished = True + run_kilosort.run_modules() + elif acq_software == "Open Ephys": + oe_probe = ephys.get_openephys_probe_data(key) + + assert len(oe_probe.recording_info["recording_files"]) == 1 + + # run kilosort + run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( + npx_input_dir=oe_probe.recording_info["recording_files"][0], + ks_output_dir=kilosort_dir, + params=params, + KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', + ) + run_kilosort._modules = ['kilosort_postprocessing', + 'noise_templates', + 'mean_waveforms', + 'quality_metrics'] + run_kilosort.run_modules() + + with open(self._modules_input_hash_fp) as f: + modules_status = json.load(f) + + self.insert1({**key, + "modules_status": modules_status, + "execution_time": execution_time, + "execution_duration": (datetime.utcnow() - execution_time).total_seconds() / 3600}) From f5724384952f801086e751d3645437bea2694604 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Thu, 5 Jan 2023 14:06:00 -0600 Subject: [PATCH 002/146] Update ecephys_spike_sorting.py --- .../spike_sorting/ecephys_spike_sorting.py | 145 ++++++++++++------ 1 file changed, 94 insertions(+), 51 deletions(-) diff --git a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py index 1dc71e7a..1592e65e 100644 --- a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py +++ b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py @@ -5,7 +5,12 @@ from datetime import datetime, timedelta from element_interface.utils import find_full_path -from element_array_ephys.readers import spikeglx, kilosort, openephys, kilosort_triggering +from element_array_ephys.readers import ( + spikeglx, + kilosort, + openephys, + kilosort_triggering, +) log = get_logger(__name__) @@ -35,8 +40,8 @@ def activate(schema_name, ephys_schema_name, *, create_schema=True, create_table @schema class KilosortPreProcessing(dj.Imported): - """A processing table to handle each clustering task. - """ + """A processing table to handle each clustering task.""" + definition = """ -> ephys.ClusteringTask --- @@ -47,9 +52,11 @@ class KilosortPreProcessing(dj.Imported): @property def key_source(self): - return (ephys.ClusteringTask * ephys.ClusteringParamSet - & {'task_mode': 'trigger'} - & 'clustering_method in ("kilosort2", "kilosort2.5", "kilosort3")') + return ( + ephys.ClusteringTask * ephys.ClusteringParamSet + & {"task_mode": "trigger"} + & 'clustering_method in ("kilosort2", "kilosort2.5", "kilosort3")' + ) def make(self, key): """Triggers or imports clustering analysis.""" @@ -62,7 +69,9 @@ def make(self, key): assert task_mode == "trigger", 'Supporting "trigger" task_mode only' if not output_dir: - output_dir = ephys.ClusteringTask.infer_output_dir(key, relative=True, mkdir=True) + output_dir = ephys.ClusteringTask.infer_output_dir( + key, relative=True, mkdir=True + ) # update clustering_output_dir ephys.ClusteringTask.update1( {**key, "clustering_output_dir": output_dir.as_posix()} @@ -71,10 +80,14 @@ def make(self, key): kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) acq_software, clustering_method, params = ( - ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key + ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key ).fetch1("acq_software", "clustering_method", "params") - assert clustering_method in ("kilosort2", "kilosort2.5", "kilosort3"), 'Supporting "kilosort" clustering_method only' + assert clustering_method in ( + "kilosort2", + "kilosort2.5", + "kilosort3", + ), 'Supporting "kilosort" clustering_method only' # add additional probe-recording and channels details into `params` params = {**params, **ephys.get_recording_channels_details(key)} @@ -82,17 +95,19 @@ def make(self, key): if acq_software == "SpikeGLX": spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX( - spikeglx_meta_filepath.parent - ) + spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) spikeglx_recording.validate_file("ap") + run_CatGT = ( + params.pop("run_CatGT", True) + and "_tcat." not in spikeglx_meta_filepath.stem + ) run_kilosort = kilosort_triggering.SGLXKilosortPipeline( npx_input_dir=spikeglx_meta_filepath.parent, ks_output_dir=kilosort_dir, params=params, KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - run_CatGT=True, + run_CatGT=run_CatGT, ) run_kilosort.run_CatGT() elif acq_software == "Open Ephys": @@ -107,19 +122,26 @@ def make(self, key): params=params, KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', ) - run_kilosort._modules = ['depth_estimation', 'median_subtraction'] + run_kilosort._modules = ["depth_estimation", "median_subtraction"] run_kilosort.run_modules() - self.insert1({**key, - "params": params, - "execution_time": execution_time, - "execution_duration": (datetime.utcnow() - execution_time).total_seconds() / 3600}) + self.insert1( + { + **key, + "params": params, + "execution_time": execution_time, + "execution_duration": ( + datetime.utcnow() - execution_time + ).total_seconds() + / 3600, + } + ) @schema class KilosortClustering(dj.Imported): - """A processing table to handle each clustering task. - """ + """A processing table to handle each clustering task.""" + definition = """ -> KilosortPreProcessing --- @@ -134,17 +156,19 @@ def make(self, key): kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) acq_software, clustering_method = ( - ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key + ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key ).fetch1("acq_software", "clustering_method") - assert clustering_method in ("kilosort2", "kilosort2.5", "kilosort3"), 'Supporting "kilosort" clustering_method only' + assert clustering_method in ( + "kilosort2", + "kilosort2.5", + "kilosort3", + ), 'Supporting "kilosort" clustering_method only' - params = (KilosortPreProcessing & key).fetch1('params') + params = (KilosortPreProcessing & key).fetch1("params") if acq_software == "SpikeGLX": spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX( - spikeglx_meta_filepath.parent - ) + spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) spikeglx_recording.validate_file("ap") run_kilosort = kilosort_triggering.SGLXKilosortPipeline( @@ -154,7 +178,7 @@ def make(self, key): KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', run_CatGT=True, ) - run_kilosort._modules = ['kilosort_helper'] + run_kilosort._modules = ["kilosort_helper"] run_kilosort._CatGT_finished = True run_kilosort.run_modules() elif acq_software == "Open Ephys": @@ -169,18 +193,25 @@ def make(self, key): params=params, KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', ) - run_kilosort._modules = ['kilosort_helper'] + run_kilosort._modules = ["kilosort_helper"] run_kilosort.run_modules() - self.insert1({**key, - "execution_time": execution_time, - "execution_duration": (datetime.utcnow() - execution_time).total_seconds() / 3600}) + self.insert1( + { + **key, + "execution_time": execution_time, + "execution_duration": ( + datetime.utcnow() - execution_time + ).total_seconds() + / 3600, + } + ) @schema class KilosortPostProcessing(dj.Imported): - """A processing table to handle each clustering task. - """ + """A processing table to handle each clustering task.""" + definition = """ -> KilosortClustering --- @@ -196,18 +227,19 @@ def make(self, key): kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) acq_software, clustering_method = ( - ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key + ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key ).fetch1("acq_software", "clustering_method") assert clustering_method in ( - "kilosort2", "kilosort2.5", "kilosort3"), 'Supporting "kilosort" clustering_method only' + "kilosort2", + "kilosort2.5", + "kilosort3", + ), 'Supporting "kilosort" clustering_method only' - params = (KilosortPreProcessing & key).fetch1('params') + params = (KilosortPreProcessing & key).fetch1("params") if acq_software == "SpikeGLX": spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX( - spikeglx_meta_filepath.parent - ) + spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) spikeglx_recording.validate_file("ap") run_kilosort = kilosort_triggering.SGLXKilosortPipeline( @@ -217,10 +249,12 @@ def make(self, key): KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', run_CatGT=True, ) - run_kilosort._modules = ['kilosort_postprocessing', - 'noise_templates', - 'mean_waveforms', - 'quality_metrics'] + run_kilosort._modules = [ + "kilosort_postprocessing", + "noise_templates", + "mean_waveforms", + "quality_metrics", + ] run_kilosort._CatGT_finished = True run_kilosort.run_modules() elif acq_software == "Open Ephys": @@ -235,16 +269,25 @@ def make(self, key): params=params, KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', ) - run_kilosort._modules = ['kilosort_postprocessing', - 'noise_templates', - 'mean_waveforms', - 'quality_metrics'] + run_kilosort._modules = [ + "kilosort_postprocessing", + "noise_templates", + "mean_waveforms", + "quality_metrics", + ] run_kilosort.run_modules() with open(self._modules_input_hash_fp) as f: modules_status = json.load(f) - self.insert1({**key, - "modules_status": modules_status, - "execution_time": execution_time, - "execution_duration": (datetime.utcnow() - execution_time).total_seconds() / 3600}) + self.insert1( + { + **key, + "modules_status": modules_status, + "execution_time": execution_time, + "execution_duration": ( + datetime.utcnow() - execution_time + ).total_seconds() + / 3600, + } + ) From edf1578b45425410c6cb53e5777866afa5f04f98 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Thu, 5 Jan 2023 14:07:53 -0600 Subject: [PATCH 003/146] Update ecephys_spike_sorting.py --- element_array_ephys/spike_sorting/ecephys_spike_sorting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py index 1592e65e..eb02c251 100644 --- a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py +++ b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py @@ -98,7 +98,7 @@ def make(self, key): spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) spikeglx_recording.validate_file("ap") run_CatGT = ( - params.pop("run_CatGT", True) + params.get("run_CatGT", True) and "_tcat." not in spikeglx_meta_filepath.stem ) From 7e267c57571c3396ecdc60d159db0245326e4047 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Thu, 5 Jan 2023 17:15:51 -0600 Subject: [PATCH 004/146] fix typo --- element_array_ephys/ephys_no_curation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 943d3354..9414c49e 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -799,7 +799,7 @@ class Clustering(dj.Imported): Attributes: ClusteringTask (foreign key): ClusteringTask primary key. clustering_time (datetime): Time when clustering results are generated. - package_version (varchar(16) ): Package version used for a clustering analysis. + package_version (varchar(16): Package version used for a clustering analysis. """ definition = """ From 6e7ddf15966d51474455efdd758e99e93850b6b7 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 6 Jan 2023 10:23:28 -0600 Subject: [PATCH 005/146] Update ecephys_spike_sorting.py --- .../spike_sorting/ecephys_spike_sorting.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py index eb02c251..ed6d699e 100644 --- a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py +++ b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py @@ -19,17 +19,23 @@ ephys = None -def activate(schema_name, ephys_schema_name, *, create_schema=True, create_tables=True): +def activate( + schema_name, + *, + ephys_module, + create_schema=True, + create_tables=True, +): """ activate(schema_name, *, create_schema=True, create_tables=True, activated_ephys=None) :param schema_name: schema name on the database server to activate the `spike_sorting` schema - :param ephys_schema_name: schema name of the activated ephys element for which this ephys_report schema will be downstream from + :param ephys_module: the activated ephys element for which this ephys_report schema will be downstream from :param create_schema: when True (default), create schema in the database if it does not yet exist. :param create_tables: when True (default), create tables in the database if they do not yet exist. (The "activation" of this ephys_report module should be evoked by one of the ephys modules only) """ global ephys - ephys = dj.create_virtual_module("ephys", ephys_schema_name) + ephys = ephys_module schema.activate( schema_name, create_schema=create_schema, From 7fd9bb4368208e12a809b49e402190e3d34eaa07 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 6 Jan 2023 11:15:43 -0600 Subject: [PATCH 006/146] improve log messages --- element_array_ephys/readers/kilosort_triggering.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/element_array_ephys/readers/kilosort_triggering.py b/element_array_ephys/readers/kilosort_triggering.py index 8e4b80ff..5d76c7af 100644 --- a/element_array_ephys/readers/kilosort_triggering.py +++ b/element_array_ephys/readers/kilosort_triggering.py @@ -21,13 +21,13 @@ get_noise_channels, ) except Exception as e: - print(f'Error in loading "ecephys_spike_sorting" package - {str(e)}') + print(f'Warning: Failed loading "ecephys_spike_sorting" package - {str(e)}') # import pykilosort package try: import pykilosort except Exception as e: - print(f'Error in loading "pykilosort" package - {str(e)}') + print(f'Warning: Failed loading "pykilosort" package - {str(e)}') class SGLXKilosortPipeline: From 26a56e72f76afa24833f275c01df0606e9d24ec2 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 6 Jan 2023 11:58:01 -0600 Subject: [PATCH 007/146] fix key_source --- element_array_ephys/spike_sorting/ecephys_spike_sorting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py index ed6d699e..d7f7865a 100644 --- a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py +++ b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py @@ -62,7 +62,7 @@ def key_source(self): ephys.ClusteringTask * ephys.ClusteringParamSet & {"task_mode": "trigger"} & 'clustering_method in ("kilosort2", "kilosort2.5", "kilosort3")' - ) + ) - ephys.Clustering def make(self, key): """Triggers or imports clustering analysis.""" From 654bc522ce8bba576b2793924f39d1a97416bb0f Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Wed, 18 Jan 2023 11:58:13 -0600 Subject: [PATCH 008/146] Update ecephys_spike_sorting.py --- element_array_ephys/spike_sorting/ecephys_spike_sorting.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py index d7f7865a..0cf4bea8 100644 --- a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py +++ b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py @@ -297,3 +297,6 @@ def make(self, key): / 3600, } ) + + # all finished, insert this `key` into ephys.Clustering + ephys.Clustering.insert1({**key, "clustering_time": datetime.utcnow()}) From f75e14f13a2f413b2c34be401652ade70ce11a94 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Thu, 19 Jan 2023 15:43:15 -0600 Subject: [PATCH 009/146] bugfix --- element_array_ephys/spike_sorting/ecephys_spike_sorting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py index 0cf4bea8..3eca46b9 100644 --- a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py +++ b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py @@ -283,7 +283,7 @@ def make(self, key): ] run_kilosort.run_modules() - with open(self._modules_input_hash_fp) as f: + with open(run_kilosort._modules_input_hash_fp) as f: modules_status = json.load(f) self.insert1( From a32d1d25b895bc1c05f4149f920b206b07646aae Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Thu, 19 Jan 2023 16:23:37 -0600 Subject: [PATCH 010/146] bugfix --- element_array_ephys/spike_sorting/ecephys_spike_sorting.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py index 3eca46b9..d33a3752 100644 --- a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py +++ b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py @@ -299,4 +299,6 @@ def make(self, key): ) # all finished, insert this `key` into ephys.Clustering - ephys.Clustering.insert1({**key, "clustering_time": datetime.utcnow()}) + ephys.Clustering.insert1( + {**key, "clustering_time": datetime.utcnow()}, allow_direct_insert=True + ) From 3bea7755245905dc59271a248514caad004c7f10 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 20 Jan 2023 16:50:35 -0600 Subject: [PATCH 011/146] Update kilosort_triggering.py --- element_array_ephys/readers/kilosort_triggering.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/element_array_ephys/readers/kilosort_triggering.py b/element_array_ephys/readers/kilosort_triggering.py index 7f30cac4..4e831d1b 100644 --- a/element_array_ephys/readers/kilosort_triggering.py +++ b/element_array_ephys/readers/kilosort_triggering.py @@ -777,8 +777,7 @@ def _write_channel_map_file( # channels to exclude mask = get_noise_channels(ap_band_file, channel_count, sample_rate, bit_volts) - bad_channel_ind = np.where(mask is False)[0] - connected[bad_channel_ind] = 0 + connected = np.where(mask is False, 0, connected) mdict = { "chanMap": chanMap, From 4f955b32e8049c2961e9c916029b2d27f008476a Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Mon, 23 Jan 2023 10:53:46 -0600 Subject: [PATCH 012/146] fix docstring --- element_array_ephys/spike_sorting/ecephys_spike_sorting.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py index d33a3752..cec8f7ac 100644 --- a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py +++ b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py @@ -29,10 +29,9 @@ def activate( """ activate(schema_name, *, create_schema=True, create_tables=True, activated_ephys=None) :param schema_name: schema name on the database server to activate the `spike_sorting` schema - :param ephys_module: the activated ephys element for which this ephys_report schema will be downstream from + :param ephys_module: the activated ephys element for which this `spike_sorting` schema will be downstream from :param create_schema: when True (default), create schema in the database if it does not yet exist. :param create_tables: when True (default), create tables in the database if they do not yet exist. - (The "activation" of this ephys_report module should be evoked by one of the ephys modules only) """ global ephys ephys = ephys_module From 53854d0a986d564c8e75295f83da4dbd5446d92b Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Mon, 23 Jan 2023 11:05:47 -0600 Subject: [PATCH 013/146] added description --- .../spike_sorting/ecephys_spike_sorting.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py index cec8f7ac..d779d0c0 100644 --- a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py +++ b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py @@ -1,3 +1,26 @@ +""" +The following DataJoint pipeline implements the sequence of steps in the spike-sorting routine featured in the +"ecephys_spike_sorting" pipeline. +The "ecephys_spike_sorting" was originally developed by the Allen Institute (https://github.com/AllenInstitute/ecephys_spike_sorting) for Neuropixels data acquired with Open Ephys acquisition system. +Then forked by Jennifer Colonell from the Janelia Research Campus (https://github.com/jenniferColonell/ecephys_spike_sorting) to support SpikeGLX acquisition system. + +At DataJoint, we fork from Jennifer's fork and implemented a version that supports both Open Ephys and Spike GLX. +https://github.com/datajoint-company/ecephys_spike_sorting + +The follow pipeline features three tables: +1. KilosortPreProcessing - for preprocessing steps (no GPU required) + - median_subtraction for Open Ephys + - or the CatGT step for SpikeGLX +2. KilosortClustering - kilosort (MATLAB) - requires GPU + - supports kilosort 2.0, 2.5 or 3.0 (https://github.com/MouseLand/Kilosort.git) +3. KilosortPostProcessing - for postprocessing steps (no GPU required) + - kilosort_postprocessing + - noise_templates + - mean_waveforms + - quality_metrics +""" + + import datajoint as dj from element_array_ephys import get_logger from decimal import Decimal From fe9955ca4269b86958541440f1587bdecc5b0c1b Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Wed, 25 Jan 2023 15:33:08 -0600 Subject: [PATCH 014/146] refactor `_supported_kilosort_versions` --- .../spike_sorting/ecephys_spike_sorting.py | 24 +++++++------------ 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py index d779d0c0..fca4e452 100644 --- a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py +++ b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py @@ -41,6 +41,12 @@ ephys = None +_supported_kilosort_versions = [ + "kilosort2", + "kilosort2.5", + "kilosort3", +] + def activate( schema_name, @@ -111,11 +117,9 @@ def make(self, key): ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key ).fetch1("acq_software", "clustering_method", "params") - assert clustering_method in ( - "kilosort2", - "kilosort2.5", - "kilosort3", - ), 'Supporting "kilosort" clustering_method only' + assert ( + clustering_method in _supported_kilosort_versions + ), f'Clustering_method "{clustering_method}" is not supported' # add additional probe-recording and channels details into `params` params = {**params, **ephys.get_recording_channels_details(key)} @@ -186,11 +190,6 @@ def make(self, key): acq_software, clustering_method = ( ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key ).fetch1("acq_software", "clustering_method") - assert clustering_method in ( - "kilosort2", - "kilosort2.5", - "kilosort3", - ), 'Supporting "kilosort" clustering_method only' params = (KilosortPreProcessing & key).fetch1("params") @@ -257,11 +256,6 @@ def make(self, key): acq_software, clustering_method = ( ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key ).fetch1("acq_software", "clustering_method") - assert clustering_method in ( - "kilosort2", - "kilosort2.5", - "kilosort3", - ), 'Supporting "kilosort" clustering_method only' params = (KilosortPreProcessing & key).fetch1("params") From aea325d9bb6a975fd4e2c382f313b209a5be0017 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Wed, 25 Jan 2023 17:27:00 -0600 Subject: [PATCH 015/146] remove unused imports --- element_array_ephys/spike_sorting/ecephys_spike_sorting.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py index fca4e452..4de349eb 100644 --- a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py +++ b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py @@ -30,8 +30,6 @@ from element_interface.utils import find_full_path from element_array_ephys.readers import ( spikeglx, - kilosort, - openephys, kilosort_triggering, ) From 4f648cc8054a6e237b971ab6c8cb4b88c6b0c568 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Wed, 1 Feb 2023 18:37:39 -0600 Subject: [PATCH 016/146] add new file for spike interface modularized clustering approach --- .../spike_sorting/si_clustering.py | 534 ++++++++++++++++++ 1 file changed, 534 insertions(+) create mode 100644 element_array_ephys/spike_sorting/si_clustering.py diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py new file mode 100644 index 00000000..32384d01 --- /dev/null +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -0,0 +1,534 @@ +""" +The following DataJoint pipeline implements the sequence of steps in the spike-sorting routine featured in the +"ecephys_spike_sorting" pipeline. +The "ecephys_spike_sorting" was originally developed by the Allen Institute (https://github.com/AllenInstitute/ecephys_spike_sorting) for Neuropixels data acquired with Open Ephys acquisition system. +Then forked by Jennifer Colonell from the Janelia Research Campus (https://github.com/jenniferColonell/ecephys_spike_sorting) to support SpikeGLX acquisition system. + +At DataJoint, we fork from Jennifer's fork and implemented a version that supports both Open Ephys and Spike GLX. +https://github.com/datajoint-company/ecephys_spike_sorting + +The follow pipeline features intermediary tables: +1. KilosortPreProcessing - for preprocessing steps (no GPU required) + - median_subtraction for Open Ephys + - or the CatGT step for SpikeGLX +2. KilosortClustering - kilosort (MATLAB) - requires GPU + - supports kilosort 2.0, 2.5 or 3.0 (https://github.com/MouseLand/Kilosort.git) +3. KilosortPostProcessing - for postprocessing steps (no GPU required) + - kilosort_postprocessing + - noise_templates + - mean_waveforms + - quality_metrics + + +""" +import datajoint as dj +import os +from element_array_ephys import get_logger +from decimal import Decimal +import json +import numpy as np +from datetime import datetime, timedelta + +from element_interface.utils import find_full_path +from element_array_ephys.readers import ( + spikeglx, + kilosort_triggering, +) +import element_array_ephys.ephys_no_curation as ephys +import element_array_ephys.probe as probe +# from element_array_ephys.ephys_no_curation import ( +# get_ephys_root_data_dir, +# get_session_directory, +# get_openephys_filepath, +# get_spikeglx_meta_filepath, +# get_recording_channels_details, +# ) +import spikeinterface as si +import spikeinterface.extractors as se +import spikeinterface.sorters as ss +import spikeinterface.comparison as sc +import spikeinterface.widgets as sw +import spikeinterface.preprocessing as sip +import probeinterface as pi + +log = get_logger(__name__) + +schema = dj.schema() + +ephys = None + +_supported_kilosort_versions = [ + "kilosort2", + "kilosort2.5", + "kilosort3", +] + + +def activate( + schema_name, + *, + ephys_module, + create_schema=True, + create_tables=True, +): + """ + activate(schema_name, *, create_schema=True, create_tables=True, activated_ephys=None) + :param schema_name: schema name on the database server to activate the `spike_sorting` schema + :param ephys_module: the activated ephys element for which this `spike_sorting` schema will be downstream from + :param create_schema: when True (default), create schema in the database if it does not yet exist. + :param create_tables: when True (default), create tables in the database if they do not yet exist. + """ + global ephys + ephys = ephys_module + schema.activate( + schema_name, + create_schema=create_schema, + create_tables=create_tables, + add_objects=ephys.__dict__, + ) + +@schema +class SI_preprocessing(dj.Imported): + """A table to handle preprocessing of each clustering task.""" + + definition = """ + -> ephys.ClusteringTask + --- + params: longblob # finalized parameterset for this run + execution_time: datetime # datetime of the start of this step + execution_duration: float # (hour) execution duration + """ + + @property + def key_source(self): + return ( + ephys.ClusteringTask * ephys.ClusteringParamSet + & {"task_mode": "trigger"} + & 'clustering_method in ("kilosort2", "kilosort2.5", "kilosort3")' + ) - ephys.Clustering + def make(self, key): + """Triggers or imports clustering analysis.""" + execution_time = datetime.utcnow() + + task_mode, output_dir = (ephys.ClusteringTask & key).fetch1( + "task_mode", "clustering_output_dir" + ) + + assert task_mode == "trigger", 'Supporting "trigger" task_mode only' + + if not output_dir: + output_dir = ephys.ClusteringTask.infer_output_dir( + key, relative=True, mkdir=True + ) + # update clustering_output_dir + ephys.ClusteringTask.update1( + {**key, "clustering_output_dir": output_dir.as_posix()} + ) + + kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + + acq_software, clustering_method, params = ( + ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key + ).fetch1("acq_software", "clustering_method", "params") + + assert ( + clustering_method in _supported_kilosort_versions + ), f'Clustering_method "{clustering_method}" is not supported' + + # add additional probe-recording and channels details into `params` + params = {**params, **ephys.get_recording_channels_details(key)} + params["fs"] = params["sample_rate"] + + if acq_software == "SpikeGLX": + sglx_full_path = find_full_path(ephys.get_ephys_root_data_dir(),ephys.get_session_directory(key)) + sglx_filepath = ephys.get_spikeglx_meta_filepath(key) + stream_name = os.path.split(sglx_filepath)[1] + + assert len(oe_probe.recording_info["recording_files"]) == 1 + + # Create SI recording extractor object + # sglx_si_recording = se.SpikeGLXRecordingExtractor(folder_path=sglx_full_path, stream_name=stream_name) + sglx_si_recording = se.read_spikeglx(folder_path=sglx_full_path, stream_name=stream_name) + electrode_query = (probe.ProbeType.Electrode + * probe.ElectrodeConfig.Electrode + * ephys.EphysRecording & key) + + xy_coords = [list(i) for i in zip(electrode_query.fetch('x_coord'),electrode_query.fetch('y_coord'))] + channels_details = ephys.get_recording_channels_details(key) + + # Create SI probe object + probe = pi.Probe(ndim=2, si_units='um') + probe.set_contacts(positions=xy_coords, shapes='square', shape_params={'width': 5}) + probe.create_auto_shape(probe_type='tip') + channel_indices = np.arange(channels_details['num_channels']) + probe.set_device_channel_indices(channel_indices) + oe_si_recording.set_probe(probe=probe) + + # run preprocessing and save results to output folder + sglx_si_recording_filtered = sip.bandpass_filter(sglx_si_recording, freq_min=300, freq_max=6000) + sglx_recording_cmr = sip.common_reference(sglx_si_recording_filtered, reference="global", operator="median") + sglx_recording_cmr.save_to_folder('sglx_recording_cmr', kilosort_dir) + + + elif acq_software == "Open Ephys": + oe_probe = ephys.get_openephys_probe_data(key) + oe_full_path = find_full_path(get_ephys_root_data_dir(),get_session_directory(key)) + oe_filepath = get_openephys_filepath(key) + stream_name = os.path.split(oe_filepath)[1] + + assert len(oe_probe.recording_info["recording_files"]) == 1 + + # Create SI recording extractor object + # oe_si_recording = se.OpenEphysBinaryRecordingExtractor(folder_path=oe_full_path, stream_name=stream_name) + oe_si_recording = se.read_openephys(folder_path=oe_full_path, stream_name=stream_name) + electrode_query = (probe.ProbeType.Electrode + * probe.ElectrodeConfig.Electrode + * ephys.EphysRecording & key) + + xy_coords = [list(i) for i in zip(electrode_query.fetch('x_coord'),electrode_query.fetch('y_coord'))] + channels_details = get_recording_channels_details(key) + + # Create SI probe object + probe = pi.Probe(ndim=2, si_units='um') + probe.set_contacts(positions=xy_coords, shapes='square', shape_params={'width': 5}) + probe.create_auto_shape(probe_type='tip') + channel_indices = np.arange(channels_details['num_channels']) + probe.set_device_channel_indices(channel_indices) + oe_si_recording.set_probe(probe=probe) + + # run preprocessing and save results to output folder + oe_si_recording_filtered = sip.bandpass_filter(oe_si_recording, freq_min=300, freq_max=6000) + oe_recording_cmr = sip.common_reference(oe_si_recording_filtered, reference="global", operator="median") + oe_recording_cmr.save_to_folder('oe_recording_cmr', kilosort_dir) + + self.insert1( + { + **key, + "params": params, + "execution_time": execution_time, + "execution_duration": ( + datetime.utcnow() - execution_time + ).total_seconds() + / 3600, + } + ) +@schema +class SI_KilosortClustering(dj.Imported): + """A processing table to handle each clustering task.""" + + definition = """ + -> KilosortPreProcessing + --- + execution_time: datetime # datetime of the start of this step + execution_duration: float # (hour) execution duration + """ + + def make(self, key): + execution_time = datetime.utcnow() + + output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") + kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + + acq_software, clustering_method = ( + ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key + ).fetch1("acq_software", "clustering_method") + + params = (KilosortPreProcessing & key).fetch1("params") + + if acq_software == "SpikeGLX": + sglx_probe = ephys.get_openephys_probe_data(key) + oe_si_recording = se.load_from_folder + assert len(oe_probe.recording_info["recording_files"]) == 1 + if clustering_method.startswith('kilosort2.5'): + sorter_name = "kilosort2_5" + else: + sorter_name = clustering_method + sorting_kilosort = si.run_sorter( + sorter_name = sorter_name, + recording = oe_si_recording, + output_folder = kilosort_dir, + docker_image = f"spikeinterface/{sorter_name}-compiled-base:latest", + **params + ) + sorting_kilosort.save_to_folder('sorting_kilosort', kilosort_dir) + elif acq_software == "Open Ephys": + oe_probe = ephys.get_openephys_probe_data(key) + oe_si_recording = se.load_from_folder + assert len(oe_probe.recording_info["recording_files"]) == 1 + if clustering_method.startswith('kilosort2.5'): + sorter_name = "kilosort2_5" + else: + sorter_name = clustering_method + sorting_kilosort = si.run_sorter( + sorter_name = sorter_name, + recording = oe_si_recording, + output_folder = kilosort_dir, + docker_image = f"spikeinterface/{sorter_name}-compiled-base:latest", + **params + ) + sorting_kilosort.save_to_folder('sorting_kilosort', kilosort_dir) + + self.insert1( + { + **key, + "execution_time": execution_time, + "execution_duration": ( + datetime.utcnow() - execution_time + ).total_seconds() + / 3600, + } + ) + + + + + +@schema +class KilosortPreProcessing(dj.Imported): + """A processing table to handle each clustering task.""" + + definition = """ + -> ephys.ClusteringTask + --- + params: longblob # finalized parameterset for this run + execution_time: datetime # datetime of the start of this step + execution_duration: float # (hour) execution duration + """ + + @property + def key_source(self): + return ( + ephys.ClusteringTask * ephys.ClusteringParamSet + & {"task_mode": "trigger"} + & 'clustering_method in ("kilosort2", "kilosort2.5", "kilosort3")' + ) - ephys.Clustering + + def make(self, key): + """Triggers or imports clustering analysis.""" + execution_time = datetime.utcnow() + + task_mode, output_dir = (ephys.ClusteringTask & key).fetch1( + "task_mode", "clustering_output_dir" + ) + + assert task_mode == "trigger", 'Supporting "trigger" task_mode only' + + if not output_dir: + output_dir = ephys.ClusteringTask.infer_output_dir( + key, relative=True, mkdir=True + ) + # update clustering_output_dir + ephys.ClusteringTask.update1( + {**key, "clustering_output_dir": output_dir.as_posix()} + ) + + kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + + acq_software, clustering_method, params = ( + ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key + ).fetch1("acq_software", "clustering_method", "params") + + assert ( + clustering_method in _supported_kilosort_versions + ), f'Clustering_method "{clustering_method}" is not supported' + + # add additional probe-recording and channels details into `params` + params = {**params, **ephys.get_recording_channels_details(key)} + params["fs"] = params["sample_rate"] + + + + + if acq_software == "SpikeGLX": + spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) + spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) + spikeglx_recording.validate_file("ap") + run_CatGT = ( + params.get("run_CatGT", True) + and "_tcat." not in spikeglx_meta_filepath.stem + ) + + run_kilosort = kilosort_triggering.SGLXKilosortPipeline( + npx_input_dir=spikeglx_meta_filepath.parent, + ks_output_dir=kilosort_dir, + params=params, + KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', + run_CatGT=run_CatGT, + ) + run_kilosort.run_CatGT() + elif acq_software == "Open Ephys": + oe_probe = ephys.get_openephys_probe_data(key) + + assert len(oe_probe.recording_info["recording_files"]) == 1 + + # run kilosort + run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( + npx_input_dir=oe_probe.recording_info["recording_files"][0], + ks_output_dir=kilosort_dir, + params=params, + KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', + ) + run_kilosort._modules = ["depth_estimation", "median_subtraction"] + run_kilosort.run_modules() + + self.insert1( + { + **key, + "params": params, + "execution_time": execution_time, + "execution_duration": ( + datetime.utcnow() - execution_time + ).total_seconds() + / 3600, + } + ) + + +@schema +class KilosortClustering(dj.Imported): + """A processing table to handle each clustering task.""" + + definition = """ + -> KilosortPreProcessing + --- + execution_time: datetime # datetime of the start of this step + execution_duration: float # (hour) execution duration + """ + + def make(self, key): + execution_time = datetime.utcnow() + + output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") + kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + + acq_software, clustering_method = ( + ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key + ).fetch1("acq_software", "clustering_method") + + params = (KilosortPreProcessing & key).fetch1("params") + + if acq_software == "SpikeGLX": + spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) + spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) + spikeglx_recording.validate_file("ap") + + run_kilosort = kilosort_triggering.SGLXKilosortPipeline( + npx_input_dir=spikeglx_meta_filepath.parent, + ks_output_dir=kilosort_dir, + params=params, + KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', + run_CatGT=True, + ) + run_kilosort._modules = ["kilosort_helper"] + run_kilosort._CatGT_finished = True + run_kilosort.run_modules() + elif acq_software == "Open Ephys": + oe_probe = ephys.get_openephys_probe_data(key) + + assert len(oe_probe.recording_info["recording_files"]) == 1 + + # run kilosort + run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( + npx_input_dir=oe_probe.recording_info["recording_files"][0], + ks_output_dir=kilosort_dir, + params=params, + KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', + ) + run_kilosort._modules = ["kilosort_helper"] + run_kilosort.run_modules() + + self.insert1( + { + **key, + "execution_time": execution_time, + "execution_duration": ( + datetime.utcnow() - execution_time + ).total_seconds() + / 3600, + } + ) + + +@schema +class KilosortPostProcessing(dj.Imported): + """A processing table to handle each clustering task.""" + + definition = """ + -> KilosortClustering + --- + modules_status: longblob # dictionary of summary status for all modules + execution_time: datetime # datetime of the start of this step + execution_duration: float # (hour) execution duration + """ + + def make(self, key): + execution_time = datetime.utcnow() + + output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") + kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + + acq_software, clustering_method = ( + ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key + ).fetch1("acq_software", "clustering_method") + + params = (KilosortPreProcessing & key).fetch1("params") + + if acq_software == "SpikeGLX": + spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) + spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) + spikeglx_recording.validate_file("ap") + + run_kilosort = kilosort_triggering.SGLXKilosortPipeline( + npx_input_dir=spikeglx_meta_filepath.parent, + ks_output_dir=kilosort_dir, + params=params, + KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', + run_CatGT=True, + ) + run_kilosort._modules = [ + "kilosort_postprocessing", + "noise_templates", + "mean_waveforms", + "quality_metrics", + ] + run_kilosort._CatGT_finished = True + run_kilosort.run_modules() + elif acq_software == "Open Ephys": + oe_probe = ephys.get_openephys_probe_data(key) + + assert len(oe_probe.recording_info["recording_files"]) == 1 + + # run kilosort + run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( + npx_input_dir=oe_probe.recording_info["recording_files"][0], + ks_output_dir=kilosort_dir, + params=params, + KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', + ) + run_kilosort._modules = [ + "kilosort_postprocessing", + "noise_templates", + "mean_waveforms", + "quality_metrics", + ] + run_kilosort.run_modules() + + with open(run_kilosort._modules_input_hash_fp) as f: + modules_status = json.load(f) + + self.insert1( + { + **key, + "modules_status": modules_status, + "execution_time": execution_time, + "execution_duration": ( + datetime.utcnow() - execution_time + ).total_seconds() + / 3600, + } + ) + + # all finished, insert this `key` into ephys.Clustering + ephys.Clustering.insert1( + {**key, "clustering_time": datetime.utcnow()}, allow_direct_insert=True + ) From 60091acad42b3081f3fbea53301d15993bc2e175 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Fri, 3 Feb 2023 01:32:18 -0600 Subject: [PATCH 017/146] add spike interface clustering and post processing modules --- .../spike_sorting/si_clustering.py | 123 ++++++++++++++++-- 1 file changed, 111 insertions(+), 12 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index 32384d01..9ddddb75 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -44,7 +44,9 @@ # get_recording_channels_details, # ) import spikeinterface as si +import spikeinterface.core as sic import spikeinterface.extractors as se +import spikeinterface.exporters as sie import spikeinterface.sorters as ss import spikeinterface.comparison as sc import spikeinterface.widgets as sw @@ -88,7 +90,7 @@ def activate( ) @schema -class SI_preprocessing(dj.Imported): +class SI_PreProcessing(dj.Imported): """A table to handle preprocessing of each clustering task.""" definition = """ @@ -172,8 +174,8 @@ def make(self, key): elif acq_software == "Open Ephys": oe_probe = ephys.get_openephys_probe_data(key) - oe_full_path = find_full_path(get_ephys_root_data_dir(),get_session_directory(key)) - oe_filepath = get_openephys_filepath(key) + oe_full_path = find_full_path(ephys.get_ephys_root_data_dir(),ephys.get_session_directory(key)) + oe_filepath = ephys.get_openephys_filepath(key) stream_name = os.path.split(oe_filepath)[1] assert len(oe_probe.recording_info["recording_files"]) == 1 @@ -186,7 +188,7 @@ def make(self, key): * ephys.EphysRecording & key) xy_coords = [list(i) for i in zip(electrode_query.fetch('x_coord'),electrode_query.fetch('y_coord'))] - channels_details = get_recording_channels_details(key) + channels_details = ephys.get_recording_channels_details(key) # Create SI probe object probe = pi.Probe(ndim=2, si_units='um') @@ -199,7 +201,8 @@ def make(self, key): # run preprocessing and save results to output folder oe_si_recording_filtered = sip.bandpass_filter(oe_si_recording, freq_min=300, freq_max=6000) oe_recording_cmr = sip.common_reference(oe_si_recording_filtered, reference="global", operator="median") - oe_recording_cmr.save_to_folder('oe_recording_cmr', kilosort_dir) + # oe_recording_cmr.save_to_folder('oe_recording_cmr', kilosort_dir) + oe_recording_cmr.dump_to_json('oe_recording_cmr.json', kilosort_dir) self.insert1( { @@ -217,7 +220,7 @@ class SI_KilosortClustering(dj.Imported): """A processing table to handle each clustering task.""" definition = """ - -> KilosortPreProcessing + -> SI_PreProcessing --- execution_time: datetime # datetime of the start of this step execution_duration: float # (hour) execution duration @@ -236,16 +239,18 @@ def make(self, key): params = (KilosortPreProcessing & key).fetch1("params") if acq_software == "SpikeGLX": - sglx_probe = ephys.get_openephys_probe_data(key) - oe_si_recording = se.load_from_folder - assert len(oe_probe.recording_info["recording_files"]) == 1 + # sglx_probe = ephys.get_openephys_probe_data(key) + recording_file = kilosort_dir / 'sglx_recording_cmr.json' + # sglx_si_recording = se.load_from_folder(recording_file) + sglx_si_recording = sic.load_extractor(recording_file) + # assert len(oe_probe.recording_info["recording_files"]) == 1 if clustering_method.startswith('kilosort2.5'): sorter_name = "kilosort2_5" else: sorter_name = clustering_method sorting_kilosort = si.run_sorter( sorter_name = sorter_name, - recording = oe_si_recording, + recording = sglx_si_recording, output_folder = kilosort_dir, docker_image = f"spikeinterface/{sorter_name}-compiled-base:latest", **params @@ -253,7 +258,7 @@ def make(self, key): sorting_kilosort.save_to_folder('sorting_kilosort', kilosort_dir) elif acq_software == "Open Ephys": oe_probe = ephys.get_openephys_probe_data(key) - oe_si_recording = se.load_from_folder + oe_si_recording = se.load_from_folder assert len(oe_probe.recording_info["recording_files"]) == 1 if clustering_method.startswith('kilosort2.5'): sorter_name = "kilosort2_5" @@ -266,7 +271,8 @@ def make(self, key): docker_image = f"spikeinterface/{sorter_name}-compiled-base:latest", **params ) - sorting_kilosort.save_to_folder('sorting_kilosort', kilosort_dir) + sorting_kilosort.save_to_folder('sorting_kilosort', kilosort_dir, n_jobs=-1, chunk_size=30000) + # sorting_kilosort.save(folder=kilosort_dir, n_jobs=20, chunk_size=30000) self.insert1( { @@ -279,7 +285,100 @@ def make(self, key): } ) +@schema +class SI_KilosortPostProcessing(dj.Imported): + """A processing table to handle each clustering task.""" + + definition = """ + -> SI_KilosortClustering + --- + modules_status: longblob # dictionary of summary status for all modules + execution_time: datetime # datetime of the start of this step + execution_duration: float # (hour) execution duration + """ + + def make(self, key): + execution_time = datetime.utcnow() + + output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") + kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + + acq_software, clustering_method = ( + ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key + ).fetch1("acq_software", "clustering_method") + + params = (KilosortPreProcessing & key).fetch1("params") + + if acq_software == "SpikeGLX": + sorting_file = kilosort_dir / 'sorting_kilosort' + recording_file = kilosort_dir / 'sglx_recording_cmr.json' + sglx_si_recording = sic.load_extractor(recording_file) + sorting_kilosort = sic.load_extractor(sorting_file) + + we_kilosort = si.WaveformExtractor.create(sglx_si_recording, sorting_kilosort, "waveforms", remove_if_exists=True) + we_kilosort.run_extract_waveforms(n_jobs=-1, chunk_size=30000) + unit_id0 = sorting_kilosort.unit_ids[0] + waveforms = we_kilosort.get_waveforms(unit_id0) + template = we_kilosort.get_template(unit_id0) + snrs = si.compute_snrs(we_kilosort) + + + # QC Metrics + si_violations_ratio, isi_violations_rate, isi_violations_count = si.compute_isi_violations(we_kilosort, isi_threshold_ms=1.5) + metrics = si.compute_quality_metrics(we_kilosort, metric_names=["firing_rate","snr","presence_ratio","isi_violation", + "num_spikes","amplitude_cutoff","amplitude_median","sliding_rp_violation","rp_violation","drift"]) + sie.export_report(we_kilosort, kilosort_dir, n_jobs=-1, chunk_size=30000) + # ["firing_rate","snr","presence_ratio","isi_violation", + # "number_violation","amplitude_cutoff","isolation_distance","l_ratio","d_prime","nn_hit_rate", + # "nn_miss_rate","silhouette_core","cumulative_drift","contamination_rate"]) + + we_kilosort.save_to_folder('we_kilosort',kilosort_dir, n_jobs=-1, chunk_size=30000) + + + + elif acq_software == "Open Ephys": + sorting_file = kilosort_dir / 'sorting_kilosort' + recording_file = kilosort_dir / 'sglx_recording_cmr.json' + sglx_si_recording = sic.load_extractor(recording_file) + sorting_kilosort = sic.load_extractor(sorting_file) + + we_kilosort = si.WaveformExtractor.create(sglx_si_recording, sorting_kilosort, "waveforms", remove_if_exists=True) + we_kilosort.run_extract_waveforms(n_jobs=-1, chunk_size=30000) + unit_id0 = sorting_kilosort.unit_ids[0] + waveforms = we_kilosort.get_waveforms(unit_id0) + template = we_kilosort.get_template(unit_id0) + snrs = si.compute_snrs(we_kilosort) + + + # QC Metrics + si_violations_ratio, isi_violations_rate, isi_violations_count = si.compute_isi_violations(we_kilosort, isi_threshold_ms=1.5) + metrics = si.compute_quality_metrics(we_kilosort, metric_names=["firing_rate","snr","presence_ratio","isi_violation", + "num_spikes","amplitude_cutoff","amplitude_median","sliding_rp_violation","rp_violation","drift"]) + sie.export_report(we_kilosort, kilosort_dir, n_jobs=-1, chunk_size=30000) + + we_kilosort.save_to_folder('we_kilosort',kilosort_dir, n_jobs=-1, chunk_size=30000) + + + + with open(run_kilosort._modules_input_hash_fp) as f: + modules_status = json.load(f) + self.insert1( + { + **key, + "modules_status": modules_status, + "execution_time": execution_time, + "execution_duration": ( + datetime.utcnow() - execution_time + ).total_seconds() + / 3600, + } + ) + + # all finished, insert this `key` into ephys.Clustering + ephys.Clustering.insert1( + {**key, "clustering_time": datetime.utcnow()}, allow_direct_insert=True + ) From bca5fa9593e7736548d253daae6ec0452bfec94e Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Fri, 3 Feb 2023 01:47:15 -0600 Subject: [PATCH 018/146] edit typos --- .../spike_sorting/si_clustering.py | 261 +----------------- 1 file changed, 4 insertions(+), 257 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index 9ddddb75..b3391f93 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -236,7 +236,8 @@ def make(self, key): ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key ).fetch1("acq_software", "clustering_method") - params = (KilosortPreProcessing & key).fetch1("params") + params = (SI_PreProcessing & key).fetch1("params") + if acq_software == "SpikeGLX": # sglx_probe = ephys.get_openephys_probe_data(key) @@ -286,7 +287,7 @@ def make(self, key): ) @schema -class SI_KilosortPostProcessing(dj.Imported): +class SI_PostProcessing(dj.Imported): """A processing table to handle each clustering task.""" definition = """ @@ -307,7 +308,7 @@ def make(self, key): ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key ).fetch1("acq_software", "clustering_method") - params = (KilosortPreProcessing & key).fetch1("params") + params = (SI_PreProcessing & key).fetch1("params") if acq_software == "SpikeGLX": sorting_file = kilosort_dir / 'sorting_kilosort' @@ -335,7 +336,6 @@ def make(self, key): we_kilosort.save_to_folder('we_kilosort',kilosort_dir, n_jobs=-1, chunk_size=30000) - elif acq_software == "Open Ephys": sorting_file = kilosort_dir / 'sorting_kilosort' recording_file = kilosort_dir / 'sglx_recording_cmr.json' @@ -358,8 +358,6 @@ def make(self, key): we_kilosort.save_to_folder('we_kilosort',kilosort_dir, n_jobs=-1, chunk_size=30000) - - with open(run_kilosort._modules_input_hash_fp) as f: modules_status = json.load(f) @@ -380,254 +378,3 @@ def make(self, key): {**key, "clustering_time": datetime.utcnow()}, allow_direct_insert=True ) - - -@schema -class KilosortPreProcessing(dj.Imported): - """A processing table to handle each clustering task.""" - - definition = """ - -> ephys.ClusteringTask - --- - params: longblob # finalized parameterset for this run - execution_time: datetime # datetime of the start of this step - execution_duration: float # (hour) execution duration - """ - - @property - def key_source(self): - return ( - ephys.ClusteringTask * ephys.ClusteringParamSet - & {"task_mode": "trigger"} - & 'clustering_method in ("kilosort2", "kilosort2.5", "kilosort3")' - ) - ephys.Clustering - - def make(self, key): - """Triggers or imports clustering analysis.""" - execution_time = datetime.utcnow() - - task_mode, output_dir = (ephys.ClusteringTask & key).fetch1( - "task_mode", "clustering_output_dir" - ) - - assert task_mode == "trigger", 'Supporting "trigger" task_mode only' - - if not output_dir: - output_dir = ephys.ClusteringTask.infer_output_dir( - key, relative=True, mkdir=True - ) - # update clustering_output_dir - ephys.ClusteringTask.update1( - {**key, "clustering_output_dir": output_dir.as_posix()} - ) - - kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - - acq_software, clustering_method, params = ( - ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key - ).fetch1("acq_software", "clustering_method", "params") - - assert ( - clustering_method in _supported_kilosort_versions - ), f'Clustering_method "{clustering_method}" is not supported' - - # add additional probe-recording and channels details into `params` - params = {**params, **ephys.get_recording_channels_details(key)} - params["fs"] = params["sample_rate"] - - - - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - spikeglx_recording.validate_file("ap") - run_CatGT = ( - params.get("run_CatGT", True) - and "_tcat." not in spikeglx_meta_filepath.stem - ) - - run_kilosort = kilosort_triggering.SGLXKilosortPipeline( - npx_input_dir=spikeglx_meta_filepath.parent, - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - run_CatGT=run_CatGT, - ) - run_kilosort.run_CatGT() - elif acq_software == "Open Ephys": - oe_probe = ephys.get_openephys_probe_data(key) - - assert len(oe_probe.recording_info["recording_files"]) == 1 - - # run kilosort - run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( - npx_input_dir=oe_probe.recording_info["recording_files"][0], - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - ) - run_kilosort._modules = ["depth_estimation", "median_subtraction"] - run_kilosort.run_modules() - - self.insert1( - { - **key, - "params": params, - "execution_time": execution_time, - "execution_duration": ( - datetime.utcnow() - execution_time - ).total_seconds() - / 3600, - } - ) - - -@schema -class KilosortClustering(dj.Imported): - """A processing table to handle each clustering task.""" - - definition = """ - -> KilosortPreProcessing - --- - execution_time: datetime # datetime of the start of this step - execution_duration: float # (hour) execution duration - """ - - def make(self, key): - execution_time = datetime.utcnow() - - output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") - kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - - acq_software, clustering_method = ( - ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key - ).fetch1("acq_software", "clustering_method") - - params = (KilosortPreProcessing & key).fetch1("params") - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - spikeglx_recording.validate_file("ap") - - run_kilosort = kilosort_triggering.SGLXKilosortPipeline( - npx_input_dir=spikeglx_meta_filepath.parent, - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - run_CatGT=True, - ) - run_kilosort._modules = ["kilosort_helper"] - run_kilosort._CatGT_finished = True - run_kilosort.run_modules() - elif acq_software == "Open Ephys": - oe_probe = ephys.get_openephys_probe_data(key) - - assert len(oe_probe.recording_info["recording_files"]) == 1 - - # run kilosort - run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( - npx_input_dir=oe_probe.recording_info["recording_files"][0], - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - ) - run_kilosort._modules = ["kilosort_helper"] - run_kilosort.run_modules() - - self.insert1( - { - **key, - "execution_time": execution_time, - "execution_duration": ( - datetime.utcnow() - execution_time - ).total_seconds() - / 3600, - } - ) - - -@schema -class KilosortPostProcessing(dj.Imported): - """A processing table to handle each clustering task.""" - - definition = """ - -> KilosortClustering - --- - modules_status: longblob # dictionary of summary status for all modules - execution_time: datetime # datetime of the start of this step - execution_duration: float # (hour) execution duration - """ - - def make(self, key): - execution_time = datetime.utcnow() - - output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") - kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - - acq_software, clustering_method = ( - ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key - ).fetch1("acq_software", "clustering_method") - - params = (KilosortPreProcessing & key).fetch1("params") - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - spikeglx_recording.validate_file("ap") - - run_kilosort = kilosort_triggering.SGLXKilosortPipeline( - npx_input_dir=spikeglx_meta_filepath.parent, - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - run_CatGT=True, - ) - run_kilosort._modules = [ - "kilosort_postprocessing", - "noise_templates", - "mean_waveforms", - "quality_metrics", - ] - run_kilosort._CatGT_finished = True - run_kilosort.run_modules() - elif acq_software == "Open Ephys": - oe_probe = ephys.get_openephys_probe_data(key) - - assert len(oe_probe.recording_info["recording_files"]) == 1 - - # run kilosort - run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( - npx_input_dir=oe_probe.recording_info["recording_files"][0], - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - ) - run_kilosort._modules = [ - "kilosort_postprocessing", - "noise_templates", - "mean_waveforms", - "quality_metrics", - ] - run_kilosort.run_modules() - - with open(run_kilosort._modules_input_hash_fp) as f: - modules_status = json.load(f) - - self.insert1( - { - **key, - "modules_status": modules_status, - "execution_time": execution_time, - "execution_duration": ( - datetime.utcnow() - execution_time - ).total_seconds() - / 3600, - } - ) - - # all finished, insert this `key` into ephys.Clustering - ephys.Clustering.insert1( - {**key, "clustering_time": datetime.utcnow()}, allow_direct_insert=True - ) From 1c4b0b578f31b2779a65db8ef67a8738a6352ff1 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Fri, 3 Feb 2023 14:49:47 -0600 Subject: [PATCH 019/146] removed module_status from table keys --- element_array_ephys/spike_sorting/si_clustering.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index b3391f93..d3c3bbf9 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -293,7 +293,6 @@ class SI_PostProcessing(dj.Imported): definition = """ -> SI_KilosortClustering --- - modules_status: longblob # dictionary of summary status for all modules execution_time: datetime # datetime of the start of this step execution_duration: float # (hour) execution duration """ @@ -358,13 +357,10 @@ def make(self, key): we_kilosort.save_to_folder('we_kilosort',kilosort_dir, n_jobs=-1, chunk_size=30000) - with open(run_kilosort._modules_input_hash_fp) as f: - modules_status = json.load(f) self.insert1( { **key, - "modules_status": modules_status, "execution_time": execution_time, "execution_duration": ( datetime.utcnow() - execution_time From 56c9941f12072b3d572520f69ee999025117ffb5 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Fri, 3 Feb 2023 17:16:47 -0600 Subject: [PATCH 020/146] remove _ from SI table names --- .../spike_sorting/si_clustering.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index d3c3bbf9..a72989ed 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -8,8 +8,8 @@ https://github.com/datajoint-company/ecephys_spike_sorting The follow pipeline features intermediary tables: -1. KilosortPreProcessing - for preprocessing steps (no GPU required) - - median_subtraction for Open Ephys +1. SIPreProcessing - for preprocessing steps (no GPU required) + - - or the CatGT step for SpikeGLX 2. KilosortClustering - kilosort (MATLAB) - requires GPU - supports kilosort 2.0, 2.5 or 3.0 (https://github.com/MouseLand/Kilosort.git) @@ -90,7 +90,7 @@ def activate( ) @schema -class SI_PreProcessing(dj.Imported): +class SIPreProcessing(dj.Imported): """A table to handle preprocessing of each clustering task.""" definition = """ @@ -168,8 +168,8 @@ def make(self, key): # run preprocessing and save results to output folder sglx_si_recording_filtered = sip.bandpass_filter(sglx_si_recording, freq_min=300, freq_max=6000) - sglx_recording_cmr = sip.common_reference(sglx_si_recording_filtered, reference="global", operator="median") - sglx_recording_cmr.save_to_folder('sglx_recording_cmr', kilosort_dir) + # sglx_recording_cmr = sip.common_reference(sglx_si_recording_filtered, reference="global", operator="median") + sglx_si_recording_filtered.save_to_folder('sglx_si_recording_filtered', kilosort_dir) elif acq_software == "Open Ephys": @@ -216,7 +216,7 @@ def make(self, key): } ) @schema -class SI_KilosortClustering(dj.Imported): +class SIClustering(dj.Imported): """A processing table to handle each clustering task.""" definition = """ @@ -236,8 +236,7 @@ def make(self, key): ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key ).fetch1("acq_software", "clustering_method") - params = (SI_PreProcessing & key).fetch1("params") - + params = (SIPreProcessing & key).fetch1("params") if acq_software == "SpikeGLX": # sglx_probe = ephys.get_openephys_probe_data(key) @@ -287,11 +286,11 @@ def make(self, key): ) @schema -class SI_PostProcessing(dj.Imported): +class SIPostProcessing(dj.Imported): """A processing table to handle each clustering task.""" definition = """ - -> SI_KilosortClustering + -> SIClustering --- execution_time: datetime # datetime of the start of this step execution_duration: float # (hour) execution duration @@ -307,7 +306,7 @@ def make(self, key): ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key ).fetch1("acq_software", "clustering_method") - params = (SI_PreProcessing & key).fetch1("params") + params = (SIPreProcessing & key).fetch1("params") if acq_software == "SpikeGLX": sorting_file = kilosort_dir / 'sorting_kilosort' From ce14098041a5292ec8dd9abd9776074d975ad3b2 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Fri, 3 Feb 2023 17:23:56 -0600 Subject: [PATCH 021/146] bugfix --- element_array_ephys/spike_sorting/si_clustering.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index a72989ed..97a7dd53 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -202,7 +202,8 @@ def make(self, key): oe_si_recording_filtered = sip.bandpass_filter(oe_si_recording, freq_min=300, freq_max=6000) oe_recording_cmr = sip.common_reference(oe_si_recording_filtered, reference="global", operator="median") # oe_recording_cmr.save_to_folder('oe_recording_cmr', kilosort_dir) - oe_recording_cmr.dump_to_json('oe_recording_cmr.json', kilosort_dir) + # oe_recording_cmr.dump_to_json('oe_recording_cmr.json', kilosort_dir) + oe_si_recording_filtered.save_to_folder('', kilosort_dir) self.insert1( { @@ -220,7 +221,7 @@ class SIClustering(dj.Imported): """A processing table to handle each clustering task.""" definition = """ - -> SI_PreProcessing + -> SIPreProcessing --- execution_time: datetime # datetime of the start of this step execution_duration: float # (hour) execution duration From 7c836f12fd1d47e5b7cf15435eaa19ad1faa7ae4 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Fri, 3 Feb 2023 17:33:52 -0600 Subject: [PATCH 022/146] change si related table names --- element_array_ephys/spike_sorting/si_clustering.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index 97a7dd53..d97066a6 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -90,7 +90,7 @@ def activate( ) @schema -class SIPreProcessing(dj.Imported): +class PreProcessing(dj.Imported): """A table to handle preprocessing of each clustering task.""" definition = """ @@ -164,7 +164,7 @@ def make(self, key): probe.create_auto_shape(probe_type='tip') channel_indices = np.arange(channels_details['num_channels']) probe.set_device_channel_indices(channel_indices) - oe_si_recording.set_probe(probe=probe) + sglx_si_recording.set_probe(probe=probe) # run preprocessing and save results to output folder sglx_si_recording_filtered = sip.bandpass_filter(sglx_si_recording, freq_min=300, freq_max=6000) @@ -217,7 +217,7 @@ def make(self, key): } ) @schema -class SIClustering(dj.Imported): +class ClusteringModule(dj.Imported): """A processing table to handle each clustering task.""" definition = """ @@ -237,7 +237,7 @@ def make(self, key): ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key ).fetch1("acq_software", "clustering_method") - params = (SIPreProcessing & key).fetch1("params") + params = (PreProcessing & key).fetch1("params") if acq_software == "SpikeGLX": # sglx_probe = ephys.get_openephys_probe_data(key) @@ -287,11 +287,11 @@ def make(self, key): ) @schema -class SIPostProcessing(dj.Imported): +class PostProcessing(dj.Imported): """A processing table to handle each clustering task.""" definition = """ - -> SIClustering + -> ClusteringModule --- execution_time: datetime # datetime of the start of this step execution_duration: float # (hour) execution duration @@ -307,7 +307,7 @@ def make(self, key): ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key ).fetch1("acq_software", "clustering_method") - params = (SIPreProcessing & key).fetch1("params") + params = (PreProcessing & key).fetch1("params") if acq_software == "SpikeGLX": sorting_file = kilosort_dir / 'sorting_kilosort' From dd6366498d1a4bd974803b89abdfd7ab30a96623 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Fri, 3 Feb 2023 17:38:49 -0600 Subject: [PATCH 023/146] bugfix --- element_array_ephys/spike_sorting/si_clustering.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index d97066a6..cabe5c25 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -221,7 +221,7 @@ class ClusteringModule(dj.Imported): """A processing table to handle each clustering task.""" definition = """ - -> SIPreProcessing + -> PreProcessing --- execution_time: datetime # datetime of the start of this step execution_duration: float # (hour) execution duration From 54888ed7c747bf5452e90d590641824bd70684f1 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Fri, 3 Feb 2023 18:02:46 -0600 Subject: [PATCH 024/146] update initial comment --- .../spike_sorting/si_clustering.py | 26 ++++++++----------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index cabe5c25..e69855f1 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -1,22 +1,20 @@ """ The following DataJoint pipeline implements the sequence of steps in the spike-sorting routine featured in the -"ecephys_spike_sorting" pipeline. -The "ecephys_spike_sorting" was originally developed by the Allen Institute (https://github.com/AllenInstitute/ecephys_spike_sorting) for Neuropixels data acquired with Open Ephys acquisition system. -Then forked by Jennifer Colonell from the Janelia Research Campus (https://github.com/jenniferColonell/ecephys_spike_sorting) to support SpikeGLX acquisition system. +"spikeinterface" pipeline. +Spikeinterface developed by Alessio Buccino, Samuel Garcia, Cole Hurwitz, Jeremy Magland, and Matthias Hennig (https://github.com/SpikeInterface) -At DataJoint, we fork from Jennifer's fork and implemented a version that supports both Open Ephys and Spike GLX. -https://github.com/datajoint-company/ecephys_spike_sorting +The DataJoint pipeline currently incorporated Spikeinterfaces approach of running Kilosort using a container The follow pipeline features intermediary tables: -1. SIPreProcessing - for preprocessing steps (no GPU required) - - - - or the CatGT step for SpikeGLX -2. KilosortClustering - kilosort (MATLAB) - requires GPU +1. PreProcessing - for preprocessing steps (no GPU required) + - create recording extractor and link it to a probe + - bandpass filtering + - common mode referencing +2. ClusteringModule - kilosort (MATLAB) - requires GPU and docker/singularity containers - supports kilosort 2.0, 2.5 or 3.0 (https://github.com/MouseLand/Kilosort.git) -3. KilosortPostProcessing - for postprocessing steps (no GPU required) - - kilosort_postprocessing - - noise_templates - - mean_waveforms +3. PostProcessing - for postprocessing steps (no GPU required) + - create waveform extractor object + - extract templates, waveforms and snrs - quality_metrics @@ -48,8 +46,6 @@ import spikeinterface.extractors as se import spikeinterface.exporters as sie import spikeinterface.sorters as ss -import spikeinterface.comparison as sc -import spikeinterface.widgets as sw import spikeinterface.preprocessing as sip import probeinterface as pi From f804233b63a74ad935b953c3051ad74de0a4f4b2 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Wed, 8 Feb 2023 19:17:51 -0600 Subject: [PATCH 025/146] fix preprocessing file loading issues --- .../spike_sorting/si_clustering.py | 64 ++++++++----------- 1 file changed, 27 insertions(+), 37 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index e69855f1..e6e129bc 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -32,15 +32,8 @@ spikeglx, kilosort_triggering, ) -import element_array_ephys.ephys_no_curation as ephys import element_array_ephys.probe as probe -# from element_array_ephys.ephys_no_curation import ( -# get_ephys_root_data_dir, -# get_session_directory, -# get_openephys_filepath, -# get_spikeglx_meta_filepath, -# get_recording_channels_details, -# ) + import spikeinterface as si import spikeinterface.core as sic import spikeinterface.extractors as se @@ -92,6 +85,7 @@ class PreProcessing(dj.Imported): definition = """ -> ephys.ClusteringTask --- + file_name: varchar(60) # filename where recording object is saved to params: longblob # finalized parameterset for this run execution_time: datetime # datetime of the start of this step execution_duration: float # (hour) execution duration @@ -137,30 +131,27 @@ def make(self, key): params = {**params, **ephys.get_recording_channels_details(key)} params["fs"] = params["sample_rate"] + if acq_software == "SpikeGLX": sglx_full_path = find_full_path(ephys.get_ephys_root_data_dir(),ephys.get_session_directory(key)) sglx_filepath = ephys.get_spikeglx_meta_filepath(key) stream_name = os.path.split(sglx_filepath)[1] - assert len(oe_probe.recording_info["recording_files"]) == 1 + # assert len(oe_probe.recording_info["recording_files"]) == 1 # Create SI recording extractor object # sglx_si_recording = se.SpikeGLXRecordingExtractor(folder_path=sglx_full_path, stream_name=stream_name) sglx_si_recording = se.read_spikeglx(folder_path=sglx_full_path, stream_name=stream_name) - electrode_query = (probe.ProbeType.Electrode - * probe.ElectrodeConfig.Electrode - * ephys.EphysRecording & key) - xy_coords = [list(i) for i in zip(electrode_query.fetch('x_coord'),electrode_query.fetch('y_coord'))] + xy_coords = [list(i) for i in zip(channels_details['x_coords'],channels_details['y_coords'])] channels_details = ephys.get_recording_channels_details(key) # Create SI probe object - probe = pi.Probe(ndim=2, si_units='um') - probe.set_contacts(positions=xy_coords, shapes='square', shape_params={'width': 5}) - probe.create_auto_shape(probe_type='tip') - channel_indices = np.arange(channels_details['num_channels']) - probe.set_device_channel_indices(channel_indices) - sglx_si_recording.set_probe(probe=probe) + si_probe = pi.Probe(ndim=2, si_units='um') + si_probe.set_contacts(positions=xy_coords, shapes='square', shape_params={'width': 5}) + si_probe.create_auto_shape(probe_type='tip') + si_probe.set_device_channel_indices(channels_details['channel_ind']) + sglx_si_recording.set_probe(probe=si_probe) # run preprocessing and save results to output folder sglx_si_recording_filtered = sip.bandpass_filter(sglx_si_recording, freq_min=300, freq_max=6000) @@ -170,29 +161,25 @@ def make(self, key): elif acq_software == "Open Ephys": oe_probe = ephys.get_openephys_probe_data(key) - oe_full_path = find_full_path(ephys.get_ephys_root_data_dir(),ephys.get_session_directory(key)) - oe_filepath = ephys.get_openephys_filepath(key) - stream_name = os.path.split(oe_filepath)[1] - + oe_session_full_path = find_full_path(ephys.get_ephys_root_data_dir(),ephys.get_session_directory(key)) + assert len(oe_probe.recording_info["recording_files"]) == 1 + stream_name = os.path.split(oe_probe.recording_info['recording_files'][0])[1] # Create SI recording extractor object # oe_si_recording = se.OpenEphysBinaryRecordingExtractor(folder_path=oe_full_path, stream_name=stream_name) - oe_si_recording = se.read_openephys(folder_path=oe_full_path, stream_name=stream_name) - electrode_query = (probe.ProbeType.Electrode - * probe.ElectrodeConfig.Electrode - * ephys.EphysRecording & key) + oe_si_recording = se.read_openephys(folder_path=oe_session_full_path, stream_name=stream_name) - xy_coords = [list(i) for i in zip(electrode_query.fetch('x_coord'),electrode_query.fetch('y_coord'))] + xy_coords = [list(i) for i in zip(channels_details['x_coords'],channels_details['y_coords'])] + channels_details = ephys.get_recording_channels_details(key) # Create SI probe object - probe = pi.Probe(ndim=2, si_units='um') - probe.set_contacts(positions=xy_coords, shapes='square', shape_params={'width': 5}) - probe.create_auto_shape(probe_type='tip') - channel_indices = np.arange(channels_details['num_channels']) - probe.set_device_channel_indices(channel_indices) - oe_si_recording.set_probe(probe=probe) + si_probe = pi.Probe(ndim=2, si_units='um') + si_probe.set_contacts(positions=xy_coords, shapes='square', shape_params={'width': 5}) + si_probe.create_auto_shape(probe_type='tip') + si_probe.set_device_channel_indices(channels_details['channel_ind']) + oe_si_recording.set_probe(probe=si_probe) # run preprocessing and save results to output folder oe_si_recording_filtered = sip.bandpass_filter(oe_si_recording, freq_min=300, freq_max=6000) @@ -219,8 +206,10 @@ class ClusteringModule(dj.Imported): definition = """ -> PreProcessing --- - execution_time: datetime # datetime of the start of this step - execution_duration: float # (hour) execution duration + recording_file: varchar(60) # filename of saved recording object + sorting_file: varchar(60) # filename of saved sorting object + execution_time: datetime # datetime of the start of this step + execution_duration: float # (hour) execution duration """ def make(self, key): @@ -234,10 +223,11 @@ def make(self, key): ).fetch1("acq_software", "clustering_method") params = (PreProcessing & key).fetch1("params") + file_name = (PreProcessing & key).fetch1("file_name") if acq_software == "SpikeGLX": # sglx_probe = ephys.get_openephys_probe_data(key) - recording_file = kilosort_dir / 'sglx_recording_cmr.json' + recording_file = kilosort_dir / file_name # sglx_si_recording = se.load_from_folder(recording_file) sglx_si_recording = sic.load_extractor(recording_file) # assert len(oe_probe.recording_info["recording_files"]) == 1 From 8e1b73dd5013b9eb6da542678726fea800a3dcf6 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Thu, 9 Feb 2023 15:58:12 -0600 Subject: [PATCH 026/146] set file saving and file loading to pickle format --- element_array_ephys/spike_sorting/si_clustering.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index e6e129bc..cf909725 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -186,11 +186,14 @@ def make(self, key): oe_recording_cmr = sip.common_reference(oe_si_recording_filtered, reference="global", operator="median") # oe_recording_cmr.save_to_folder('oe_recording_cmr', kilosort_dir) # oe_recording_cmr.dump_to_json('oe_recording_cmr.json', kilosort_dir) - oe_si_recording_filtered.save_to_folder('', kilosort_dir) + save_file_name = 'si_recording.pkl' + save_file_path = kilosort_dir / save_file_name + oe_si_recording_filtered.dump_to_pickle(file_path=save_file_path) self.insert1( { **key, + "file_name": save_file_name, "params": params, "execution_time": execution_time, "execution_duration": ( From f0b7497e7b173396a35efbd9a1545981a095d96c Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Fri, 10 Feb 2023 18:09:40 -0600 Subject: [PATCH 027/146] sglx preprocessing modifications --- .../spike_sorting/si_clustering.py | 58 ++++++++++--------- 1 file changed, 31 insertions(+), 27 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index cf909725..6ec4b6e2 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -34,7 +34,7 @@ ) import element_array_ephys.probe as probe -import spikeinterface as si +import spikeinterface.full as si import spikeinterface.core as sic import spikeinterface.extractors as se import spikeinterface.exporters as sie @@ -85,7 +85,7 @@ class PreProcessing(dj.Imported): definition = """ -> ephys.ClusteringTask --- - file_name: varchar(60) # filename where recording object is saved to + recording_filename: varchar(60) # filename where recording object is saved to params: longblob # finalized parameterset for this run execution_time: datetime # datetime of the start of this step execution_duration: float # (hour) execution duration @@ -133,22 +133,19 @@ def make(self, key): if acq_software == "SpikeGLX": - sglx_full_path = find_full_path(ephys.get_ephys_root_data_dir(),ephys.get_session_directory(key)) + # sglx_session_full_path = find_full_path(ephys.get_ephys_root_data_dir(),ephys.get_session_directory(key)) sglx_filepath = ephys.get_spikeglx_meta_filepath(key) - stream_name = os.path.split(sglx_filepath)[1] - - # assert len(oe_probe.recording_info["recording_files"]) == 1 # Create SI recording extractor object - # sglx_si_recording = se.SpikeGLXRecordingExtractor(folder_path=sglx_full_path, stream_name=stream_name) - sglx_si_recording = se.read_spikeglx(folder_path=sglx_full_path, stream_name=stream_name) - - xy_coords = [list(i) for i in zip(channels_details['x_coords'],channels_details['y_coords'])] + sglx_si_recording = se.read_spikeglx(folder_path=sglx_filepath.parent) + channels_details = ephys.get_recording_channels_details(key) + xy_coords = [list(i) for i in zip(channels_details['x_coords'],channels_details['y_coords'])] + # Create SI probe object si_probe = pi.Probe(ndim=2, si_units='um') - si_probe.set_contacts(positions=xy_coords, shapes='square', shape_params={'width': 5}) + si_probe.set_contacts(positions=xy_coords, shapes='square', shape_params={'width': 12}) si_probe.create_auto_shape(probe_type='tip') si_probe.set_device_channel_indices(channels_details['channel_ind']) sglx_si_recording.set_probe(probe=si_probe) @@ -156,7 +153,10 @@ def make(self, key): # run preprocessing and save results to output folder sglx_si_recording_filtered = sip.bandpass_filter(sglx_si_recording, freq_min=300, freq_max=6000) # sglx_recording_cmr = sip.common_reference(sglx_si_recording_filtered, reference="global", operator="median") - sglx_si_recording_filtered.save_to_folder('sglx_si_recording_filtered', kilosort_dir) + + save_file_name = 'si_recording.pkl' + save_file_path = kilosort_dir / save_file_name + sglx_si_recording_filtered.dump_to_pickle(file_path=save_file_path) elif acq_software == "Open Ephys": @@ -170,22 +170,21 @@ def make(self, key): # oe_si_recording = se.OpenEphysBinaryRecordingExtractor(folder_path=oe_full_path, stream_name=stream_name) oe_si_recording = se.read_openephys(folder_path=oe_session_full_path, stream_name=stream_name) + channels_details = ephys.get_recording_channels_details(key) xy_coords = [list(i) for i in zip(channels_details['x_coords'],channels_details['y_coords'])] - channels_details = ephys.get_recording_channels_details(key) - # Create SI probe object si_probe = pi.Probe(ndim=2, si_units='um') - si_probe.set_contacts(positions=xy_coords, shapes='square', shape_params={'width': 5}) + si_probe.set_contacts(positions=xy_coords, shapes='square', shape_params={'width': 12}) si_probe.create_auto_shape(probe_type='tip') si_probe.set_device_channel_indices(channels_details['channel_ind']) oe_si_recording.set_probe(probe=si_probe) # run preprocessing and save results to output folder + # Switch case to allow for specified preprocessing steps oe_si_recording_filtered = sip.bandpass_filter(oe_si_recording, freq_min=300, freq_max=6000) oe_recording_cmr = sip.common_reference(oe_si_recording_filtered, reference="global", operator="median") - # oe_recording_cmr.save_to_folder('oe_recording_cmr', kilosort_dir) - # oe_recording_cmr.dump_to_json('oe_recording_cmr.json', kilosort_dir) + save_file_name = 'si_recording.pkl' save_file_path = kilosort_dir / save_file_name oe_si_recording_filtered.dump_to_pickle(file_path=save_file_path) @@ -193,7 +192,7 @@ def make(self, key): self.insert1( { **key, - "file_name": save_file_name, + "recording_filename": save_file_name, "params": params, "execution_time": execution_time, "execution_duration": ( @@ -202,15 +201,14 @@ def make(self, key): / 3600, } ) -@schema + @schema class ClusteringModule(dj.Imported): """A processing table to handle each clustering task.""" definition = """ -> PreProcessing --- - recording_file: varchar(60) # filename of saved recording object - sorting_file: varchar(60) # filename of saved sorting object + sorting_filename: varchar(60) # filename of saved sorting object execution_time: datetime # datetime of the start of this step execution_duration: float # (hour) execution duration """ @@ -226,13 +224,13 @@ def make(self, key): ).fetch1("acq_software", "clustering_method") params = (PreProcessing & key).fetch1("params") - file_name = (PreProcessing & key).fetch1("file_name") + recording_filename = (PreProcessing & key).fetch1("recording_filename") if acq_software == "SpikeGLX": # sglx_probe = ephys.get_openephys_probe_data(key) - recording_file = kilosort_dir / file_name + recording_fullpath = kilosort_dir / recording_filename # sglx_si_recording = se.load_from_folder(recording_file) - sglx_si_recording = sic.load_extractor(recording_file) + sglx_si_recording = sic.load_extractor(recording_fullpath) # assert len(oe_probe.recording_info["recording_files"]) == 1 if clustering_method.startswith('kilosort2.5'): sorter_name = "kilosort2_5" @@ -245,10 +243,11 @@ def make(self, key): docker_image = f"spikeinterface/{sorter_name}-compiled-base:latest", **params ) - sorting_kilosort.save_to_folder('sorting_kilosort', kilosort_dir) + sorting_save_path = kilosort_dir / 'sorting_kilosort.pkl' + sorting_kilosort.dump_to_pickle(sorting_save_path) elif acq_software == "Open Ephys": oe_probe = ephys.get_openephys_probe_data(key) - oe_si_recording = se.load_from_folder + oe_si_recording = sic.load_extractor(recording_fullpath) assert len(oe_probe.recording_info["recording_files"]) == 1 if clustering_method.startswith('kilosort2.5'): sorter_name = "kilosort2_5" @@ -261,7 +260,8 @@ def make(self, key): docker_image = f"spikeinterface/{sorter_name}-compiled-base:latest", **params ) - sorting_kilosort.save_to_folder('sorting_kilosort', kilosort_dir, n_jobs=-1, chunk_size=30000) + sorting_save_path = kilosort_dir / 'sorting_kilosort.pkl' + sorting_kilosort.dump_to_pickle(sorting_save_path) # sorting_kilosort.save(folder=kilosort_dir, n_jobs=20, chunk_size=30000) self.insert1( @@ -363,3 +363,7 @@ def make(self, key): {**key, "clustering_time": datetime.utcnow()}, allow_direct_insert=True ) + + +def preProcessing_switch(preprocess_list): + \ No newline at end of file From 13fe31c49fa5e5286c4ef916605bf017a33a9d87 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Mon, 13 Feb 2023 17:48:12 -0600 Subject: [PATCH 028/146] sglx testing progress --- element_array_ephys/spike_sorting/si_clustering.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index 6ec4b6e2..08ff86bd 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -128,8 +128,8 @@ def make(self, key): ), f'Clustering_method "{clustering_method}" is not supported' # add additional probe-recording and channels details into `params` - params = {**params, **ephys.get_recording_channels_details(key)} - params["fs"] = params["sample_rate"] + # params = {**params, **ephys.get_recording_channels_details(key)} + # params["fs"] = params["sample_rate"] if acq_software == "SpikeGLX": From d41c7f3c6d0fb2b0c97578d6c8cd28ab1b6a2832 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Tue, 14 Feb 2023 18:46:54 -0600 Subject: [PATCH 029/146] wip parametrize preprocessing --- .../spike_sorting/si_clustering.py | 117 +++++++++++++++--- 1 file changed, 99 insertions(+), 18 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index 08ff86bd..f0144ea6 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -131,6 +131,46 @@ def make(self, key): # params = {**params, **ephys.get_recording_channels_details(key)} # params["fs"] = params["sample_rate"] + + preprocess_list = params.pop('PreProcessing_params') + + # If else + if preprocess_list['Filter']: + oe_si_recording = sip.FilterRecording(oe_si_recording) + elif preprocess_list['BandpassFilter']: + oe_si_recording = sip.BandpassFilterRecording(oe_si_recording) + elif preprocess_list['HighpassFilter']: + oe_si_recording = sip.HighpassFilterRecording(oe_si_recording) + elif preprocess_list['NormalizeByQuantile']: + oe_si_recording = sip.NormalizeByQuantileRecording(oe_si_recording) + elif preprocess_list['Scale']: + oe_si_recording = sip.ScaleRecording(oe_si_recording) + elif preprocess_list['Center']: + oe_si_recording = sip.CenterRecording(oe_si_recording) + elif preprocess_list['ZScore']: + oe_si_recording = sip.ZScoreRecording(oe_si_recording) + elif preprocess_list['Whiten']: + oe_si_recording = sip.WhitenRecording(oe_si_recording) + elif preprocess_list['CommonReference']: + oe_si_recording = sip.CommonReferenceRecording(oe_si_recording) + elif preprocess_list['PhaseShift']: + oe_si_recording = sip.PhaseShiftRecording(oe_si_recording) + elif preprocess_list['Rectify']: + oe_si_recording = sip.RectifyRecording(oe_si_recording) + elif preprocess_list['Clip']: + oe_si_recording = sip.ClipRecording(oe_si_recording) + elif preprocess_list['BlankSaturation']: + oe_si_recording = sip.BlankSaturationRecording(oe_si_recording) + elif preprocess_list['RemoveArtifacts']: + oe_si_recording = sip.RemoveArtifactsRecording(oe_si_recording) + elif preprocess_list['RemoveBadChannels']: + oe_si_recording = sip.RemoveBadChannelsRecording(oe_si_recording) + elif preprocess_list['ZeroChannelPad']: + oe_si_recording = sip.ZeroChannelPadRecording(oe_si_recording) + elif preprocess_list['DeepInterpolation']: + oe_si_recording = sip.DeepInterpolationRecording(oe_si_recording) + elif preprocess_list['Resample']: + oe_si_recording = sip.ResampleRecording(oe_si_recording) if acq_software == "SpikeGLX": # sglx_session_full_path = find_full_path(ephys.get_ephys_root_data_dir(),ephys.get_session_directory(key)) @@ -232,17 +272,23 @@ def make(self, key): # sglx_si_recording = se.load_from_folder(recording_file) sglx_si_recording = sic.load_extractor(recording_fullpath) # assert len(oe_probe.recording_info["recording_files"]) == 1 + + ## Assume that the worker process will trigger this sorting step + # - Will need to store/load the sorter_name, sglx_si_recording object etc. + # - Store in shared EC2 space accessible by all containers (needs to be mounted) + # - Load into the cloud init script, and + # - Option A: Can call this function within a separate container within spike_sorting_worker if clustering_method.startswith('kilosort2.5'): sorter_name = "kilosort2_5" else: sorter_name = clustering_method - sorting_kilosort = si.run_sorter( - sorter_name = sorter_name, - recording = sglx_si_recording, - output_folder = kilosort_dir, - docker_image = f"spikeinterface/{sorter_name}-compiled-base:latest", - **params - ) + # sorting_kilosort = si.run_sorter( + # sorter_name = sorter_name, + # recording = sglx_si_recording, + # output_folder = kilosort_dir, + # docker_image = f"spikeinterface/{sorter_name}-compiled-base:latest", + # **params + # ) sorting_save_path = kilosort_dir / 'sorting_kilosort.pkl' sorting_kilosort.dump_to_pickle(sorting_save_path) elif acq_software == "Open Ephys": @@ -253,13 +299,13 @@ def make(self, key): sorter_name = "kilosort2_5" else: sorter_name = clustering_method - sorting_kilosort = si.run_sorter( - sorter_name = sorter_name, - recording = oe_si_recording, - output_folder = kilosort_dir, - docker_image = f"spikeinterface/{sorter_name}-compiled-base:latest", - **params - ) + # sorting_kilosort = si.run_sorter( + # sorter_name = sorter_name, + # recording = oe_si_recording, + # output_folder = kilosort_dir, + # docker_image = f"spikeinterface/{sorter_name}-compiled-base:latest", + # **params + # ) sorting_save_path = kilosort_dir / 'sorting_kilosort.pkl' sorting_kilosort.dump_to_pickle(sorting_save_path) # sorting_kilosort.save(folder=kilosort_dir, n_jobs=20, chunk_size=30000) @@ -363,7 +409,42 @@ def make(self, key): {**key, "clustering_time": datetime.utcnow()}, allow_direct_insert=True ) - - -def preProcessing_switch(preprocess_list): - \ No newline at end of file +## Example SI parameter set +''' +{'detect_threshold': 6, + 'projection_threshold': [10, 4], + 'preclust_threshold': 8, + 'car': True, + 'minFR': 0.02, + 'minfr_goodchannels': 0.1, + 'nblocks': 5, + 'sig': 20, + 'freq_min': 150, + 'sigmaMask': 30, + 'nPCs': 3, + 'ntbuff': 64, + 'nfilt_factor': 4, + 'NT': None, + 'do_correction': True, + 'wave_length': 61, + 'keep_good_only': False, + 'PreProcessing_params': {'Filter': False, + 'BandpassFilter': True, + 'HighpassFilter': False, + 'NotchFilter': False, + 'NormalizeByQuantile': False, + 'Scale': False, + 'Center': False, + 'ZScore': False, + 'Whiten': False, + 'CommonReference': False, + 'PhaseShift': False, + 'Rectify': False, + 'Clip': False, + 'BlankSaturation': False, + 'RemoveArtifacts': False, + 'RemoveBadChannels': False, + 'ZeroChannelPad': False, + 'DeepInterpolation': False, + 'Resample': False}} +''' \ No newline at end of file From 109a71ad4e8c884659bee146dda2afdbfe4fa77a Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Fri, 17 Feb 2023 19:34:53 -0600 Subject: [PATCH 030/146] post processing waveform extractor extensions --- .../spike_sorting/si_clustering.py | 282 +++++++++++------- 1 file changed, 178 insertions(+), 104 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index f0144ea6..13f129d8 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -34,6 +34,7 @@ ) import element_array_ephys.probe as probe +import spikeinterface import spikeinterface.full as si import spikeinterface.core as sic import spikeinterface.extractors as se @@ -78,6 +79,7 @@ def activate( add_objects=ephys.__dict__, ) + @schema class PreProcessing(dj.Imported): """A table to handle preprocessing of each clustering task.""" @@ -85,7 +87,7 @@ class PreProcessing(dj.Imported): definition = """ -> ephys.ClusteringTask --- - recording_filename: varchar(60) # filename where recording object is saved to + recording_filename: varchar(30) # filename where recording object is saved to params: longblob # finalized parameterset for this run execution_time: datetime # datetime of the start of this step execution_duration: float # (hour) execution duration @@ -98,6 +100,7 @@ def key_source(self): & {"task_mode": "trigger"} & 'clustering_method in ("kilosort2", "kilosort2.5", "kilosort3")' ) - ephys.Clustering + def make(self, key): """Triggers or imports clustering analysis.""" execution_time = datetime.utcnow() @@ -131,101 +134,121 @@ def make(self, key): # params = {**params, **ephys.get_recording_channels_details(key)} # params["fs"] = params["sample_rate"] - - preprocess_list = params.pop('PreProcessing_params') + preprocess_list = params.pop("PreProcessing_params") - # If else - if preprocess_list['Filter']: + # If else + # need to figure out ordering + if preprocess_list["Filter"]: oe_si_recording = sip.FilterRecording(oe_si_recording) - elif preprocess_list['BandpassFilter']: + elif preprocess_list["BandpassFilter"]: oe_si_recording = sip.BandpassFilterRecording(oe_si_recording) - elif preprocess_list['HighpassFilter']: + elif preprocess_list["HighpassFilter"]: oe_si_recording = sip.HighpassFilterRecording(oe_si_recording) - elif preprocess_list['NormalizeByQuantile']: + elif preprocess_list["NormalizeByQuantile"]: oe_si_recording = sip.NormalizeByQuantileRecording(oe_si_recording) - elif preprocess_list['Scale']: + elif preprocess_list["Scale"]: oe_si_recording = sip.ScaleRecording(oe_si_recording) - elif preprocess_list['Center']: + elif preprocess_list["Center"]: oe_si_recording = sip.CenterRecording(oe_si_recording) - elif preprocess_list['ZScore']: + elif preprocess_list["ZScore"]: oe_si_recording = sip.ZScoreRecording(oe_si_recording) - elif preprocess_list['Whiten']: + elif preprocess_list["Whiten"]: oe_si_recording = sip.WhitenRecording(oe_si_recording) - elif preprocess_list['CommonReference']: + elif preprocess_list["CommonReference"]: oe_si_recording = sip.CommonReferenceRecording(oe_si_recording) - elif preprocess_list['PhaseShift']: + elif preprocess_list["PhaseShift"]: oe_si_recording = sip.PhaseShiftRecording(oe_si_recording) - elif preprocess_list['Rectify']: + elif preprocess_list["Rectify"]: oe_si_recording = sip.RectifyRecording(oe_si_recording) - elif preprocess_list['Clip']: + elif preprocess_list["Clip"]: oe_si_recording = sip.ClipRecording(oe_si_recording) - elif preprocess_list['BlankSaturation']: + elif preprocess_list["BlankSaturation"]: oe_si_recording = sip.BlankSaturationRecording(oe_si_recording) - elif preprocess_list['RemoveArtifacts']: + elif preprocess_list["RemoveArtifacts"]: oe_si_recording = sip.RemoveArtifactsRecording(oe_si_recording) - elif preprocess_list['RemoveBadChannels']: + elif preprocess_list["RemoveBadChannels"]: oe_si_recording = sip.RemoveBadChannelsRecording(oe_si_recording) - elif preprocess_list['ZeroChannelPad']: + elif preprocess_list["ZeroChannelPad"]: oe_si_recording = sip.ZeroChannelPadRecording(oe_si_recording) - elif preprocess_list['DeepInterpolation']: + elif preprocess_list["DeepInterpolation"]: oe_si_recording = sip.DeepInterpolationRecording(oe_si_recording) - elif preprocess_list['Resample']: + elif preprocess_list["Resample"]: oe_si_recording = sip.ResampleRecording(oe_si_recording) - + if acq_software == "SpikeGLX": # sglx_session_full_path = find_full_path(ephys.get_ephys_root_data_dir(),ephys.get_session_directory(key)) sglx_filepath = ephys.get_spikeglx_meta_filepath(key) # Create SI recording extractor object - sglx_si_recording = se.read_spikeglx(folder_path=sglx_filepath.parent) - + sglx_si_recording = se.read_spikeglx(folder_path=sglx_filepath.parent) + channels_details = ephys.get_recording_channels_details(key) - xy_coords = [list(i) for i in zip(channels_details['x_coords'],channels_details['y_coords'])] - - - # Create SI probe object - si_probe = pi.Probe(ndim=2, si_units='um') - si_probe.set_contacts(positions=xy_coords, shapes='square', shape_params={'width': 12}) - si_probe.create_auto_shape(probe_type='tip') - si_probe.set_device_channel_indices(channels_details['channel_ind']) + xy_coords = [ + list(i) + for i in zip(channels_details["x_coords"], channels_details["y_coords"]) + ] + + # Create SI probe object + si_probe = pi.Probe(ndim=2, si_units="um") + si_probe.set_contacts( + positions=xy_coords, shapes="square", shape_params={"width": 12} + ) + si_probe.create_auto_shape(probe_type="tip") + si_probe.set_device_channel_indices(channels_details["channel_ind"]) sglx_si_recording.set_probe(probe=si_probe) # run preprocessing and save results to output folder - sglx_si_recording_filtered = sip.bandpass_filter(sglx_si_recording, freq_min=300, freq_max=6000) + sglx_si_recording_filtered = sip.bandpass_filter( + sglx_si_recording, freq_min=300, freq_max=6000 + ) # sglx_recording_cmr = sip.common_reference(sglx_si_recording_filtered, reference="global", operator="median") - save_file_name = 'si_recording.pkl' + save_file_name = "si_recording.pkl" save_file_path = kilosort_dir / save_file_name sglx_si_recording_filtered.dump_to_pickle(file_path=save_file_path) - elif acq_software == "Open Ephys": oe_probe = ephys.get_openephys_probe_data(key) - oe_session_full_path = find_full_path(ephys.get_ephys_root_data_dir(),ephys.get_session_directory(key)) - + oe_session_full_path = find_full_path( + ephys.get_ephys_root_data_dir(), ephys.get_session_directory(key) + ) + assert len(oe_probe.recording_info["recording_files"]) == 1 - stream_name = os.path.split(oe_probe.recording_info['recording_files'][0])[1] + stream_name = os.path.split(oe_probe.recording_info["recording_files"][0])[ + 1 + ] # Create SI recording extractor object - # oe_si_recording = se.OpenEphysBinaryRecordingExtractor(folder_path=oe_full_path, stream_name=stream_name) - oe_si_recording = se.read_openephys(folder_path=oe_session_full_path, stream_name=stream_name) + # oe_si_recording = se.OpenEphysBinaryRecordingExtractor(folder_path=oe_full_path, stream_name=stream_name) + oe_si_recording = se.read_openephys( + folder_path=oe_session_full_path, stream_name=stream_name + ) channels_details = ephys.get_recording_channels_details(key) - xy_coords = [list(i) for i in zip(channels_details['x_coords'],channels_details['y_coords'])] - - # Create SI probe object - si_probe = pi.Probe(ndim=2, si_units='um') - si_probe.set_contacts(positions=xy_coords, shapes='square', shape_params={'width': 12}) - si_probe.create_auto_shape(probe_type='tip') - si_probe.set_device_channel_indices(channels_details['channel_ind']) + xy_coords = [ + list(i) + for i in zip(channels_details["x_coords"], channels_details["y_coords"]) + ] + + # Create SI probe object + si_probe = pi.Probe(ndim=2, si_units="um") + si_probe.set_contacts( + positions=xy_coords, shapes="square", shape_params={"width": 12} + ) + si_probe.create_auto_shape(probe_type="tip") + si_probe.set_device_channel_indices(channels_details["channel_ind"]) oe_si_recording.set_probe(probe=si_probe) # run preprocessing and save results to output folder # Switch case to allow for specified preprocessing steps - oe_si_recording_filtered = sip.bandpass_filter(oe_si_recording, freq_min=300, freq_max=6000) - oe_recording_cmr = sip.common_reference(oe_si_recording_filtered, reference="global", operator="median") + oe_si_recording_filtered = sip.bandpass_filter( + oe_si_recording, freq_min=300, freq_max=6000 + ) + oe_recording_cmr = sip.common_reference( + oe_si_recording_filtered, reference="global", operator="median" + ) - save_file_name = 'si_recording.pkl' + save_file_name = "si_recording.pkl" save_file_path = kilosort_dir / save_file_name oe_si_recording_filtered.dump_to_pickle(file_path=save_file_path) @@ -240,15 +263,17 @@ def make(self, key): ).total_seconds() / 3600, } - ) - @schema -class ClusteringModule(dj.Imported): + ) + + +@schema +class SIClustering(dj.Imported): """A processing table to handle each clustering task.""" definition = """ -> PreProcessing --- - sorting_filename: varchar(60) # filename of saved sorting object + sorting_filename: varchar(30) # filename of saved sorting object execution_time: datetime # datetime of the start of this step execution_duration: float # (hour) execution duration """ @@ -263,56 +288,56 @@ def make(self, key): ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key ).fetch1("acq_software", "clustering_method") - params = (PreProcessing & key).fetch1("params") - recording_filename = (PreProcessing & key).fetch1("recording_filename") + params = (PreProcessing & key).fetch1("params") + recording_filename = (PreProcessing & key).fetch1("recording_filename") if acq_software == "SpikeGLX": # sglx_probe = ephys.get_openephys_probe_data(key) recording_fullpath = kilosort_dir / recording_filename - # sglx_si_recording = se.load_from_folder(recording_file) + # sglx_si_recording = se.load_from_folder(recording_file) sglx_si_recording = sic.load_extractor(recording_fullpath) # assert len(oe_probe.recording_info["recording_files"]) == 1 ## Assume that the worker process will trigger this sorting step - # - Will need to store/load the sorter_name, sglx_si_recording object etc. + # - Will need to store/load the sorter_name, sglx_si_recording object etc. # - Store in shared EC2 space accessible by all containers (needs to be mounted) - # - Load into the cloud init script, and + # - Load into the cloud init script, and # - Option A: Can call this function within a separate container within spike_sorting_worker - if clustering_method.startswith('kilosort2.5'): + if clustering_method.startswith("kilosort2.5"): sorter_name = "kilosort2_5" else: sorter_name = clustering_method - # sorting_kilosort = si.run_sorter( - # sorter_name = sorter_name, - # recording = sglx_si_recording, - # output_folder = kilosort_dir, - # docker_image = f"spikeinterface/{sorter_name}-compiled-base:latest", - # **params - # ) - sorting_save_path = kilosort_dir / 'sorting_kilosort.pkl' + sorting_kilosort = si.run_sorter( + sorter_name=sorter_name, + recording=sglx_si_recording, + output_folder=kilosort_dir, + docker_image=f"spikeinterface/{sorter_name}-compiled-base:latest", + **params, + ) + sorting_save_path = kilosort_dir / "sorting_kilosort.pkl" sorting_kilosort.dump_to_pickle(sorting_save_path) elif acq_software == "Open Ephys": oe_probe = ephys.get_openephys_probe_data(key) - oe_si_recording = sic.load_extractor(recording_fullpath) + oe_si_recording = sic.load_extractor(recording_fullpath) assert len(oe_probe.recording_info["recording_files"]) == 1 - if clustering_method.startswith('kilosort2.5'): + if clustering_method.startswith("kilosort2.5"): sorter_name = "kilosort2_5" else: sorter_name = clustering_method - # sorting_kilosort = si.run_sorter( - # sorter_name = sorter_name, - # recording = oe_si_recording, - # output_folder = kilosort_dir, - # docker_image = f"spikeinterface/{sorter_name}-compiled-base:latest", - # **params - # ) - sorting_save_path = kilosort_dir / 'sorting_kilosort.pkl' + sorting_kilosort = si.run_sorter( + sorter_name=sorter_name, + recording=oe_si_recording, + output_folder=kilosort_dir, + docker_image=f"spikeinterface/{sorter_name}-compiled-base:latest", + **params, + ) + sorting_save_path = kilosort_dir / "sorting_kilosort.pkl" sorting_kilosort.dump_to_pickle(sorting_save_path) - # sorting_kilosort.save(folder=kilosort_dir, n_jobs=20, chunk_size=30000) self.insert1( { **key, + "sorting_filename": list(sorting_save_path.parts)[-1], "execution_time": execution_time, "execution_duration": ( datetime.utcnow() - execution_time @@ -321,6 +346,7 @@ def make(self, key): } ) + @schema class PostProcessing(dj.Imported): """A processing table to handle each clustering task.""" @@ -345,53 +371,100 @@ def make(self, key): params = (PreProcessing & key).fetch1("params") if acq_software == "SpikeGLX": - sorting_file = kilosort_dir / 'sorting_kilosort' - recording_file = kilosort_dir / 'sglx_recording_cmr.json' - sglx_si_recording = sic.load_extractor(recording_file) + recording_filename = (PreProcessing & key).fetch1("recording_filename") + sorting_file = kilosort_dir / "sorting_kilosort" + filtered_recording_file = kilosort_dir / recording_filename + sglx_si_recording_filtered = sic.load_extractor(recording_file) sorting_kilosort = sic.load_extractor(sorting_file) - we_kilosort = si.WaveformExtractor.create(sglx_si_recording, sorting_kilosort, "waveforms", remove_if_exists=True) + we_kilosort = si.WaveformExtractor.create( + sglx_si_recording_filtered, + sorting_kilosort, + "waveforms", + remove_if_exists=True, + ) + we_kilosort.set_params(ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500) we_kilosort.run_extract_waveforms(n_jobs=-1, chunk_size=30000) unit_id0 = sorting_kilosort.unit_ids[0] waveforms = we_kilosort.get_waveforms(unit_id0) template = we_kilosort.get_template(unit_id0) snrs = si.compute_snrs(we_kilosort) - - # QC Metrics - si_violations_ratio, isi_violations_rate, isi_violations_count = si.compute_isi_violations(we_kilosort, isi_threshold_ms=1.5) - metrics = si.compute_quality_metrics(we_kilosort, metric_names=["firing_rate","snr","presence_ratio","isi_violation", - "num_spikes","amplitude_cutoff","amplitude_median","sliding_rp_violation","rp_violation","drift"]) + # QC Metrics + ( + si_violations_ratio, + isi_violations_rate, + isi_violations_count, + ) = si.compute_isi_violations(we_kilosort, isi_threshold_ms=1.5) + metrics = si.compute_quality_metrics( + we_kilosort, + metric_names=[ + "firing_rate", + "snr", + "presence_ratio", + "isi_violation", + "num_spikes", + "amplitude_cutoff", + "amplitude_median", + "sliding_rp_violation", + "rp_violation", + "drift", + ], + ) sie.export_report(we_kilosort, kilosort_dir, n_jobs=-1, chunk_size=30000) # ["firing_rate","snr","presence_ratio","isi_violation", # "number_violation","amplitude_cutoff","isolation_distance","l_ratio","d_prime","nn_hit_rate", # "nn_miss_rate","silhouette_core","cumulative_drift","contamination_rate"]) - - we_kilosort.save_to_folder('we_kilosort',kilosort_dir, n_jobs=-1, chunk_size=30000) - + we_savedir = kilosort_dir / "we_kilosort" + we_kilosort.save(we_savedir, n_jobs=-1, chunk_size=30000) elif acq_software == "Open Ephys": - sorting_file = kilosort_dir / 'sorting_kilosort' - recording_file = kilosort_dir / 'sglx_recording_cmr.json' + sorting_file = kilosort_dir / "sorting_kilosort" + recording_file = kilosort_dir / "sglx_recording_cmr.json" sglx_si_recording = sic.load_extractor(recording_file) sorting_kilosort = sic.load_extractor(sorting_file) - we_kilosort = si.WaveformExtractor.create(sglx_si_recording, sorting_kilosort, "waveforms", remove_if_exists=True) + we_kilosort = si.WaveformExtractor.create( + sglx_si_recording, sorting_kilosort, "waveforms", remove_if_exists=True + ) + we_kilosort.set_params(ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500) we_kilosort.run_extract_waveforms(n_jobs=-1, chunk_size=30000) unit_id0 = sorting_kilosort.unit_ids[0] waveforms = we_kilosort.get_waveforms(unit_id0) template = we_kilosort.get_template(unit_id0) snrs = si.compute_snrs(we_kilosort) - - # QC Metrics - si_violations_ratio, isi_violations_rate, isi_violations_count = si.compute_isi_violations(we_kilosort, isi_threshold_ms=1.5) - metrics = si.compute_quality_metrics(we_kilosort, metric_names=["firing_rate","snr","presence_ratio","isi_violation", - "num_spikes","amplitude_cutoff","amplitude_median","sliding_rp_violation","rp_violation","drift"]) + # QC Metrics + # Apply waveform extractor extensions + spike_locations = si.compute_spike_locations(we_kilosort) + spike_amplitudes = si.compute_spike_amplitudes(we_kilosort) + unit_locations = si.compute_unit_locations(we_kilosort) + template_metrics = si.compute_template_metrics(we_kilosort) + noise_levels = si.compute_noise_levels(we_kilosort) + drift_metrics = si.compute_drift_metrics(we_kilosort) + + (isi_violations_ratio, isi_violations_count) = si.compute_isi_violations( + we_kilosort, isi_threshold_ms=1.5 + ) + (isi_histograms, bins) = si.compute_isi_histograms(we_kilosort) + metrics = si.compute_quality_metrics( + we_kilosort, + metric_names=[ + "firing_rate", + "snr", + "presence_ratio", + "isi_violation", + "num_spikes", + "amplitude_cutoff", + "amplitude_median", + # "sliding_rp_violation", + "rp_violation", + "drift", + ], + ) sie.export_report(we_kilosort, kilosort_dir, n_jobs=-1, chunk_size=30000) - we_kilosort.save_to_folder('we_kilosort',kilosort_dir, n_jobs=-1, chunk_size=30000) - + we_kilosort.save("we_kilosort", kilosort_dir, n_jobs=-1, chunk_size=30000) self.insert1( { @@ -409,8 +482,9 @@ def make(self, key): {**key, "clustering_time": datetime.utcnow()}, allow_direct_insert=True ) + ## Example SI parameter set -''' +""" {'detect_threshold': 6, 'projection_threshold': [10, 4], 'preclust_threshold': 8, @@ -447,4 +521,4 @@ def make(self, key): 'ZeroChannelPad': False, 'DeepInterpolation': False, 'Resample': False}} -''' \ No newline at end of file +""" From 1febd7e05111aa19679c64f86947973e0b533ebf Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Fri, 17 Feb 2023 19:37:12 -0600 Subject: [PATCH 031/146] post processing waveform extractor extensions --- element_array_ephys/spike_sorting/si_clustering.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index 13f129d8..a018119d 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -441,8 +441,9 @@ def make(self, key): unit_locations = si.compute_unit_locations(we_kilosort) template_metrics = si.compute_template_metrics(we_kilosort) noise_levels = si.compute_noise_levels(we_kilosort) + pcs = si.compute_principal_components(we_kilosort) drift_metrics = si.compute_drift_metrics(we_kilosort) - + template_similarity = si.compute_tempoate_similarity(we_kilosort) (isi_violations_ratio, isi_violations_count) = si.compute_isi_violations( we_kilosort, isi_threshold_ms=1.5 ) From a478e0679ee5ae3747a324880415f090791cb868 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Mon, 20 Feb 2023 16:21:23 -0600 Subject: [PATCH 032/146] Fix data loading bug related to cluster_groups and KSLabel df key --- element_array_ephys/readers/kilosort.py | 42 +++++++++++++------------ 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/element_array_ephys/readers/kilosort.py b/element_array_ephys/readers/kilosort.py index abddee74..e88ba335 100644 --- a/element_array_ephys/readers/kilosort.py +++ b/element_array_ephys/readers/kilosort.py @@ -1,19 +1,16 @@ -import logging -import pathlib -import re -from datetime import datetime from os import path - -import numpy as np +from datetime import datetime +import pathlib import pandas as pd - +import numpy as np +import re +import logging from .utils import convert_to_number log = logging.getLogger(__name__) class Kilosort: - _kilosort_core_files = [ "params.py", "amplitudes.npy", @@ -118,7 +115,8 @@ def _load(self): # Read the Cluster Groups for cluster_pattern, cluster_col_name in zip( - ["cluster_group.*", "cluster_KSLabel.*"], ["group", "KSLabel"] + ["cluster_group.*", "cluster_KSLabel.*", "cluster_group.*"], + ["group", "KSLabel", "KSLabel"], ): try: cluster_file = next(self._kilosort_dir.glob(cluster_pattern)) @@ -127,22 +125,26 @@ def _load(self): else: cluster_file_suffix = cluster_file.suffix assert cluster_file_suffix in (".tsv", ".xlsx") - break + + if cluster_file_suffix == ".tsv": + df = pd.read_csv(cluster_file, sep="\t", header=0) + elif cluster_file_suffix == ".xlsx": + df = pd.read_excel(cluster_file, engine="openpyxl") + else: + df = pd.read_csv(cluster_file, delimiter="\t") + + try: + self._data["cluster_groups"] = np.array(df[cluster_col_name].values) + self._data["cluster_ids"] = np.array(df["cluster_id"].values) + except KeyError: + continue + else: + break else: raise FileNotFoundError( 'Neither "cluster_groups" nor "cluster_KSLabel" file found!' ) - if cluster_file_suffix == ".tsv": - df = pd.read_csv(cluster_file, sep="\t", header=0) - elif cluster_file_suffix == ".xlsx": - df = pd.read_excel(cluster_file, engine="openpyxl") - else: - df = pd.read_csv(cluster_file, delimiter="\t") - - self._data["cluster_groups"] = np.array(df[cluster_col_name].values) - self._data["cluster_ids"] = np.array(df["cluster_id"].values) - def get_best_channel(self, unit): template_idx = self.data["spike_templates"][ np.where(self.data["spike_clusters"] == unit)[0][0] From 634761ddad17bfd024b84567881499fba5e2d46e Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Tue, 21 Feb 2023 18:13:02 -0600 Subject: [PATCH 033/146] waveform extraction wip --- element_array_ephys/ephys_no_curation.py | 109 ++++++++++-------- .../spike_sorting/si_clustering.py | 4 +- 2 files changed, 63 insertions(+), 50 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index f4ed4b55..69afaea2 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1,17 +1,17 @@ -import gc -import importlib -import inspect +import datajoint as dj import pathlib import re -from decimal import Decimal - -import datajoint as dj import numpy as np +import inspect +import importlib +import gc +from decimal import Decimal import pandas as pd -from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory -from . import ephys_report, get_logger, probe -from .readers import kilosort, openephys, spikeglx +from element_interface.utils import find_root_directory, find_full_path, dict_to_uuid +from .readers import spikeglx, kilosort, openephys +from element_array_ephys import probe, get_logger, ephys_report + log = get_logger(__name__) @@ -19,6 +19,9 @@ _linking_module = None +import spikeinterface +import spikeinterface.full as si + def activate( ephys_schema_name: str, @@ -32,7 +35,7 @@ def activate( Args: ephys_schema_name (str): A string containing the name of the ephys schema. - probe_schema_name (str): A string containing the name of the probe schema. + probe_schema_name (str): A string containing the name of the probe scehma. create_schema (bool): If True, schema will be created in the database. create_tables (bool): If True, tables related to the schema will be created in the database. linking_module (str): A string containing the module name or module containing the required dependencies to activate the schema. @@ -129,7 +132,7 @@ class AcquisitionSoftware(dj.Lookup): """ definition = """ # Name of software used for recording of neuropixels probes - SpikeGLX or Open Ephys - acq_software: varchar(24) + acq_software: varchar(24) """ contents = zip(["SpikeGLX", "Open Ephys"]) @@ -272,11 +275,11 @@ class EphysRecording(dj.Imported): definition = """ # Ephys recording from a probe insertion for a given session. - -> ProbeInsertion + -> ProbeInsertion --- -> probe.ElectrodeConfig -> AcquisitionSoftware - sampling_rate: float # (Hz) + sampling_rate: float # (Hz) recording_datetime: datetime # datetime of the recording from this probe recording_duration: float # (seconds) duration of the recording from this probe """ @@ -315,8 +318,8 @@ def make(self, key): break else: raise FileNotFoundError( - "Ephys recording data not found!" - " Neither SpikeGLX nor Open Ephys recording files found" + f"Ephys recording data not found!" + f" Neither SpikeGLX nor Open Ephys recording files found" ) supported_probe_types = probe.ProbeType.fetch("probe_type") @@ -471,9 +474,9 @@ class Electrode(dj.Part): definition = """ -> master - -> probe.ElectrodeConfig.Electrode + -> probe.ElectrodeConfig.Electrode --- - lfp: longblob # (uV) recorded lfp at this electrode + lfp: longblob # (uV) recorded lfp at this electrode """ # Only store LFP for every 9th channel, due to high channel density, @@ -614,14 +617,14 @@ class ClusteringParamSet(dj.Lookup): ClusteringMethod (dict): ClusteringMethod primary key. paramset_desc (varchar(128) ): Description of the clustering parameter set. param_set_hash (uuid): UUID hash for the parameter set. - params (longblob): Set of clustering parameters + params (longblob) """ definition = """ # Parameter set to be used in a clustering procedure paramset_idx: smallint --- - -> ClusteringMethod + -> ClusteringMethod paramset_desc: varchar(128) param_set_hash: uuid unique index (param_set_hash) @@ -724,18 +727,15 @@ class ClusteringTask(dj.Manual): """ @classmethod - def infer_output_dir( - cls, key, relative: bool = False, mkdir: bool = False - ) -> pathlib.Path: + def infer_output_dir(cls, key, relative: bool = False, mkdir: bool = False): """Infer output directory if it is not provided. Args: key (dict): ClusteringTask primary key. Returns: - Expected clustering_output_dir based on the following convention: - processed_dir / session_dir / probe_{insertion_number} / {clustering_method}_{paramset_idx} - e.g.: sub4/sess1/probe_2/kilosort2_0 + Pathlib.Path: Expected clustering_output_dir based on the following convention: processed_dir / session_dir / probe_{insertion_number} / {clustering_method}_{paramset_idx} + e.g.: sub4/sess1/probe_2/kilosort2_0 """ processed_dir = pathlib.Path(get_processed_root_data_dir()) session_dir = find_full_path( @@ -802,14 +802,14 @@ class Clustering(dj.Imported): Attributes: ClusteringTask (foreign key): ClusteringTask primary key. clustering_time (datetime): Time when clustering results are generated. - package_version (varchar(16): Package version used for a clustering analysis. + package_version (varchar(16) ): Package version used for a clustering analysis. """ definition = """ # Clustering Procedure -> ClusteringTask --- - clustering_time: datetime # time of generation of this set of clustering results + clustering_time: datetime # time of generation of this set of clustering results package_version='': varchar(16) """ @@ -850,10 +850,6 @@ def make(self, key): spikeglx_meta_filepath.parent ) spikeglx_recording.validate_file("ap") - run_CatGT = ( - params.pop("run_CatGT", True) - and "_tcat." not in spikeglx_meta_filepath.stem - ) if clustering_method.startswith("pykilosort"): kilosort_triggering.run_pykilosort( @@ -874,7 +870,7 @@ def make(self, key): ks_output_dir=kilosort_dir, params=params, KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - run_CatGT=run_CatGT, + run_CatGT=True, ) run_kilosort.run_modules() elif acq_software == "Open Ephys": @@ -929,7 +925,7 @@ class CuratedClustering(dj.Imported): definition = """ # Clustering results of the spike sorting step. - -> Clustering + -> Clustering """ class Unit(dj.Part): @@ -946,7 +942,7 @@ class Unit(dj.Part): spike_depths (longblob): Array of depths associated with each spike, relative to each spike. """ - definition = """ + definition = """ # Properties of a given unit from a round of clustering (and curation) -> master unit: int @@ -956,7 +952,7 @@ class Unit(dj.Part): spike_count: int # how many spikes in this recording for this unit spike_times: longblob # (s) spike times of this unit, relative to the start of the EphysRecording spike_sites : longblob # array of electrode associated with each spike - spike_depths=null : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe + spike_depths=null : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe """ def make(self, key): @@ -1080,8 +1076,8 @@ class Waveform(dj.Part): # Spike waveforms and their mean across spikes for the given unit -> master -> CuratedClustering.Unit - -> probe.ElectrodeConfig.Electrode - --- + -> probe.ElectrodeConfig.Electrode + --- waveform_mean: longblob # (uV) mean waveform across spikes of the given unit waveforms=null: longblob # (uV) (spike x sample) waveforms of a sampling of spikes at the given electrode for the given unit """ @@ -1109,15 +1105,32 @@ def make(self, key): for u in (CuratedClustering.Unit & key).fetch(as_dict=True, order_by="unit") } + waveforms_folder = kilosort_dir / "we_kilosort" + + waveforms_folder = kilosort_dir.rglob(*waveform) + # Mean waveforms need to be extracted from waveform extractor object + if (waveforms_folder).exists(): + we_kilosort = si.load_waveforms(waveforms_folder) + unit_waveforms = we_kilosort.get_all_templates() + + def yield_unit_waveforms(): + for unit_no, unit_waveform in zip( + kilosort_dataset.data["cluster_ids"], unit_waveforms + ): + unit_peak_waveform = {} + unit_electrode_waveforms = [] + + if unit_no in units: + unit_waveform = we_kilosort.get_waveforms(unit_id=unit_no) + mean_templates = we_kilosort.get_templates(unit_id=unit_no) + if (kilosort_dir / "mean_waveforms.npy").exists(): unit_waveforms = np.load( kilosort_dir / "mean_waveforms.npy" ) # unit x channel x sample def yield_unit_waveforms(): - for unit_no, unit_waveform in zip( - kilosort_dataset.data["cluster_ids"], unit_waveforms - ): + for unit_no, unit_waveform in zip(cluster_ids, unit_waveforms): unit_peak_waveform = {} unit_electrode_waveforms = [] if unit_no in units: @@ -1207,7 +1220,7 @@ class QualityMetrics(dj.Imported): definition = """ # Clusters and waveforms metrics - -> CuratedClustering + -> CuratedClustering """ class Cluster(dj.Part): @@ -1232,26 +1245,26 @@ class Cluster(dj.Part): contamination_rate (float): Frequency of spikes in the refractory period. """ - definition = """ + definition = """ # Cluster metrics for a particular unit -> master -> CuratedClustering.Unit --- - firing_rate=null: float # (Hz) firing rate for a unit + firing_rate=null: float # (Hz) firing rate for a unit snr=null: float # signal-to-noise ratio for a unit presence_ratio=null: float # fraction of time in which spikes are present isi_violation=null: float # rate of ISI violation as a fraction of overall rate number_violation=null: int # total number of ISI violations amplitude_cutoff=null: float # estimate of miss rate based on amplitude histogram isolation_distance=null: float # distance to nearest cluster in Mahalanobis space - l_ratio=null: float # + l_ratio=null: float # d_prime=null: float # Classification accuracy based on LDA nn_hit_rate=null: float # Fraction of neighbors for target cluster that are also in target cluster nn_miss_rate=null: float # Fraction of neighbors outside target cluster that are in target cluster silhouette_score=null: float # Standard metric for cluster overlap max_drift=null: float # Maximum change in spike depth throughout recording - cumulative_drift=null: float # Cumulative change in spike depth throughout recording - contamination_rate=null: float # + cumulative_drift=null: float # Cumulative change in spike depth throughout recording + contamination_rate=null: float # """ class Waveform(dj.Part): @@ -1268,10 +1281,10 @@ class Waveform(dj.Part): recovery_slope (float): Slope of the regression line fit to first 30 microseconds from peak to tail. spread (float): The range with amplitude over 12-percent of maximum amplitude along the probe. velocity_above (float): inverse velocity of waveform propagation from soma to the top of the probe. - velocity_below (float): inverse velocity of waveform propagation from soma toward the bottom of the probe. + velocity_below (float) inverse velocity of waveform propagation from soma toward the bottom of the probe. """ - definition = """ + definition = """ # Waveform metrics for a particular unit -> master -> CuratedClustering.Unit diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index a018119d..be50356b 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -464,8 +464,8 @@ def make(self, key): ], ) sie.export_report(we_kilosort, kilosort_dir, n_jobs=-1, chunk_size=30000) - - we_kilosort.save("we_kilosort", kilosort_dir, n_jobs=-1, chunk_size=30000) + we_savedir = kilosort_dir / "we_kilosort" + we_kilosort.save(we_savedir, n_jobs=-1, chunk_size=30000) self.insert1( { From ff0dfee68bc45fbc43e42dec19541473ec9090e0 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Wed, 22 Feb 2023 20:08:36 -0600 Subject: [PATCH 034/146] modification to handle spike interface waveforms --- element_array_ephys/ephys_no_curation.py | 84 ++++++++++++++++++++---- 1 file changed, 70 insertions(+), 14 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 69afaea2..85ecb1a7 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1105,13 +1105,14 @@ def make(self, key): for u in (CuratedClustering.Unit & key).fetch(as_dict=True, order_by="unit") } - waveforms_folder = kilosort_dir / "we_kilosort" + waveforms_folder = [ + f for f in kilosort_dir.parent.rglob(r"*/waveforms*") if f.is_dir() + ] - waveforms_folder = kilosort_dir.rglob(*waveform) - # Mean waveforms need to be extracted from waveform extractor object - if (waveforms_folder).exists(): - we_kilosort = si.load_waveforms(waveforms_folder) - unit_waveforms = we_kilosort.get_all_templates() + if (kilosort_dir / "mean_waveforms.npy").exists(): + unit_waveforms = np.load( + kilosort_dir / "mean_waveforms.npy" + ) # unit x channel x sample def yield_unit_waveforms(): for unit_no, unit_waveform in zip( @@ -1119,18 +1120,46 @@ def yield_unit_waveforms(): ): unit_peak_waveform = {} unit_electrode_waveforms = [] - if unit_no in units: - unit_waveform = we_kilosort.get_waveforms(unit_id=unit_no) - mean_templates = we_kilosort.get_templates(unit_id=unit_no) + for channel, channel_waveform in zip( + kilosort_dataset.data["channel_map"], unit_waveform + ): + unit_electrode_waveforms.append( + { + **units[unit_no], + **channel2electrodes[channel], + "waveform_mean": channel_waveform, + } + ) + if ( + channel2electrodes[channel]["electrode"] + == units[unit_no]["electrode"] + ): + unit_peak_waveform = { + **units[unit_no], + "peak_electrode_waveform": channel_waveform, + } + yield unit_peak_waveform, unit_electrode_waveforms - if (kilosort_dir / "mean_waveforms.npy").exists(): - unit_waveforms = np.load( - kilosort_dir / "mean_waveforms.npy" - ) # unit x channel x sample + # Spike interface mean and peak waveform extraction from we object + + elif len(waveforms_folder) > 0 & (waveforms_folder[0]).exists(): + we_kilosort = si.load_waveforms(waveforms_folder[0].parent) + unit_templates = we_kilosort.get_all_templates() + unit_waveforms = np.reshape( + unit_templates, + ( + unit_templates.shape[1], + unit_templates.shape[3], + unit_templates.shape[2], + ), + ) + # Approach assumes unit_waveforms was generated correctly (templates are actually the same as mean_waveforms) def yield_unit_waveforms(): - for unit_no, unit_waveform in zip(cluster_ids, unit_waveforms): + for unit_no, unit_waveform in zip( + kilosort_dataset.data["cluster_ids"], unit_waveforms + ): unit_peak_waveform = {} unit_electrode_waveforms = [] if unit_no in units: @@ -1154,6 +1183,33 @@ def yield_unit_waveforms(): } yield unit_peak_waveform, unit_electrode_waveforms + # Approach not using spike interface templates (ie. taking mean of each unit waveform) + # def yield_unit_waveforms(): + # for unit_id in we_kilosort.unit_ids: + # unit_waveform = np.mean(we_kilosort.get_waveforms(unit_id), 0) + # unit_peak_waveform = {} + # unit_electrode_waveforms = [] + # if unit_id in units: + # for channel, channel_waveform in zip( + # kilosort_dataset.data["channel_map"], unit_waveform + # ): + # unit_electrode_waveforms.append( + # { + # **units[unit_id], + # **channel2electrodes[channel], + # "waveform_mean": channel_waveform, + # } + # ) + # if ( + # channel2electrodes[channel]["electrode"] + # == units[unit_id]["electrode"] + # ): + # unit_peak_waveform = { + # **units[unit_id], + # "peak_electrode_waveform": channel_waveform, + # } + # yield unit_peak_waveform, unit_electrode_waveforms + else: if acq_software == "SpikeGLX": spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) From cb31229e89d6599c71647a3c7bb34e2498dcd192 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Wed, 22 Feb 2023 20:11:37 -0600 Subject: [PATCH 035/146] adjust post processing --- .../spike_sorting/si_clustering.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index be50356b..84f26644 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -130,10 +130,15 @@ def make(self, key): clustering_method in _supported_kilosort_versions ), f'Clustering_method "{clustering_method}" is not supported' + if clustering_method.startswith("kilosort2.5"): + sorter_name = "kilosort2_5" + else: + sorter_name = clustering_method # add additional probe-recording and channels details into `params` # params = {**params, **ephys.get_recording_channels_details(key)} # params["fs"] = params["sample_rate"] + default_params = si.get_default_sorter_params(sorter_name) preprocess_list = params.pop("PreProcessing_params") # If else @@ -406,7 +411,7 @@ def make(self, key): "num_spikes", "amplitude_cutoff", "amplitude_median", - "sliding_rp_violation", + # "sliding_rp_violation", "rp_violation", "drift", ], @@ -436,14 +441,14 @@ def make(self, key): # QC Metrics # Apply waveform extractor extensions - spike_locations = si.compute_spike_locations(we_kilosort) - spike_amplitudes = si.compute_spike_amplitudes(we_kilosort) - unit_locations = si.compute_unit_locations(we_kilosort) - template_metrics = si.compute_template_metrics(we_kilosort) - noise_levels = si.compute_noise_levels(we_kilosort) - pcs = si.compute_principal_components(we_kilosort) - drift_metrics = si.compute_drift_metrics(we_kilosort) - template_similarity = si.compute_tempoate_similarity(we_kilosort) + _ = si.compute_spike_locations(we_kilosort) + _ = si.compute_spike_amplitudes(we_kilosort) + _ = si.compute_unit_locations(we_kilosort) + _ = si.compute_template_metrics(we_kilosort) + _ = si.compute_noise_levels(we_kilosort) + _ = si.compute_principal_components(we_kilosort) + _ = si.compute_drift_metrics(we_kilosort) + _ = si.compute_tempoate_similarity(we_kilosort) (isi_violations_ratio, isi_violations_count) = si.compute_isi_violations( we_kilosort, isi_threshold_ms=1.5 ) From 6098421e9037017be02010672402efd885ce5b24 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Mon, 6 Mar 2023 11:56:38 -0600 Subject: [PATCH 036/146] bugfix in postprocessing definition --- element_array_ephys/spike_sorting/si_clustering.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index 84f26644..f3fd4c1d 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -10,7 +10,7 @@ - create recording extractor and link it to a probe - bandpass filtering - common mode referencing -2. ClusteringModule - kilosort (MATLAB) - requires GPU and docker/singularity containers +2. SIClustering - kilosort (MATLAB) - requires GPU and docker/singularity containers - supports kilosort 2.0, 2.5 or 3.0 (https://github.com/MouseLand/Kilosort.git) 3. PostProcessing - for postprocessing steps (no GPU required) - create waveform extractor object @@ -357,7 +357,7 @@ class PostProcessing(dj.Imported): """A processing table to handle each clustering task.""" definition = """ - -> ClusteringModule + -> SIClustering --- execution_time: datetime # datetime of the start of this step execution_duration: float # (hour) execution duration @@ -426,11 +426,11 @@ def make(self, key): elif acq_software == "Open Ephys": sorting_file = kilosort_dir / "sorting_kilosort" recording_file = kilosort_dir / "sglx_recording_cmr.json" - sglx_si_recording = sic.load_extractor(recording_file) + oe_si_recording = sic.load_extractor(recording_file) sorting_kilosort = sic.load_extractor(sorting_file) we_kilosort = si.WaveformExtractor.create( - sglx_si_recording, sorting_kilosort, "waveforms", remove_if_exists=True + oe_si_recording, sorting_kilosort, "waveforms", remove_if_exists=True ) we_kilosort.set_params(ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500) we_kilosort.run_extract_waveforms(n_jobs=-1, chunk_size=30000) @@ -472,6 +472,9 @@ def make(self, key): we_savedir = kilosort_dir / "we_kilosort" we_kilosort.save(we_savedir, n_jobs=-1, chunk_size=30000) + metrics_savefile = kilosort_dir / "metrics.csv" + metrics.to_csv(metrics_savefile) + self.insert1( { **key, From 60064646c1d3c65d2d37173f0ef686fa5a108387 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Tue, 7 Mar 2023 18:17:35 -0600 Subject: [PATCH 037/146] add SI ibl destriping and catGT implementations --- .../spike_sorting/si_clustering.py | 108 +++++++++++------- 1 file changed, 69 insertions(+), 39 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index f3fd4c1d..cb4e1858 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -141,45 +141,6 @@ def make(self, key): default_params = si.get_default_sorter_params(sorter_name) preprocess_list = params.pop("PreProcessing_params") - # If else - # need to figure out ordering - if preprocess_list["Filter"]: - oe_si_recording = sip.FilterRecording(oe_si_recording) - elif preprocess_list["BandpassFilter"]: - oe_si_recording = sip.BandpassFilterRecording(oe_si_recording) - elif preprocess_list["HighpassFilter"]: - oe_si_recording = sip.HighpassFilterRecording(oe_si_recording) - elif preprocess_list["NormalizeByQuantile"]: - oe_si_recording = sip.NormalizeByQuantileRecording(oe_si_recording) - elif preprocess_list["Scale"]: - oe_si_recording = sip.ScaleRecording(oe_si_recording) - elif preprocess_list["Center"]: - oe_si_recording = sip.CenterRecording(oe_si_recording) - elif preprocess_list["ZScore"]: - oe_si_recording = sip.ZScoreRecording(oe_si_recording) - elif preprocess_list["Whiten"]: - oe_si_recording = sip.WhitenRecording(oe_si_recording) - elif preprocess_list["CommonReference"]: - oe_si_recording = sip.CommonReferenceRecording(oe_si_recording) - elif preprocess_list["PhaseShift"]: - oe_si_recording = sip.PhaseShiftRecording(oe_si_recording) - elif preprocess_list["Rectify"]: - oe_si_recording = sip.RectifyRecording(oe_si_recording) - elif preprocess_list["Clip"]: - oe_si_recording = sip.ClipRecording(oe_si_recording) - elif preprocess_list["BlankSaturation"]: - oe_si_recording = sip.BlankSaturationRecording(oe_si_recording) - elif preprocess_list["RemoveArtifacts"]: - oe_si_recording = sip.RemoveArtifactsRecording(oe_si_recording) - elif preprocess_list["RemoveBadChannels"]: - oe_si_recording = sip.RemoveBadChannelsRecording(oe_si_recording) - elif preprocess_list["ZeroChannelPad"]: - oe_si_recording = sip.ZeroChannelPadRecording(oe_si_recording) - elif preprocess_list["DeepInterpolation"]: - oe_si_recording = sip.DeepInterpolationRecording(oe_si_recording) - elif preprocess_list["Resample"]: - oe_si_recording = sip.ResampleRecording(oe_si_recording) - if acq_software == "SpikeGLX": # sglx_session_full_path = find_full_path(ephys.get_ephys_root_data_dir(),ephys.get_session_directory(key)) sglx_filepath = ephys.get_spikeglx_meta_filepath(key) @@ -212,6 +173,8 @@ def make(self, key): save_file_path = kilosort_dir / save_file_name sglx_si_recording_filtered.dump_to_pickle(file_path=save_file_path) + sglx_si_recording = run_IBLdestriping(sglx_si_recording) + elif acq_software == "Open Ephys": oe_probe = ephys.get_openephys_probe_data(key) oe_session_full_path = find_full_path( @@ -492,6 +455,73 @@ def make(self, key): ) +# def runPreProcessList(preprocess_list, recording): +# # If else +# # need to figure out ordering +# if preprocess_list["Filter"]: +# recording = sip.FilterRecording(recording) +# if preprocess_list["BandpassFilter"]: +# recording = sip.BandpassFilterRecording(recording) +# if preprocess_list["HighpassFilter"]: +# recording = sip.HighpassFilterRecording(recording) +# if preprocess_list["NormalizeByQuantile"]: +# recording = sip.NormalizeByQuantileRecording(recording) +# if preprocess_list["Scale"]: +# recording = sip.ScaleRecording(recording) +# if preprocess_list["Center"]: +# recording = sip.CenterRecording(recording) +# if preprocess_list["ZScore"]: +# recording = sip.ZScoreRecording(recording) +# if preprocess_list["Whiten"]: +# recording = sip.WhitenRecording(recording) +# if preprocess_list["CommonReference"]: +# recording = sip.CommonReferenceRecording(recording) +# if preprocess_list["PhaseShift"]: +# recording = sip.PhaseShiftRecording(recording) +# elif preprocess_list["Rectify"]: +# recording = sip.RectifyRecording(recording) +# elif preprocess_list["Clip"]: +# recording = sip.ClipRecording(recording) +# elif preprocess_list["BlankSaturation"]: +# recording = sip.BlankSaturationRecording(recording) +# elif preprocess_list["RemoveArtifacts"]: +# recording = sip.RemoveArtifactsRecording(recording) +# elif preprocess_list["RemoveBadChannels"]: +# recording = sip.RemoveBadChannelsRecording(recording) +# elif preprocess_list["ZeroChannelPad"]: +# recording = sip.ZeroChannelPadRecording(recording) +# elif preprocess_list["DeepInterpolation"]: +# recording = sip.DeepInterpolationRecording(recording) +# elif preprocess_list["Resample"]: +# recording = sip.ResampleRecording(recording) + + +def mimic_IBLdestriping_modified(recording): + # From SpikeInterface Implementation (https://spikeinterface.readthedocs.io/en/latest/how_to/analyse_neuropixels.html) + recording = si.highpass_filter(recording, freq_min=400.0) + bad_channel_ids, channel_labels = si.detect_bad_channels(recording) + # For IBL destriping interpolate bad channels + recording = recording.remove_channels(bad_channel_ids) + recording = si.phase_shift(recording) + recording = si.common_reference(recording, operator="median", reference="global") + return recording + +def mimic_IBLdestriping(recording): + # From International Brain Laboratory. “Spike sorting pipeline for the International Brain Laboratory”. 4 May 2022. 9 Jun 2022. + recording = si.highpass_filter(recording, freq_min=400.0) + bad_channel_ids, channel_labels = si.detect_bad_channels(recording) + # For IBL destriping interpolate bad channels + recording = sip.interpolate_bad_channels(bad_channel_ids) + recording = si.phase_shift(recording) + recording = si.highpass_spatial_filter(recording, operator="median", reference="global") + # For IBL destriping use highpass_spatial_filter used instead of common reference + return recording + +def mimic_catGT(sglx_recording): + sglx_recording = si.phase_shift(sglx_recording) + sglx_recording = si.common_reference(sglx_recording, operator="median", reference="global") + return sglx_recording + ## Example SI parameter set """ {'detect_threshold': 6, From c050875c3f3d98e085c1a2d6aa591952df124632 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Tue, 7 Mar 2023 18:24:09 -0600 Subject: [PATCH 038/146] remove preprocess params list --- .../spike_sorting/si_clustering.py | 33 +++++++++++-------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index cb4e1858..81e5f0b3 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -139,7 +139,7 @@ def make(self, key): # params["fs"] = params["sample_rate"] default_params = si.get_default_sorter_params(sorter_name) - preprocess_list = params.pop("PreProcessing_params") + # preprocess_list = params.pop("PreProcessing_params") if acq_software == "SpikeGLX": # sglx_session_full_path = find_full_path(ephys.get_ephys_root_data_dir(),ephys.get_session_directory(key)) @@ -173,7 +173,7 @@ def make(self, key): save_file_path = kilosort_dir / save_file_name sglx_si_recording_filtered.dump_to_pickle(file_path=save_file_path) - sglx_si_recording = run_IBLdestriping(sglx_si_recording) + sglx_si_recording = mimic_catGT(sglx_si_recording) elif acq_software == "Open Ephys": oe_probe = ephys.get_openephys_probe_data(key) @@ -208,17 +208,17 @@ def make(self, key): oe_si_recording.set_probe(probe=si_probe) # run preprocessing and save results to output folder - # Switch case to allow for specified preprocessing steps - oe_si_recording_filtered = sip.bandpass_filter( - oe_si_recording, freq_min=300, freq_max=6000 - ) - oe_recording_cmr = sip.common_reference( - oe_si_recording_filtered, reference="global", operator="median" - ) - + # # Switch case to allow for specified preprocessing steps + # oe_si_recording_filtered = sip.bandpass_filter( + # oe_si_recording, freq_min=300, freq_max=6000 + # ) + # oe_recording_cmr = sip.common_reference( + # oe_si_recording_filtered, reference="global", operator="median" + # ) + oe_si_recording = mimic_IBLdestriping(oe_si_recording) save_file_name = "si_recording.pkl" save_file_path = kilosort_dir / save_file_name - oe_si_recording_filtered.dump_to_pickle(file_path=save_file_path) + oe_si_recording.dump_to_pickle(file_path=save_file_path) self.insert1( { @@ -506,6 +506,7 @@ def mimic_IBLdestriping_modified(recording): recording = si.common_reference(recording, operator="median", reference="global") return recording + def mimic_IBLdestriping(recording): # From International Brain Laboratory. “Spike sorting pipeline for the International Brain Laboratory”. 4 May 2022. 9 Jun 2022. recording = si.highpass_filter(recording, freq_min=400.0) @@ -513,15 +514,21 @@ def mimic_IBLdestriping(recording): # For IBL destriping interpolate bad channels recording = sip.interpolate_bad_channels(bad_channel_ids) recording = si.phase_shift(recording) - recording = si.highpass_spatial_filter(recording, operator="median", reference="global") # For IBL destriping use highpass_spatial_filter used instead of common reference + recording = si.highpass_spatial_filter( + recording, operator="median", reference="global" + ) return recording + def mimic_catGT(sglx_recording): sglx_recording = si.phase_shift(sglx_recording) - sglx_recording = si.common_reference(sglx_recording, operator="median", reference="global") + sglx_recording = si.common_reference( + sglx_recording, operator="median", reference="global" + ) return sglx_recording + ## Example SI parameter set """ {'detect_threshold': 6, From 4ea56c0f3d86bd41ad4d623cc31ef82d131f03a3 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Tue, 7 Mar 2023 18:28:18 -0600 Subject: [PATCH 039/146] preprocessing changes --- .../spike_sorting/si_clustering.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index 81e5f0b3..33704081 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -138,7 +138,7 @@ def make(self, key): # params = {**params, **ephys.get_recording_channels_details(key)} # params["fs"] = params["sample_rate"] - default_params = si.get_default_sorter_params(sorter_name) + # default_params = si.get_default_sorter_params(sorter_name) # preprocess_list = params.pop("PreProcessing_params") if acq_software == "SpikeGLX": @@ -163,17 +163,15 @@ def make(self, key): si_probe.set_device_channel_indices(channels_details["channel_ind"]) sglx_si_recording.set_probe(probe=si_probe) - # run preprocessing and save results to output folder - sglx_si_recording_filtered = sip.bandpass_filter( - sglx_si_recording, freq_min=300, freq_max=6000 - ) + # # run preprocessing and save results to output folder + # sglx_si_recording_filtered = sip.bandpass_filter( + # sglx_si_recording, freq_min=300, freq_max=6000 + # ) # sglx_recording_cmr = sip.common_reference(sglx_si_recording_filtered, reference="global", operator="median") - + sglx_si_recording = mimic_catGT(sglx_si_recording) save_file_name = "si_recording.pkl" save_file_path = kilosort_dir / save_file_name - sglx_si_recording_filtered.dump_to_pickle(file_path=save_file_path) - - sglx_si_recording = mimic_catGT(sglx_si_recording) + sglx_si_recording.dump_to_pickle(file_path=save_file_path) elif acq_software == "Open Ephys": oe_probe = ephys.get_openephys_probe_data(key) From 0a875794726cc3d4c481825ed5566ede22f46514 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Mon, 15 May 2023 18:04:32 -0500 Subject: [PATCH 040/146] Update requirements.txt --- requirements.txt | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index 528f6349..0d47a42f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,5 +6,4 @@ plotly pyopenephys>=1.1.6 seaborn scikit-image -spikeinterface -nbformat>=4.2.0 \ No newline at end of file +nbformat>=4.2.0 From 22f1f65fe3773f2e4d8803cab15b694e3921d0a2 Mon Sep 17 00:00:00 2001 From: Sidharth Hulyalkar Date: Wed, 14 Jun 2023 15:36:07 -0500 Subject: [PATCH 041/146] fix spikeglx stream loading --- element_array_ephys/spike_sorting/si_clustering.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index 33704081..f99d63d2 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -146,7 +146,12 @@ def make(self, key): sglx_filepath = ephys.get_spikeglx_meta_filepath(key) # Create SI recording extractor object - sglx_si_recording = se.read_spikeglx(folder_path=sglx_filepath.parent) + stream_name = sglx_filepath.stem.split(".", 1)[1] + sglx_si_recording = se.read_spikeglx( + folder_path=sglx_filepath.parent, + stream_name=stream_name, + stream_id=stream_name, + ) channels_details = ephys.get_recording_channels_details(key) xy_coords = [ From b62f16215efe91c9310383e99a72d7abdc8de983 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Wed, 11 Oct 2023 18:32:49 -0500 Subject: [PATCH 042/146] build: :pushpin: update requirements.txt & add env,.yml --- env.yml | 7 +++++++ requirements.txt | 3 ++- setup.py | 1 + 3 files changed, 10 insertions(+), 1 deletion(-) create mode 100644 env.yml diff --git a/env.yml b/env.yml new file mode 100644 index 00000000..e9b3ce13 --- /dev/null +++ b/env.yml @@ -0,0 +1,7 @@ +channels: + - conda-forge + - defaults +dependencies: + - pip + - python>=3.7,<3.11 +name: element_array_ephys diff --git a/requirements.txt b/requirements.txt index 0d47a42f..721bfeda 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,10 @@ datajoint>=0.13 element-interface>=0.4.0 ipywidgets +nbformat>=4.2.0 openpyxl plotly pyopenephys>=1.1.6 seaborn scikit-image -nbformat>=4.2.0 +spikeinterface \ No newline at end of file diff --git a/setup.py b/setup.py index 31b9be61..cc538478 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,7 @@ setup( name=pkg_name.replace("_", "-"), + python_requires='>=3.7, <3.11', version=__version__, # noqa F821 description="DataJoint Element for Extracellular Array Electrophysiology", long_description=long_description, From 849b576c8982b9231b4b1167740d2d1a2ad1cbdd Mon Sep 17 00:00:00 2001 From: JaerongA Date: Wed, 11 Oct 2023 18:36:51 -0500 Subject: [PATCH 043/146] refactor: :art: clean up spikeinterface import & remove unused import --- .../spike_sorting/si_clustering.py | 157 +++++++++--------- 1 file changed, 77 insertions(+), 80 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index f99d63d2..2cb5bf2e 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -16,16 +16,12 @@ - create waveform extractor object - extract templates, waveforms and snrs - quality_metrics - - """ + import datajoint as dj import os from element_array_ephys import get_logger -from decimal import Decimal -import json -import numpy as np -from datetime import datetime, timedelta +from datetime import datetime from element_interface.utils import find_full_path from element_array_ephys.readers import ( @@ -34,13 +30,7 @@ ) import element_array_ephys.probe as probe -import spikeinterface -import spikeinterface.full as si -import spikeinterface.core as sic -import spikeinterface.extractors as se -import spikeinterface.exporters as sie -import spikeinterface.sorters as ss -import spikeinterface.preprocessing as sip +import spikeinterface as si import probeinterface as pi log = get_logger(__name__) @@ -138,7 +128,7 @@ def make(self, key): # params = {**params, **ephys.get_recording_channels_details(key)} # params["fs"] = params["sample_rate"] - # default_params = si.get_default_sorter_params(sorter_name) + # default_params = si.full.get_default_sorter_params(sorter_name) # preprocess_list = params.pop("PreProcessing_params") if acq_software == "SpikeGLX": @@ -147,7 +137,7 @@ def make(self, key): # Create SI recording extractor object stream_name = sglx_filepath.stem.split(".", 1)[1] - sglx_si_recording = se.read_spikeglx( + sglx_si_recording = si.extractors.read_spikeglx( folder_path=sglx_filepath.parent, stream_name=stream_name, stream_id=stream_name, @@ -169,10 +159,10 @@ def make(self, key): sglx_si_recording.set_probe(probe=si_probe) # # run preprocessing and save results to output folder - # sglx_si_recording_filtered = sip.bandpass_filter( + # sglx_si_recording_filtered = si.preprocessing.bandpass_filter( # sglx_si_recording, freq_min=300, freq_max=6000 # ) - # sglx_recording_cmr = sip.common_reference(sglx_si_recording_filtered, reference="global", operator="median") + # sglx_recording_cmr = si.preprocessing.common_reference(sglx_si_recording_filtered, reference="global", operator="median") sglx_si_recording = mimic_catGT(sglx_si_recording) save_file_name = "si_recording.pkl" save_file_path = kilosort_dir / save_file_name @@ -190,8 +180,8 @@ def make(self, key): ] # Create SI recording extractor object - # oe_si_recording = se.OpenEphysBinaryRecordingExtractor(folder_path=oe_full_path, stream_name=stream_name) - oe_si_recording = se.read_openephys( + # oe_si_recording = si.extractors.OpenEphysBinaryRecordingExtractor(folder_path=oe_full_path, stream_name=stream_name) + oe_si_recording = si.extractors.read_openephys( folder_path=oe_session_full_path, stream_name=stream_name ) @@ -212,10 +202,10 @@ def make(self, key): # run preprocessing and save results to output folder # # Switch case to allow for specified preprocessing steps - # oe_si_recording_filtered = sip.bandpass_filter( + # oe_si_recording_filtered = si.preprocessing.bandpass_filter( # oe_si_recording, freq_min=300, freq_max=6000 # ) - # oe_recording_cmr = sip.common_reference( + # oe_recording_cmr = si.preprocessing.common_reference( # oe_si_recording_filtered, reference="global", operator="median" # ) oe_si_recording = mimic_IBLdestriping(oe_si_recording) @@ -265,8 +255,8 @@ def make(self, key): if acq_software == "SpikeGLX": # sglx_probe = ephys.get_openephys_probe_data(key) recording_fullpath = kilosort_dir / recording_filename - # sglx_si_recording = se.load_from_folder(recording_file) - sglx_si_recording = sic.load_extractor(recording_fullpath) + # sglx_si_recording = si.extractors.load_from_folder(recording_file) + sglx_si_recording = si.core.load_extractor(recording_fullpath) # assert len(oe_probe.recording_info["recording_files"]) == 1 ## Assume that the worker process will trigger this sorting step @@ -278,7 +268,7 @@ def make(self, key): sorter_name = "kilosort2_5" else: sorter_name = clustering_method - sorting_kilosort = si.run_sorter( + sorting_kilosort = si.full.run_sorter( sorter_name=sorter_name, recording=sglx_si_recording, output_folder=kilosort_dir, @@ -289,13 +279,13 @@ def make(self, key): sorting_kilosort.dump_to_pickle(sorting_save_path) elif acq_software == "Open Ephys": oe_probe = ephys.get_openephys_probe_data(key) - oe_si_recording = sic.load_extractor(recording_fullpath) + oe_si_recording = si.core.load_extractor(recording_fullpath) assert len(oe_probe.recording_info["recording_files"]) == 1 if clustering_method.startswith("kilosort2.5"): sorter_name = "kilosort2_5" else: sorter_name = clustering_method - sorting_kilosort = si.run_sorter( + sorting_kilosort = si.full.run_sorter( sorter_name=sorter_name, recording=oe_si_recording, output_folder=kilosort_dir, @@ -345,10 +335,10 @@ def make(self, key): recording_filename = (PreProcessing & key).fetch1("recording_filename") sorting_file = kilosort_dir / "sorting_kilosort" filtered_recording_file = kilosort_dir / recording_filename - sglx_si_recording_filtered = sic.load_extractor(recording_file) - sorting_kilosort = sic.load_extractor(sorting_file) + sglx_si_recording_filtered = si.core.load_extractor(recording_file) + sorting_kilosort = si.core.load_extractor(sorting_file) - we_kilosort = si.WaveformExtractor.create( + we_kilosort = si.full.WaveformExtractor.create( sglx_si_recording_filtered, sorting_kilosort, "waveforms", @@ -359,15 +349,15 @@ def make(self, key): unit_id0 = sorting_kilosort.unit_ids[0] waveforms = we_kilosort.get_waveforms(unit_id0) template = we_kilosort.get_template(unit_id0) - snrs = si.compute_snrs(we_kilosort) + snrs = si.full.compute_snrs(we_kilosort) # QC Metrics ( si_violations_ratio, isi_violations_rate, isi_violations_count, - ) = si.compute_isi_violations(we_kilosort, isi_threshold_ms=1.5) - metrics = si.compute_quality_metrics( + ) = si.full.compute_isi_violations(we_kilosort, isi_threshold_ms=1.5) + metrics = si.full.compute_quality_metrics( we_kilosort, metric_names=[ "firing_rate", @@ -382,7 +372,9 @@ def make(self, key): "drift", ], ) - sie.export_report(we_kilosort, kilosort_dir, n_jobs=-1, chunk_size=30000) + si.exporters.export_report( + we_kilosort, kilosort_dir, n_jobs=-1, chunk_size=30000 + ) # ["firing_rate","snr","presence_ratio","isi_violation", # "number_violation","amplitude_cutoff","isolation_distance","l_ratio","d_prime","nn_hit_rate", # "nn_miss_rate","silhouette_core","cumulative_drift","contamination_rate"]) @@ -392,10 +384,10 @@ def make(self, key): elif acq_software == "Open Ephys": sorting_file = kilosort_dir / "sorting_kilosort" recording_file = kilosort_dir / "sglx_recording_cmr.json" - oe_si_recording = sic.load_extractor(recording_file) - sorting_kilosort = sic.load_extractor(sorting_file) + oe_si_recording = si.core.load_extractor(recording_file) + sorting_kilosort = si.core.load_extractor(sorting_file) - we_kilosort = si.WaveformExtractor.create( + we_kilosort = si.full.WaveformExtractor.create( oe_si_recording, sorting_kilosort, "waveforms", remove_if_exists=True ) we_kilosort.set_params(ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500) @@ -403,23 +395,24 @@ def make(self, key): unit_id0 = sorting_kilosort.unit_ids[0] waveforms = we_kilosort.get_waveforms(unit_id0) template = we_kilosort.get_template(unit_id0) - snrs = si.compute_snrs(we_kilosort) + snrs = si.full.compute_snrs(we_kilosort) # QC Metrics # Apply waveform extractor extensions - _ = si.compute_spike_locations(we_kilosort) - _ = si.compute_spike_amplitudes(we_kilosort) - _ = si.compute_unit_locations(we_kilosort) - _ = si.compute_template_metrics(we_kilosort) - _ = si.compute_noise_levels(we_kilosort) - _ = si.compute_principal_components(we_kilosort) - _ = si.compute_drift_metrics(we_kilosort) - _ = si.compute_tempoate_similarity(we_kilosort) - (isi_violations_ratio, isi_violations_count) = si.compute_isi_violations( - we_kilosort, isi_threshold_ms=1.5 - ) - (isi_histograms, bins) = si.compute_isi_histograms(we_kilosort) - metrics = si.compute_quality_metrics( + _ = si.full.compute_spike_locations(we_kilosort) + _ = si.full.compute_spike_amplitudes(we_kilosort) + _ = si.full.compute_unit_locations(we_kilosort) + _ = si.full.compute_template_metrics(we_kilosort) + _ = si.full.compute_noise_levels(we_kilosort) + _ = si.full.compute_principal_components(we_kilosort) + _ = si.full.compute_drift_metrics(we_kilosort) + _ = si.full.compute_tempoate_similarity(we_kilosort) + ( + isi_violations_ratio, + isi_violations_count, + ) = si.full.compute_isi_violations(we_kilosort, isi_threshold_ms=1.5) + (isi_histograms, bins) = si.full.compute_isi_histograms(we_kilosort) + metrics = si.full.compute_quality_metrics( we_kilosort, metric_names=[ "firing_rate", @@ -434,7 +427,9 @@ def make(self, key): "drift", ], ) - sie.export_report(we_kilosort, kilosort_dir, n_jobs=-1, chunk_size=30000) + si.exporters.export_report( + we_kilosort, kilosort_dir, n_jobs=-1, chunk_size=30000 + ) we_savedir = kilosort_dir / "we_kilosort" we_kilosort.save(we_savedir, n_jobs=-1, chunk_size=30000) @@ -462,71 +457,73 @@ def make(self, key): # # If else # # need to figure out ordering # if preprocess_list["Filter"]: -# recording = sip.FilterRecording(recording) +# recording = si.preprocessing.FilterRecording(recording) # if preprocess_list["BandpassFilter"]: -# recording = sip.BandpassFilterRecording(recording) +# recording = si.preprocessing.BandpassFilterRecording(recording) # if preprocess_list["HighpassFilter"]: -# recording = sip.HighpassFilterRecording(recording) +# recording = si.preprocessing.HighpassFilterRecording(recording) # if preprocess_list["NormalizeByQuantile"]: -# recording = sip.NormalizeByQuantileRecording(recording) +# recording = si.preprocessing.NormalizeByQuantileRecording(recording) # if preprocess_list["Scale"]: -# recording = sip.ScaleRecording(recording) +# recording = si.preprocessing.ScaleRecording(recording) # if preprocess_list["Center"]: -# recording = sip.CenterRecording(recording) +# recording = si.preprocessing.CenterRecording(recording) # if preprocess_list["ZScore"]: -# recording = sip.ZScoreRecording(recording) +# recording = si.preprocessing.ZScoreRecording(recording) # if preprocess_list["Whiten"]: -# recording = sip.WhitenRecording(recording) +# recording = si.preprocessing.WhitenRecording(recording) # if preprocess_list["CommonReference"]: -# recording = sip.CommonReferenceRecording(recording) +# recording = si.preprocessing.CommonReferenceRecording(recording) # if preprocess_list["PhaseShift"]: -# recording = sip.PhaseShiftRecording(recording) +# recording = si.preprocessing.PhaseShiftRecording(recording) # elif preprocess_list["Rectify"]: -# recording = sip.RectifyRecording(recording) +# recording = si.preprocessing.RectifyRecording(recording) # elif preprocess_list["Clip"]: -# recording = sip.ClipRecording(recording) +# recording = si.preprocessing.ClipRecording(recording) # elif preprocess_list["BlankSaturation"]: -# recording = sip.BlankSaturationRecording(recording) +# recording = si.preprocessing.BlankSaturationRecording(recording) # elif preprocess_list["RemoveArtifacts"]: -# recording = sip.RemoveArtifactsRecording(recording) +# recording = si.preprocessing.RemoveArtifactsRecording(recording) # elif preprocess_list["RemoveBadChannels"]: -# recording = sip.RemoveBadChannelsRecording(recording) +# recording = si.preprocessing.RemoveBadChannelsRecording(recording) # elif preprocess_list["ZeroChannelPad"]: -# recording = sip.ZeroChannelPadRecording(recording) +# recording = si.preprocessing.ZeroChannelPadRecording(recording) # elif preprocess_list["DeepInterpolation"]: -# recording = sip.DeepInterpolationRecording(recording) +# recording = si.preprocessing.DeepInterpolationRecording(recording) # elif preprocess_list["Resample"]: -# recording = sip.ResampleRecording(recording) +# recording = si.preprocessing.ResampleRecording(recording) def mimic_IBLdestriping_modified(recording): # From SpikeInterface Implementation (https://spikeinterface.readthedocs.io/en/latest/how_to/analyse_neuropixels.html) - recording = si.highpass_filter(recording, freq_min=400.0) - bad_channel_ids, channel_labels = si.detect_bad_channels(recording) + recording = si.full.highpass_filter(recording, freq_min=400.0) + bad_channel_ids, channel_labels = si.full.detect_bad_channels(recording) # For IBL destriping interpolate bad channels recording = recording.remove_channels(bad_channel_ids) - recording = si.phase_shift(recording) - recording = si.common_reference(recording, operator="median", reference="global") + recording = si.full.phase_shift(recording) + recording = si.full.common_reference( + recording, operator="median", reference="global" + ) return recording def mimic_IBLdestriping(recording): # From International Brain Laboratory. “Spike sorting pipeline for the International Brain Laboratory”. 4 May 2022. 9 Jun 2022. - recording = si.highpass_filter(recording, freq_min=400.0) - bad_channel_ids, channel_labels = si.detect_bad_channels(recording) + recording = si.full.highpass_filter(recording, freq_min=400.0) + bad_channel_ids, channel_labels = si.full.detect_bad_channels(recording) # For IBL destriping interpolate bad channels - recording = sip.interpolate_bad_channels(bad_channel_ids) - recording = si.phase_shift(recording) + recording = si.preprocessing.interpolate_bad_channels(bad_channel_ids) + recording = si.full.phase_shift(recording) # For IBL destriping use highpass_spatial_filter used instead of common reference - recording = si.highpass_spatial_filter( + recording = si.full.highpass_spatial_filter( recording, operator="median", reference="global" ) return recording def mimic_catGT(sglx_recording): - sglx_recording = si.phase_shift(sglx_recording) - sglx_recording = si.common_reference( + sglx_recording = si.full.phase_shift(sglx_recording) + sglx_recording = si.full.common_reference( sglx_recording, operator="median", reference="global" ) return sglx_recording From 7836a8b4e6c600ba31b8af7efd0e17030e25b158 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 13 Oct 2023 16:47:09 -0500 Subject: [PATCH 044/146] modify key_source in PreProcessing --- element_array_ephys/spike_sorting/si_clustering.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index 2cb5bf2e..e8b6517c 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -85,11 +85,11 @@ class PreProcessing(dj.Imported): @property def key_source(self): - return ( + return (( ephys.ClusteringTask * ephys.ClusteringParamSet & {"task_mode": "trigger"} & 'clustering_method in ("kilosort2", "kilosort2.5", "kilosort3")' - ) - ephys.Clustering + ) - ephys.Clustering).proj() def make(self, key): """Triggers or imports clustering analysis.""" From 6bee166f5fe356ddea10e1a5b391e6daf6929ec9 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Thu, 14 Dec 2023 16:31:31 -0600 Subject: [PATCH 045/146] feat: :sparkles: improve to_probeinterface --- element_array_ephys/readers/probe_geometry.py | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/element_array_ephys/readers/probe_geometry.py b/element_array_ephys/readers/probe_geometry.py index 11e3ae99..7247abe9 100644 --- a/element_array_ephys/readers/probe_geometry.py +++ b/element_array_ephys/readers/probe_geometry.py @@ -132,8 +132,8 @@ def build_npx_probe( return elec_pos_df -def to_probeinterface(electrodes_df): - from probeinterface import Probe +def to_probeinterface(electrodes_df, **kwargs): + import probeinterface as pi probe_df = electrodes_df.copy() probe_df.rename( @@ -145,10 +145,22 @@ def to_probeinterface(electrodes_df): }, inplace=True, ) - probe_df["contact_shapes"] = "square" - probe_df["width"] = 12 - - return Probe.from_dataframe(probe_df) + # Get the contact shapes. By default, it's set to circle with a radius of 10. + contact_shapes = kwargs.get("contact_shapes", "circle") + assert ( + contact_shapes in pi.probe._possible_contact_shapes + ), f"contacts shape should be in {pi.probe._possible_contact_shapes}" + + probe_df["contact_shapes"] = contact_shapes + if contact_shapes == "circle": + probe_df["radius"] = kwargs.get("radius", 10) + elif contact_shapes == "square": + probe_df["width"] = kwargs.get("width", 10) + elif contact_shapes == "rect": + probe_df["width"] = kwargs.get("width") + probe_df["height"] = kwargs.get("height") + + return pi.Probe.from_dataframe(probe_df) def build_electrode_layouts( From 9ae6b4492583429db9931eda93c37d5092b05e8e Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 15 Dec 2023 17:10:07 -0600 Subject: [PATCH 046/146] create preprocessing.py --- .../spike_sorting/preprocessing.py | 85 +++++++++++++++++++ .../spike_sorting/si_clustering.py | 11 ++- 2 files changed, 95 insertions(+), 1 deletion(-) create mode 100644 element_array_ephys/spike_sorting/preprocessing.py diff --git a/element_array_ephys/spike_sorting/preprocessing.py b/element_array_ephys/spike_sorting/preprocessing.py new file mode 100644 index 00000000..77a95792 --- /dev/null +++ b/element_array_ephys/spike_sorting/preprocessing.py @@ -0,0 +1,85 @@ +import spikeinterface as si +from spikeinterface import preprocessing + + +def mimic_catGT(recording): + recording = si.preprocessing.phase_shift(recording) + recording = si.preprocessing.common_reference( + recording, operator="median", reference="global" + ) + return recording + + +def mimic_IBLdestriping(recording): + # From International Brain Laboratory. “Spike sorting pipeline for the International Brain Laboratory”. 4 May 2022. 9 Jun 2022. + recording = si.preprocessing.highpass_filter(recording, freq_min=400.0) + bad_channel_ids, channel_labels = si.preprocessing.detect_bad_channels(recording) + # For IBL destriping interpolate bad channels + recording = si.preprocessing.interpolate_bad_channels(bad_channel_ids) + recording = si.preprocessing.phase_shift(recording) + # For IBL destriping use highpass_spatial_filter used instead of common reference + recording = si.preprocessing.highpass_spatial_filter( + recording, operator="median", reference="global" + ) + return recording + + +def mimic_IBLdestriping_modified(recording): + # From SpikeInterface Implementation (https://spikeinterface.readthedocs.io/en/latest/how_to/analyse_neuropixels.html) + recording = si.preprocessing.highpass_filter(recording, freq_min=400.0) + bad_channel_ids, channel_labels = si.preprocessing.detect_bad_channels(recording) + # For IBL destriping interpolate bad channels + recording = recording.remove_channels(bad_channel_ids) + recording = si.preprocessing.phase_shift(recording) + recording = si.preprocessing.common_reference( + recording, operator="median", reference="global" + ) + return recording + + +_preprocessing_function = { + "catGT": mimic_catGT, + "IBLdestriping": mimic_IBLdestriping, + "IBLdestriping_modified": mimic_IBLdestriping_modified, +} + + +## Example SI parameter set +""" +{'detect_threshold': 6, + 'projection_threshold': [10, 4], + 'preclust_threshold': 8, + 'car': True, + 'minFR': 0.02, + 'minfr_goodchannels': 0.1, + 'nblocks': 5, + 'sig': 20, + 'freq_min': 150, + 'sigmaMask': 30, + 'nPCs': 3, + 'ntbuff': 64, + 'nfilt_factor': 4, + 'NT': None, + 'do_correction': True, + 'wave_length': 61, + 'keep_good_only': False, + 'PreProcessing_params': {'Filter': False, + 'BandpassFilter': True, + 'HighpassFilter': False, + 'NotchFilter': False, + 'NormalizeByQuantile': False, + 'Scale': False, + 'Center': False, + 'ZScore': False, + 'Whiten': False, + 'CommonReference': False, + 'PhaseShift': False, + 'Rectify': False, + 'Clip': False, + 'BlankSaturation': False, + 'RemoveArtifacts': False, + 'RemoveBadChannels': False, + 'ZeroChannelPad': False, + 'DeepInterpolation': False, + 'Resample': False}} +""" diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index e8b6517c..a8d5d8c0 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -31,7 +31,16 @@ import element_array_ephys.probe as probe import spikeinterface as si -import probeinterface as pi +from element_interface.utils import find_full_path, find_root_directory +from spikeinterface import sorters + +from element_array_ephys import get_logger, probe, readers + +from .preprocessing import ( + mimic_catGT, + mimic_IBLdestriping, + mimic_IBLdestriping_modified, +) log = get_logger(__name__) From 9d5eee66aea3d9b967bd22beeefd2651b93a3b6b Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 15 Dec 2023 17:12:19 -0600 Subject: [PATCH 047/146] add SI_SORTERS , SI_READERS --- .../spike_sorting/si_clustering.py | 27 ++++++++----------- 1 file changed, 11 insertions(+), 16 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index a8d5d8c0..b9e9cb2e 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -18,18 +18,10 @@ - quality_metrics """ -import datajoint as dj -import os -from element_array_ephys import get_logger from datetime import datetime -from element_interface.utils import find_full_path -from element_array_ephys.readers import ( - spikeglx, - kilosort_triggering, -) -import element_array_ephys.probe as probe - +import datajoint as dj +import probeinterface as pi import spikeinterface as si from element_interface.utils import find_full_path, find_root_directory from spikeinterface import sorters @@ -48,12 +40,6 @@ ephys = None -_supported_kilosort_versions = [ - "kilosort2", - "kilosort2.5", - "kilosort3", -] - def activate( schema_name, @@ -79,6 +65,15 @@ def activate( ) +SI_SORTERS = [s.replace(".", "_") for s in si.sorters.sorter_dict.keys()] + +SI_READERS = { + "Open Ephys": si.extractors.read_openephys, + "SpikeGLX": si.extractors.read_spikeglx, + "Intan": si.extractors.read_intan, +} + + @schema class PreProcessing(dj.Imported): """A table to handle preprocessing of each clustering task.""" From 1fceb90a1816613a0e86f2f7288ba56ba254de40 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 15 Dec 2023 17:15:45 -0600 Subject: [PATCH 048/146] feat: :art: si_clustering.PreProcessing --- .../spike_sorting/si_clustering.py | 149 +++++------------- 1 file changed, 37 insertions(+), 112 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index b9e9cb2e..5510436f 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -76,34 +76,31 @@ def activate( @schema class PreProcessing(dj.Imported): - """A table to handle preprocessing of each clustering task.""" + """A table to handle preprocessing of each clustering task. The output will be serialized and stored as a si_recording.pkl in the output directory.""" definition = """ -> ephys.ClusteringTask --- - recording_filename: varchar(30) # filename where recording object is saved to - params: longblob # finalized parameterset for this run execution_time: datetime # datetime of the start of this step - execution_duration: float # (hour) execution duration + execution_duration: float # execution duration in hours """ @property def key_source(self): - return (( + return ( ephys.ClusteringTask * ephys.ClusteringParamSet & {"task_mode": "trigger"} - & 'clustering_method in ("kilosort2", "kilosort2.5", "kilosort3")' - ) - ephys.Clustering).proj() + & f"clustering_method in {tuple(SI_SORTERS)}" + ) - ephys.Clustering def make(self, key): """Triggers or imports clustering analysis.""" execution_time = datetime.utcnow() - task_mode, output_dir = (ephys.ClusteringTask & key).fetch1( - "task_mode", "clustering_output_dir" - ) - - assert task_mode == "trigger", 'Supporting "trigger" task_mode only' + # Set the output directory + acq_software, output_dir = ( + ephys.ClusteringTask * ephys.EphysRecording & key + ).fetch1("acq_software", "clustering_output_dir") if not output_dir: output_dir = ephys.ClusteringTask.infer_output_dir( @@ -113,115 +110,43 @@ def make(self, key): ephys.ClusteringTask.update1( {**key, "clustering_output_dir": output_dir.as_posix()} ) + output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - - acq_software, clustering_method, params = ( - ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key - ).fetch1("acq_software", "clustering_method", "params") - - assert ( - clustering_method in _supported_kilosort_versions - ), f'Clustering_method "{clustering_method}" is not supported' - - if clustering_method.startswith("kilosort2.5"): - sorter_name = "kilosort2_5" - else: - sorter_name = clustering_method - # add additional probe-recording and channels details into `params` - # params = {**params, **ephys.get_recording_channels_details(key)} - # params["fs"] = params["sample_rate"] - - # default_params = si.full.get_default_sorter_params(sorter_name) - # preprocess_list = params.pop("PreProcessing_params") - - if acq_software == "SpikeGLX": - # sglx_session_full_path = find_full_path(ephys.get_ephys_root_data_dir(),ephys.get_session_directory(key)) - sglx_filepath = ephys.get_spikeglx_meta_filepath(key) - - # Create SI recording extractor object - stream_name = sglx_filepath.stem.split(".", 1)[1] - sglx_si_recording = si.extractors.read_spikeglx( - folder_path=sglx_filepath.parent, - stream_name=stream_name, - stream_id=stream_name, - ) - - channels_details = ephys.get_recording_channels_details(key) - xy_coords = [ - list(i) - for i in zip(channels_details["x_coords"], channels_details["y_coords"]) - ] - - # Create SI probe object - si_probe = pi.Probe(ndim=2, si_units="um") - si_probe.set_contacts( - positions=xy_coords, shapes="square", shape_params={"width": 12} - ) - si_probe.create_auto_shape(probe_type="tip") - si_probe.set_device_channel_indices(channels_details["channel_ind"]) - sglx_si_recording.set_probe(probe=si_probe) - - # # run preprocessing and save results to output folder - # sglx_si_recording_filtered = si.preprocessing.bandpass_filter( - # sglx_si_recording, freq_min=300, freq_max=6000 - # ) - # sglx_recording_cmr = si.preprocessing.common_reference(sglx_si_recording_filtered, reference="global", operator="median") - sglx_si_recording = mimic_catGT(sglx_si_recording) - save_file_name = "si_recording.pkl" - save_file_path = kilosort_dir / save_file_name - sglx_si_recording.dump_to_pickle(file_path=save_file_path) - - elif acq_software == "Open Ephys": - oe_probe = ephys.get_openephys_probe_data(key) - oe_session_full_path = find_full_path( - ephys.get_ephys_root_data_dir(), ephys.get_session_directory(key) - ) + # Create SI recording extractor object + si_recording: si.BaseRecording = SI_READERS[acq_software]( + folder_path=output_dir + ) - assert len(oe_probe.recording_info["recording_files"]) == 1 - stream_name = os.path.split(oe_probe.recording_info["recording_files"][0])[ - 1 - ] - - # Create SI recording extractor object - # oe_si_recording = si.extractors.OpenEphysBinaryRecordingExtractor(folder_path=oe_full_path, stream_name=stream_name) - oe_si_recording = si.extractors.read_openephys( - folder_path=oe_session_full_path, stream_name=stream_name + # Add probe information to recording object + electrode_config_key = ( + probe.ElectrodeConfig * ephys.EphysRecording & key + ).fetch1("KEY") + electrodes_df = ( + ( + probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode + & electrode_config_key ) + .fetch(format="frame") + .reset_index()[["electrode", "x_coord", "y_coord", "shank"]] + ) - channels_details = ephys.get_recording_channels_details(key) - xy_coords = [ - list(i) - for i in zip(channels_details["x_coords"], channels_details["y_coords"]) - ] + # Create SI probe object + si_probe = readers.probe_geometry.to_probeinterface(electrodes_df) + si_recording.set_probe(probe=si_probe, in_place=True) - # Create SI probe object - si_probe = pi.Probe(ndim=2, si_units="um") - si_probe.set_contacts( - positions=xy_coords, shapes="square", shape_params={"width": 12} - ) - si_probe.create_auto_shape(probe_type="tip") - si_probe.set_device_channel_indices(channels_details["channel_ind"]) - oe_si_recording.set_probe(probe=si_probe) - - # run preprocessing and save results to output folder - # # Switch case to allow for specified preprocessing steps - # oe_si_recording_filtered = si.preprocessing.bandpass_filter( - # oe_si_recording, freq_min=300, freq_max=6000 - # ) - # oe_recording_cmr = si.preprocessing.common_reference( - # oe_si_recording_filtered, reference="global", operator="median" - # ) - oe_si_recording = mimic_IBLdestriping(oe_si_recording) - save_file_name = "si_recording.pkl" - save_file_path = kilosort_dir / save_file_name - oe_si_recording.dump_to_pickle(file_path=save_file_path) + # Run preprocessing and save results to output folder + preprocessing_method = "catGT" # where to load this info? + si_recording = { + "catGT": mimic_catGT, + "IBLdestriping": mimic_IBLdestriping, + "IBLdestriping_modified": mimic_IBLdestriping_modified, + }[preprocessing_method](si_recording) + recording_file_name = output_dir / "si_recording.pkl" + si_recording.dump_to_pickle(file_path=recording_file_name) self.insert1( { **key, - "recording_filename": save_file_name, - "params": params, "execution_time": execution_time, "execution_duration": ( datetime.utcnow() - execution_time From df8ed7464ebc3452148c0daa1f1e43987d7035ec Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 19 Dec 2023 21:53:43 -0600 Subject: [PATCH 049/146] feat: :art: si_clustering.SIClustering --- .../spike_sorting/si_clustering.py | 93 +++++++------------ 1 file changed, 35 insertions(+), 58 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index 5510436f..debc4336 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -21,10 +21,11 @@ from datetime import datetime import datajoint as dj +import pandas as pd import probeinterface as pi import spikeinterface as si from element_interface.utils import find_full_path, find_root_directory -from spikeinterface import sorters +from spikeinterface import exporters, qualitymetrics, sorters from element_array_ephys import get_logger, probe, readers @@ -65,7 +66,7 @@ def activate( ) -SI_SORTERS = [s.replace(".", "_") for s in si.sorters.sorter_dict.keys()] +SI_SORTERS = [s.replace("_", ".") for s in si.sorters.sorter_dict.keys()] SI_READERS = { "Open Ephys": si.extractors.read_openephys, @@ -141,8 +142,8 @@ def make(self, key): "IBLdestriping": mimic_IBLdestriping, "IBLdestriping_modified": mimic_IBLdestriping_modified, }[preprocessing_method](si_recording) - recording_file_name = output_dir / "si_recording.pkl" - si_recording.dump_to_pickle(file_path=recording_file_name) + recording_file = output_dir / "si_recording.pkl" + si_recording.dump_to_pickle(file_path=recording_file) self.insert1( { @@ -162,72 +163,48 @@ class SIClustering(dj.Imported): definition = """ -> PreProcessing + sorter_name: varchar(30) # name of the sorter used --- - sorting_filename: varchar(30) # filename of saved sorting object - execution_time: datetime # datetime of the start of this step - execution_duration: float # (hour) execution duration + execution_time: datetime # datetime of the start of this step + execution_duration: float # execution duration in hours """ def make(self, key): execution_time = datetime.utcnow() + # Load recording object. output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") - kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - - acq_software, clustering_method = ( - ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key - ).fetch1("acq_software", "clustering_method") - - params = (PreProcessing & key).fetch1("params") - recording_filename = (PreProcessing & key).fetch1("recording_filename") + output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + recording_file = output_dir / "si_recording.pkl" + si_recording: si.BaseRecording = si.load_extractor(recording_file) + + # Get sorter method and create output directory. + clustering_method, params = ( + ephys.ClusteringTask * ephys.ClusteringParamSet & key + ).fetch1("clustering_method", "params") + sorter_name = ( + "kilosort_2_5" if clustering_method == "kilsort2.5" else clustering_method + ) + sorter_dir = output_dir / sorter_name + + # Run sorting + si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter( + sorter_name=sorter_name, + recording=si_recording, + output_folder=sorter_dir, + verbse=True, + docker_image=True, + **params, + ) - if acq_software == "SpikeGLX": - # sglx_probe = ephys.get_openephys_probe_data(key) - recording_fullpath = kilosort_dir / recording_filename - # sglx_si_recording = si.extractors.load_from_folder(recording_file) - sglx_si_recording = si.core.load_extractor(recording_fullpath) - # assert len(oe_probe.recording_info["recording_files"]) == 1 - - ## Assume that the worker process will trigger this sorting step - # - Will need to store/load the sorter_name, sglx_si_recording object etc. - # - Store in shared EC2 space accessible by all containers (needs to be mounted) - # - Load into the cloud init script, and - # - Option A: Can call this function within a separate container within spike_sorting_worker - if clustering_method.startswith("kilosort2.5"): - sorter_name = "kilosort2_5" - else: - sorter_name = clustering_method - sorting_kilosort = si.full.run_sorter( - sorter_name=sorter_name, - recording=sglx_si_recording, - output_folder=kilosort_dir, - docker_image=f"spikeinterface/{sorter_name}-compiled-base:latest", - **params, - ) - sorting_save_path = kilosort_dir / "sorting_kilosort.pkl" - sorting_kilosort.dump_to_pickle(sorting_save_path) - elif acq_software == "Open Ephys": - oe_probe = ephys.get_openephys_probe_data(key) - oe_si_recording = si.core.load_extractor(recording_fullpath) - assert len(oe_probe.recording_info["recording_files"]) == 1 - if clustering_method.startswith("kilosort2.5"): - sorter_name = "kilosort2_5" - else: - sorter_name = clustering_method - sorting_kilosort = si.full.run_sorter( - sorter_name=sorter_name, - recording=oe_si_recording, - output_folder=kilosort_dir, - docker_image=f"spikeinterface/{sorter_name}-compiled-base:latest", - **params, - ) - sorting_save_path = kilosort_dir / "sorting_kilosort.pkl" - sorting_kilosort.dump_to_pickle(sorting_save_path) + # Run sorting + sorting_save_path = sorter_dir / "si_sorting.pkl" + si_sorting.dump_to_pickle(sorting_save_path) self.insert1( { **key, - "sorting_filename": list(sorting_save_path.parts)[-1], + "sorter_name": sorter_name, "execution_time": execution_time, "execution_duration": ( datetime.utcnow() - execution_time From 2ed337bf0fa8245a7f8d481dc779b072cfcbadd0 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Wed, 20 Dec 2023 15:57:55 -0600 Subject: [PATCH 050/146] feat: :sparkles: add PostProcessing table & clean up --- .../spike_sorting/si_clustering.py | 275 +++--------------- 1 file changed, 45 insertions(+), 230 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index debc4336..935d7360 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -25,7 +25,7 @@ import probeinterface as pi import spikeinterface as si from element_interface.utils import find_full_path, find_root_directory -from spikeinterface import exporters, qualitymetrics, sorters +from spikeinterface import exporters, postprocessing, qualitymetrics, sorters from element_array_ephys import get_logger, probe, readers @@ -222,126 +222,58 @@ class PostProcessing(dj.Imported): -> SIClustering --- execution_time: datetime # datetime of the start of this step - execution_duration: float # (hour) execution duration + execution_duration: float # execution duration in hours """ def make(self, key): execution_time = datetime.utcnow() + JOB_KWARGS = dict(n_jobs=-1, chunk_size=30000) + # Load sorting & recording object. output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") - kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - - acq_software, clustering_method = ( - ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key - ).fetch1("acq_software", "clustering_method") - - params = (PreProcessing & key).fetch1("params") - - if acq_software == "SpikeGLX": - recording_filename = (PreProcessing & key).fetch1("recording_filename") - sorting_file = kilosort_dir / "sorting_kilosort" - filtered_recording_file = kilosort_dir / recording_filename - sglx_si_recording_filtered = si.core.load_extractor(recording_file) - sorting_kilosort = si.core.load_extractor(sorting_file) - - we_kilosort = si.full.WaveformExtractor.create( - sglx_si_recording_filtered, - sorting_kilosort, - "waveforms", - remove_if_exists=True, - ) - we_kilosort.set_params(ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500) - we_kilosort.run_extract_waveforms(n_jobs=-1, chunk_size=30000) - unit_id0 = sorting_kilosort.unit_ids[0] - waveforms = we_kilosort.get_waveforms(unit_id0) - template = we_kilosort.get_template(unit_id0) - snrs = si.full.compute_snrs(we_kilosort) - - # QC Metrics - ( - si_violations_ratio, - isi_violations_rate, - isi_violations_count, - ) = si.full.compute_isi_violations(we_kilosort, isi_threshold_ms=1.5) - metrics = si.full.compute_quality_metrics( - we_kilosort, - metric_names=[ - "firing_rate", - "snr", - "presence_ratio", - "isi_violation", - "num_spikes", - "amplitude_cutoff", - "amplitude_median", - # "sliding_rp_violation", - "rp_violation", - "drift", - ], - ) - si.exporters.export_report( - we_kilosort, kilosort_dir, n_jobs=-1, chunk_size=30000 - ) - # ["firing_rate","snr","presence_ratio","isi_violation", - # "number_violation","amplitude_cutoff","isolation_distance","l_ratio","d_prime","nn_hit_rate", - # "nn_miss_rate","silhouette_core","cumulative_drift","contamination_rate"]) - we_savedir = kilosort_dir / "we_kilosort" - we_kilosort.save(we_savedir, n_jobs=-1, chunk_size=30000) - - elif acq_software == "Open Ephys": - sorting_file = kilosort_dir / "sorting_kilosort" - recording_file = kilosort_dir / "sglx_recording_cmr.json" - oe_si_recording = si.core.load_extractor(recording_file) - sorting_kilosort = si.core.load_extractor(sorting_file) - - we_kilosort = si.full.WaveformExtractor.create( - oe_si_recording, sorting_kilosort, "waveforms", remove_if_exists=True - ) - we_kilosort.set_params(ms_before=3.0, ms_after=4.0, max_spikes_per_unit=500) - we_kilosort.run_extract_waveforms(n_jobs=-1, chunk_size=30000) - unit_id0 = sorting_kilosort.unit_ids[0] - waveforms = we_kilosort.get_waveforms(unit_id0) - template = we_kilosort.get_template(unit_id0) - snrs = si.full.compute_snrs(we_kilosort) - - # QC Metrics - # Apply waveform extractor extensions - _ = si.full.compute_spike_locations(we_kilosort) - _ = si.full.compute_spike_amplitudes(we_kilosort) - _ = si.full.compute_unit_locations(we_kilosort) - _ = si.full.compute_template_metrics(we_kilosort) - _ = si.full.compute_noise_levels(we_kilosort) - _ = si.full.compute_principal_components(we_kilosort) - _ = si.full.compute_drift_metrics(we_kilosort) - _ = si.full.compute_tempoate_similarity(we_kilosort) - ( - isi_violations_ratio, - isi_violations_count, - ) = si.full.compute_isi_violations(we_kilosort, isi_threshold_ms=1.5) - (isi_histograms, bins) = si.full.compute_isi_histograms(we_kilosort) - metrics = si.full.compute_quality_metrics( - we_kilosort, - metric_names=[ - "firing_rate", - "snr", - "presence_ratio", - "isi_violation", - "num_spikes", - "amplitude_cutoff", - "amplitude_median", - # "sliding_rp_violation", - "rp_violation", - "drift", - ], - ) - si.exporters.export_report( - we_kilosort, kilosort_dir, n_jobs=-1, chunk_size=30000 - ) - we_savedir = kilosort_dir / "we_kilosort" - we_kilosort.save(we_savedir, n_jobs=-1, chunk_size=30000) + output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + recording_file = output_dir / "si_recording.pkl" + sorter_dir = output_dir / key["sorter_name"] + sorting_file = sorter_dir / "si_sorting.pkl" + + si_recording: si.BaseRecording = si.load_extractor(recording_file) + si_sorting: si.sorters.BaseSorter = si.load_extractor(sorting_file) + + # Extract waveforms + we: si.WaveformExtractor = si.extract_waveforms( + si_recording, + si_sorting, + folder=sorter_dir / "waveform", # The folder where waveforms are cached + ms_before=3.0, + ms_after=4.0, + max_spikes_per_unit=500, + overwrite=True, + **JOB_KWARGS, + ) - metrics_savefile = kilosort_dir / "metrics.csv" - metrics.to_csv(metrics_savefile) + # Calculate QC Metrics + metrics: pd.DataFrame = si.qualitymetrics.compute_quality_metrics( + we, + metric_names=[ + "firing_rate", + "snr", + "presence_ratio", + "isi_violation", + "num_spikes", + "amplitude_cutoff", + "amplitude_median", + "sliding_rp_violation", + "rp_violation", + "drift", + ], + ) + # Add PCA based metrics. These will be added to the metrics dataframe above. + _ = si.postprocessing.compute_principal_components( + waveform_extractor=we, n_components=5, mode="by_channel_local" + ) # TODO: the parameters need to be checked + metrics = si.qualitymetrics.compute_quality_metrics(waveform_extractor=we) + # Save results self.insert1( { **key, @@ -357,120 +289,3 @@ def make(self, key): ephys.Clustering.insert1( {**key, "clustering_time": datetime.utcnow()}, allow_direct_insert=True ) - - -# def runPreProcessList(preprocess_list, recording): -# # If else -# # need to figure out ordering -# if preprocess_list["Filter"]: -# recording = si.preprocessing.FilterRecording(recording) -# if preprocess_list["BandpassFilter"]: -# recording = si.preprocessing.BandpassFilterRecording(recording) -# if preprocess_list["HighpassFilter"]: -# recording = si.preprocessing.HighpassFilterRecording(recording) -# if preprocess_list["NormalizeByQuantile"]: -# recording = si.preprocessing.NormalizeByQuantileRecording(recording) -# if preprocess_list["Scale"]: -# recording = si.preprocessing.ScaleRecording(recording) -# if preprocess_list["Center"]: -# recording = si.preprocessing.CenterRecording(recording) -# if preprocess_list["ZScore"]: -# recording = si.preprocessing.ZScoreRecording(recording) -# if preprocess_list["Whiten"]: -# recording = si.preprocessing.WhitenRecording(recording) -# if preprocess_list["CommonReference"]: -# recording = si.preprocessing.CommonReferenceRecording(recording) -# if preprocess_list["PhaseShift"]: -# recording = si.preprocessing.PhaseShiftRecording(recording) -# elif preprocess_list["Rectify"]: -# recording = si.preprocessing.RectifyRecording(recording) -# elif preprocess_list["Clip"]: -# recording = si.preprocessing.ClipRecording(recording) -# elif preprocess_list["BlankSaturation"]: -# recording = si.preprocessing.BlankSaturationRecording(recording) -# elif preprocess_list["RemoveArtifacts"]: -# recording = si.preprocessing.RemoveArtifactsRecording(recording) -# elif preprocess_list["RemoveBadChannels"]: -# recording = si.preprocessing.RemoveBadChannelsRecording(recording) -# elif preprocess_list["ZeroChannelPad"]: -# recording = si.preprocessing.ZeroChannelPadRecording(recording) -# elif preprocess_list["DeepInterpolation"]: -# recording = si.preprocessing.DeepInterpolationRecording(recording) -# elif preprocess_list["Resample"]: -# recording = si.preprocessing.ResampleRecording(recording) - - -def mimic_IBLdestriping_modified(recording): - # From SpikeInterface Implementation (https://spikeinterface.readthedocs.io/en/latest/how_to/analyse_neuropixels.html) - recording = si.full.highpass_filter(recording, freq_min=400.0) - bad_channel_ids, channel_labels = si.full.detect_bad_channels(recording) - # For IBL destriping interpolate bad channels - recording = recording.remove_channels(bad_channel_ids) - recording = si.full.phase_shift(recording) - recording = si.full.common_reference( - recording, operator="median", reference="global" - ) - return recording - - -def mimic_IBLdestriping(recording): - # From International Brain Laboratory. “Spike sorting pipeline for the International Brain Laboratory”. 4 May 2022. 9 Jun 2022. - recording = si.full.highpass_filter(recording, freq_min=400.0) - bad_channel_ids, channel_labels = si.full.detect_bad_channels(recording) - # For IBL destriping interpolate bad channels - recording = si.preprocessing.interpolate_bad_channels(bad_channel_ids) - recording = si.full.phase_shift(recording) - # For IBL destriping use highpass_spatial_filter used instead of common reference - recording = si.full.highpass_spatial_filter( - recording, operator="median", reference="global" - ) - return recording - - -def mimic_catGT(sglx_recording): - sglx_recording = si.full.phase_shift(sglx_recording) - sglx_recording = si.full.common_reference( - sglx_recording, operator="median", reference="global" - ) - return sglx_recording - - -## Example SI parameter set -""" -{'detect_threshold': 6, - 'projection_threshold': [10, 4], - 'preclust_threshold': 8, - 'car': True, - 'minFR': 0.02, - 'minfr_goodchannels': 0.1, - 'nblocks': 5, - 'sig': 20, - 'freq_min': 150, - 'sigmaMask': 30, - 'nPCs': 3, - 'ntbuff': 64, - 'nfilt_factor': 4, - 'NT': None, - 'do_correction': True, - 'wave_length': 61, - 'keep_good_only': False, - 'PreProcessing_params': {'Filter': False, - 'BandpassFilter': True, - 'HighpassFilter': False, - 'NotchFilter': False, - 'NormalizeByQuantile': False, - 'Scale': False, - 'Center': False, - 'ZScore': False, - 'Whiten': False, - 'CommonReference': False, - 'PhaseShift': False, - 'Rectify': False, - 'Clip': False, - 'BlankSaturation': False, - 'RemoveArtifacts': False, - 'RemoveBadChannels': False, - 'ZeroChannelPad': False, - 'DeepInterpolation': False, - 'Resample': False}} -""" From f6e3e4624255b9b42e89de3e520c8696bc60089f Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 2 Jan 2024 18:04:25 -0600 Subject: [PATCH 051/146] fix: :bug: fix input/output data directory --- .../spike_sorting/si_clustering.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index 935d7360..80449c88 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -18,6 +18,7 @@ - quality_metrics """ +import pathlib from datetime import datetime import datajoint as dj @@ -111,11 +112,19 @@ def make(self, key): ephys.ClusteringTask.update1( {**key, "clustering_output_dir": output_dir.as_posix()} ) - output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + output_full_dir = find_full_path( + ephys.get_ephys_root_data_dir(), output_dir + ) # output directory in the processed data directory # Create SI recording extractor object + data_dir = ( + ephys.get_ephys_root_data_dir()[0] / pathlib.Path(output_dir).parent + ) # raw data directory + stream_names, stream_ids = si.extractors.get_neo_streams( + acq_software.strip().lower(), folder_path=data_dir + ) si_recording: si.BaseRecording = SI_READERS[acq_software]( - folder_path=output_dir + folder_path=data_dir, stream_name=stream_names[0] ) # Add probe information to recording object @@ -142,7 +151,7 @@ def make(self, key): "IBLdestriping": mimic_IBLdestriping, "IBLdestriping_modified": mimic_IBLdestriping_modified, }[preprocessing_method](si_recording) - recording_file = output_dir / "si_recording.pkl" + recording_file = output_full_dir / "si_recording.pkl" si_recording.dump_to_pickle(file_path=recording_file) self.insert1( From e1c0d689d6b7958c231389e8d11b7ef2e326657f Mon Sep 17 00:00:00 2001 From: JaerongA Date: Wed, 3 Jan 2024 11:47:03 -0600 Subject: [PATCH 052/146] check for presence of recording file --- .../spike_sorting/si_clustering.py | 107 ++++++++++-------- 1 file changed, 57 insertions(+), 50 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index 80449c88..8a5adffb 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -25,7 +25,7 @@ import pandas as pd import probeinterface as pi import spikeinterface as si -from element_interface.utils import find_full_path, find_root_directory +from element_interface.utils import find_full_path from spikeinterface import exporters, postprocessing, qualitymetrics, sorters from element_array_ephys import get_logger, probe, readers @@ -112,58 +112,65 @@ def make(self, key): ephys.ClusteringTask.update1( {**key, "clustering_output_dir": output_dir.as_posix()} ) + output_dir = pathlib.Path(output_dir) output_full_dir = find_full_path( - ephys.get_ephys_root_data_dir(), output_dir - ) # output directory in the processed data directory - - # Create SI recording extractor object - data_dir = ( - ephys.get_ephys_root_data_dir()[0] / pathlib.Path(output_dir).parent - ) # raw data directory - stream_names, stream_ids = si.extractors.get_neo_streams( - acq_software.strip().lower(), folder_path=data_dir - ) - si_recording: si.BaseRecording = SI_READERS[acq_software]( - folder_path=data_dir, stream_name=stream_names[0] - ) - - # Add probe information to recording object - electrode_config_key = ( - probe.ElectrodeConfig * ephys.EphysRecording & key - ).fetch1("KEY") - electrodes_df = ( - ( - probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode - & electrode_config_key + ephys.get_ephys_root_data_dir(), output_dir.parent + ) # recording object will be stored in the parent dir since it can be re-used for multiple sorters + + recording_file = ( + output_full_dir / "si_recording.pkl" + ) # recording cache to be created for each key + + if not recording_file.exists(): # skip if si_recording.pkl already exists + # Create SI recording extractor object + data_dir = ( + ephys.get_ephys_root_data_dir()[0] / output_dir.parent + ) # raw data directory + stream_names, stream_ids = si.extractors.get_neo_streams( + acq_software.strip().lower(), folder_path=data_dir + ) + si_recording: si.BaseRecording = SI_READERS[acq_software]( + folder_path=data_dir, stream_name=stream_names[0] ) - .fetch(format="frame") - .reset_index()[["electrode", "x_coord", "y_coord", "shank"]] - ) - - # Create SI probe object - si_probe = readers.probe_geometry.to_probeinterface(electrodes_df) - si_recording.set_probe(probe=si_probe, in_place=True) - - # Run preprocessing and save results to output folder - preprocessing_method = "catGT" # where to load this info? - si_recording = { - "catGT": mimic_catGT, - "IBLdestriping": mimic_IBLdestriping, - "IBLdestriping_modified": mimic_IBLdestriping_modified, - }[preprocessing_method](si_recording) - recording_file = output_full_dir / "si_recording.pkl" - si_recording.dump_to_pickle(file_path=recording_file) - self.insert1( - { - **key, - "execution_time": execution_time, - "execution_duration": ( - datetime.utcnow() - execution_time - ).total_seconds() - / 3600, - } - ) + # Add probe information to recording object + electrode_config_key = ( + probe.ElectrodeConfig * ephys.EphysRecording & key + ).fetch1("KEY") + electrodes_df = ( + ( + probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode + & electrode_config_key + ) + .fetch(format="frame") + .reset_index()[["electrode", "x_coord", "y_coord", "shank"]] + ) + channels_details = ephys.get_recording_channels_details(key) + + # Create SI probe object + si_probe = readers.probe_geometry.to_probeinterface(electrodes_df) + si_probe.set_device_channel_indices(channels_details["channel_ind"]) + si_recording.set_probe(probe=si_probe, in_place=True) + + # Run preprocessing and save results to output folder + preprocessing_method = "catGT" # where to load this info? + si_recording = { + "catGT": mimic_catGT, + "IBLdestriping": mimic_IBLdestriping, + "IBLdestriping_modified": mimic_IBLdestriping_modified, + }[preprocessing_method](si_recording) + si_recording.dump_to_pickle(file_path=recording_file) + + self.insert1( + { + **key, + "execution_time": execution_time, + "execution_duration": ( + datetime.utcnow() - execution_time + ).total_seconds() + / 3600, + } + ) @schema From 653e7e84bcacd3cf7ae382e5236b732db500ad06 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Wed, 3 Jan 2024 15:05:53 -0600 Subject: [PATCH 053/146] fix: :bug: fix path & typo --- element_array_ephys/spike_sorting/si_clustering.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_clustering.py index 8a5adffb..32804645 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_clustering.py @@ -190,8 +190,8 @@ def make(self, key): # Load recording object. output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") - output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - recording_file = output_dir / "si_recording.pkl" + output_full_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + recording_file = output_full_dir.parent / "si_recording.pkl" si_recording: si.BaseRecording = si.load_extractor(recording_file) # Get sorter method and create output directory. @@ -199,9 +199,9 @@ def make(self, key): ephys.ClusteringTask * ephys.ClusteringParamSet & key ).fetch1("clustering_method", "params") sorter_name = ( - "kilosort_2_5" if clustering_method == "kilsort2.5" else clustering_method + "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method ) - sorter_dir = output_dir / sorter_name + sorter_dir = output_full_dir / sorter_name # Run sorting si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter( From 8c25bd21f69bb5151508729f559f7515b8fc3d08 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 5 Jan 2024 12:56:42 -0600 Subject: [PATCH 054/146] code review --- .../{preprocessing.py => si_preprocessing.py} | 2 +- .../{si_clustering.py => si_spike_sorting.py} | 131 +++++++++--------- 2 files changed, 70 insertions(+), 63 deletions(-) rename element_array_ephys/spike_sorting/{preprocessing.py => si_preprocessing.py} (98%) rename element_array_ephys/spike_sorting/{si_clustering.py => si_spike_sorting.py} (72%) diff --git a/element_array_ephys/spike_sorting/preprocessing.py b/element_array_ephys/spike_sorting/si_preprocessing.py similarity index 98% rename from element_array_ephys/spike_sorting/preprocessing.py rename to element_array_ephys/spike_sorting/si_preprocessing.py index 77a95792..2edf443d 100644 --- a/element_array_ephys/spike_sorting/preprocessing.py +++ b/element_array_ephys/spike_sorting/si_preprocessing.py @@ -37,7 +37,7 @@ def mimic_IBLdestriping_modified(recording): return recording -_preprocessing_function = { +preprocessing_function_mapping = { "catGT": mimic_catGT, "IBLdestriping": mimic_IBLdestriping, "IBLdestriping_modified": mimic_IBLdestriping_modified, diff --git a/element_array_ephys/spike_sorting/si_clustering.py b/element_array_ephys/spike_sorting/si_spike_sorting.py similarity index 72% rename from element_array_ephys/spike_sorting/si_clustering.py rename to element_array_ephys/spike_sorting/si_spike_sorting.py index 32804645..f491b5b5 100644 --- a/element_array_ephys/spike_sorting/si_clustering.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -30,11 +30,7 @@ from element_array_ephys import get_logger, probe, readers -from .preprocessing import ( - mimic_catGT, - mimic_IBLdestriping, - mimic_IBLdestriping_modified, -) +from . import si_preprocessing log = get_logger(__name__) @@ -100,9 +96,13 @@ def make(self, key): execution_time = datetime.utcnow() # Set the output directory - acq_software, output_dir = ( - ephys.ClusteringTask * ephys.EphysRecording & key - ).fetch1("acq_software", "clustering_output_dir") + acq_software, clustering_method, params = ( + ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key + ).fetch1("acq_software", "clustering_method", "params") + + for req_key in ("SI_PREPROCESSING_METHOD", "SI_SORTING_PARAMS", "SI_WAVEFORM_EXTRACTION_PARAMS", "SI_QUALITY_METRICS_PARAMS"): + if req_key not in params: + raise ValueError(f"{req_key} must be defined in ClusteringParamSet for SpikeInterface execution") if not output_dir: output_dir = ephys.ClusteringTask.infer_output_dir( @@ -114,63 +114,68 @@ def make(self, key): ) output_dir = pathlib.Path(output_dir) output_full_dir = find_full_path( - ephys.get_ephys_root_data_dir(), output_dir.parent - ) # recording object will be stored in the parent dir since it can be re-used for multiple sorters + ephys.get_ephys_root_data_dir(), output_dir + ) recording_file = ( output_full_dir / "si_recording.pkl" ) # recording cache to be created for each key - if not recording_file.exists(): # skip if si_recording.pkl already exists - # Create SI recording extractor object - data_dir = ( - ephys.get_ephys_root_data_dir()[0] / output_dir.parent - ) # raw data directory - stream_names, stream_ids = si.extractors.get_neo_streams( - acq_software.strip().lower(), folder_path=data_dir - ) - si_recording: si.BaseRecording = SI_READERS[acq_software]( - folder_path=data_dir, stream_name=stream_names[0] - ) + # Create SI recording extractor object + if acq_software == "SpikeGLX": + spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) + spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) + spikeglx_recording.validate_file("ap") + data_dir = spikeglx_meta_filepath.parent + elif acq_software == "Open Ephys": + oe_probe = ephys.get_openephys_probe_data(key) + assert len(oe_probe.recording_info["recording_files"]) == 1 + data_dir = oe_probe.recording_info["recording_files"][0] + else: + raise NotImplementedError(f"Not implemented for {acq_software}") + + stream_names, stream_ids = si.extractors.get_neo_streams( + acq_software.strip().lower(), folder_path=data_dir + ) + si_recording: si.BaseRecording = SI_READERS[acq_software]( + folder_path=data_dir, stream_name=stream_names[0] + ) - # Add probe information to recording object - electrode_config_key = ( - probe.ElectrodeConfig * ephys.EphysRecording & key - ).fetch1("KEY") - electrodes_df = ( - ( - probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode - & electrode_config_key - ) - .fetch(format="frame") - .reset_index()[["electrode", "x_coord", "y_coord", "shank"]] - ) - channels_details = ephys.get_recording_channels_details(key) - - # Create SI probe object - si_probe = readers.probe_geometry.to_probeinterface(electrodes_df) - si_probe.set_device_channel_indices(channels_details["channel_ind"]) - si_recording.set_probe(probe=si_probe, in_place=True) - - # Run preprocessing and save results to output folder - preprocessing_method = "catGT" # where to load this info? - si_recording = { - "catGT": mimic_catGT, - "IBLdestriping": mimic_IBLdestriping, - "IBLdestriping_modified": mimic_IBLdestriping_modified, - }[preprocessing_method](si_recording) - si_recording.dump_to_pickle(file_path=recording_file) - - self.insert1( - { - **key, - "execution_time": execution_time, - "execution_duration": ( - datetime.utcnow() - execution_time - ).total_seconds() - / 3600, - } + # Add probe information to recording object + electrode_config_key = ( + probe.ElectrodeConfig * ephys.EphysRecording & key + ).fetch1("KEY") + electrodes_df = ( + ( + probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode + & electrode_config_key ) + .fetch(format="frame") + .reset_index()[["electrode", "x_coord", "y_coord", "shank"]] + ) + channels_details = ephys.get_recording_channels_details(key) + + # Create SI probe object + si_probe = readers.probe_geometry.to_probeinterface(electrodes_df) + si_probe.set_device_channel_indices(channels_details["channel_ind"]) + si_recording.set_probe(probe=si_probe, in_place=True) + + # Run preprocessing and save results to output folder + preprocessing_method = params["SI_PREPROCESSING_METHOD"] + si_preproc_func = si_preprocessing.preprocessing_function_mapping[preprocessing_method] + si_recording = si_preproc_func(si_recording) + si_recording.dump_to_pickle(file_path=recording_file) + + self.insert1( + { + **key, + "execution_time": execution_time, + "execution_duration": ( + datetime.utcnow() - execution_time + ).total_seconds() + / 3600, + } + ) @schema @@ -203,6 +208,8 @@ def make(self, key): ) sorter_dir = output_full_dir / sorter_name + si_sorting_params = params["SI_SORTING_PARAMS"] + # Run sorting si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter( sorter_name=sorter_name, @@ -210,7 +217,7 @@ def make(self, key): output_folder=sorter_dir, verbse=True, docker_image=True, - **params, + **si_sorting_params, ) # Run sorting @@ -255,14 +262,14 @@ def make(self, key): si_recording: si.BaseRecording = si.load_extractor(recording_file) si_sorting: si.sorters.BaseSorter = si.load_extractor(sorting_file) + si_waveform_extraction_params = params["SI_WAVEFORM_EXTRACTION_PARAMS"] + # Extract waveforms we: si.WaveformExtractor = si.extract_waveforms( si_recording, si_sorting, folder=sorter_dir / "waveform", # The folder where waveforms are cached - ms_before=3.0, - ms_after=4.0, - max_spikes_per_unit=500, + **si_waveform_extraction_params overwrite=True, **JOB_KWARGS, ) From 90d1ba4e4c87493c4e79ae93ed139ee0f5e3218f Mon Sep 17 00:00:00 2001 From: JaerongA Date: Thu, 1 Feb 2024 15:27:57 -0600 Subject: [PATCH 055/146] feat: :sparkles: modify QualityMetrics make function --- element_array_ephys/ephys_no_curation.py | 34 +++++++++++++++--------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 8ee7ee8b..ca293d95 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1358,24 +1358,34 @@ class Waveform(dj.Part): def make(self, key): """Populates tables with quality metrics data.""" + # Load metrics.csv output_dir = (ClusteringTask & key).fetch1("clustering_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - metric_fp = kilosort_dir / "metrics.csv" - rename_dict = { - "isi_viol": "isi_violation", - "num_viol": "number_violation", - "contam_rate": "contamination_rate", - } - + output_dir = find_full_path(get_ephys_root_data_dir(), output_dir) + metric_fp = output_dir / "metrics.csv" if not metric_fp.exists(): raise FileNotFoundError(f"QC metrics file not found: {metric_fp}") - metrics_df = pd.read_csv(metric_fp) - metrics_df.set_index("cluster_id", inplace=True) + + # Conform the dataframe to match the table definition + if "cluster_id" in metrics_df.columns: + metrics_df.set_index("cluster_id", inplace=True) + else: + metrics_df.rename( + columns={metrics_df.columns[0]: "cluster_id"}, inplace=True + ) + metrics_df.set_index("cluster_id", inplace=True) metrics_df.replace([np.inf, -np.inf], np.nan, inplace=True) metrics_df.columns = metrics_df.columns.str.lower() - metrics_df.rename(columns=rename_dict, inplace=True) + + metrics_df.rename( + columns={ + "isi_viol": "isi_violation", + "num_viol": "number_violation", + "contam_rate": "contamination_rate", + }, + inplace=True, + ) + metrics_list = [ dict(metrics_df.loc[unit_key["unit"]], **unit_key) for unit_key in (CuratedClustering.Unit & key).fetch("KEY") From cacefaceb0246bbc7b1a86b7e9afbaff335f5a86 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 6 Feb 2024 21:02:47 +0000 Subject: [PATCH 056/146] update si_spike_sorting.PreProcessing make function --- .../spike_sorting/si_spike_sorting.py | 30 ++++++++++++------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index f491b5b5..82729fe7 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -96,13 +96,20 @@ def make(self, key): execution_time = datetime.utcnow() # Set the output directory - acq_software, clustering_method, params = ( + acq_software, output_dir, params = ( ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key - ).fetch1("acq_software", "clustering_method", "params") - - for req_key in ("SI_PREPROCESSING_METHOD", "SI_SORTING_PARAMS", "SI_WAVEFORM_EXTRACTION_PARAMS", "SI_QUALITY_METRICS_PARAMS"): + ).fetch1("acq_software", "clustering_output_dir", "params") + + for req_key in ( + "SI_SORTING_PARAMS", + "SI_PREPROCESSING_METHOD", + "SI_WAVEFORM_EXTRACTION_PARAMS", + "SI_QUALITY_METRICS_PARAMS", + ): if req_key not in params: - raise ValueError(f"{req_key} must be defined in ClusteringParamSet for SpikeInterface execution") + raise ValueError( + f"{req_key} must be defined in ClusteringParamSet for SpikeInterface execution" + ) if not output_dir: output_dir = ephys.ClusteringTask.infer_output_dir( @@ -113,9 +120,7 @@ def make(self, key): {**key, "clustering_output_dir": output_dir.as_posix()} ) output_dir = pathlib.Path(output_dir) - output_full_dir = find_full_path( - ephys.get_ephys_root_data_dir(), output_dir - ) + output_full_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) recording_file = ( output_full_dir / "si_recording.pkl" @@ -124,7 +129,9 @@ def make(self, key): # Create SI recording extractor object if acq_software == "SpikeGLX": spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) + spikeglx_recording = readers.spikeglx.SpikeGLX( + spikeglx_meta_filepath.parent + ) spikeglx_recording.validate_file("ap") data_dir = spikeglx_meta_filepath.parent elif acq_software == "Open Ephys": @@ -161,8 +168,9 @@ def make(self, key): si_recording.set_probe(probe=si_probe, in_place=True) # Run preprocessing and save results to output folder - preprocessing_method = params["SI_PREPROCESSING_METHOD"] - si_preproc_func = si_preprocessing.preprocessing_function_mapping[preprocessing_method] + si_preproc_func = si_preprocessing.preprocessing_function_mapping[ + params["SI_PREPROCESSING_METHOD"] + ] si_recording = si_preproc_func(si_recording) si_recording.dump_to_pickle(file_path=recording_file) From f98e1ed332c7c5ca61b9321e461c0101b063a064 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Wed, 7 Feb 2024 17:20:15 +0000 Subject: [PATCH 057/146] update SIClustering make function --- .../spike_sorting/si_spike_sorting.py | 26 ++++++++----------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 82729fe7..a2f2db24 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -25,11 +25,10 @@ import pandas as pd import probeinterface as pi import spikeinterface as si +from element_array_ephys import get_logger, probe, readers from element_interface.utils import find_full_path from spikeinterface import exporters, postprocessing, qualitymetrics, sorters -from element_array_ephys import get_logger, probe, readers - from . import si_preprocessing log = get_logger(__name__) @@ -202,34 +201,31 @@ def make(self, key): execution_time = datetime.utcnow() # Load recording object. - output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") - output_full_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - recording_file = output_full_dir.parent / "si_recording.pkl" + clustering_method, output_dir, params = ( + ephys.ClusteringTask * ephys.ClusteringParamSet & key + ).fetch1("clustering_method", "clustering_output_dir", "params") + output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + recording_file = output_dir / "si_recording.pkl" si_recording: si.BaseRecording = si.load_extractor(recording_file) # Get sorter method and create output directory. - clustering_method, params = ( - ephys.ClusteringTask * ephys.ClusteringParamSet & key - ).fetch1("clustering_method", "params") sorter_name = ( "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method ) - sorter_dir = output_full_dir / sorter_name - - si_sorting_params = params["SI_SORTING_PARAMS"] # Run sorting si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter( sorter_name=sorter_name, recording=si_recording, - output_folder=sorter_dir, - verbse=True, + output_folder=output_dir / sorter_name, + remove_existing_folder=True, + verbose=True, docker_image=True, - **si_sorting_params, + **params.get("SI_SORTING_PARAMS", {}), ) # Run sorting - sorting_save_path = sorter_dir / "si_sorting.pkl" + sorting_save_path = output_dir / "si_sorting.pkl" si_sorting.dump_to_pickle(sorting_save_path) self.insert1( From 7a060ef5b875dd74189841aa5a72d3eea55d7dda Mon Sep 17 00:00:00 2001 From: JaerongA Date: Wed, 7 Feb 2024 17:21:16 +0000 Subject: [PATCH 058/146] update PostProcessing make function --- .../spike_sorting/si_spike_sorting.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index a2f2db24..5eb2b822 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -254,28 +254,26 @@ class PostProcessing(dj.Imported): def make(self, key): execution_time = datetime.utcnow() - JOB_KWARGS = dict(n_jobs=-1, chunk_size=30000) # Load sorting & recording object. - output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") + output_dir, params = (ephys.ClusteringTask & key).fetch1( + "clustering_output_dir", "params" + ) output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) recording_file = output_dir / "si_recording.pkl" - sorter_dir = output_dir / key["sorter_name"] - sorting_file = sorter_dir / "si_sorting.pkl" + sorting_file = output_dir / "si_sorting.pkl" si_recording: si.BaseRecording = si.load_extractor(recording_file) si_sorting: si.sorters.BaseSorter = si.load_extractor(sorting_file) - si_waveform_extraction_params = params["SI_WAVEFORM_EXTRACTION_PARAMS"] - # Extract waveforms we: si.WaveformExtractor = si.extract_waveforms( si_recording, si_sorting, - folder=sorter_dir / "waveform", # The folder where waveforms are cached - **si_waveform_extraction_params + folder=output_dir / "waveform", # The folder where waveforms are cached overwrite=True, - **JOB_KWARGS, + **params.get("SI_WAVEFORM_EXTRACTION_PARAMS", {}), + **params.get("SI_JOB_KWARGS", {"n_jobs": -1, "chunk_size": 30000}), ) # Calculate QC Metrics @@ -296,9 +294,11 @@ def make(self, key): ) # Add PCA based metrics. These will be added to the metrics dataframe above. _ = si.postprocessing.compute_principal_components( - waveform_extractor=we, n_components=5, mode="by_channel_local" - ) # TODO: the parameters need to be checked + waveform_extractor=we, **params.get("SI_QUALITY_METRICS_PARAMS", None) + ) + # Save the output (metrics.csv to the output dir) metrics = si.qualitymetrics.compute_quality_metrics(waveform_extractor=we) + metrics.to_csv(output_dir / "metrics.csv") # Save results self.insert1( @@ -312,7 +312,7 @@ def make(self, key): } ) - # all finished, insert this `key` into ephys.Clustering + # Once finished, insert this `key` into ephys.Clustering ephys.Clustering.insert1( {**key, "clustering_time": datetime.utcnow()}, allow_direct_insert=True ) From 6daf0c5a1f3f40d3700ce09bfa047326b4477cab Mon Sep 17 00:00:00 2001 From: JaerongA Date: Wed, 7 Feb 2024 17:01:52 -0600 Subject: [PATCH 059/146] feat: :sparkles: add n.a. to ClusterQualityLabel --- element_array_ephys/ephys_no_curation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index ca293d95..aa743598 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -703,6 +703,7 @@ class ClusterQualityLabel(dj.Lookup): ("ok", "probably a single unit, but could be contaminated"), ("mua", "multi-unit activity"), ("noise", "bad unit"), + ("n.a.", "not available"), ] From e41ff1daeaddeb2847b426893f430906a86e91d6 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Wed, 7 Feb 2024 17:15:05 -0600 Subject: [PATCH 060/146] extract all waveforms --- element_array_ephys/spike_sorting/si_spike_sorting.py | 1 + 1 file changed, 1 insertion(+) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 5eb2b822..432b6c10 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -271,6 +271,7 @@ def make(self, key): si_recording, si_sorting, folder=output_dir / "waveform", # The folder where waveforms are cached + max_spikes_per_unit=None, overwrite=True, **params.get("SI_WAVEFORM_EXTRACTION_PARAMS", {}), **params.get("SI_JOB_KWARGS", {"n_jobs": -1, "chunk_size": 30000}), From e8d9854f4014302248d483604faaa2b4f2858fc2 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Wed, 7 Feb 2024 19:15:34 -0600 Subject: [PATCH 061/146] feat: :sparkles: modify CuratedClustering make function for spike interface --- element_array_ephys/ephys_no_curation.py | 195 ++++++++++++++++------- 1 file changed, 138 insertions(+), 57 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index aa743598..70ce87cf 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -959,75 +959,156 @@ class Unit(dj.Part): def make(self, key): """Automated population of Unit information.""" output_dir = (ClusteringTask & key).fetch1("clustering_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) + output_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - kilosort_dataset = kilosort.Kilosort(kilosort_dir) - acq_software, sample_rate = (EphysRecording & key).fetch1( - "acq_software", "sampling_rate" - ) + if (output_dir / "waveform").exists(): # read from spikeinterface outputs + we: si.WaveformExtractor = si.load_waveforms( + output_dir / "waveform", with_recording=False + ) + si_sorting: si.sorters.BaseSorter = si.load_extractor( + output_dir / "sorting.pkl" + ) - sample_rate = kilosort_dataset.data["params"].get("sample_rate", sample_rate) + unit_peak_channel_map: dict[int, int] = si.get_template_extremum_channel( + we, outputs="index" + ) # {unit: peak_channel_index} - # ---------- Unit ---------- - # -- Remove 0-spike units - withspike_idx = [ - i - for i, u in enumerate(kilosort_dataset.data["cluster_ids"]) - if (kilosort_dataset.data["spike_clusters"] == u).any() - ] - valid_units = kilosort_dataset.data["cluster_ids"][withspike_idx] - valid_unit_labels = kilosort_dataset.data["cluster_groups"][withspike_idx] - # -- Get channel and electrode-site mapping - channel2electrodes = get_neuropixels_channel2electrode_map(key, acq_software) - - # -- Spike-times -- - # spike_times_sec_adj > spike_times_sec > spike_times - spike_time_key = ( - "spike_times_sec_adj" - if "spike_times_sec_adj" in kilosort_dataset.data - else "spike_times_sec" - if "spike_times_sec" in kilosort_dataset.data - else "spike_times" - ) - spike_times = kilosort_dataset.data[spike_time_key] - kilosort_dataset.extract_spike_depths() + spike_count_dict = dict[int, int] = si_sorting.count_num_spikes_per_unit() + # {unit: spike_count} - # -- Spike-sites and Spike-depths -- - spike_sites = np.array( - [ - channel2electrodes[s]["electrode"] - for s in kilosort_dataset.data["spike_sites"] - ] - ) - spike_depths = kilosort_dataset.data["spike_depths"] - - # -- Insert unit, label, peak-chn - units = [] - for unit, unit_lbl in zip(valid_units, valid_unit_labels): - if (kilosort_dataset.data["spike_clusters"] == unit).any(): - unit_channel, _ = kilosort_dataset.get_best_channel(unit) - unit_spike_times = ( - spike_times[kilosort_dataset.data["spike_clusters"] == unit] - / sample_rate - ) - spike_count = len(unit_spike_times) + spikes = si_sorting.to_spike_vector( + extremum_channel_inds=unit_peak_channel_map + ) + + # Get electrode info + electrode_config_key = ( + EphysRecording * probe.ElectrodeConfig & key + ).fetch1("KEY") + + electrode_query = ( + probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode + & electrode_config_key + ) + channel2electrode_map = dict( + zip(*electrode_query.fetch("channel", "electrode")) + ) + + # Get channel to electrode mapping + channel2depth_map = dict(zip(*electrode_query.fetch("channel", "y_coord"))) + + peak_electrode_ind = np.array( + [ + channel2electrode_map[unit_peak_channel_map[unit_id]] + for unit_id in si_sorting.unit_ids + ] + ) + + # Get channel to depth mapping + electrode_depth_ind = np.array( + [ + channel2depth_map[unit_peak_channel_map[unit_id]] + for unit_id in si_sorting.unit_ids + ] + ) + spikes["electrode"] = peak_electrode_ind[spikes["unit_index"]] + spikes["depth"] = electrode_depth_ind[spikes["unit_index"]] + + units = [] + for unit_id in si_sorting.unit_ids: + unit_id = int(unit_id) units.append( { - "unit": unit, - "cluster_quality_label": unit_lbl, - **channel2electrodes[unit_channel], - "spike_times": unit_spike_times, - "spike_count": spike_count, - "spike_sites": spike_sites[ - kilosort_dataset.data["spike_clusters"] == unit + "unit": unit_id, + "cluster_quality_label": "n.a.", + "spike_times": si_sorting.get_unit_spike_train( + unit_id, return_times=True + ), + "spike_count": spike_count_dict[unit_id], + "spike_sites": spikes["electrode"][ + spikes["unit_index"] == unit_id ], - "spike_depths": spike_depths[ - kilosort_dataset.data["spike_clusters"] == unit + "spike_depths": spikes["depth"][ + spikes["unit_index"] == unit_id ], } ) + else: + kilosort_dataset = kilosort.Kilosort(output_dir) + acq_software, sample_rate = (EphysRecording & key).fetch1( + "acq_software", "sampling_rate" + ) + + sample_rate = kilosort_dataset.data["params"].get( + "sample_rate", sample_rate + ) + + # ---------- Unit ---------- + # -- Remove 0-spike units + withspike_idx = [ + i + for i, u in enumerate(kilosort_dataset.data["cluster_ids"]) + if (kilosort_dataset.data["spike_clusters"] == u).any() + ] + valid_units = kilosort_dataset.data["cluster_ids"][withspike_idx] + valid_unit_labels = kilosort_dataset.data["cluster_groups"][withspike_idx] + + # -- Spike-times -- + # spike_times_sec_adj > spike_times_sec > spike_times + spike_time_key = ( + "spike_times_sec_adj" + if "spike_times_sec_adj" in kilosort_dataset.data + else ( + "spike_times_sec" + if "spike_times_sec" in kilosort_dataset.data + else "spike_times" + ) + ) + spike_times = kilosort_dataset.data[spike_time_key] + kilosort_dataset.extract_spike_depths() + + # Get channel and electrode-site mapping + channel2electrodes = get_neuropixels_channel2electrode_map( + key, acq_software + ) + + # -- Spike-sites and Spike-depths -- + spike_sites = np.array( + [ + channel2electrodes[s]["electrode"] + for s in kilosort_dataset.data["spike_sites"] + ] + ) + spike_depths = kilosort_dataset.data["spike_depths"] + + # -- Insert unit, label, peak-chn + units = [] + for unit, unit_lbl in zip(valid_units, valid_unit_labels): + if (kilosort_dataset.data["spike_clusters"] == unit).any(): + unit_channel, _ = kilosort_dataset.get_best_channel(unit) + unit_spike_times = ( + spike_times[kilosort_dataset.data["spike_clusters"] == unit] + / sample_rate + ) + spike_count = len(unit_spike_times) + + units.append( + { + "unit": unit, + "cluster_quality_label": unit_lbl, + **channel2electrodes[unit_channel], + "spike_times": unit_spike_times, + "spike_count": spike_count, + "spike_sites": spike_sites[ + kilosort_dataset.data["spike_clusters"] == unit + ], + "spike_depths": spike_depths[ + kilosort_dataset.data["spike_clusters"] == unit + ], + } + ) + self.insert1(key) self.Unit.insert([{**key, **u} for u in units]) From 00b82f81017fdb92459cd334cda8bb3dc49b0fde Mon Sep 17 00:00:00 2001 From: JaerongA Date: Wed, 7 Feb 2024 19:17:15 -0600 Subject: [PATCH 062/146] refactor: :recycle: import si module & re-organize imports --- element_array_ephys/ephys_no_curation.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 70ce87cf..63e72951 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1,17 +1,18 @@ -import datajoint as dj +import gc +import importlib +import inspect import pathlib import re -import numpy as np -import inspect -import importlib -import gc from decimal import Decimal -import pandas as pd -from element_interface.utils import find_root_directory, find_full_path, dict_to_uuid -from .readers import spikeglx, kilosort, openephys -from element_array_ephys import probe, get_logger, ephys_report +import datajoint as dj +import numpy as np +import pandas as pd +from element_array_ephys import ephys_report, get_logger, probe +from element_interface.utils import (dict_to_uuid, find_full_path, + find_root_directory) +from .readers import kilosort, openephys, spikeglx log = get_logger(__name__) @@ -19,8 +20,8 @@ _linking_module = None -import spikeinterface -import spikeinterface.full as si +import spikeinterface as si +from spikeinterface import exporters, postprocessing, qualitymetrics, sorters def activate( From b01c36c81595e8e4cb38e22f2dee146986508b79 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Mon, 12 Feb 2024 09:39:04 -0600 Subject: [PATCH 063/146] update WaveformSet ingestion --- element_array_ephys/ephys_no_curation.py | 368 ++++++++++++++--------- 1 file changed, 218 insertions(+), 150 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 63e72951..4887096c 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1168,177 +1168,245 @@ class Waveform(dj.Part): def make(self, key): """Populates waveform tables.""" output_dir = (ClusteringTask & key).fetch1("clustering_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) + output_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - kilosort_dataset = kilosort.Kilosort(kilosort_dir) + if (output_dir / "waveform").exists(): # read from spikeinterface outputs - acq_software, probe_serial_number = ( - EphysRecording * ProbeInsertion & key - ).fetch1("acq_software", "probe") + we: si.WaveformExtractor = si.load_waveforms( + output_dir / "waveform", with_recording=False + ) + unit_id_to_peak_channel_indices: dict[int, np.ndarray] = ( + si.ChannelSparsity.from_best_channels( + we, 1, peak_sign="neg" + ).unit_id_to_channel_indices + ) # {unit: peak_channel_index} - # -- Get channel and electrode-site mapping - recording_key = (EphysRecording & key).fetch1("KEY") - channel2electrodes = get_neuropixels_channel2electrode_map( - recording_key, acq_software - ) + units = (CuratedClustering.Unit & key).fetch("KEY", order_by="unit") - # Get all units - units = { - u["unit"]: u - for u in (CuratedClustering.Unit & key).fetch(as_dict=True, order_by="unit") - } + # Get electrode info + electrode_config_key = ( + EphysRecording * probe.ElectrodeConfig & key + ).fetch1("KEY") - waveforms_folder = [ - f for f in kilosort_dir.parent.rglob(r"*/waveforms*") if f.is_dir() - ] + electrode_query = ( + probe.ProbeType.Electrode.proj() * probe.ElectrodeConfig.Electrode + & electrode_config_key + ) + electrode_info = electrode_query.fetch( + "KEY", order_by="electrode", as_dict=True + ) - if (kilosort_dir / "mean_waveforms.npy").exists(): - unit_waveforms = np.load( - kilosort_dir / "mean_waveforms.npy" - ) # unit x channel x sample + # Get mean waveform for each unit from all channels + mean_waveforms = we.get_all_templates( + mode="average" + ) # (unit x sample x channel) - def yield_unit_waveforms(): - for unit_no, unit_waveform in zip( - kilosort_dataset.data["cluster_ids"], unit_waveforms - ): - unit_peak_waveform = {} - unit_electrode_waveforms = [] - if unit_no in units: - for channel, channel_waveform in zip( - kilosort_dataset.data["channel_map"], unit_waveform - ): - unit_electrode_waveforms.append( - { - **units[unit_no], - **channel2electrodes[channel], - "waveform_mean": channel_waveform, - } - ) - if ( - channel2electrodes[channel]["electrode"] - == units[unit_no]["electrode"] - ): - unit_peak_waveform = { - **units[unit_no], - "peak_electrode_waveform": channel_waveform, - } - yield unit_peak_waveform, unit_electrode_waveforms - - # Spike interface mean and peak waveform extraction from we object - - elif len(waveforms_folder) > 0 & (waveforms_folder[0]).exists(): - we_kilosort = si.load_waveforms(waveforms_folder[0].parent) - unit_templates = we_kilosort.get_all_templates() - unit_waveforms = np.reshape( - unit_templates, - ( - unit_templates.shape[1], - unit_templates.shape[3], - unit_templates.shape[2], - ), + unit_peak_waveform = [] + unit_electrode_waveforms = [] + + for unit in units: + unit_peak_waveform.append( + { + **unit, + "peak_electrode_waveform": we.get_template( + unit_id=unit["unit"], mode="average", force_dense=True + )[:, unit_id_to_peak_channel_indices[unit["unit"]][0]], + } + ) + + unit_electrode_waveforms.extend( + [ + { + **unit, + **e, + "waveform_mean": mean_waveforms[ + unit["unit"], :, e["electrode"] + ], + } + for e in electrode_info + ] + ) + + self.insert1(key) + self.PeakWaveform.insert(unit_peak_waveform) + self.Waveform.insert(unit_electrode_waveforms) + + else: + kilosort_dataset = kilosort.Kilosort(output_dir) + + acq_software, probe_serial_number = ( + EphysRecording * ProbeInsertion & key + ).fetch1("acq_software", "probe") + + # -- Get channel and electrode-site mapping + recording_key = (EphysRecording & key).fetch1("KEY") + channel2electrodes = get_neuropixels_channel2electrode_map( + recording_key, acq_software ) - # Approach assumes unit_waveforms was generated correctly (templates are actually the same as mean_waveforms) - def yield_unit_waveforms(): - for unit_no, unit_waveform in zip( - kilosort_dataset.data["cluster_ids"], unit_waveforms - ): - unit_peak_waveform = {} - unit_electrode_waveforms = [] - if unit_no in units: + # Get all units + units = { + u["unit"]: u + for u in (CuratedClustering.Unit & key).fetch( + as_dict=True, order_by="unit" + ) + } + + waveforms_folder = [ + f for f in output_dir.parent.rglob(r"*/waveforms*") if f.is_dir() + ] + + if (output_dir / "mean_waveforms.npy").exists(): + unit_waveforms = np.load( + output_dir / "mean_waveforms.npy" + ) # unit x channel x sample + + def yield_unit_waveforms(): + for unit_no, unit_waveform in zip( + kilosort_dataset.data["cluster_ids"], unit_waveforms + ): + unit_peak_waveform = {} + unit_electrode_waveforms = [] + if unit_no in units: + for channel, channel_waveform in zip( + kilosort_dataset.data["channel_map"], unit_waveform + ): + unit_electrode_waveforms.append( + { + **units[unit_no], + **channel2electrodes[channel], + "waveform_mean": channel_waveform, + } + ) + if ( + channel2electrodes[channel]["electrode"] + == units[unit_no]["electrode"] + ): + unit_peak_waveform = { + **units[unit_no], + "peak_electrode_waveform": channel_waveform, + } + yield unit_peak_waveform, unit_electrode_waveforms + + # Spike interface mean and peak waveform extraction from we object + + elif len(waveforms_folder) > 0 & (waveforms_folder[0]).exists(): + we_kilosort = si.load_waveforms(waveforms_folder[0].parent) + unit_templates = we_kilosort.get_all_templates() + unit_waveforms = np.reshape( + unit_templates, + ( + unit_templates.shape[1], + unit_templates.shape[3], + unit_templates.shape[2], + ), + ) + + # Approach assumes unit_waveforms was generated correctly (templates are actually the same as mean_waveforms) + def yield_unit_waveforms(): + for unit_no, unit_waveform in zip( + kilosort_dataset.data["cluster_ids"], unit_waveforms + ): + unit_peak_waveform = {} + unit_electrode_waveforms = [] + if unit_no in units: + for channel, channel_waveform in zip( + kilosort_dataset.data["channel_map"], unit_waveform + ): + unit_electrode_waveforms.append( + { + **units[unit_no], + **channel2electrodes[channel], + "waveform_mean": channel_waveform, + } + ) + if ( + channel2electrodes[channel]["electrode"] + == units[unit_no]["electrode"] + ): + unit_peak_waveform = { + **units[unit_no], + "peak_electrode_waveform": channel_waveform, + } + yield unit_peak_waveform, unit_electrode_waveforms + + # Approach not using spike interface templates (ie. taking mean of each unit waveform) + # def yield_unit_waveforms(): + # for unit_id in we_kilosort.unit_ids: + # unit_waveform = np.mean(we_kilosort.get_waveforms(unit_id), 0) + # unit_peak_waveform = {} + # unit_electrode_waveforms = [] + # if unit_id in units: + # for channel, channel_waveform in zip( + # kilosort_dataset.data["channel_map"], unit_waveform + # ): + # unit_electrode_waveforms.append( + # { + # **units[unit_id], + # **channel2electrodes[channel], + # "waveform_mean": channel_waveform, + # } + # ) + # if ( + # channel2electrodes[channel]["electrode"] + # == units[unit_id]["electrode"] + # ): + # unit_peak_waveform = { + # **units[unit_id], + # "peak_electrode_waveform": channel_waveform, + # } + # yield unit_peak_waveform, unit_electrode_waveforms + + else: + if acq_software == "SpikeGLX": + spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) + neuropixels_recording = spikeglx.SpikeGLX( + spikeglx_meta_filepath.parent + ) + elif acq_software == "Open Ephys": + session_dir = find_full_path( + get_ephys_root_data_dir(), get_session_directory(key) + ) + openephys_dataset = openephys.OpenEphys(session_dir) + neuropixels_recording = openephys_dataset.probes[ + probe_serial_number + ] + + def yield_unit_waveforms(): + for unit_dict in units.values(): + unit_peak_waveform = {} + unit_electrode_waveforms = [] + + spikes = unit_dict["spike_times"] + waveforms = neuropixels_recording.extract_spike_waveforms( + spikes, kilosort_dataset.data["channel_map"] + ) # (sample x channel x spike) + waveforms = waveforms.transpose( + (1, 2, 0) + ) # (channel x spike x sample) for channel, channel_waveform in zip( - kilosort_dataset.data["channel_map"], unit_waveform + kilosort_dataset.data["channel_map"], waveforms ): unit_electrode_waveforms.append( { - **units[unit_no], + **unit_dict, **channel2electrodes[channel], - "waveform_mean": channel_waveform, + "waveform_mean": channel_waveform.mean(axis=0), + "waveforms": channel_waveform, } ) if ( channel2electrodes[channel]["electrode"] - == units[unit_no]["electrode"] + == unit_dict["electrode"] ): unit_peak_waveform = { - **units[unit_no], - "peak_electrode_waveform": channel_waveform, + **unit_dict, + "peak_electrode_waveform": channel_waveform.mean( + axis=0 + ), } - yield unit_peak_waveform, unit_electrode_waveforms - - # Approach not using spike interface templates (ie. taking mean of each unit waveform) - # def yield_unit_waveforms(): - # for unit_id in we_kilosort.unit_ids: - # unit_waveform = np.mean(we_kilosort.get_waveforms(unit_id), 0) - # unit_peak_waveform = {} - # unit_electrode_waveforms = [] - # if unit_id in units: - # for channel, channel_waveform in zip( - # kilosort_dataset.data["channel_map"], unit_waveform - # ): - # unit_electrode_waveforms.append( - # { - # **units[unit_id], - # **channel2electrodes[channel], - # "waveform_mean": channel_waveform, - # } - # ) - # if ( - # channel2electrodes[channel]["electrode"] - # == units[unit_id]["electrode"] - # ): - # unit_peak_waveform = { - # **units[unit_id], - # "peak_electrode_waveform": channel_waveform, - # } - # yield unit_peak_waveform, unit_electrode_waveforms - else: - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) - neuropixels_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - elif acq_software == "Open Ephys": - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(key) - ) - openephys_dataset = openephys.OpenEphys(session_dir) - neuropixels_recording = openephys_dataset.probes[probe_serial_number] - - def yield_unit_waveforms(): - for unit_dict in units.values(): - unit_peak_waveform = {} - unit_electrode_waveforms = [] - - spikes = unit_dict["spike_times"] - waveforms = neuropixels_recording.extract_spike_waveforms( - spikes, kilosort_dataset.data["channel_map"] - ) # (sample x channel x spike) - waveforms = waveforms.transpose( - (1, 2, 0) - ) # (channel x spike x sample) - for channel, channel_waveform in zip( - kilosort_dataset.data["channel_map"], waveforms - ): - unit_electrode_waveforms.append( - { - **unit_dict, - **channel2electrodes[channel], - "waveform_mean": channel_waveform.mean(axis=0), - "waveforms": channel_waveform, - } - ) - if ( - channel2electrodes[channel]["electrode"] - == unit_dict["electrode"] - ): - unit_peak_waveform = { - **unit_dict, - "peak_electrode_waveform": channel_waveform.mean( - axis=0 - ), - } - - yield unit_peak_waveform, unit_electrode_waveforms + yield unit_peak_waveform, unit_electrode_waveforms # insert waveform on a per-unit basis to mitigate potential memory issue self.insert1(key) @@ -1448,7 +1516,7 @@ def make(self, key): if not metric_fp.exists(): raise FileNotFoundError(f"QC metrics file not found: {metric_fp}") metrics_df = pd.read_csv(metric_fp) - + # Conform the dataframe to match the table definition if "cluster_id" in metrics_df.columns: metrics_df.set_index("cluster_id", inplace=True) From 853b66f5bd39edd82ed14de635064f24c52855e4 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 13 Feb 2024 11:38:39 -0600 Subject: [PATCH 064/146] Update element_array_ephys/ephys_no_curation.py Co-authored-by: Kushal Bakshi <52367253+kushalbakshi@users.noreply.github.com> --- element_array_ephys/ephys_no_curation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 4887096c..c82f986f 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -319,8 +319,8 @@ def make(self, key): break else: raise FileNotFoundError( - f"Ephys recording data not found!" - f" Neither SpikeGLX nor Open Ephys recording files found" + "Ephys recording data not found!" + "Neither SpikeGLX nor Open Ephys recording files found" ) supported_probe_types = probe.ProbeType.fetch("probe_type") From 67ebf4e994411617826263782142fcfc270b98f0 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 13 Feb 2024 11:38:47 -0600 Subject: [PATCH 065/146] Update element_array_ephys/ephys_no_curation.py Co-authored-by: Kushal Bakshi <52367253+kushalbakshi@users.noreply.github.com> --- element_array_ephys/ephys_no_curation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index c82f986f..0efddf9a 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -618,7 +618,7 @@ class ClusteringParamSet(dj.Lookup): ClusteringMethod (dict): ClusteringMethod primary key. paramset_desc (varchar(128) ): Description of the clustering parameter set. param_set_hash (uuid): UUID hash for the parameter set. - params (longblob) + params (longblob): Set of clustering parameters. """ definition = """ From 5fe60434655f8f0c9f8cb2b2ecaa63e4a3a28e2a Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 13 Feb 2024 11:38:51 -0600 Subject: [PATCH 066/146] Update element_array_ephys/ephys_no_curation.py Co-authored-by: Kushal Bakshi <52367253+kushalbakshi@users.noreply.github.com> --- element_array_ephys/ephys_no_curation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 0efddf9a..92224409 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1488,7 +1488,7 @@ class Waveform(dj.Part): recovery_slope (float): Slope of the regression line fit to first 30 microseconds from peak to tail. spread (float): The range with amplitude over 12-percent of maximum amplitude along the probe. velocity_above (float): inverse velocity of waveform propagation from soma to the top of the probe. - velocity_below (float) inverse velocity of waveform propagation from soma toward the bottom of the probe. + velocity_below (float): inverse velocity of waveform propagation from soma toward the bottom of the probe. """ definition = """ From ac08163cb55ac46d97a1b5995d965be6e8140ed8 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 13 Feb 2024 11:48:41 -0600 Subject: [PATCH 067/146] ci: run test only on the main branch --- .github/workflows/test.yaml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index acaddca0..fec7ce0c 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -1,11 +1,13 @@ name: Test on: push: + branches: + - main pull_request: + branches: + - main workflow_dispatch: jobs: - # devcontainer-build: - # uses: datajoint/.github/.github/workflows/devcontainer-build.yaml@main tests: runs-on: ubuntu-latest strategy: @@ -31,4 +33,3 @@ jobs: run: | python_version=${{matrix.py_ver}} black element_array_ephys --check --verbose --target-version py${python_version//.} - From e95331c54babe966ccbb6ce902463eb48c869c6d Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 13 Feb 2024 14:33:26 -0600 Subject: [PATCH 068/146] build: :heavy_plus_sign: add spikingcircus dependencies --- setup.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index f93247b6..ebf5d114 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,6 @@ from os import path -from setuptools import find_packages, setup +from setuptools import find_packages, setup pkg_name = "element_array_ephys" here = path.abspath(path.dirname(__file__)) @@ -16,7 +16,7 @@ setup( name=pkg_name.replace("_", "-"), - python_requires='>=3.7, <3.11', + python_requires=">=3.7, <3.11", version=__version__, # noqa F821 description="Extracellular Array Electrophysiology DataJoint Element", long_description=long_description, @@ -50,5 +50,6 @@ ], "nwb": ["dandi", "neuroconv[ecephys]", "pynwb"], "tests": ["pre-commit", "pytest", "pytest-cov"], + "spikingcircus": ["hdbscan", "numba"], }, ) From be5135e05ba40bb2054a18e9b00486ea1ba411d0 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 13 Feb 2024 16:00:47 -0600 Subject: [PATCH 069/146] refactor: fix typo & black formatting --- element_array_ephys/ephys_no_curation.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 76fdeefc..b105f8f8 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -35,7 +35,7 @@ def activate( Args: ephys_schema_name (str): A string containing the name of the ephys schema. - probe_schema_name (str): A string containing the name of the probe scehma. + probe_schema_name (str): A string containing the name of the probe schema. create_schema (bool): If True, schema will be created in the database. create_tables (bool): If True, tables related to the schema will be created in the database. linking_module (str): A string containing the module name or module containing the required dependencies to activate the schema. @@ -1174,11 +1174,11 @@ def make(self, key): we: si.WaveformExtractor = si.load_waveforms( output_dir / "waveform", with_recording=False ) - unit_id_to_peak_channel_indices: dict[int, np.ndarray] = ( - si.ChannelSparsity.from_best_channels( - we, 1, peak_sign="neg" - ).unit_id_to_channel_indices - ) # {unit: peak_channel_index} + unit_id_to_peak_channel_indices: dict[ + int, np.ndarray + ] = si.ChannelSparsity.from_best_channels( + we, 1, peak_sign="neg" + ).unit_id_to_channel_indices # {unit: peak_channel_index} units = (CuratedClustering.Unit & key).fetch("KEY", order_by="unit") From 4b6fc0e9fe45f2a6b44be466f0731faad301734c Mon Sep 17 00:00:00 2001 From: JaerongA Date: Wed, 14 Feb 2024 20:10:56 -0600 Subject: [PATCH 070/146] feat: :sparkles: add EphysRecording.Channel part table --- element_array_ephys/ephys_no_curation.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index b105f8f8..25b2c147 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -284,6 +284,15 @@ class EphysRecording(dj.Imported): recording_duration: float # (seconds) duration of the recording from this probe """ + class Channel(dj.Part): + definitoin = """ + -> master + channel_idx: int # channel index + --- + -> probe.ElectrodeConfig.Electrode + channel_name="": varchar(64) + """ + class EphysFile(dj.Part): """Paths of electrophysiology recording files for each insertion. From 48025112da9acaab8b6e043b8da56e54a9e9725d Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 16 Feb 2024 14:25:45 -0600 Subject: [PATCH 071/146] fix: :bug: fix get_logger missing error --- element_array_ephys/__init__.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/element_array_ephys/__init__.py b/element_array_ephys/__init__.py index 1c0c7285..3a0e5af6 100644 --- a/element_array_ephys/__init__.py +++ b/element_array_ephys/__init__.py @@ -1 +1,22 @@ +""" +isort:skip_file +""" + +import logging +import os + +import datajoint as dj + + +__all__ = ["ephys", "get_logger"] + +dj.config["enable_python_native_blobs"] = True + + +def get_logger(name): + log = logging.getLogger(name) + log.setLevel(os.getenv("LOGLEVEL", "INFO")) + return log + + from . import ephys_acute as ephys From bca67b3e47138cd2d7eaaf6c0b892557ff576786 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 16 Feb 2024 21:03:28 +0000 Subject: [PATCH 072/146] fix typo & remove sorter_name --- element_array_ephys/ephys_no_curation.py | 2 +- element_array_ephys/spike_sorting/si_spike_sorting.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 25b2c147..5894fe16 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -285,7 +285,7 @@ class EphysRecording(dj.Imported): """ class Channel(dj.Part): - definitoin = """ + definition = """ -> master channel_idx: int # channel index --- diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 432b6c10..461987e6 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -191,7 +191,6 @@ class SIClustering(dj.Imported): definition = """ -> PreProcessing - sorter_name: varchar(30) # name of the sorter used --- execution_time: datetime # datetime of the start of this step execution_duration: float # execution duration in hours @@ -231,7 +230,6 @@ def make(self, key): self.insert1( { **key, - "sorter_name": sorter_name, "execution_time": execution_time, "execution_duration": ( datetime.utcnow() - execution_time From 1faa8f2e0e7f78f1097a220e4282ea4f8d6e7d1b Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 16 Feb 2024 21:08:15 +0000 Subject: [PATCH 073/146] feat: :sparkles: add memoized_result implementation in SIClustering --- .../spike_sorting/si_spike_sorting.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 461987e6..6acd5a2b 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -26,7 +26,7 @@ import probeinterface as pi import spikeinterface as si from element_array_ephys import get_logger, probe, readers -from element_interface.utils import find_full_path +from element_interface.utils import find_full_path, memoized_result from spikeinterface import exporters, postprocessing, qualitymetrics, sorters from . import si_preprocessing @@ -213,7 +213,17 @@ def make(self, key): ) # Run sorting - si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter( + @memoized_result( + parameters={**key, **params}, + output_directory=output_dir / sorter_name, + ) + def _run_sorter(*args, **kwargs): + si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter(*args, **kwargs) + sorting_save_path = output_dir / sorter_name / "si_sorting.pkl" + si_sorting.dump_to_pickle(sorting_save_path) + return sorting_save_path + + sorting_save_path = _run_sorter( sorter_name=sorter_name, recording=si_recording, output_folder=output_dir / sorter_name, From 58f3a4453a821e67bef49971d74f4e97d9f97ef2 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 16 Feb 2024 22:40:01 +0000 Subject: [PATCH 074/146] create a folder for storing recording pickle object --- .../spike_sorting/si_spike_sorting.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 6acd5a2b..3205d056 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -95,11 +95,16 @@ def make(self, key): execution_time = datetime.utcnow() # Set the output directory - acq_software, output_dir, params = ( + clustering_method, acq_software, output_dir, params = ( ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key - ).fetch1("acq_software", "clustering_output_dir", "params") - - for req_key in ( + ).fetch1("clustering_method", "acq_software", "clustering_output_dir", "params") + + # Get sorter method and create output directory. + sorter_name = ( + "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method + ) + + for required_key in ( "SI_SORTING_PARAMS", "SI_PREPROCESSING_METHOD", "SI_WAVEFORM_EXTRACTION_PARAMS", @@ -110,6 +115,7 @@ def make(self, key): f"{req_key} must be defined in ClusteringParamSet for SpikeInterface execution" ) + # Set directory to store recording file. if not output_dir: output_dir = ephys.ClusteringTask.infer_output_dir( key, relative=True, mkdir=True @@ -118,11 +124,11 @@ def make(self, key): ephys.ClusteringTask.update1( {**key, "clustering_output_dir": output_dir.as_posix()} ) - output_dir = pathlib.Path(output_dir) - output_full_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - + output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + recording_dir = output_dir / sorter_name / "recording" + recording_dir.mkdir(parents=True, exist_ok=True) recording_file = ( - output_full_dir / "si_recording.pkl" + recording_dir / "si_recording.pkl" ) # recording cache to be created for each key # Create SI recording extractor object From 4fcea517577095232ef317e4a77421e3181b3632 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 16 Feb 2024 22:41:31 +0000 Subject: [PATCH 075/146] install element_interface from datajoint upstream --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index ebf5d114..532c72f6 100644 --- a/setup.py +++ b/setup.py @@ -44,7 +44,7 @@ "elements": [ "element-animal @ git+https://github.com/datajoint/element-animal.git", "element-event @ git+https://github.com/datajoint/element-event.git", - "element-interface @ git+https://github.com/datajoint/element-interface.git", + "element-interface @ git+https://github.com/datajoint/element-interface.git@dev_memoized_results", "element-lab @ git+https://github.com/datajoint/element-lab.git", "element-session @ git+https://github.com/datajoint/element-session.git", ], From 4f0e0204cd5b6c31d78576dd3c3ff9b88c5598c3 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 16 Feb 2024 22:52:49 +0000 Subject: [PATCH 076/146] add required_key for parameters --- element_array_ephys/spike_sorting/si_spike_sorting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 3205d056..d09a2c6b 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -110,9 +110,9 @@ def make(self, key): "SI_WAVEFORM_EXTRACTION_PARAMS", "SI_QUALITY_METRICS_PARAMS", ): - if req_key not in params: + if required_key not in params: raise ValueError( - f"{req_key} must be defined in ClusteringParamSet for SpikeInterface execution" + f"{required_key} must be defined in ClusteringParamSet for SpikeInterface execution" ) # Set directory to store recording file. From 83e7a166c18e7b20d83e05a6c050ff02bde9b74e Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 16 Feb 2024 22:55:58 +0000 Subject: [PATCH 077/146] set recording channel info --- element_array_ephys/spike_sorting/si_spike_sorting.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index d09a2c6b..6bf3c4bd 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -165,11 +165,10 @@ def make(self, key): .fetch(format="frame") .reset_index()[["electrode", "x_coord", "y_coord", "shank"]] ) - channels_details = ephys.get_recording_channels_details(key) - + # Create SI probe object si_probe = readers.probe_geometry.to_probeinterface(electrodes_df) - si_probe.set_device_channel_indices(channels_details["channel_ind"]) + si_probe.set_device_channel_indices(range(len(electrodes_df))) si_recording.set_probe(probe=si_probe, in_place=True) # Run preprocessing and save results to output folder From 1d18a39cf22d81d3ce74aa591c5c8f4138895b01 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 16 Feb 2024 22:57:55 +0000 Subject: [PATCH 078/146] fix loading preprocessor --- .../spike_sorting/si_preprocessing.py | 56 ++----------------- .../spike_sorting/si_spike_sorting.py | 4 +- 2 files changed, 5 insertions(+), 55 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_preprocessing.py b/element_array_ephys/spike_sorting/si_preprocessing.py index 2edf443d..07a49293 100644 --- a/element_array_ephys/spike_sorting/si_preprocessing.py +++ b/element_array_ephys/spike_sorting/si_preprocessing.py @@ -2,7 +2,7 @@ from spikeinterface import preprocessing -def mimic_catGT(recording): +def catGT(recording): recording = si.preprocessing.phase_shift(recording) recording = si.preprocessing.common_reference( recording, operator="median", reference="global" @@ -10,7 +10,7 @@ def mimic_catGT(recording): return recording -def mimic_IBLdestriping(recording): +def IBLdestriping(recording): # From International Brain Laboratory. “Spike sorting pipeline for the International Brain Laboratory”. 4 May 2022. 9 Jun 2022. recording = si.preprocessing.highpass_filter(recording, freq_min=400.0) bad_channel_ids, channel_labels = si.preprocessing.detect_bad_channels(recording) @@ -24,7 +24,7 @@ def mimic_IBLdestriping(recording): return recording -def mimic_IBLdestriping_modified(recording): +def IBLdestriping_modified(recording): # From SpikeInterface Implementation (https://spikeinterface.readthedocs.io/en/latest/how_to/analyse_neuropixels.html) recording = si.preprocessing.highpass_filter(recording, freq_min=400.0) bad_channel_ids, channel_labels = si.preprocessing.detect_bad_channels(recording) @@ -34,52 +34,4 @@ def mimic_IBLdestriping_modified(recording): recording = si.preprocessing.common_reference( recording, operator="median", reference="global" ) - return recording - - -preprocessing_function_mapping = { - "catGT": mimic_catGT, - "IBLdestriping": mimic_IBLdestriping, - "IBLdestriping_modified": mimic_IBLdestriping_modified, -} - - -## Example SI parameter set -""" -{'detect_threshold': 6, - 'projection_threshold': [10, 4], - 'preclust_threshold': 8, - 'car': True, - 'minFR': 0.02, - 'minfr_goodchannels': 0.1, - 'nblocks': 5, - 'sig': 20, - 'freq_min': 150, - 'sigmaMask': 30, - 'nPCs': 3, - 'ntbuff': 64, - 'nfilt_factor': 4, - 'NT': None, - 'do_correction': True, - 'wave_length': 61, - 'keep_good_only': False, - 'PreProcessing_params': {'Filter': False, - 'BandpassFilter': True, - 'HighpassFilter': False, - 'NotchFilter': False, - 'NormalizeByQuantile': False, - 'Scale': False, - 'Center': False, - 'ZScore': False, - 'Whiten': False, - 'CommonReference': False, - 'PhaseShift': False, - 'Rectify': False, - 'Clip': False, - 'BlankSaturation': False, - 'RemoveArtifacts': False, - 'RemoveBadChannels': False, - 'ZeroChannelPad': False, - 'DeepInterpolation': False, - 'Resample': False}} -""" + return recording \ No newline at end of file diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 6bf3c4bd..c8d2f1b7 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -172,9 +172,7 @@ def make(self, key): si_recording.set_probe(probe=si_probe, in_place=True) # Run preprocessing and save results to output folder - si_preproc_func = si_preprocessing.preprocessing_function_mapping[ - params["SI_PREPROCESSING_METHOD"] - ] + si_preproc_func = getattr(si_preprocessing, params["SI_PREPROCESSING_METHOD"]) si_recording = si_preproc_func(si_recording) si_recording.dump_to_pickle(file_path=recording_file) From b0a863fec91f82a1171ae2c7041f76cf7dc612aa Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 16 Feb 2024 23:18:42 +0000 Subject: [PATCH 079/146] make all output dir non-sharable --- .../spike_sorting/si_spike_sorting.py | 40 +++++++++++-------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index c8d2f1b7..13f569e7 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -207,39 +207,35 @@ def make(self, key): ephys.ClusteringTask * ephys.ClusteringParamSet & key ).fetch1("clustering_method", "clustering_output_dir", "params") output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - recording_file = output_dir / "si_recording.pkl" - si_recording: si.BaseRecording = si.load_extractor(recording_file) # Get sorter method and create output directory. sorter_name = ( "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method ) - + recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl" + si_recording: si.BaseRecording = si.load_extractor(recording_file) + # Run sorting @memoized_result( parameters={**key, **params}, - output_directory=output_dir / sorter_name, + output_directory=output_dir / sorter_name / "spike_sorting", ) def _run_sorter(*args, **kwargs): si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter(*args, **kwargs) - sorting_save_path = output_dir / sorter_name / "si_sorting.pkl" + sorting_save_path = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" si_sorting.dump_to_pickle(sorting_save_path) return sorting_save_path sorting_save_path = _run_sorter( sorter_name=sorter_name, recording=si_recording, - output_folder=output_dir / sorter_name, + output_folder=output_dir / sorter_name / "spike_sorting", remove_existing_folder=True, verbose=True, docker_image=True, **params.get("SI_SORTING_PARAMS", {}), ) - # Run sorting - sorting_save_path = output_dir / "si_sorting.pkl" - si_sorting.dump_to_pickle(sorting_save_path) - self.insert1( { **key, @@ -266,13 +262,20 @@ class PostProcessing(dj.Imported): def make(self, key): execution_time = datetime.utcnow() - # Load sorting & recording object. - output_dir, params = (ephys.ClusteringTask & key).fetch1( - "clustering_output_dir", "params" + # Load recording object. + clustering_method, output_dir, params = ( + ephys.ClusteringTask * ephys.ClusteringParamSet & key + ).fetch1("clustering_method", "clustering_output_dir", "params") + output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + + # Get sorter method and create output directory. + sorter_name = ( + "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method ) + output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - recording_file = output_dir / "si_recording.pkl" - sorting_file = output_dir / "si_sorting.pkl" + recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl" + sorting_file = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" si_recording: si.BaseRecording = si.load_extractor(recording_file) si_sorting: si.sorters.BaseSorter = si.load_extractor(sorting_file) @@ -281,7 +284,7 @@ def make(self, key): we: si.WaveformExtractor = si.extract_waveforms( si_recording, si_sorting, - folder=output_dir / "waveform", # The folder where waveforms are cached + folder=output_dir / sorter_name / "waveform", # The folder where waveforms are cached max_spikes_per_unit=None, overwrite=True, **params.get("SI_WAVEFORM_EXTRACTION_PARAMS", {}), @@ -309,8 +312,11 @@ def make(self, key): waveform_extractor=we, **params.get("SI_QUALITY_METRICS_PARAMS", None) ) # Save the output (metrics.csv to the output dir) + metrics_output_dir = output_dir / sorter_name / "metrics" + metrics_output_dir.mkdir(parents=True, exist_ok=True) + metrics = si.qualitymetrics.compute_quality_metrics(waveform_extractor=we) - metrics.to_csv(output_dir / "metrics.csv") + metrics.to_csv(metrics_output_dir / "metrics.csv") # Save results self.insert1( From 7d28351baa629e1a72e0d9a52c67f0d6590af689 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Mon, 19 Feb 2024 15:30:30 -0600 Subject: [PATCH 080/146] refactor & accept changes from code review --- element_array_ephys/ephys_no_curation.py | 12 +++++----- .../spike_sorting/si_preprocessing.py | 2 +- .../spike_sorting/si_spike_sorting.py | 22 +++++++++++-------- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 5894fe16..acf7c76f 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -989,7 +989,7 @@ def make(self, key): extremum_channel_inds=unit_peak_channel_map ) - # Get electrode info + # Get electrode info !#TODO: need to be modified electrode_config_key = ( EphysRecording * probe.ElectrodeConfig & key ).fetch1("KEY") @@ -1183,11 +1183,11 @@ def make(self, key): we: si.WaveformExtractor = si.load_waveforms( output_dir / "waveform", with_recording=False ) - unit_id_to_peak_channel_indices: dict[ - int, np.ndarray - ] = si.ChannelSparsity.from_best_channels( - we, 1, peak_sign="neg" - ).unit_id_to_channel_indices # {unit: peak_channel_index} + unit_id_to_peak_channel_indices: dict[int, np.ndarray] = ( + si.ChannelSparsity.from_best_channels( + we, 1, peak_sign="neg" + ).unit_id_to_channel_indices + ) # {unit: peak_channel_index} units = (CuratedClustering.Unit & key).fetch("KEY", order_by="unit") diff --git a/element_array_ephys/spike_sorting/si_preprocessing.py b/element_array_ephys/spike_sorting/si_preprocessing.py index 07a49293..4db5f303 100644 --- a/element_array_ephys/spike_sorting/si_preprocessing.py +++ b/element_array_ephys/spike_sorting/si_preprocessing.py @@ -34,4 +34,4 @@ def IBLdestriping_modified(recording): recording = si.preprocessing.common_reference( recording, operator="median", reference="global" ) - return recording \ No newline at end of file + return recording diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 13f569e7..1b8366dc 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -26,7 +26,7 @@ import probeinterface as pi import spikeinterface as si from element_array_ephys import get_logger, probe, readers -from element_interface.utils import find_full_path, memoized_result +from element_interface.utils import find_full_path # , memoized_result from spikeinterface import exporters, postprocessing, qualitymetrics, sorters from . import si_preprocessing @@ -98,12 +98,12 @@ def make(self, key): clustering_method, acq_software, output_dir, params = ( ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key ).fetch1("clustering_method", "acq_software", "clustering_output_dir", "params") - + # Get sorter method and create output directory. sorter_name = ( "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method ) - + for required_key in ( "SI_SORTING_PARAMS", "SI_PREPROCESSING_METHOD", @@ -165,7 +165,7 @@ def make(self, key): .fetch(format="frame") .reset_index()[["electrode", "x_coord", "y_coord", "shank"]] ) - + # Create SI probe object si_probe = readers.probe_geometry.to_probeinterface(electrodes_df) si_probe.set_device_channel_indices(range(len(electrodes_df))) @@ -214,7 +214,7 @@ def make(self, key): ) recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl" si_recording: si.BaseRecording = si.load_extractor(recording_file) - + # Run sorting @memoized_result( parameters={**key, **params}, @@ -222,7 +222,9 @@ def make(self, key): ) def _run_sorter(*args, **kwargs): si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter(*args, **kwargs) - sorting_save_path = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" + sorting_save_path = ( + output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" + ) si_sorting.dump_to_pickle(sorting_save_path) return sorting_save_path @@ -272,7 +274,7 @@ def make(self, key): sorter_name = ( "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method ) - + output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl" sorting_file = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" @@ -284,7 +286,9 @@ def make(self, key): we: si.WaveformExtractor = si.extract_waveforms( si_recording, si_sorting, - folder=output_dir / sorter_name / "waveform", # The folder where waveforms are cached + folder=output_dir + / sorter_name + / "waveform", # The folder where waveforms are cached max_spikes_per_unit=None, overwrite=True, **params.get("SI_WAVEFORM_EXTRACTION_PARAMS", {}), @@ -314,7 +318,7 @@ def make(self, key): # Save the output (metrics.csv to the output dir) metrics_output_dir = output_dir / sorter_name / "metrics" metrics_output_dir.mkdir(parents=True, exist_ok=True) - + metrics = si.qualitymetrics.compute_quality_metrics(waveform_extractor=we) metrics.to_csv(metrics_output_dir / "metrics.csv") From 95f5286704ca51a8768a1c8bad41d2ae2de94767 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Mon, 19 Feb 2024 16:50:15 -0600 Subject: [PATCH 081/146] remove memoized_result for testing --- .../spike_sorting/si_spike_sorting.py | 21 +++++++------------ 1 file changed, 7 insertions(+), 14 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 1b8366dc..0e3da684 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -26,7 +26,7 @@ import probeinterface as pi import spikeinterface as si from element_array_ephys import get_logger, probe, readers -from element_interface.utils import find_full_path # , memoized_result +from element_interface.utils import find_full_path from spikeinterface import exporters, postprocessing, qualitymetrics, sorters from . import si_preprocessing @@ -216,19 +216,7 @@ def make(self, key): si_recording: si.BaseRecording = si.load_extractor(recording_file) # Run sorting - @memoized_result( - parameters={**key, **params}, - output_directory=output_dir / sorter_name / "spike_sorting", - ) - def _run_sorter(*args, **kwargs): - si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter(*args, **kwargs) - sorting_save_path = ( - output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" - ) - si_sorting.dump_to_pickle(sorting_save_path) - return sorting_save_path - - sorting_save_path = _run_sorter( + si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter( sorter_name=sorter_name, recording=si_recording, output_folder=output_dir / sorter_name / "spike_sorting", @@ -238,6 +226,11 @@ def _run_sorter(*args, **kwargs): **params.get("SI_SORTING_PARAMS", {}), ) + sorting_save_path = ( + output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" + ) + si_sorting.dump_to_pickle(sorting_save_path) + self.insert1( { **key, From a7ebb9a61c6c2e90fd01d7f5529b8409b36889c3 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 20 Feb 2024 17:59:46 -0600 Subject: [PATCH 082/146] build: :heavy_plus_sign: add element-interface to required packages --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 532c72f6..204008a6 100644 --- a/setup.py +++ b/setup.py @@ -39,12 +39,12 @@ "scikit-image", "nbformat>=4.2.0", "pyopenephys>=1.1.6", + "element-interface @ git+https://github.com/datajoint/element-interface.git", ], extras_require={ "elements": [ "element-animal @ git+https://github.com/datajoint/element-animal.git", "element-event @ git+https://github.com/datajoint/element-event.git", - "element-interface @ git+https://github.com/datajoint/element-interface.git@dev_memoized_results", "element-lab @ git+https://github.com/datajoint/element-lab.git", "element-session @ git+https://github.com/datajoint/element-session.git", ], From 134ff54eb124896f6fd70f5d335680ed7c2c9a06 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 20 Feb 2024 18:00:07 -0600 Subject: [PATCH 083/146] update pre-commit with the latest hooks --- .pre-commit-config.yaml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0d513df7..6d28ef11 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,7 +3,7 @@ exclude: (^.github/|^docs/|^images/) repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.5.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -16,7 +16,7 @@ repos: # black - repo: https://github.com/psf/black - rev: 22.12.0 + rev: 24.2.0 hooks: - id: black - id: black-jupyter @@ -25,7 +25,7 @@ repos: # isort - repo: https://github.com/pycqa/isort - rev: 5.11.2 + rev: 5.13.2 hooks: - id: isort args: ["--profile", "black"] @@ -33,7 +33,7 @@ repos: # flake8 - repo: https://github.com/pycqa/flake8 - rev: 4.0.1 + rev: 7.0.0 hooks: - id: flake8 args: # arguments to configure flake8 From 79724268bd3b0c29f129b2fc3781577a931eb098 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 27 Feb 2024 17:25:37 -0600 Subject: [PATCH 084/146] build: :heavy_plus_sign: Add numba as required package --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 204008a6..52cd38b1 100644 --- a/setup.py +++ b/setup.py @@ -40,6 +40,7 @@ "nbformat>=4.2.0", "pyopenephys>=1.1.6", "element-interface @ git+https://github.com/datajoint/element-interface.git", + "numba", ], extras_require={ "elements": [ @@ -50,6 +51,6 @@ ], "nwb": ["dandi", "neuroconv[ecephys]", "pynwb"], "tests": ["pre-commit", "pytest", "pytest-cov"], - "spikingcircus": ["hdbscan", "numba"], + "spikingcircus": ["hdbscan"], }, ) From ed11526a00649cf6b1849802720c151a33beb374 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 27 Feb 2024 17:26:24 -0600 Subject: [PATCH 085/146] adjust extract_waveforms parameters --- element_array_ephys/spike_sorting/si_spike_sorting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 0e3da684..6df25de8 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -282,8 +282,8 @@ def make(self, key): folder=output_dir / sorter_name / "waveform", # The folder where waveforms are cached - max_spikes_per_unit=None, overwrite=True, + allow_unfiltered=True, **params.get("SI_WAVEFORM_EXTRACTION_PARAMS", {}), **params.get("SI_JOB_KWARGS", {"n_jobs": -1, "chunk_size": 30000}), ) From bab86b7c1471f02e2c9d95c1f4ee8442345b3817 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Wed, 28 Feb 2024 22:52:28 -0600 Subject: [PATCH 086/146] refactor: :recycle: update the output dir for CuratedClustering --- element_array_ephys/ephys_no_curation.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index acf7c76f..71648fec 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -967,22 +967,31 @@ class Unit(dj.Part): def make(self, key): """Automated population of Unit information.""" - output_dir = (ClusteringTask & key).fetch1("clustering_output_dir") + clustering_method, output_dir = ( + ClusteringTask * ClusteringParamSet & key + ).fetch1("clustering_method", "clustering_output_dir") output_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - if (output_dir / "waveform").exists(): # read from spikeinterface outputs + # Get sorter method and create output directory. + sorter_name = ( + "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method + ) + waveform_dir = output_dir / sorter_name / "waveform" + sorting_dir = output_dir / sorter_name / "spike_sorting" + + if waveform_dir.exists(): # read from spikeinterface outputs we: si.WaveformExtractor = si.load_waveforms( - output_dir / "waveform", with_recording=False + waveform_dir, with_recording=False ) si_sorting: si.sorters.BaseSorter = si.load_extractor( - output_dir / "sorting.pkl" + sorting_dir / "si_sorting.pkl" ) unit_peak_channel_map: dict[int, int] = si.get_template_extremum_channel( we, outputs="index" ) # {unit: peak_channel_index} - spike_count_dict = dict[int, int] = si_sorting.count_num_spikes_per_unit() + spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit() # {unit: spike_count} spikes = si_sorting.to_spike_vector( From 0898ce5cdaeb7656ed7142f395351dfcde652d40 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Thu, 29 Feb 2024 10:08:22 -0600 Subject: [PATCH 087/146] feat: :sparkles: add quality label mapping --- element_array_ephys/ephys_no_curation.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 71648fec..cb91baa9 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1011,6 +1011,21 @@ def make(self, key): zip(*electrode_query.fetch("channel", "electrode")) ) + # Get unit id to quality label mapping + cluster_quality_label_map = {} + try: + cluster_quality_label_map = pd.read_csv( + sorting_dir / "sorter_output" / "cluster_KSLabel.tsv", + delimiter="\t", + ) + cluster_quality_label_map: dict[ + int, str + ] = cluster_quality_label_map.set_index("cluster_id")[ + "KSLabel" + ].to_dict() # {unit: quality_label} + except FileNotFoundError: + pass + # Get channel to electrode mapping channel2depth_map = dict(zip(*electrode_query.fetch("channel", "y_coord"))) @@ -1038,7 +1053,9 @@ def make(self, key): units.append( { "unit": unit_id, - "cluster_quality_label": "n.a.", + "cluster_quality_label": cluster_quality_label_map.get( + unit_id, "n.a." + ), "spike_times": si_sorting.get_unit_spike_train( unit_id, return_times=True ), From 727af24cca6f00fc35651f88f7f258d2ab67e43e Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 1 Mar 2024 10:19:09 -0600 Subject: [PATCH 088/146] feat: :sparkles: Ingest EphysRecording.Channel --- element_array_ephys/ephys_no_curation.py | 242 +++++++++++++---------- 1 file changed, 133 insertions(+), 109 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index cb91baa9..e9adc290 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -8,7 +8,9 @@ import datajoint as dj import numpy as np import pandas as pd +import spikeinterface as si from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory +from spikeinterface import exporters, postprocessing, qualitymetrics, sorters from . import ephys_report, probe from .readers import kilosort, openephys, spikeglx @@ -19,9 +21,6 @@ _linking_module = None -import spikeinterface as si -from spikeinterface import exporters, postprocessing, qualitymetrics, sorters - def activate( ephys_schema_name: str, @@ -327,129 +326,154 @@ def make(self, key): break else: raise FileNotFoundError( - "Ephys recording data not found!" + f"Ephys recording data not found! for {key}." "Neither SpikeGLX nor Open Ephys recording files found" ) - supported_probe_types = probe.ProbeType.fetch("probe_type") + if acq_software not in AcquisitionSoftware.fetch("acq_software"): + raise NotImplementedError( + f"Processing ephys files from acquisition software of type {acq_software} is not yet implemented." + ) - if acq_software == "SpikeGLX": - for meta_filepath in ephys_meta_filepaths: - spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) - if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: - break - else: - raise FileNotFoundError( - "No SpikeGLX data found for probe insertion: {}".format(key) - ) + else: + supported_probe_types = probe.ProbeType.fetch("probe_type") + + if acq_software == "SpikeGLX": + for meta_filepath in ephys_meta_filepaths: + spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) + if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: + break + else: + raise FileNotFoundError( + "No SpikeGLX data found for probe insertion: {}".format(key) + ) + + if spikeglx_meta.probe_model in supported_probe_types: + probe_type = spikeglx_meta.probe_model + electrode_query = probe.ProbeType.Electrode & { + "probe_type": probe_type + } - if spikeglx_meta.probe_model in supported_probe_types: - probe_type = spikeglx_meta.probe_model - electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} + probe_electrodes = { + (shank, shank_col, shank_row): key + for key, shank, shank_col, shank_row in zip( + *electrode_query.fetch( + "KEY", "shank", "shank_col", "shank_row" + ) + ) + } - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") + electrode_group_members = [ + probe_electrodes[(shank, shank_col, shank_row)] + for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap[ + "data" + ] + ] + else: + raise NotImplementedError( + "Processing for neuropixels probe model" + " {} not yet implemented".format(spikeglx_meta.probe_model) ) - } - electrode_group_members = [ - probe_electrodes[(shank, shank_col, shank_row)] - for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"] - ] - else: - raise NotImplementedError( - "Processing for neuropixels probe model" - " {} not yet implemented".format(spikeglx_meta.probe_model) + self.insert1( + { + **key, + **generate_electrode_config( + probe_type, electrode_group_members + ), + "acq_software": acq_software, + "sampling_rate": spikeglx_meta.meta["imSampRate"], + "recording_datetime": spikeglx_meta.recording_time, + "recording_duration": ( + spikeglx_meta.recording_duration + or spikeglx.retrieve_recording_duration(meta_filepath) + ), + } ) - self.insert1( - { - **key, - **generate_electrode_config(probe_type, electrode_group_members), - "acq_software": acq_software, - "sampling_rate": spikeglx_meta.meta["imSampRate"], - "recording_datetime": spikeglx_meta.recording_time, - "recording_duration": ( - spikeglx_meta.recording_duration - or spikeglx.retrieve_recording_duration(meta_filepath) - ), - } - ) - - root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath) - self.EphysFile.insert1( - {**key, "file_path": meta_filepath.relative_to(root_dir).as_posix()} - ) - elif acq_software == "Open Ephys": - dataset = openephys.OpenEphys(session_dir) - for serial_number, probe_data in dataset.probes.items(): - if str(serial_number) == inserted_probe_serial_number: - break - else: - raise FileNotFoundError( - "No Open Ephys data found for probe insertion: {}".format(key) + root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath) + self.EphysFile.insert1( + {**key, "file_path": meta_filepath.relative_to(root_dir).as_posix()} ) + elif acq_software == "Open Ephys": + dataset = openephys.OpenEphys(session_dir) + for serial_number, probe_data in dataset.probes.items(): + if str(serial_number) == inserted_probe_serial_number: + break + else: + raise FileNotFoundError( + "No Open Ephys data found for probe insertion: {}".format(key) + ) - if not probe_data.ap_meta: - raise IOError( - 'No analog signals found - check "structure.oebin" file or "continuous" directory' - ) + if not probe_data.ap_meta: + raise IOError( + 'No analog signals found - check "structure.oebin" file or "continuous" directory' + ) - if probe_data.probe_model in supported_probe_types: - probe_type = probe_data.probe_model - electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} + if probe_data.probe_model in supported_probe_types: + probe_type = probe_data.probe_model + electrode_query = probe.ProbeType.Electrode & { + "probe_type": probe_type + } - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } + probe_electrodes = { + key["electrode"]: key for key in electrode_query.fetch("KEY") + } - electrode_group_members = [ - probe_electrodes[channel_idx] - for channel_idx in probe_data.ap_meta["channels_indices"] - ] - else: - raise NotImplementedError( - "Processing for neuropixels" - " probe model {} not yet implemented".format(probe_data.probe_model) + electrode_group_members = [ + probe_electrodes[channel_idx] + for channel_idx in probe_data.ap_meta["channels_indices"] + ] + else: + raise NotImplementedError( + "Processing for neuropixels" + " probe model {} not yet implemented".format( + probe_data.probe_model + ) + ) + + self.insert1( + { + **key, + **generate_electrode_config( + probe_type, electrode_group_members + ), + "acq_software": acq_software, + "sampling_rate": probe_data.ap_meta["sample_rate"], + "recording_datetime": probe_data.recording_info[ + "recording_datetimes" + ][0], + "recording_duration": np.sum( + probe_data.recording_info["recording_durations"] + ), + } ) - self.insert1( - { - **key, - **generate_electrode_config(probe_type, electrode_group_members), - "acq_software": acq_software, - "sampling_rate": probe_data.ap_meta["sample_rate"], - "recording_datetime": probe_data.recording_info[ - "recording_datetimes" - ][0], - "recording_duration": np.sum( - probe_data.recording_info["recording_durations"] - ), - } - ) + root_dir = find_root_directory( + get_ephys_root_data_dir(), + probe_data.recording_info["recording_files"][0], + ) + self.EphysFile.insert( + [ + {**key, "file_path": fp.relative_to(root_dir).as_posix()} + for fp in probe_data.recording_info["recording_files"] + ] + ) + # Explicitly garbage collect "dataset" as these may have large memory footprint and may not be cleared fast enough + del probe_data, dataset + gc.collect() - root_dir = find_root_directory( - get_ephys_root_data_dir(), - probe_data.recording_info["recording_files"][0], + # Insert channel information + # Get channel and electrode-site mapping + channel2electrodes = get_neuropixels_channel2electrode_map( + key, acq_software ) - self.EphysFile.insert( + self.Channel.insert( [ - {**key, "file_path": fp.relative_to(root_dir).as_posix()} - for fp in probe_data.recording_info["recording_files"] + {**key, "channel_idx": channel_idx, **channel_info} + for channel_idx, channel_info in channel2electrodes.items() ] ) - # explicitly garbage collect "dataset" - # as these may have large memory footprint and may not be cleared fast enough - del probe_data, dataset - gc.collect() - else: - raise NotImplementedError( - f"Processing ephys files from" - f" acquisition software of type {acq_software} is" - f" not yet implemented" - ) @schema @@ -1209,11 +1233,11 @@ def make(self, key): we: si.WaveformExtractor = si.load_waveforms( output_dir / "waveform", with_recording=False ) - unit_id_to_peak_channel_indices: dict[int, np.ndarray] = ( - si.ChannelSparsity.from_best_channels( - we, 1, peak_sign="neg" - ).unit_id_to_channel_indices - ) # {unit: peak_channel_index} + unit_id_to_peak_channel_indices: dict[ + int, np.ndarray + ] = si.ChannelSparsity.from_best_channels( + we, 1, peak_sign="neg" + ).unit_id_to_channel_indices # {unit: peak_channel_index} units = (CuratedClustering.Unit & key).fetch("KEY", order_by="unit") From 1df41ea6ee8668c5b3ca5cb970dad024f65a668c Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 1 Mar 2024 11:05:05 -0600 Subject: [PATCH 089/146] get channel to electrode mapping in CuratedClustering --- element_array_ephys/ephys_no_curation.py | 30 +++++++++++++----------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index e9adc290..c313684f 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -9,7 +9,8 @@ import numpy as np import pandas as pd import spikeinterface as si -from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory +from element_interface.utils import (dict_to_uuid, find_full_path, + find_root_directory) from spikeinterface import exporters, postprocessing, qualitymetrics, sorters from . import ephys_report, probe @@ -1022,7 +1023,7 @@ def make(self, key): extremum_channel_inds=unit_peak_channel_map ) - # Get electrode info !#TODO: need to be modified + # Get electrode & channel info electrode_config_key = ( EphysRecording * probe.ElectrodeConfig & key ).fetch1("KEY") @@ -1030,10 +1031,11 @@ def make(self, key): electrode_query = ( probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode & electrode_config_key - ) + ) * (dj.U("electrode", "channel_idx") & EphysRecording.Channel) + channel2electrode_map = dict( - zip(*electrode_query.fetch("channel", "electrode")) - ) + zip(*electrode_query.fetch("channel_idx", "electrode")) + ) # {channel: electrode} # Get unit id to quality label mapping cluster_quality_label_map = {} @@ -1051,24 +1053,24 @@ def make(self, key): pass # Get channel to electrode mapping - channel2depth_map = dict(zip(*electrode_query.fetch("channel", "y_coord"))) + channel2depth_map = dict(zip(*electrode_query.fetch("channel_idx", "y_coord"))) # {channel: depth} peak_electrode_ind = np.array( [ channel2electrode_map[unit_peak_channel_map[unit_id]] for unit_id in si_sorting.unit_ids ] - ) + ) # get the electrode where peak unit activity is recorded # Get channel to depth mapping - electrode_depth_ind = np.array( + channel_depth_ind = np.array( [ channel2depth_map[unit_peak_channel_map[unit_id]] for unit_id in si_sorting.unit_ids ] ) spikes["electrode"] = peak_electrode_ind[spikes["unit_index"]] - spikes["depth"] = electrode_depth_ind[spikes["unit_index"]] + spikes["depth"] = channel_depth_ind[spikes["unit_index"]] units = [] @@ -1233,11 +1235,11 @@ def make(self, key): we: si.WaveformExtractor = si.load_waveforms( output_dir / "waveform", with_recording=False ) - unit_id_to_peak_channel_indices: dict[ - int, np.ndarray - ] = si.ChannelSparsity.from_best_channels( - we, 1, peak_sign="neg" - ).unit_id_to_channel_indices # {unit: peak_channel_index} + unit_id_to_peak_channel_indices: dict[int, np.ndarray] = ( + si.ChannelSparsity.from_best_channels( + we, 1, peak_sign="neg" + ).unit_id_to_channel_indices + ) # {unit: peak_channel_index} units = (CuratedClustering.Unit & key).fetch("KEY", order_by="unit") From 417219fd4baeeadc1d9e4feeaa29ef04c4b21555 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 1 Mar 2024 12:10:29 -0600 Subject: [PATCH 090/146] refactor: :recycle: Fix metrics directory in QualityMetrics --- element_array_ephys/ephys_no_curation.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index c313684f..70525451 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -9,8 +9,7 @@ import numpy as np import pandas as pd import spikeinterface as si -from element_interface.utils import (dict_to_uuid, find_full_path, - find_root_directory) +from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory from spikeinterface import exporters, postprocessing, qualitymetrics, sorters from . import ephys_report, probe @@ -1032,7 +1031,7 @@ def make(self, key): probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode & electrode_config_key ) * (dj.U("electrode", "channel_idx") & EphysRecording.Channel) - + channel2electrode_map = dict( zip(*electrode_query.fetch("channel_idx", "electrode")) ) # {channel: electrode} @@ -1053,7 +1052,9 @@ def make(self, key): pass # Get channel to electrode mapping - channel2depth_map = dict(zip(*electrode_query.fetch("channel_idx", "y_coord"))) # {channel: depth} + channel2depth_map = dict( + zip(*electrode_query.fetch("channel_idx", "y_coord")) + ) # {channel: depth} peak_electrode_ind = np.array( [ @@ -1570,9 +1571,14 @@ class Waveform(dj.Part): def make(self, key): """Populates tables with quality metrics data.""" # Load metrics.csv - output_dir = (ClusteringTask & key).fetch1("clustering_output_dir") + clustering_method, output_dir = ( + ClusteringTask * ClusteringParamSet & key + ).fetch1("clustering_method", "clustering_output_dir") output_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - metric_fp = output_dir / "metrics.csv" + sorter_name = ( + "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method + ) + metric_fp = output_dir / sorter_name / "metrics" / "metrics.csv" if not metric_fp.exists(): raise FileNotFoundError(f"QC metrics file not found: {metric_fp}") metrics_df = pd.read_csv(metric_fp) From f70ae4ee1e294b0ba1173fa6a9b1255e2d27f6b3 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 1 Mar 2024 13:06:22 -0600 Subject: [PATCH 091/146] feat: :sparkles: replace get_neuropixels_channel2electrode_map with channel_info --- element_array_ephys/ephys_no_curation.py | 54 ++++++++++++++++-------- 1 file changed, 36 insertions(+), 18 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 70525451..5730ebf3 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1096,7 +1096,7 @@ def make(self, key): } ) - else: + else: # read from kilosort outputs kilosort_dataset = kilosort.Kilosort(output_dir) acq_software, sample_rate = (EphysRecording & key).fetch1( "acq_software", "sampling_rate" @@ -1131,14 +1131,19 @@ def make(self, key): kilosort_dataset.extract_spike_depths() # Get channel and electrode-site mapping - channel2electrodes = get_neuropixels_channel2electrode_map( - key, acq_software + channel_info = ( + (EphysRecording.Channel & key) + .proj(..., "-channel_name") + .fetch(as_dict=True, order_by="channel_idx") ) + channel_info: dict[int, dict] = { + ch.pop("channel_idx"): ch for ch in channel_info + } # e.g., {0: {'subject': 'sglx', 'session_id': 912231859, 'insertion_number': 1, 'electrode_config_hash': UUID('8d4cc6d8-a02d-42c8-bf27-7459c39ea0ee'), 'probe_type': 'neuropixels 1.0 - 3A', 'electrode': 0}} # -- Spike-sites and Spike-depths -- spike_sites = np.array( [ - channel2electrodes[s]["electrode"] + channel_info[s]["electrode"] for s in kilosort_dataset.data["spike_sites"] ] ) @@ -1157,9 +1162,10 @@ def make(self, key): units.append( { + **key, "unit": unit, "cluster_quality_label": unit_lbl, - **channel2electrodes[unit_channel], + **channel_info[unit_channel], "spike_times": unit_spike_times, "spike_count": spike_count, "spike_sites": spike_sites[ @@ -1228,13 +1234,21 @@ class Waveform(dj.Part): def make(self, key): """Populates waveform tables.""" - output_dir = (ClusteringTask & key).fetch1("clustering_output_dir") + clustering_method, output_dir = ( + ClusteringTask * ClusteringParamSet & key + ).fetch1("clustering_method", "clustering_output_dir") output_dir = find_full_path(get_ephys_root_data_dir(), output_dir) + sorter_name = ( + "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method + ) - if (output_dir / "waveform").exists(): # read from spikeinterface outputs + if ( + output_dir / sorter_name / "waveform" + ).exists(): # read from spikeinterface outputs + waveform_dir = output_dir / sorter_name / "waveform" we: si.WaveformExtractor = si.load_waveforms( - output_dir / "waveform", with_recording=False + waveform_dir, with_recording=False ) unit_id_to_peak_channel_indices: dict[int, np.ndarray] = ( si.ChannelSparsity.from_best_channels( @@ -1299,11 +1313,15 @@ def make(self, key): EphysRecording * ProbeInsertion & key ).fetch1("acq_software", "probe") - # -- Get channel and electrode-site mapping - recording_key = (EphysRecording & key).fetch1("KEY") - channel2electrodes = get_neuropixels_channel2electrode_map( - recording_key, acq_software + # Get channel and electrode-site mapping + channel_info = ( + (EphysRecording.Channel & key) + .proj(..., "-channel_name") + .fetch(as_dict=True, order_by="channel_idx") ) + channel_info: dict[int, dict] = { + ch.pop("channel_idx"): ch for ch in channel_info + } # e.g., {0: {'subject': 'sglx', 'session_id': 912231859, 'insertion_number': 1, 'electrode_config_hash': UUID('8d4cc6d8-a02d-42c8-bf27-7459c39ea0ee'), 'probe_type': 'neuropixels 1.0 - 3A', 'electrode': 0}} # Get all units units = { @@ -1335,12 +1353,12 @@ def yield_unit_waveforms(): unit_electrode_waveforms.append( { **units[unit_no], - **channel2electrodes[channel], + **channel_info[channel], "waveform_mean": channel_waveform, } ) if ( - channel2electrodes[channel]["electrode"] + channel_info[channel]["electrode"] == units[unit_no]["electrode"] ): unit_peak_waveform = { @@ -1377,12 +1395,12 @@ def yield_unit_waveforms(): unit_electrode_waveforms.append( { **units[unit_no], - **channel2electrodes[channel], + **channel_info[channel], "waveform_mean": channel_waveform, } ) if ( - channel2electrodes[channel]["electrode"] + channel_info[channel]["electrode"] == units[unit_no]["electrode"] ): unit_peak_waveform = { @@ -1451,13 +1469,13 @@ def yield_unit_waveforms(): unit_electrode_waveforms.append( { **unit_dict, - **channel2electrodes[channel], + **channel_info[channel], "waveform_mean": channel_waveform.mean(axis=0), "waveforms": channel_waveform, } ) if ( - channel2electrodes[channel]["electrode"] + channel_info[channel]["electrode"] == unit_dict["electrode"] ): unit_peak_waveform = { From 0ccbec962b5a11107d8fe28906e79bf19c693475 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 1 Mar 2024 21:59:06 +0000 Subject: [PATCH 092/146] fix CuratedClustering make function --- element_array_ephys/ephys_no_curation.py | 36 +++++++++++++++++------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 5730ebf3..6ddd8fec 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1032,6 +1032,12 @@ def make(self, key): & electrode_config_key ) * (dj.U("electrode", "channel_idx") & EphysRecording.Channel) + channel_info = electrode_query.fetch(as_dict=True, order_by="channel_idx") + + channel_info: dict[int, dict] = { + ch.pop("channel_idx"): ch for ch in channel_info + } + channel2electrode_map = dict( zip(*electrode_query.fetch("channel_idx", "electrode")) ) # {channel: electrode} @@ -1058,7 +1064,7 @@ def make(self, key): peak_electrode_ind = np.array( [ - channel2electrode_map[unit_peak_channel_map[unit_id]] + channel_info[unit_peak_channel_map[unit_id]]["electrode"] for unit_id in si_sorting.unit_ids ] ) # get the electrode where peak unit activity is recorded @@ -1066,19 +1072,29 @@ def make(self, key): # Get channel to depth mapping channel_depth_ind = np.array( [ - channel2depth_map[unit_peak_channel_map[unit_id]] + channel_info[unit_peak_channel_map[unit_id]]["y_coord"] for unit_id in si_sorting.unit_ids ] ) - spikes["electrode"] = peak_electrode_ind[spikes["unit_index"]] - spikes["depth"] = channel_depth_ind[spikes["unit_index"]] + + # Assign electrode and depth for each spike + new_spikes = np.empty(spikes.shape, spikes.dtype.descr + [('electrode', ' Date: Fri, 1 Mar 2024 23:11:25 +0000 Subject: [PATCH 093/146] improve try except logic --- element_array_ephys/ephys_no_curation.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 6ddd8fec..b8c22f05 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1043,25 +1043,20 @@ def make(self, key): ) # {channel: electrode} # Get unit id to quality label mapping - cluster_quality_label_map = {} try: cluster_quality_label_map = pd.read_csv( sorting_dir / "sorter_output" / "cluster_KSLabel.tsv", delimiter="\t", ) + except FileNotFoundError: + cluster_quality_label_map = {} + else: cluster_quality_label_map: dict[ int, str ] = cluster_quality_label_map.set_index("cluster_id")[ "KSLabel" ].to_dict() # {unit: quality_label} - except FileNotFoundError: - pass - - # Get channel to electrode mapping - channel2depth_map = dict( - zip(*electrode_query.fetch("channel_idx", "y_coord")) - ) # {channel: depth} - + peak_electrode_ind = np.array( [ channel_info[unit_peak_channel_map[unit_id]]["electrode"] From 673faedc1c8232098888340bab687896cd89e7f1 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 1 Mar 2024 23:14:45 +0000 Subject: [PATCH 094/146] docs: :memo: update comments in ephys_no_curation --- element_array_ephys/ephys_no_curation.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index b8c22f05..f4ba5b29 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -286,10 +286,10 @@ class EphysRecording(dj.Imported): class Channel(dj.Part): definition = """ -> master - channel_idx: int # channel index + channel_idx: int # channel index (index of the raw data) --- -> probe.ElectrodeConfig.Electrode - channel_name="": varchar(64) + channel_name="": varchar(64) # alias of the channel """ class EphysFile(dj.Part): @@ -1033,14 +1033,9 @@ def make(self, key): ) * (dj.U("electrode", "channel_idx") & EphysRecording.Channel) channel_info = electrode_query.fetch(as_dict=True, order_by="channel_idx") - channel_info: dict[int, dict] = { ch.pop("channel_idx"): ch for ch in channel_info } - - channel2electrode_map = dict( - zip(*electrode_query.fetch("channel_idx", "electrode")) - ) # {channel: electrode} # Get unit id to quality label mapping try: @@ -1056,15 +1051,16 @@ def make(self, key): ] = cluster_quality_label_map.set_index("cluster_id")[ "KSLabel" ].to_dict() # {unit: quality_label} - + + # Get electrode where peak unit activity is recorded peak_electrode_ind = np.array( [ channel_info[unit_peak_channel_map[unit_id]]["electrode"] for unit_id in si_sorting.unit_ids ] - ) # get the electrode where peak unit activity is recorded + ) - # Get channel to depth mapping + # Get channel depth channel_depth_ind = np.array( [ channel_info[unit_peak_channel_map[unit_id]]["y_coord"] @@ -1707,7 +1703,7 @@ def get_openephys_probe_data(ephys_recording_key: dict) -> list: def get_neuropixels_channel2electrode_map( ephys_recording_key: dict, acq_software: str -) -> dict: +) -> dict: #TODO: remove this function """Get the channel map for neuropixels probe.""" if acq_software == "SpikeGLX": spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key) From 8ab4c58a9d5c744bc3c071f31f360873ff8ee8a2 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 1 Mar 2024 23:15:21 +0000 Subject: [PATCH 095/146] refactor: :recycle: improve if else block in EphysRecording --- element_array_ephys/ephys_no_curation.py | 237 +++++++++++------------ 1 file changed, 118 insertions(+), 119 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index f4ba5b29..5d77a041 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -335,145 +335,144 @@ def make(self, key): f"Processing ephys files from acquisition software of type {acq_software} is not yet implemented." ) - else: - supported_probe_types = probe.ProbeType.fetch("probe_type") - - if acq_software == "SpikeGLX": - for meta_filepath in ephys_meta_filepaths: - spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) - if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: - break - else: - raise FileNotFoundError( - "No SpikeGLX data found for probe insertion: {}".format(key) - ) + supported_probe_types = probe.ProbeType.fetch("probe_type") - if spikeglx_meta.probe_model in supported_probe_types: - probe_type = spikeglx_meta.probe_model - electrode_query = probe.ProbeType.Electrode & { - "probe_type": probe_type - } + if acq_software == "SpikeGLX": + for meta_filepath in ephys_meta_filepaths: + spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) + if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: + break + else: + raise FileNotFoundError( + "No SpikeGLX data found for probe insertion: {}".format(key) + ) - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch( - "KEY", "shank", "shank_col", "shank_row" - ) - ) - } + if spikeglx_meta.probe_model in supported_probe_types: + probe_type = spikeglx_meta.probe_model + electrode_query = probe.ProbeType.Electrode & { + "probe_type": probe_type + } - electrode_group_members = [ - probe_electrodes[(shank, shank_col, shank_row)] - for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap[ - "data" - ] - ] - else: - raise NotImplementedError( - "Processing for neuropixels probe model" - " {} not yet implemented".format(spikeglx_meta.probe_model) + probe_electrodes = { + (shank, shank_col, shank_row): key + for key, shank, shank_col, shank_row in zip( + *electrode_query.fetch( + "KEY", "shank", "shank_col", "shank_row" + ) ) + } - self.insert1( - { - **key, - **generate_electrode_config( - probe_type, electrode_group_members - ), - "acq_software": acq_software, - "sampling_rate": spikeglx_meta.meta["imSampRate"], - "recording_datetime": spikeglx_meta.recording_time, - "recording_duration": ( - spikeglx_meta.recording_duration - or spikeglx.retrieve_recording_duration(meta_filepath) - ), - } + electrode_group_members = [ + probe_electrodes[(shank, shank_col, shank_row)] + for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap[ + "data" + ] + ] + else: + raise NotImplementedError( + "Processing for neuropixels probe model" + " {} not yet implemented".format(spikeglx_meta.probe_model) ) - root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath) - self.EphysFile.insert1( - {**key, "file_path": meta_filepath.relative_to(root_dir).as_posix()} + self.insert1( + { + **key, + **generate_electrode_config( + probe_type, electrode_group_members + ), + "acq_software": acq_software, + "sampling_rate": spikeglx_meta.meta["imSampRate"], + "recording_datetime": spikeglx_meta.recording_time, + "recording_duration": ( + spikeglx_meta.recording_duration + or spikeglx.retrieve_recording_duration(meta_filepath) + ), + } + ) + + root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath) + self.EphysFile.insert1( + {**key, "file_path": meta_filepath.relative_to(root_dir).as_posix()} + ) + elif acq_software == "Open Ephys": + dataset = openephys.OpenEphys(session_dir) + for serial_number, probe_data in dataset.probes.items(): + if str(serial_number) == inserted_probe_serial_number: + break + else: + raise FileNotFoundError( + "No Open Ephys data found for probe insertion: {}".format(key) ) - elif acq_software == "Open Ephys": - dataset = openephys.OpenEphys(session_dir) - for serial_number, probe_data in dataset.probes.items(): - if str(serial_number) == inserted_probe_serial_number: - break - else: - raise FileNotFoundError( - "No Open Ephys data found for probe insertion: {}".format(key) - ) - if not probe_data.ap_meta: - raise IOError( - 'No analog signals found - check "structure.oebin" file or "continuous" directory' - ) + if not probe_data.ap_meta: + raise IOError( + 'No analog signals found - check "structure.oebin" file or "continuous" directory' + ) - if probe_data.probe_model in supported_probe_types: - probe_type = probe_data.probe_model - electrode_query = probe.ProbeType.Electrode & { - "probe_type": probe_type - } + if probe_data.probe_model in supported_probe_types: + probe_type = probe_data.probe_model + electrode_query = probe.ProbeType.Electrode & { + "probe_type": probe_type + } - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } + probe_electrodes = { + key["electrode"]: key for key in electrode_query.fetch("KEY") + } - electrode_group_members = [ - probe_electrodes[channel_idx] - for channel_idx in probe_data.ap_meta["channels_indices"] - ] - else: - raise NotImplementedError( - "Processing for neuropixels" - " probe model {} not yet implemented".format( - probe_data.probe_model - ) + electrode_group_members = [ + probe_electrodes[channel_idx] + for channel_idx in probe_data.ap_meta["channels_indices"] + ] + else: + raise NotImplementedError( + "Processing for neuropixels" + " probe model {} not yet implemented".format( + probe_data.probe_model ) - - self.insert1( - { - **key, - **generate_electrode_config( - probe_type, electrode_group_members - ), - "acq_software": acq_software, - "sampling_rate": probe_data.ap_meta["sample_rate"], - "recording_datetime": probe_data.recording_info[ - "recording_datetimes" - ][0], - "recording_duration": np.sum( - probe_data.recording_info["recording_durations"] - ), - } ) - root_dir = find_root_directory( - get_ephys_root_data_dir(), - probe_data.recording_info["recording_files"][0], - ) - self.EphysFile.insert( - [ - {**key, "file_path": fp.relative_to(root_dir).as_posix()} - for fp in probe_data.recording_info["recording_files"] - ] - ) - # Explicitly garbage collect "dataset" as these may have large memory footprint and may not be cleared fast enough - del probe_data, dataset - gc.collect() + self.insert1( + { + **key, + **generate_electrode_config( + probe_type, electrode_group_members + ), + "acq_software": acq_software, + "sampling_rate": probe_data.ap_meta["sample_rate"], + "recording_datetime": probe_data.recording_info[ + "recording_datetimes" + ][0], + "recording_duration": np.sum( + probe_data.recording_info["recording_durations"] + ), + } + ) - # Insert channel information - # Get channel and electrode-site mapping - channel2electrodes = get_neuropixels_channel2electrode_map( - key, acq_software + root_dir = find_root_directory( + get_ephys_root_data_dir(), + probe_data.recording_info["recording_files"][0], ) - self.Channel.insert( + self.EphysFile.insert( [ - {**key, "channel_idx": channel_idx, **channel_info} - for channel_idx, channel_info in channel2electrodes.items() + {**key, "file_path": fp.relative_to(root_dir).as_posix()} + for fp in probe_data.recording_info["recording_files"] ] ) + # Explicitly garbage collect "dataset" as these may have large memory footprint and may not be cleared fast enough + del probe_data, dataset + gc.collect() + + # Insert channel information + # Get channel and electrode-site mapping + channel2electrodes = get_neuropixels_channel2electrode_map( + key, acq_software + ) + self.Channel.insert( + [ + {**key, "channel_idx": channel_idx, **channel_info} + for channel_idx, channel_info in channel2electrodes.items() + ] + ) @schema From 226142b82614f4964a3f7c8655ffd2ca6f43dd3d Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 1 Mar 2024 23:52:16 +0000 Subject: [PATCH 096/146] feat: :sparkles: Update WaveformSet make function --- element_array_ephys/ephys_no_curation.py | 47 ++++++++---------------- 1 file changed, 16 insertions(+), 31 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 5d77a041..a721e5b6 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1248,6 +1248,16 @@ def make(self, key): "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method ) + # Get channel and electrode-site mapping + channel_info = ( + (EphysRecording.Channel & key) + .proj(..., "-channel_name") + .fetch(as_dict=True, order_by="channel_idx") + ) + channel_info: dict[int, dict] = { + ch.pop("channel_idx"): ch for ch in channel_info + } # e.g., {0: {'subject': 'sglx', 'session_id': 912231859, 'insertion_number': 1, 'electrode_config_hash': UUID('8d4cc6d8-a02d-42c8-bf27-7459c39ea0ee'), 'probe_type': 'neuropixels 1.0 - 3A', 'electrode': 0}} + if ( output_dir / sorter_name / "waveform" ).exists(): # read from spikeinterface outputs @@ -1256,27 +1266,12 @@ def make(self, key): we: si.WaveformExtractor = si.load_waveforms( waveform_dir, with_recording=False ) - unit_id_to_peak_channel_indices: dict[int, np.ndarray] = ( + unit_id_to_peak_channel_map: dict[int, np.ndarray] = ( si.ChannelSparsity.from_best_channels( we, 1, peak_sign="neg" ).unit_id_to_channel_indices ) # {unit: peak_channel_index} - units = (CuratedClustering.Unit & key).fetch("KEY", order_by="unit") - - # Get electrode info - electrode_config_key = ( - EphysRecording * probe.ElectrodeConfig & key - ).fetch1("KEY") - - electrode_query = ( - probe.ProbeType.Electrode.proj() * probe.ElectrodeConfig.Electrode - & electrode_config_key - ) - electrode_info = electrode_query.fetch( - "KEY", order_by="electrode", as_dict=True - ) - # Get mean waveform for each unit from all channels mean_waveforms = we.get_all_templates( mode="average" @@ -1285,13 +1280,13 @@ def make(self, key): unit_peak_waveform = [] unit_electrode_waveforms = [] - for unit in units: + for unit in (CuratedClustering.Unit & key).fetch("KEY", order_by="unit"): unit_peak_waveform.append( { **unit, "peak_electrode_waveform": we.get_template( unit_id=unit["unit"], mode="average", force_dense=True - )[:, unit_id_to_peak_channel_indices[unit["unit"]][0]], + )[:, unit_id_to_peak_channel_map[unit["unit"]][0]], } ) @@ -1299,12 +1294,12 @@ def make(self, key): [ { **unit, - **e, + **channel_info[c], "waveform_mean": mean_waveforms[ - unit["unit"], :, e["electrode"] + unit["unit"] - 1, :, c ], } - for e in electrode_info + for c in channel_info ] ) @@ -1319,16 +1314,6 @@ def make(self, key): EphysRecording * ProbeInsertion & key ).fetch1("acq_software", "probe") - # Get channel and electrode-site mapping - channel_info = ( - (EphysRecording.Channel & key) - .proj(..., "-channel_name") - .fetch(as_dict=True, order_by="channel_idx") - ) - channel_info: dict[int, dict] = { - ch.pop("channel_idx"): ch for ch in channel_info - } # e.g., {0: {'subject': 'sglx', 'session_id': 912231859, 'insertion_number': 1, 'electrode_config_hash': UUID('8d4cc6d8-a02d-42c8-bf27-7459c39ea0ee'), 'probe_type': 'neuropixels 1.0 - 3A', 'electrode': 0}} - # Get all units units = { u["unit"]: u From ea398391223d6ce8ac26ea22efc2240c86a95525 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 5 Mar 2024 09:24:14 -0600 Subject: [PATCH 097/146] refactor: :fire: remove & get_neuropixels_channel2electrode_map and generate_electrode_config --- element_array_ephys/ephys_no_curation.py | 305 +++++++++++------------ 1 file changed, 146 insertions(+), 159 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index a721e5b6..26608997 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -315,7 +315,7 @@ def make(self, key): "probe" ) - # search session dir and determine acquisition software + # Search session dir and determine acquisition software for ephys_pattern, ephys_acq_type in ( ("*.ap.meta", "SpikeGLX"), ("*.oebin", "Open Ephys"), @@ -338,62 +338,117 @@ def make(self, key): supported_probe_types = probe.ProbeType.fetch("probe_type") if acq_software == "SpikeGLX": - for meta_filepath in ephys_meta_filepaths: - spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) - if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: - break - else: - raise FileNotFoundError( - "No SpikeGLX data found for probe insertion: {}".format(key) - ) + spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) + spikeglx_meta = spikeglx.SpikeGLXMeta(spikeglx_meta_filepath) - if spikeglx_meta.probe_model in supported_probe_types: + if spikeglx_meta.probe_model not in supported_probe_types: + raise NotImplementedError( + f"Processing for neuropixels probe model {spikeglx_meta.probe_model} not yet implemented." + ) + else: probe_type = spikeglx_meta.probe_model - electrode_query = probe.ProbeType.Electrode & { - "probe_type": probe_type - } + electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} probe_electrodes = { (shank, shank_col, shank_row): key for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch( - "KEY", "shank", "shank_col", "shank_row" - ) + *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") ) - } - + } # electrode configuration electrode_group_members = [ probe_electrodes[(shank, shank_col, shank_row)] - for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap[ - "data" + for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"] + ] # recording session-specific electrode configuration + + # Compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode) + electrode_config_hash = dict_to_uuid( + {k["electrode"]: k for k in electrode_group_members} + ) + + electrode_list = sorted( + [k["electrode"] for k in electrode_group_members] + ) + electrode_gaps = ( + [-1] + + np.where(np.diff(electrode_list) > 1)[0].tolist() + + [len(electrode_list) - 1] + ) + electrode_config_name = "; ".join( + [ + f"{electrode_list[start + 1]}-{electrode_list[end]}" + for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:]) ] - ] - else: - raise NotImplementedError( - "Processing for neuropixels probe model" - " {} not yet implemented".format(spikeglx_meta.probe_model) ) + electrode_config_key = {"electrode_config_hash": electrode_config_hash} + + # Insert into ElectrodeConfig + if not probe.ElectrodeConfig & electrode_config_key: + probe.ElectrodeConfig.insert1( + { + **electrode_config_key, + "probe_type": probe_type, + "electrode_config_name": electrode_config_name, + } + ) + probe.ElectrodeConfig.Electrode.insert( + {**electrode_config_key, **electrode} + for electrode in electrode_group_members + ) + self.insert1( { **key, - **generate_electrode_config( - probe_type, electrode_group_members - ), + "electrode_config_hash": electrode_config_hash, "acq_software": acq_software, "sampling_rate": spikeglx_meta.meta["imSampRate"], "recording_datetime": spikeglx_meta.recording_time, "recording_duration": ( spikeglx_meta.recording_duration - or spikeglx.retrieve_recording_duration(meta_filepath) + or spikeglx.retrieve_recording_duration(spikeglx_meta_filepath) ), } ) - root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath) + root_dir = find_root_directory( + get_ephys_root_data_dir(), spikeglx_meta_filepath + ) self.EphysFile.insert1( - {**key, "file_path": meta_filepath.relative_to(root_dir).as_posix()} + { + **key, + "file_path": spikeglx_meta_filepath.relative_to( + root_dir + ).as_posix(), + } + ) + + # Insert channel information + # Get channel and electrode-site mapping + electrode_query = ( + probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode + & electrode_config_key + ) + + probe_electrodes = { + (shank, shank_col, shank_row): key + for key, shank, shank_col, shank_row in zip( + *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") + ) + } + + channel2electrode_map = { + recorded_site: probe_electrodes[(shank, shank_col, shank_row)] + for recorded_site, (shank, shank_col, shank_row, _) in enumerate( + spikeglx_meta.shankmap["data"] + ) + } + self.Channel.insert( + [ + {**key, "channel_idx": channel_idx, **channel_info} + for channel_idx, channel_info in channel2electrode_map.items() + ] ) + elif acq_software == "Open Ephys": dataset = openephys.OpenEphys(session_dir) for serial_number, probe_data in dataset.probes.items(): @@ -409,11 +464,13 @@ def make(self, key): 'No analog signals found - check "structure.oebin" file or "continuous" directory' ) - if probe_data.probe_model in supported_probe_types: + if probe_data.probe_model not in supported_probe_types: + raise NotImplementedError( + f"Processing for neuropixels probe model {spikeglx_meta.probe_model} not yet implemented." + ) + else: probe_type = probe_data.probe_model - electrode_query = probe.ProbeType.Electrode & { - "probe_type": probe_type - } + electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} probe_electrodes = { key["electrode"]: key for key in electrode_query.fetch("KEY") @@ -423,20 +480,33 @@ def make(self, key): probe_electrodes[channel_idx] for channel_idx in probe_data.ap_meta["channels_indices"] ] - else: - raise NotImplementedError( - "Processing for neuropixels" - " probe model {} not yet implemented".format( - probe_data.probe_model - ) + + # Compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode) + electrode_config_hash = dict_to_uuid( + {k["electrode"]: k for k in electrode_group_members} + ) + + electrode_list = sorted( + [k["electrode"] for k in electrode_group_members] + ) + electrode_gaps = ( + [-1] + + np.where(np.diff(electrode_list) > 1)[0].tolist() + + [len(electrode_list) - 1] + ) + electrode_config_name = "; ".join( + [ + f"{electrode_list[start + 1]}-{electrode_list[end]}" + for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:]) + ] ) + electrode_config_key = {"electrode_config_hash": electrode_config_hash} + self.insert1( { **key, - **generate_electrode_config( - probe_type, electrode_group_members - ), + "electrode_config_hash": electrode_config_hash, "acq_software": acq_software, "sampling_rate": probe_data.ap_meta["sample_rate"], "recording_datetime": probe_data.recording_info[ @@ -462,17 +532,26 @@ def make(self, key): del probe_data, dataset gc.collect() - # Insert channel information - # Get channel and electrode-site mapping - channel2electrodes = get_neuropixels_channel2electrode_map( - key, acq_software - ) - self.Channel.insert( - [ - {**key, "channel_idx": channel_idx, **channel_info} - for channel_idx, channel_info in channel2electrodes.items() - ] - ) + probe_dataset = get_openephys_probe_data(key) + electrode_query = ( + probe.ProbeType.Electrode + * probe.ElectrodeConfig.Electrode + * EphysRecording + & key + ) + probe_electrodes = { + key["electrode"]: key for key in electrode_query.fetch("KEY") + } + channel2electrode_map = { + channel_idx: probe_electrodes[channel_idx] + for channel_idx in probe_dataset.ap_meta["channels_indices"] + } + self.Channel.insert( + [ + {**key, "channel_idx": channel_idx, **channel_info} + for channel_idx, channel_info in channel2electrode_map.items() + ] + ) @schema @@ -1034,7 +1113,7 @@ def make(self, key): channel_info = electrode_query.fetch(as_dict=True, order_by="channel_idx") channel_info: dict[int, dict] = { ch.pop("channel_idx"): ch for ch in channel_info - } + } # Get unit id to quality label mapping try: @@ -1050,14 +1129,14 @@ def make(self, key): ] = cluster_quality_label_map.set_index("cluster_id")[ "KSLabel" ].to_dict() # {unit: quality_label} - + # Get electrode where peak unit activity is recorded peak_electrode_ind = np.array( [ channel_info[unit_peak_channel_map[unit_id]]["electrode"] for unit_id in si_sorting.unit_ids ] - ) + ) # Get channel depth channel_depth_ind = np.array( @@ -1066,14 +1145,17 @@ def make(self, key): for unit_id in si_sorting.unit_ids ] ) - + # Assign electrode and depth for each spike - new_spikes = np.empty(spikes.shape, spikes.dtype.descr + [('electrode', ' list: return probe_data -def get_neuropixels_channel2electrode_map( - ephys_recording_key: dict, acq_software: str -) -> dict: #TODO: remove this function - """Get the channel map for neuropixels probe.""" - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key) - spikeglx_meta = spikeglx.SpikeGLXMeta(spikeglx_meta_filepath) - electrode_config_key = ( - EphysRecording * probe.ElectrodeConfig & ephys_recording_key - ).fetch1("KEY") - - electrode_query = ( - probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode - & electrode_config_key - ) - - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - - channel2electrode_map = { - recorded_site: probe_electrodes[(shank, shank_col, shank_row)] - for recorded_site, (shank, shank_col, shank_row, _) in enumerate( - spikeglx_meta.shankmap["data"] - ) - } - elif acq_software == "Open Ephys": - probe_dataset = get_openephys_probe_data(ephys_recording_key) - - electrode_query = ( - probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode * EphysRecording - & ephys_recording_key - ) - - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - channel2electrode_map = { - channel_idx: probe_electrodes[channel_idx] - for channel_idx in probe_dataset.ap_meta["channels_indices"] - } - - return channel2electrode_map - - -def generate_electrode_config(probe_type: str, electrode_keys: list) -> dict: - """Generate and insert new ElectrodeConfig - - Args: - probe_type (str): probe type (e.g. neuropixels 2.0 - SS) - electrode_keys (list): list of keys of the probe.ProbeType.Electrode table - - Returns: - dict: representing a key of the probe.ElectrodeConfig table - """ - # compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode) - electrode_config_hash = dict_to_uuid({k["electrode"]: k for k in electrode_keys}) - - electrode_list = sorted([k["electrode"] for k in electrode_keys]) - electrode_gaps = ( - [-1] - + np.where(np.diff(electrode_list) > 1)[0].tolist() - + [len(electrode_list) - 1] - ) - electrode_config_name = "; ".join( - [ - f"{electrode_list[start + 1]}-{electrode_list[end]}" - for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:]) - ] - ) - - electrode_config_key = {"electrode_config_hash": electrode_config_hash} - - # ---- make new ElectrodeConfig if needed ---- - if not probe.ElectrodeConfig & electrode_config_key: - probe.ElectrodeConfig.insert1( - { - **electrode_config_key, - "probe_type": probe_type, - "electrode_config_name": electrode_config_name, - } - ) - probe.ElectrodeConfig.Electrode.insert( - {**electrode_config_key, **electrode} for electrode in electrode_keys - ) - - return electrode_config_key - - def get_recording_channels_details(ephys_recording_key: dict) -> np.array: """Get details of recording channels for a given recording.""" channels_details = {} From 5bfe201293b5c2ee34a0fb0d666ed280caa22cf9 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 5 Mar 2024 14:53:51 -0600 Subject: [PATCH 098/146] Update element_array_ephys/ephys_no_curation.py Co-authored-by: Thinh Nguyen --- element_array_ephys/ephys_no_curation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 26608997..2e45d4fe 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -326,7 +326,7 @@ def make(self, key): break else: raise FileNotFoundError( - f"Ephys recording data not found! for {key}." + f"Ephys recording data not found in {session_dir}." "Neither SpikeGLX nor Open Ephys recording files found" ) From 1d805849f20ee1e5ce6a911d06d189f9d699900f Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 5 Mar 2024 15:19:36 -0600 Subject: [PATCH 099/146] add generate_electrode_config_name --- element_array_ephys/ephys_no_curation.py | 96 ++++++++++++------------ 1 file changed, 47 insertions(+), 49 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 2e45d4fe..7ce99df2 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -364,38 +364,10 @@ def make(self, key): electrode_config_hash = dict_to_uuid( {k["electrode"]: k for k in electrode_group_members} ) - - electrode_list = sorted( - [k["electrode"] for k in electrode_group_members] - ) - electrode_gaps = ( - [-1] - + np.where(np.diff(electrode_list) > 1)[0].tolist() - + [len(electrode_list) - 1] - ) - electrode_config_name = "; ".join( - [ - f"{electrode_list[start + 1]}-{electrode_list[end]}" - for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:]) - ] + electrode_config_name = generate_electrode_config_name( + probe_type, electrode_group_members ) - electrode_config_key = {"electrode_config_hash": electrode_config_hash} - - # Insert into ElectrodeConfig - if not probe.ElectrodeConfig & electrode_config_key: - probe.ElectrodeConfig.insert1( - { - **electrode_config_key, - "probe_type": probe_type, - "electrode_config_name": electrode_config_name, - } - ) - probe.ElectrodeConfig.Electrode.insert( - {**electrode_config_key, **electrode} - for electrode in electrode_group_members - ) - self.insert1( { **key, @@ -426,7 +398,7 @@ def make(self, key): # Get channel and electrode-site mapping electrode_query = ( probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode - & electrode_config_key + & {"electrode_config_hash": electrode_config_hash} ) probe_electrodes = { @@ -474,34 +446,20 @@ def make(self, key): probe_electrodes = { key["electrode"]: key for key in electrode_query.fetch("KEY") - } + } # electrode configuration electrode_group_members = [ probe_electrodes[channel_idx] for channel_idx in probe_data.ap_meta["channels_indices"] - ] + ] # recording session-specific electrode configuration # Compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode) electrode_config_hash = dict_to_uuid( {k["electrode"]: k for k in electrode_group_members} ) - - electrode_list = sorted( - [k["electrode"] for k in electrode_group_members] - ) - electrode_gaps = ( - [-1] - + np.where(np.diff(electrode_list) > 1)[0].tolist() - + [len(electrode_list) - 1] + electrode_config_name = generate_electrode_config_name( + probe_type, electrode_group_members ) - electrode_config_name = "; ".join( - [ - f"{electrode_list[start + 1]}-{electrode_list[end]}" - for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:]) - ] - ) - - electrode_config_key = {"electrode_config_hash": electrode_config_hash} self.insert1( { @@ -553,6 +511,20 @@ def make(self, key): ] ) + # Insert into probe.ElectrodeConfig (recording configuration) + if not probe.ElectrodeConfig & {"electrode_config_hash": electrode_config_hash}: + probe.ElectrodeConfig.insert1( + { + "probe_type": probe_type, + "electrode_config_hash": electrode_config_hash, + "electrode_config_name": electrode_config_name, + } + ) + probe.ElectrodeConfig.Electrode.insert( + {"electrode_config_hash": electrode_config_hash, **electrode} + for electrode in electrode_group_members + ) + @schema class LFP(dj.Imported): @@ -1820,3 +1792,29 @@ def get_recording_channels_details(ephys_recording_key: dict) -> np.array: ) return channels_details + + +def generate_electrode_config_name(probe_type: str, electrode_keys: list) -> str: + """Generate electrode config name. + + Args: + probe_type (str): probe type (e.g. neuropixels 2.0 - SS) + electrode_keys (list): list of keys of the probe.ProbeType.Electrode table + + Returns: + electrode_config_name (str) + """ + electrode_list = sorted([k["electrode"] for k in electrode_keys]) + electrode_gaps = ( + [-1] + + np.where(np.diff(electrode_list) > 1)[0].tolist() + + [len(electrode_list) - 1] + ) + electrode_config_name = "; ".join( + [ + f"{electrode_list[start + 1]}-{electrode_list[end]}" + for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:]) + ] + ) + + return electrode_config_name From 88ce139bec18a0a01b13943b92c57dbcbc64e074 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Thu, 7 Mar 2024 12:29:50 -0600 Subject: [PATCH 100/146] refactor: :recycle: change sorter_name --- .../spike_sorting/si_spike_sorting.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 6df25de8..a7d1b963 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -100,9 +100,7 @@ def make(self, key): ).fetch1("clustering_method", "acq_software", "clustering_output_dir", "params") # Get sorter method and create output directory. - sorter_name = ( - "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method - ) + sorter_name = clustering_method.replace(".", "_") for required_key in ( "SI_SORTING_PARAMS", @@ -209,9 +207,7 @@ def make(self, key): output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) # Get sorter method and create output directory. - sorter_name = ( - "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method - ) + sorter_name = clustering_method.replace(".", "_") recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl" si_recording: si.BaseRecording = si.load_extractor(recording_file) @@ -264,10 +260,7 @@ def make(self, key): output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) # Get sorter method and create output directory. - sorter_name = ( - "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method - ) - + sorter_name = clustering_method.replace(".", "_") output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl" sorting_file = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" From 7eaefa49852f0e0807497de37a3c19d30ee1c5f2 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Fri, 8 Mar 2024 11:11:38 -0600 Subject: [PATCH 101/146] address review comments for generate_electrode_config_entry --- element_array_ephys/ephys_no_curation.py | 55 +++++++++++------------- 1 file changed, 26 insertions(+), 29 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 7ce99df2..a6edbe54 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -360,18 +360,14 @@ def make(self, key): for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"] ] # recording session-specific electrode configuration - # Compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode) - electrode_config_hash = dict_to_uuid( - {k["electrode"]: k for k in electrode_group_members} - ) - electrode_config_name = generate_electrode_config_name( + econfig_entry, econfig_electrodes = generate_electrode_config_entry( probe_type, electrode_group_members ) self.insert1( { **key, - "electrode_config_hash": electrode_config_hash, + "electrode_config_hash": econfig_entry["electrode_config_hash"], "acq_software": acq_software, "sampling_rate": spikeglx_meta.meta["imSampRate"], "recording_datetime": spikeglx_meta.recording_time, @@ -398,7 +394,7 @@ def make(self, key): # Get channel and electrode-site mapping electrode_query = ( probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode - & {"electrode_config_hash": electrode_config_hash} + & {"electrode_config_hash": econfig_entry["electrode_config_hash"]} ) probe_electrodes = { @@ -453,18 +449,14 @@ def make(self, key): for channel_idx in probe_data.ap_meta["channels_indices"] ] # recording session-specific electrode configuration - # Compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode) - electrode_config_hash = dict_to_uuid( - {k["electrode"]: k for k in electrode_group_members} - ) - electrode_config_name = generate_electrode_config_name( + econfig_entry, econfig_electrodes = generate_electrode_config_entry( probe_type, electrode_group_members ) self.insert1( { **key, - "electrode_config_hash": electrode_config_hash, + "electrode_config_hash": econfig_entry["electrode_config_hash"], "acq_software": acq_software, "sampling_rate": probe_data.ap_meta["sample_rate"], "recording_datetime": probe_data.recording_info[ @@ -512,18 +504,11 @@ def make(self, key): ) # Insert into probe.ElectrodeConfig (recording configuration) - if not probe.ElectrodeConfig & {"electrode_config_hash": electrode_config_hash}: - probe.ElectrodeConfig.insert1( - { - "probe_type": probe_type, - "electrode_config_hash": electrode_config_hash, - "electrode_config_name": electrode_config_name, - } - ) - probe.ElectrodeConfig.Electrode.insert( - {"electrode_config_hash": electrode_config_hash, **electrode} - for electrode in electrode_group_members - ) + if not probe.ElectrodeConfig & { + "electrode_config_hash": econfig_entry["electrode_config_hash"] + }: + probe.ElectrodeConfig.insert1(econfig_entry) + probe.ElectrodeConfig.Electrode.insert(econfig_electrodes) @schema @@ -1794,16 +1779,19 @@ def get_recording_channels_details(ephys_recording_key: dict) -> np.array: return channels_details -def generate_electrode_config_name(probe_type: str, electrode_keys: list) -> str: - """Generate electrode config name. +def generate_electrode_config_entry(probe_type: str, electrode_keys: list) -> dict: + """Generate and insert new ElectrodeConfig Args: probe_type (str): probe type (e.g. neuropixels 2.0 - SS) electrode_keys (list): list of keys of the probe.ProbeType.Electrode table Returns: - electrode_config_name (str) + dict: representing a key of the probe.ElectrodeConfig table """ + # compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode) + electrode_config_hash = dict_to_uuid({k["electrode"]: k for k in electrode_keys}) + electrode_list = sorted([k["electrode"] for k in electrode_keys]) electrode_gaps = ( [-1] @@ -1816,5 +1804,14 @@ def generate_electrode_config_name(probe_type: str, electrode_keys: list) -> str for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:]) ] ) + electrode_config_key = {"electrode_config_hash": electrode_config_hash} + econfig_entry = { + **electrode_config_key, + "probe_type": probe_type, + "electrode_config_name": electrode_config_name, + } + econfig_electrodes = [ + {**electrode, **electrode_config_key} for electrode in electrode_keys + ] - return electrode_config_name + return econfig_entry, econfig_electrodes From d47be56dd8baeb88d8238701c1aa5ada457d3c36 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Mon, 11 Mar 2024 18:39:58 -0500 Subject: [PATCH 102/146] refactor: :art: refactor PostProcessing --- element_array_ephys/spike_sorting/si_spike_sorting.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index a7d1b963..a3f54e1e 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -253,15 +253,14 @@ class PostProcessing(dj.Imported): def make(self, key): execution_time = datetime.utcnow() - # Load recording object. + # Load recording & sorting object. clustering_method, output_dir, params = ( ephys.ClusteringTask * ephys.ClusteringParamSet & key ).fetch1("clustering_method", "clustering_output_dir", "params") output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - - # Get sorter method and create output directory. sorter_name = clustering_method.replace(".", "_") output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl" sorting_file = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" @@ -301,14 +300,13 @@ def make(self, key): _ = si.postprocessing.compute_principal_components( waveform_extractor=we, **params.get("SI_QUALITY_METRICS_PARAMS", None) ) + metrics = si.qualitymetrics.compute_quality_metrics(waveform_extractor=we) + # Save the output (metrics.csv to the output dir) metrics_output_dir = output_dir / sorter_name / "metrics" metrics_output_dir.mkdir(parents=True, exist_ok=True) - - metrics = si.qualitymetrics.compute_quality_metrics(waveform_extractor=we) metrics.to_csv(metrics_output_dir / "metrics.csv") - # Save results self.insert1( { **key, From 8dfc8583017864ce23d22865b8075339a274b1f3 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Mon, 11 Mar 2024 18:43:09 -0500 Subject: [PATCH 103/146] chore: :art: run docker if the package is not built into spikeinterface --- element_array_ephys/spike_sorting/si_spike_sorting.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index a3f54e1e..c74ee9d4 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -205,23 +205,23 @@ def make(self, key): ephys.ClusteringTask * ephys.ClusteringParamSet & key ).fetch1("clustering_method", "clustering_output_dir", "params") output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - - # Get sorter method and create output directory. sorter_name = clustering_method.replace(".", "_") recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl" si_recording: si.BaseRecording = si.load_extractor(recording_file) # Run sorting + # Sorting performed in a dedicated docker environment if the sorter is not built in the spikeinterface package. si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter( sorter_name=sorter_name, recording=si_recording, output_folder=output_dir / sorter_name / "spike_sorting", remove_existing_folder=True, verbose=True, - docker_image=True, + docker_image=sorter_name not in si.sorters.installed_sorters(), **params.get("SI_SORTING_PARAMS", {}), ) + # Save sorting object sorting_save_path = ( output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" ) From 6e20a11e92455220eebc065cd4870dc856739544 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Mon, 11 Mar 2024 18:45:21 -0500 Subject: [PATCH 104/146] refactor: :recycle: clean up import & docstring --- .../spike_sorting/si_spike_sorting.py | 21 ++----------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index c74ee9d4..36956d8f 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -1,34 +1,17 @@ """ -The following DataJoint pipeline implements the sequence of steps in the spike-sorting routine featured in the -"spikeinterface" pipeline. -Spikeinterface developed by Alessio Buccino, Samuel Garcia, Cole Hurwitz, Jeremy Magland, and Matthias Hennig (https://github.com/SpikeInterface) - -The DataJoint pipeline currently incorporated Spikeinterfaces approach of running Kilosort using a container - -The follow pipeline features intermediary tables: -1. PreProcessing - for preprocessing steps (no GPU required) - - create recording extractor and link it to a probe - - bandpass filtering - - common mode referencing -2. SIClustering - kilosort (MATLAB) - requires GPU and docker/singularity containers - - supports kilosort 2.0, 2.5 or 3.0 (https://github.com/MouseLand/Kilosort.git) -3. PostProcessing - for postprocessing steps (no GPU required) - - create waveform extractor object - - extract templates, waveforms and snrs - - quality_metrics +The following DataJoint pipeline implements the sequence of steps in the spike-sorting routine featured in the "spikeinterface" pipeline. Spikeinterface was developed by Alessio Buccino, Samuel Garcia, Cole Hurwitz, Jeremy Magland, and Matthias Hennig (https://github.com/SpikeInterface) """ -import pathlib from datetime import datetime import datajoint as dj import pandas as pd -import probeinterface as pi import spikeinterface as si from element_array_ephys import get_logger, probe, readers from element_interface.utils import find_full_path from spikeinterface import exporters, postprocessing, qualitymetrics, sorters +from .. import get_logger, probe, readers from . import si_preprocessing log = get_logger(__name__) From 8d04e10ce8e08370d9052b1a085bfbdf89d53fbd Mon Sep 17 00:00:00 2001 From: JaerongA Date: Mon, 11 Mar 2024 18:55:01 -0500 Subject: [PATCH 105/146] revert: :art: replace SI_READERS with si_extractor --- .../spike_sorting/si_spike_sorting.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 36956d8f..2ebe90ba 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -47,12 +47,6 @@ def activate( SI_SORTERS = [s.replace("_", ".") for s in si.sorters.sorter_dict.keys()] -SI_READERS = { - "Open Ephys": si.extractors.read_openephys, - "SpikeGLX": si.extractors.read_spikeglx, - "Intan": si.extractors.read_intan, -} - @schema class PreProcessing(dj.Imported): @@ -108,9 +102,7 @@ def make(self, key): output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) recording_dir = output_dir / sorter_name / "recording" recording_dir.mkdir(parents=True, exist_ok=True) - recording_file = ( - recording_dir / "si_recording.pkl" - ) # recording cache to be created for each key + recording_file = recording_dir / "si_recording.pkl" # Create SI recording extractor object if acq_software == "SpikeGLX": @@ -125,12 +117,16 @@ def make(self, key): assert len(oe_probe.recording_info["recording_files"]) == 1 data_dir = oe_probe.recording_info["recording_files"][0] else: - raise NotImplementedError(f"Not implemented for {acq_software}") + si_extractor: si.extractors.neoextractors = ( + si.extractors.extractorlist.recording_extractor_full_dict[ + acq_software.replace(" ", "").lower() + ] + ) # data extractor object stream_names, stream_ids = si.extractors.get_neo_streams( acq_software.strip().lower(), folder_path=data_dir ) - si_recording: si.BaseRecording = SI_READERS[acq_software]( + si_recording: si.BaseRecording = si_extractor[acq_software]( folder_path=data_dir, stream_name=stream_names[0] ) From bb39194aeb2a06390be6b0415afe0bd46310dbbf Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 12 Mar 2024 10:39:15 -0500 Subject: [PATCH 106/146] fix acq_software name --- element_array_ephys/spike_sorting/si_spike_sorting.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 2ebe90ba..12fc069b 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -117,14 +117,13 @@ def make(self, key): assert len(oe_probe.recording_info["recording_files"]) == 1 data_dir = oe_probe.recording_info["recording_files"][0] else: + acq_software = acq_software.replace(" ", "").lower() si_extractor: si.extractors.neoextractors = ( - si.extractors.extractorlist.recording_extractor_full_dict[ - acq_software.replace(" ", "").lower() - ] + si.extractors.extractorlist.recording_extractor_full_dict[acq_software] ) # data extractor object stream_names, stream_ids = si.extractors.get_neo_streams( - acq_software.strip().lower(), folder_path=data_dir + acq_software, folder_path=data_dir ) si_recording: si.BaseRecording = si_extractor[acq_software]( folder_path=data_dir, stream_name=stream_names[0] From 01ff816fd4e2077c3b9c3ff4c5c439faddbb43c9 Mon Sep 17 00:00:00 2001 From: JaerongA Date: Tue, 12 Mar 2024 10:54:39 -0500 Subject: [PATCH 107/146] feat: :ambulance: make all secondary attributes nullable in QualityMetrics some sorters don't output values expected by the table --- element_array_ephys/ephys_no_curation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index a6edbe54..bfb3e2ad 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1603,8 +1603,8 @@ class Waveform(dj.Part): -> master -> CuratedClustering.Unit --- - amplitude: float # (uV) absolute difference between waveform peak and trough - duration: float # (ms) time between waveform peak and trough + amplitude=null: float # (uV) absolute difference between waveform peak and trough + duration=null: float # (ms) time between waveform peak and trough halfwidth=null: float # (ms) spike width at half max amplitude pt_ratio=null: float # absolute amplitude of peak divided by absolute amplitude of trough relative to 0 repolarization_slope=null: float # the repolarization slope was defined by fitting a regression line to the first 30us from trough to peak From 67a1ffc767261e5a9c7d9e7c85d418005c3dac80 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Wed, 17 Apr 2024 09:15:31 -0500 Subject: [PATCH 108/146] feat: save spike interface results with relative path --- element_array_ephys/spike_sorting/si_spike_sorting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 12fc069b..ba310d6e 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -150,7 +150,7 @@ def make(self, key): # Run preprocessing and save results to output folder si_preproc_func = getattr(si_preprocessing, params["SI_PREPROCESSING_METHOD"]) si_recording = si_preproc_func(si_recording) - si_recording.dump_to_pickle(file_path=recording_file) + si_recording.dump_to_pickle(file_path=recording_file, relative_to=output_dir) self.insert1( { @@ -203,7 +203,7 @@ def make(self, key): sorting_save_path = ( output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" ) - si_sorting.dump_to_pickle(sorting_save_path) + si_sorting.dump_to_pickle(sorting_save_path, relative_to=output_dir) self.insert1( { From d44dbaa03aa8debb2f9d15fe60811a4fcb52a535 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 26 Apr 2024 11:50:15 -0500 Subject: [PATCH 109/146] fix(spikeglx): bugfix loading spikeglx data --- element_array_ephys/ephys_no_curation.py | 11 +++++++++-- element_array_ephys/spike_sorting/si_spike_sorting.py | 11 +++++++---- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index bfb3e2ad..2dde282b 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -338,8 +338,15 @@ def make(self, key): supported_probe_types = probe.ProbeType.fetch("probe_type") if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) - spikeglx_meta = spikeglx.SpikeGLXMeta(spikeglx_meta_filepath) + for meta_filepath in ephys_meta_filepaths: + spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) + if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: + spikeglx_meta_filepath = meta_filepath + break + else: + raise FileNotFoundError( + "No SpikeGLX data found for probe insertion: {}".format(key) + ) if spikeglx_meta.probe_model not in supported_probe_types: raise NotImplementedError( diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 12fc069b..0b53bf1d 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -117,10 +117,13 @@ def make(self, key): assert len(oe_probe.recording_info["recording_files"]) == 1 data_dir = oe_probe.recording_info["recording_files"][0] else: - acq_software = acq_software.replace(" ", "").lower() - si_extractor: si.extractors.neoextractors = ( - si.extractors.extractorlist.recording_extractor_full_dict[acq_software] - ) # data extractor object + raise NotImplementedError( + f"SpikeInterface processing for {acq_software} not yet implemented." + ) + acq_software = acq_software.replace(" ", "").lower() + si_extractor: si.extractors.neoextractors = ( + si.extractors.extractorlist.recording_extractor_full_dict[acq_software] + ) # data extractor object stream_names, stream_ids = si.extractors.get_neo_streams( acq_software, folder_path=data_dir From d86928bf41a2bb0e30c7136d74fc485c9de2b90f Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 26 Apr 2024 12:13:15 -0500 Subject: [PATCH 110/146] fix: bugfix inserting `ElectrodeConfig` --- element_array_ephys/ephys_no_curation.py | 94 ++++++++++++------------ 1 file changed, 48 insertions(+), 46 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 2dde282b..dcb2ded6 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -371,31 +371,30 @@ def make(self, key): probe_type, electrode_group_members ) - self.insert1( - { - **key, - "electrode_config_hash": econfig_entry["electrode_config_hash"], - "acq_software": acq_software, - "sampling_rate": spikeglx_meta.meta["imSampRate"], - "recording_datetime": spikeglx_meta.recording_time, - "recording_duration": ( - spikeglx_meta.recording_duration - or spikeglx.retrieve_recording_duration(spikeglx_meta_filepath) - ), - } - ) + ephys_recording_entry = { + **key, + "electrode_config_hash": econfig_entry["electrode_config_hash"], + "acq_software": acq_software, + "sampling_rate": spikeglx_meta.meta["imSampRate"], + "recording_datetime": spikeglx_meta.recording_time, + "recording_duration": ( + spikeglx_meta.recording_duration + or spikeglx.retrieve_recording_duration(spikeglx_meta_filepath) + ), + } root_dir = find_root_directory( get_ephys_root_data_dir(), spikeglx_meta_filepath ) - self.EphysFile.insert1( + + ephys_file_entries = [ { **key, "file_path": spikeglx_meta_filepath.relative_to( root_dir ).as_posix(), } - ) + ] # Insert channel information # Get channel and electrode-site mapping @@ -417,13 +416,11 @@ def make(self, key): spikeglx_meta.shankmap["data"] ) } - self.Channel.insert( - [ - {**key, "channel_idx": channel_idx, **channel_info} - for channel_idx, channel_info in channel2electrode_map.items() - ] - ) + ephys_channel_entries = [ + {**key, "channel_idx": channel_idx, **channel_info} + for channel_idx, channel_info in channel2electrode_map.items() + ] elif acq_software == "Open Ephys": dataset = openephys.OpenEphys(session_dir) for serial_number, probe_data in dataset.probes.items(): @@ -460,31 +457,29 @@ def make(self, key): probe_type, electrode_group_members ) - self.insert1( - { - **key, - "electrode_config_hash": econfig_entry["electrode_config_hash"], - "acq_software": acq_software, - "sampling_rate": probe_data.ap_meta["sample_rate"], - "recording_datetime": probe_data.recording_info[ - "recording_datetimes" - ][0], - "recording_duration": np.sum( - probe_data.recording_info["recording_durations"] - ), - } - ) + ephys_recording_entry = { + **key, + "electrode_config_hash": econfig_entry["electrode_config_hash"], + "acq_software": acq_software, + "sampling_rate": probe_data.ap_meta["sample_rate"], + "recording_datetime": probe_data.recording_info["recording_datetimes"][ + 0 + ], + "recording_duration": np.sum( + probe_data.recording_info["recording_durations"] + ), + } root_dir = find_root_directory( get_ephys_root_data_dir(), probe_data.recording_info["recording_files"][0], ) - self.EphysFile.insert( - [ - {**key, "file_path": fp.relative_to(root_dir).as_posix()} - for fp in probe_data.recording_info["recording_files"] - ] - ) + + ephys_file_entries = [ + {**key, "file_path": fp.relative_to(root_dir).as_posix()} + for fp in probe_data.recording_info["recording_files"] + ] + # Explicitly garbage collect "dataset" as these may have large memory footprint and may not be cleared fast enough del probe_data, dataset gc.collect() @@ -503,11 +498,14 @@ def make(self, key): channel_idx: probe_electrodes[channel_idx] for channel_idx in probe_dataset.ap_meta["channels_indices"] } - self.Channel.insert( - [ - {**key, "channel_idx": channel_idx, **channel_info} - for channel_idx, channel_info in channel2electrode_map.items() - ] + + ephys_channel_entries = [ + {**key, "channel_idx": channel_idx, **channel_info} + for channel_idx, channel_info in channel2electrode_map.items() + ] + else: + raise NotImplementedError( + f"Processing ephys files from acquisition software of type {acq_software} is not yet implemented." ) # Insert into probe.ElectrodeConfig (recording configuration) @@ -517,6 +515,10 @@ def make(self, key): probe.ElectrodeConfig.insert1(econfig_entry) probe.ElectrodeConfig.Electrode.insert(econfig_electrodes) + self.insert1(ephys_recording_entry) + self.EphysFile.insert(ephys_file_entries) + self.Channel.insert(ephys_channel_entries) + @schema class LFP(dj.Imported): From f8ffd7760cb1be6ac19d24e37ebf69d11d773972 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Wed, 3 Apr 2024 14:14:35 -0500 Subject: [PATCH 111/146] feat(spikesorting): save to phy and generate report --- element_array_ephys/spike_sorting/si_spike_sorting.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 90a88260..52c96709 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -288,6 +288,11 @@ def make(self, key): metrics_output_dir.mkdir(parents=True, exist_ok=True) metrics.to_csv(metrics_output_dir / "metrics.csv") + # Save to phy format + si.exporters.export_to_phy(waveform_extractor=we, output_folder=output_dir / sorter_name / "phy") + # Generate spike interface report + si.exporters.export_report(waveform_extractor=we, output_folder=output_dir / sorter_name / "spikeinterface_report") + self.insert1( { **key, From 7309082858b5210dcbf9566f2e8afd72416e9655 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 26 Apr 2024 12:35:39 -0500 Subject: [PATCH 112/146] chore: cleanup init --- element_array_ephys/__init__.py | 21 ------------------- .../spike_sorting/ecephys_spike_sorting.py | 3 +-- .../spike_sorting/si_spike_sorting.py | 5 ++--- 3 files changed, 3 insertions(+), 26 deletions(-) diff --git a/element_array_ephys/__init__.py b/element_array_ephys/__init__.py index 3a0e5af6..1c0c7285 100644 --- a/element_array_ephys/__init__.py +++ b/element_array_ephys/__init__.py @@ -1,22 +1 @@ -""" -isort:skip_file -""" - -import logging -import os - -import datajoint as dj - - -__all__ = ["ephys", "get_logger"] - -dj.config["enable_python_native_blobs"] = True - - -def get_logger(name): - log = logging.getLogger(name) - log.setLevel(os.getenv("LOGLEVEL", "INFO")) - return log - - from . import ephys_acute as ephys diff --git a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py index 4de349eb..3a43c384 100644 --- a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py +++ b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py @@ -22,7 +22,6 @@ import datajoint as dj -from element_array_ephys import get_logger from decimal import Decimal import json from datetime import datetime, timedelta @@ -33,7 +32,7 @@ kilosort_triggering, ) -log = get_logger(__name__) +log = dj.logger schema = dj.schema() diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 52c96709..306c1eb6 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -7,14 +7,13 @@ import datajoint as dj import pandas as pd import spikeinterface as si -from element_array_ephys import get_logger, probe, readers +from element_array_ephys import probe, readers from element_interface.utils import find_full_path from spikeinterface import exporters, postprocessing, qualitymetrics, sorters -from .. import get_logger, probe, readers from . import si_preprocessing -log = get_logger(__name__) +log = dj.logger schema = dj.schema() From d778b1e7d8822173ad43d60707fbb8fa8c7ff801 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 26 Apr 2024 13:24:45 -0500 Subject: [PATCH 113/146] fix: update channel-electrode mapping --- element_array_ephys/ephys_no_curation.py | 164 ++++++++---------- .../spike_sorting/si_spike_sorting.py | 13 +- 2 files changed, 81 insertions(+), 96 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index dcb2ded6..68251309 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1040,51 +1040,47 @@ def make(self, key): ).fetch1("clustering_method", "clustering_output_dir") output_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - # Get sorter method and create output directory. - sorter_name = ( - "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method + # Get channel and electrode-site mapping + electrode_query = ( + (EphysRecording.Channel & key) + .proj(..., "-channel_name") ) - waveform_dir = output_dir / sorter_name / "waveform" - sorting_dir = output_dir / sorter_name / "spike_sorting" + channel2electrode_map = electrode_query.fetch(as_dict=True) + channel2electrode_map: dict[int, dict] = { + chn.pop("channel_idx"): chn for chn in channel2electrode_map + } - if waveform_dir.exists(): # read from spikeinterface outputs - we: si.WaveformExtractor = si.load_waveforms( - waveform_dir, with_recording=False - ) + # Get sorter method and create output directory. + sorter_name = clustering_method.replace(".", "_") + si_waveform_dir = output_dir / sorter_name / "waveform" + si_sorting_dir = output_dir / sorter_name / "spike_sorting" + + if si_waveform_dir.exists(): + + # Read from spikeinterface outputs + we: si.WaveformExtractor = si.load_waveforms(si_waveform_dir, with_recording=False) si_sorting: si.sorters.BaseSorter = si.load_extractor( - sorting_dir / "si_sorting.pkl" + si_sorting_dir / "si_sorting.pkl" ) - unit_peak_channel_map: dict[int, int] = si.get_template_extremum_channel( - we, outputs="index" - ) # {unit: peak_channel_index} + unit_peak_channel: dict[int, int] = si.get_template_extremum_channel( + we, outputs="id" + ) # {unit: peak_channel_id} spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit() # {unit: spike_count} - spikes = si_sorting.to_spike_vector( - extremum_channel_inds=unit_peak_channel_map - ) - - # Get electrode & channel info - electrode_config_key = ( - EphysRecording * probe.ElectrodeConfig & key - ).fetch1("KEY") - - electrode_query = ( - probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode - & electrode_config_key - ) * (dj.U("electrode", "channel_idx") & EphysRecording.Channel) + spikes = si_sorting.to_spike_vector() - channel_info = electrode_query.fetch(as_dict=True, order_by="channel_idx") - channel_info: dict[int, dict] = { - ch.pop("channel_idx"): ch for ch in channel_info + # reorder channel2electrode_map according to recording channel ids + channel2electrode_map = { + chn_id: channel2electrode_map[int(chn_id)] for chn_id in we.channel_ids } # Get unit id to quality label mapping try: cluster_quality_label_map = pd.read_csv( - sorting_dir / "sorter_output" / "cluster_KSLabel.tsv", + si_sorting_dir / "sorter_output" / "cluster_KSLabel.tsv", delimiter="\t", ) except FileNotFoundError: @@ -1099,7 +1095,7 @@ def make(self, key): # Get electrode where peak unit activity is recorded peak_electrode_ind = np.array( [ - channel_info[unit_peak_channel_map[unit_id]]["electrode"] + channel2electrode_map[unit_peak_channel[unit_id]]["electrode"] for unit_id in si_sorting.unit_ids ] ) @@ -1107,7 +1103,7 @@ def make(self, key): # Get channel depth channel_depth_ind = np.array( [ - channel_info[unit_peak_channel_map[unit_id]]["y_coord"] + channel2electrode_map[unit_peak_channel[unit_id]]["y_coord"] for unit_id in si_sorting.unit_ids ] ) @@ -1132,7 +1128,7 @@ def make(self, key): units.append( { **key, - **channel_info[unit_peak_channel_map[unit_id]], + **channel2electrode_map[unit_peak_channel[unit_id]], "unit": unit_id, "cluster_quality_label": cluster_quality_label_map.get( unit_id, "n.a." @@ -1143,10 +1139,10 @@ def make(self, key): "spike_count": spike_count_dict[unit_id], "spike_sites": new_spikes["electrode"][ new_spikes["unit_index"] == unit_id - ], + ], "spike_depths": new_spikes["depth"][ new_spikes["unit_index"] == unit_id - ], + ], } ) @@ -1184,20 +1180,10 @@ def make(self, key): spike_times = kilosort_dataset.data[spike_time_key] kilosort_dataset.extract_spike_depths() - # Get channel and electrode-site mapping - channel_info = ( - (EphysRecording.Channel & key) - .proj(..., "-channel_name") - .fetch(as_dict=True, order_by="channel_idx") - ) - channel_info: dict[int, dict] = { - ch.pop("channel_idx"): ch for ch in channel_info - } # e.g., {0: {'subject': 'sglx', 'session_id': 912231859, 'insertion_number': 1, 'electrode_config_hash': UUID('8d4cc6d8-a02d-42c8-bf27-7459c39ea0ee'), 'probe_type': 'neuropixels 1.0 - 3A', 'electrode': 0}} - # -- Spike-sites and Spike-depths -- spike_sites = np.array( [ - channel_info[s]["electrode"] + channel2electrode_map[s]["electrode"] for s in kilosort_dataset.data["spike_sites"] ] ) @@ -1219,7 +1205,7 @@ def make(self, key): **key, "unit": unit, "cluster_quality_label": unit_lbl, - **channel_info[unit_channel], + **channel2electrode_map[unit_channel], "spike_times": unit_spike_times, "spike_count": spike_count, "spike_sites": spike_sites[ @@ -1292,33 +1278,31 @@ def make(self, key): ClusteringTask * ClusteringParamSet & key ).fetch1("clustering_method", "clustering_output_dir") output_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - sorter_name = ( - "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method - ) + sorter_name = clustering_method.replace(".", "_") # Get channel and electrode-site mapping - channel_info = ( + electrode_query = ( (EphysRecording.Channel & key) .proj(..., "-channel_name") - .fetch(as_dict=True, order_by="channel_idx") ) - channel_info: dict[int, dict] = { - ch.pop("channel_idx"): ch for ch in channel_info - } # e.g., {0: {'subject': 'sglx', 'session_id': 912231859, 'insertion_number': 1, 'electrode_config_hash': UUID('8d4cc6d8-a02d-42c8-bf27-7459c39ea0ee'), 'probe_type': 'neuropixels 1.0 - 3A', 'electrode': 0}} + channel2electrode_map = electrode_query.fetch(as_dict=True) + channel2electrode_map: dict[int, dict] = { + chn.pop("channel_idx"): chn for chn in channel2electrode_map + } - if ( - output_dir / sorter_name / "waveform" - ).exists(): # read from spikeinterface outputs + si_waveform_dir = output_dir / sorter_name / "waveform" + if si_waveform_dir.exists(): # read from spikeinterface outputs + we: si.WaveformExtractor = si.load_waveforms(si_waveform_dir, with_recording=False) + unit_id_to_peak_channel_map: dict[ + int, np.ndarray + ] = si.ChannelSparsity.from_best_channels( + we, 1, peak_sign="neg" + ).unit_id_to_channel_indices # {unit: peak_channel_index} - waveform_dir = output_dir / sorter_name / "waveform" - we: si.WaveformExtractor = si.load_waveforms( - waveform_dir, with_recording=False - ) - unit_id_to_peak_channel_map: dict[int, np.ndarray] = ( - si.ChannelSparsity.from_best_channels( - we, 1, peak_sign="neg" - ).unit_id_to_channel_indices - ) # {unit: peak_channel_index} + # reorder channel2electrode_map according to recording channel ids + channel2electrode_map = { + chn_id: channel2electrode_map[int(chn_id)] for chn_id in we.channel_ids + } # Get mean waveform for each unit from all channels mean_waveforms = we.get_all_templates( @@ -1329,30 +1313,32 @@ def make(self, key): unit_electrode_waveforms = [] for unit in (CuratedClustering.Unit & key).fetch("KEY", order_by="unit"): + unit_waveforms = we.get_template( + unit_id=unit["unit"], mode="average", force_dense=True + ) # (sample x channel) + peak_chn_idx = list(we.channel_ids).index( + unit_id_to_peak_channel_map[unit["unit"]][0] + ) unit_peak_waveform.append( { **unit, - "peak_electrode_waveform": we.get_template( - unit_id=unit["unit"], mode="average", force_dense=True - )[:, unit_id_to_peak_channel_map[unit["unit"]][0]], + "peak_electrode_waveform": unit_waveforms[:, peak_chn_idx], } ) - unit_electrode_waveforms.extend( [ { **unit, - **channel_info[c], - "waveform_mean": mean_waveforms[unit["unit"] - 1, :, c], + **channel2electrode_map[c], + "waveform_mean": mean_waveforms[unit["unit"] - 1, :, c_idx], } - for c in channel_info + for c_idx, c in enumerate(channel2electrode_map) ] ) self.insert1(key) self.PeakWaveform.insert(unit_peak_waveform) self.Waveform.insert(unit_electrode_waveforms) - else: kilosort_dataset = kilosort.Kilosort(output_dir) @@ -1390,12 +1376,12 @@ def yield_unit_waveforms(): unit_electrode_waveforms.append( { **units[unit_no], - **channel_info[channel], + **channel2electrode_map[channel], "waveform_mean": channel_waveform, } ) if ( - channel_info[channel]["electrode"] + channel2electrode_map[channel]["electrode"] == units[unit_no]["electrode"] ): unit_peak_waveform = { @@ -1405,7 +1391,6 @@ def yield_unit_waveforms(): yield unit_peak_waveform, unit_electrode_waveforms # Spike interface mean and peak waveform extraction from we object - elif len(waveforms_folder) > 0 & (waveforms_folder[0]).exists(): we_kilosort = si.load_waveforms(waveforms_folder[0].parent) unit_templates = we_kilosort.get_all_templates() @@ -1432,12 +1417,12 @@ def yield_unit_waveforms(): unit_electrode_waveforms.append( { **units[unit_no], - **channel_info[channel], + **channel2electrode_map[channel], "waveform_mean": channel_waveform, } ) if ( - channel_info[channel]["electrode"] + channel2electrode_map[channel]["electrode"] == units[unit_no]["electrode"] ): unit_peak_waveform = { @@ -1506,13 +1491,13 @@ def yield_unit_waveforms(): unit_electrode_waveforms.append( { **unit_dict, - **channel_info[channel], + **channel2electrode_map[channel], "waveform_mean": channel_waveform.mean(axis=0), "waveforms": channel_waveform, } ) if ( - channel_info[channel]["electrode"] + channel2electrode_map[channel]["electrode"] == unit_dict["electrode"] ): unit_peak_waveform = { @@ -1630,12 +1615,15 @@ def make(self, key): ClusteringTask * ClusteringParamSet & key ).fetch1("clustering_method", "clustering_output_dir") output_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - sorter_name = ( - "kilosort2_5" if clustering_method == "kilosort2.5" else clustering_method - ) - metric_fp = output_dir / sorter_name / "metrics" / "metrics.csv" - if not metric_fp.exists(): - raise FileNotFoundError(f"QC metrics file not found: {metric_fp}") + sorter_name = clustering_method.replace(".", "_") + + # find metric_fp + for metric_fp in [output_dir / "metrics.csv", output_dir / sorter_name / "metrics" / "metrics.csv"]: + if metric_fp.exists(): + break + else: + raise FileNotFoundError(f"QC metrics file not found in: {output_dir}") + metrics_df = pd.read_csv(metric_fp) # Conform the dataframe to match the table definition diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 306c1eb6..d14746fb 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -132,21 +132,18 @@ def make(self, key): ) # Add probe information to recording object - electrode_config_key = ( - probe.ElectrodeConfig * ephys.EphysRecording & key - ).fetch1("KEY") electrodes_df = ( ( - probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode - & electrode_config_key + ephys.EphysRecording.Channel * probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode + & key ) .fetch(format="frame") - .reset_index()[["electrode", "x_coord", "y_coord", "shank"]] + .reset_index() ) # Create SI probe object - si_probe = readers.probe_geometry.to_probeinterface(electrodes_df) - si_probe.set_device_channel_indices(range(len(electrodes_df))) + si_probe = readers.probe_geometry.to_probeinterface(electrodes_df[["electrode", "x_coord", "y_coord", "shank"]]) + si_probe.set_device_channel_indices(electrodes_df["channel_idx"].values) si_recording.set_probe(probe=si_probe, in_place=True) # Run preprocessing and save results to output folder From 015341c1127300e10e9011ec5d49a96abc3322f0 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 26 Apr 2024 17:23:29 -0500 Subject: [PATCH 114/146] feat: test spikeinterface for spikeglx data --- element_array_ephys/ephys_no_curation.py | 144 ++++++++---------- .../spike_sorting/si_preprocessing.py | 2 +- .../spike_sorting/si_spike_sorting.py | 8 +- 3 files changed, 72 insertions(+), 82 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 68251309..333a189a 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -352,24 +352,24 @@ def make(self, key): raise NotImplementedError( f"Processing for neuropixels probe model {spikeglx_meta.probe_model} not yet implemented." ) - else: - probe_type = spikeglx_meta.probe_model - electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } # electrode configuration - electrode_group_members = [ - probe_electrodes[(shank, shank_col, shank_row)] - for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"] - ] # recording session-specific electrode configuration - - econfig_entry, econfig_electrodes = generate_electrode_config_entry( - probe_type, electrode_group_members + probe_type = spikeglx_meta.probe_model + electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} + + probe_electrodes = { + (shank, shank_col, shank_row): key + for key, shank, shank_col, shank_row in zip( + *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") ) + } # electrode configuration + electrode_group_members = [ + probe_electrodes[(shank, shank_col, shank_row)] + for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"] + ] # recording session-specific electrode configuration + + econfig_entry, econfig_electrodes = generate_electrode_config_entry( + probe_type, electrode_group_members + ) ephys_recording_entry = { **key, @@ -398,18 +398,6 @@ def make(self, key): # Insert channel information # Get channel and electrode-site mapping - electrode_query = ( - probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode - & {"electrode_config_hash": econfig_entry["electrode_config_hash"]} - ) - - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - channel2electrode_map = { recorded_site: probe_electrodes[(shank, shank_col, shank_row)] for recorded_site, (shank, shank_col, shank_row, _) in enumerate( @@ -418,7 +406,12 @@ def make(self, key): } ephys_channel_entries = [ - {**key, "channel_idx": channel_idx, **channel_info} + { + **key, + "electrode_config_hash": econfig_entry["electrode_config_hash"], + "channel_idx": channel_idx, + **channel_info, + } for channel_idx, channel_info in channel2electrode_map.items() ] elif acq_software == "Open Ephys": @@ -438,24 +431,24 @@ def make(self, key): if probe_data.probe_model not in supported_probe_types: raise NotImplementedError( - f"Processing for neuropixels probe model {spikeglx_meta.probe_model} not yet implemented." + f"Processing for neuropixels probe model {probe_data.probe_model} not yet implemented." ) - else: - probe_type = probe_data.probe_model - electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } # electrode configuration + probe_type = probe_data.probe_model + electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} - electrode_group_members = [ - probe_electrodes[channel_idx] - for channel_idx in probe_data.ap_meta["channels_indices"] - ] # recording session-specific electrode configuration + probe_electrodes = { + key["electrode"]: key for key in electrode_query.fetch("KEY") + } # electrode configuration - econfig_entry, econfig_electrodes = generate_electrode_config_entry( - probe_type, electrode_group_members - ) + electrode_group_members = [ + probe_electrodes[channel_idx] + for channel_idx in probe_data.ap_meta["channels_indices"] + ] # recording session-specific electrode configuration + + econfig_entry, econfig_electrodes = generate_electrode_config_entry( + probe_type, electrode_group_members + ) ephys_recording_entry = { **key, @@ -480,29 +473,24 @@ def make(self, key): for fp in probe_data.recording_info["recording_files"] ] - # Explicitly garbage collect "dataset" as these may have large memory footprint and may not be cleared fast enough - del probe_data, dataset - gc.collect() - - probe_dataset = get_openephys_probe_data(key) - electrode_query = ( - probe.ProbeType.Electrode - * probe.ElectrodeConfig.Electrode - * EphysRecording - & key - ) - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } channel2electrode_map = { channel_idx: probe_electrodes[channel_idx] - for channel_idx in probe_dataset.ap_meta["channels_indices"] + for channel_idx in probe_data.ap_meta["channels_indices"] } ephys_channel_entries = [ - {**key, "channel_idx": channel_idx, **channel_info} + { + **key, + "electrode_config_hash": econfig_entry["electrode_config_hash"], + "channel_idx": channel_idx, + **channel_info, + } for channel_idx, channel_info in channel2electrode_map.items() ] + + # Explicitly garbage collect "dataset" as these may have large memory footprint and may not be cleared fast enough + del probe_data, dataset + gc.collect() else: raise NotImplementedError( f"Processing ephys files from acquisition software of type {acq_software} is not yet implemented." @@ -1041,10 +1029,7 @@ def make(self, key): output_dir = find_full_path(get_ephys_root_data_dir(), output_dir) # Get channel and electrode-site mapping - electrode_query = ( - (EphysRecording.Channel & key) - .proj(..., "-channel_name") - ) + electrode_query = (EphysRecording.Channel & key).proj(..., "-channel_name") channel2electrode_map = electrode_query.fetch(as_dict=True) channel2electrode_map: dict[int, dict] = { chn.pop("channel_idx"): chn for chn in channel2electrode_map @@ -1058,7 +1043,9 @@ def make(self, key): if si_waveform_dir.exists(): # Read from spikeinterface outputs - we: si.WaveformExtractor = si.load_waveforms(si_waveform_dir, with_recording=False) + we: si.WaveformExtractor = si.load_waveforms( + si_waveform_dir, with_recording=False + ) si_sorting: si.sorters.BaseSorter = si.load_extractor( si_sorting_dir / "si_sorting.pkl" ) @@ -1139,10 +1126,10 @@ def make(self, key): "spike_count": spike_count_dict[unit_id], "spike_sites": new_spikes["electrode"][ new_spikes["unit_index"] == unit_id - ], + ], "spike_depths": new_spikes["depth"][ new_spikes["unit_index"] == unit_id - ], + ], } ) @@ -1281,10 +1268,7 @@ def make(self, key): sorter_name = clustering_method.replace(".", "_") # Get channel and electrode-site mapping - electrode_query = ( - (EphysRecording.Channel & key) - .proj(..., "-channel_name") - ) + electrode_query = (EphysRecording.Channel & key).proj(..., "-channel_name") channel2electrode_map = electrode_query.fetch(as_dict=True) channel2electrode_map: dict[int, dict] = { chn.pop("channel_idx"): chn for chn in channel2electrode_map @@ -1292,12 +1276,14 @@ def make(self, key): si_waveform_dir = output_dir / sorter_name / "waveform" if si_waveform_dir.exists(): # read from spikeinterface outputs - we: si.WaveformExtractor = si.load_waveforms(si_waveform_dir, with_recording=False) - unit_id_to_peak_channel_map: dict[ - int, np.ndarray - ] = si.ChannelSparsity.from_best_channels( - we, 1, peak_sign="neg" - ).unit_id_to_channel_indices # {unit: peak_channel_index} + we: si.WaveformExtractor = si.load_waveforms( + si_waveform_dir, with_recording=False + ) + unit_id_to_peak_channel_map: dict[int, np.ndarray] = ( + si.ChannelSparsity.from_best_channels( + we, 1, peak_sign="neg" + ).unit_id_to_channel_indices + ) # {unit: peak_channel_index} # reorder channel2electrode_map according to recording channel ids channel2electrode_map = { @@ -1391,6 +1377,7 @@ def yield_unit_waveforms(): yield unit_peak_waveform, unit_electrode_waveforms # Spike interface mean and peak waveform extraction from we object + elif len(waveforms_folder) > 0 & (waveforms_folder[0]).exists(): we_kilosort = si.load_waveforms(waveforms_folder[0].parent) unit_templates = we_kilosort.get_all_templates() @@ -1618,7 +1605,10 @@ def make(self, key): sorter_name = clustering_method.replace(".", "_") # find metric_fp - for metric_fp in [output_dir / "metrics.csv", output_dir / sorter_name / "metrics" / "metrics.csv"]: + for metric_fp in [ + output_dir / "metrics.csv", + output_dir / sorter_name / "metrics" / "metrics.csv", + ]: if metric_fp.exists(): break else: diff --git a/element_array_ephys/spike_sorting/si_preprocessing.py b/element_array_ephys/spike_sorting/si_preprocessing.py index 4db5f303..22adbdca 100644 --- a/element_array_ephys/spike_sorting/si_preprocessing.py +++ b/element_array_ephys/spike_sorting/si_preprocessing.py @@ -2,7 +2,7 @@ from spikeinterface import preprocessing -def catGT(recording): +def CatGT(recording): recording = si.preprocessing.phase_shift(recording) recording = si.preprocessing.common_reference( recording, operator="median", reference="global" diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index d14746fb..c1a906ea 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -127,7 +127,7 @@ def make(self, key): stream_names, stream_ids = si.extractors.get_neo_streams( acq_software, folder_path=data_dir ) - si_recording: si.BaseRecording = si_extractor[acq_software]( + si_recording: si.BaseRecording = si_extractor( folder_path=data_dir, stream_name=stream_names[0] ) @@ -184,7 +184,7 @@ def make(self, key): output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) sorter_name = clustering_method.replace(".", "_") recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl" - si_recording: si.BaseRecording = si.load_extractor(recording_file) + si_recording: si.BaseRecording = si.load_extractor(recording_file, base_folder=output_dir) # Run sorting # Sorting performed in a dedicated docker environment if the sorter is not built in the spikeinterface package. @@ -241,8 +241,8 @@ def make(self, key): recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl" sorting_file = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" - si_recording: si.BaseRecording = si.load_extractor(recording_file) - si_sorting: si.sorters.BaseSorter = si.load_extractor(sorting_file) + si_recording: si.BaseRecording = si.load_extractor(recording_file, base_folder=output_dir) + si_sorting: si.sorters.BaseSorter = si.load_extractor(sorting_file, base_folder=output_dir) # Extract waveforms we: si.WaveformExtractor = si.extract_waveforms( From 05ccfdb80cee7418e58322ebb3bbb9f4a1df6b8e Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Mon, 29 Apr 2024 11:46:32 -0500 Subject: [PATCH 115/146] fix: update ingestion from spikeinterface results --- element_array_ephys/ephys_no_curation.py | 137 +++++------------------ 1 file changed, 27 insertions(+), 110 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 333a189a..0cf2021c 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1040,18 +1040,16 @@ def make(self, key): si_waveform_dir = output_dir / sorter_name / "waveform" si_sorting_dir = output_dir / sorter_name / "spike_sorting" - if si_waveform_dir.exists(): - - # Read from spikeinterface outputs + if si_waveform_dir.exists(): # Read from spikeinterface outputs we: si.WaveformExtractor = si.load_waveforms( si_waveform_dir, with_recording=False ) si_sorting: si.sorters.BaseSorter = si.load_extractor( - si_sorting_dir / "si_sorting.pkl" + si_sorting_dir / "si_sorting.pkl", base_folder=output_dir ) unit_peak_channel: dict[int, int] = si.get_template_extremum_channel( - we, outputs="id" + we, outputs="index" ) # {unit: peak_channel_id} spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit() @@ -1061,7 +1059,8 @@ def make(self, key): # reorder channel2electrode_map according to recording channel ids channel2electrode_map = { - chn_id: channel2electrode_map[int(chn_id)] for chn_id in we.channel_ids + chn_idx: channel2electrode_map[chn_idx] + for chn_idx in we.channel_ids_to_indices(we.channel_ids) } # Get unit id to quality label mapping @@ -1090,7 +1089,7 @@ def make(self, key): # Get channel depth channel_depth_ind = np.array( [ - channel2electrode_map[unit_peak_channel[unit_id]]["y_coord"] + we.get_probe().contact_positions[unit_peak_channel[unit_id]][1] for unit_id in si_sorting.unit_ids ] ) @@ -1132,7 +1131,6 @@ def make(self, key): ], } ) - else: # read from kilosort outputs kilosort_dataset = kilosort.Kilosort(output_dir) acq_software, sample_rate = (EphysRecording & key).fetch1( @@ -1286,46 +1284,38 @@ def make(self, key): ) # {unit: peak_channel_index} # reorder channel2electrode_map according to recording channel ids + channel_indices = we.channel_ids_to_indices(we.channel_ids).tolist() channel2electrode_map = { - chn_id: channel2electrode_map[int(chn_id)] for chn_id in we.channel_ids + chn_idx: channel2electrode_map[chn_idx] for chn_idx in channel_indices } - # Get mean waveform for each unit from all channels - mean_waveforms = we.get_all_templates( - mode="average" - ) # (unit x sample x channel) - - unit_peak_waveform = [] - unit_electrode_waveforms = [] - - for unit in (CuratedClustering.Unit & key).fetch("KEY", order_by="unit"): - unit_waveforms = we.get_template( - unit_id=unit["unit"], mode="average", force_dense=True - ) # (sample x channel) - peak_chn_idx = list(we.channel_ids).index( - unit_id_to_peak_channel_map[unit["unit"]][0] - ) - unit_peak_waveform.append( - { + def yield_unit_waveforms(): + for unit in (CuratedClustering.Unit & key).fetch( + "KEY", order_by="unit" + ): + # Get mean waveform for this unit from all channels - (sample x channel) + unit_waveforms = we.get_template( + unit_id=unit["unit"], mode="average", force_dense=True + ) + peak_chn_idx = channel_indices.index( + unit_id_to_peak_channel_map[unit["unit"]][0] + ) + unit_peak_waveform = { **unit, "peak_electrode_waveform": unit_waveforms[:, peak_chn_idx], } - ) - unit_electrode_waveforms.extend( - [ + + unit_electrode_waveforms = [ { **unit, - **channel2electrode_map[c], - "waveform_mean": mean_waveforms[unit["unit"] - 1, :, c_idx], + **channel2electrode_map[chn_idx], + "waveform_mean": unit_waveforms[:, chn_idx], } - for c_idx, c in enumerate(channel2electrode_map) + for chn_idx in channel_indices ] - ) - self.insert1(key) - self.PeakWaveform.insert(unit_peak_waveform) - self.Waveform.insert(unit_electrode_waveforms) - else: + yield unit_peak_waveform, unit_electrode_waveforms + else: # read from kilosort outputs kilosort_dataset = kilosort.Kilosort(output_dir) acq_software, probe_serial_number = ( @@ -1340,10 +1330,6 @@ def make(self, key): ) } - waveforms_folder = [ - f for f in output_dir.parent.rglob(r"*/waveforms*") if f.is_dir() - ] - if (output_dir / "mean_waveforms.npy").exists(): unit_waveforms = np.load( output_dir / "mean_waveforms.npy" @@ -1376,75 +1362,6 @@ def yield_unit_waveforms(): } yield unit_peak_waveform, unit_electrode_waveforms - # Spike interface mean and peak waveform extraction from we object - - elif len(waveforms_folder) > 0 & (waveforms_folder[0]).exists(): - we_kilosort = si.load_waveforms(waveforms_folder[0].parent) - unit_templates = we_kilosort.get_all_templates() - unit_waveforms = np.reshape( - unit_templates, - ( - unit_templates.shape[1], - unit_templates.shape[3], - unit_templates.shape[2], - ), - ) - - # Approach assumes unit_waveforms was generated correctly (templates are actually the same as mean_waveforms) - def yield_unit_waveforms(): - for unit_no, unit_waveform in zip( - kilosort_dataset.data["cluster_ids"], unit_waveforms - ): - unit_peak_waveform = {} - unit_electrode_waveforms = [] - if unit_no in units: - for channel, channel_waveform in zip( - kilosort_dataset.data["channel_map"], unit_waveform - ): - unit_electrode_waveforms.append( - { - **units[unit_no], - **channel2electrode_map[channel], - "waveform_mean": channel_waveform, - } - ) - if ( - channel2electrode_map[channel]["electrode"] - == units[unit_no]["electrode"] - ): - unit_peak_waveform = { - **units[unit_no], - "peak_electrode_waveform": channel_waveform, - } - yield unit_peak_waveform, unit_electrode_waveforms - - # Approach not using spike interface templates (ie. taking mean of each unit waveform) - # def yield_unit_waveforms(): - # for unit_id in we_kilosort.unit_ids: - # unit_waveform = np.mean(we_kilosort.get_waveforms(unit_id), 0) - # unit_peak_waveform = {} - # unit_electrode_waveforms = [] - # if unit_id in units: - # for channel, channel_waveform in zip( - # kilosort_dataset.data["channel_map"], unit_waveform - # ): - # unit_electrode_waveforms.append( - # { - # **units[unit_id], - # **channel2electrodes[channel], - # "waveform_mean": channel_waveform, - # } - # ) - # if ( - # channel2electrodes[channel]["electrode"] - # == units[unit_id]["electrode"] - # ): - # unit_peak_waveform = { - # **units[unit_id], - # "peak_electrode_waveform": channel_waveform, - # } - # yield unit_peak_waveform, unit_electrode_waveforms - else: if acq_software == "SpikeGLX": spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) From 93895a965471902b3a3aa5448c7648ce09432928 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Wed, 8 May 2024 00:23:28 +0200 Subject: [PATCH 116/146] Refactor Quality Metrics Logic + blackformatting --- element_array_ephys/ephys_no_curation.py | 22 +++-- .../spike_sorting/si_spike_sorting.py | 84 +++++++++++++------ 2 files changed, 73 insertions(+), 33 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 0cf2021c..b0a8bc26 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1277,11 +1277,11 @@ def make(self, key): we: si.WaveformExtractor = si.load_waveforms( si_waveform_dir, with_recording=False ) - unit_id_to_peak_channel_map: dict[int, np.ndarray] = ( - si.ChannelSparsity.from_best_channels( - we, 1, peak_sign="neg" - ).unit_id_to_channel_indices - ) # {unit: peak_channel_index} + unit_id_to_peak_channel_map: dict[ + int, np.ndarray + ] = si.ChannelSparsity.from_best_channels( + we, 1, peak_sign="neg" + ).unit_id_to_channel_indices # {unit: peak_channel_index} # reorder channel2electrode_map according to recording channel ids channel_indices = we.channel_ids_to_indices(we.channel_ids).tolist() @@ -1315,6 +1315,7 @@ def yield_unit_waveforms(): ] yield unit_peak_waveform, unit_electrode_waveforms + else: # read from kilosort outputs kilosort_dataset = kilosort.Kilosort(output_dir) @@ -1546,9 +1547,14 @@ def make(self, key): metrics_df.rename( columns={ - "isi_viol": "isi_violation", - "num_viol": "number_violation", - "contam_rate": "contamination_rate", + "isi_violations_ratio": "isi_violation", + "isi_violations_count": "number_violation", + "silhouette": "silhouette_score", + "rp_contamination": "contamination_rate", + "drift_ptp": "max_drift", + "drift_mad": "cumulative_drift", + "half_width": "halfwidth", + "peak_trough_ratio": "pt_ratio", }, inplace=True, ) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index c1a906ea..94f12f84 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -134,7 +134,9 @@ def make(self, key): # Add probe information to recording object electrodes_df = ( ( - ephys.EphysRecording.Channel * probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode + ephys.EphysRecording.Channel + * probe.ElectrodeConfig.Electrode + * probe.ProbeType.Electrode & key ) .fetch(format="frame") @@ -142,7 +144,9 @@ def make(self, key): ) # Create SI probe object - si_probe = readers.probe_geometry.to_probeinterface(electrodes_df[["electrode", "x_coord", "y_coord", "shank"]]) + si_probe = readers.probe_geometry.to_probeinterface( + electrodes_df[["electrode", "x_coord", "y_coord", "shank"]] + ) si_probe.set_device_channel_indices(electrodes_df["channel_idx"].values) si_recording.set_probe(probe=si_probe, in_place=True) @@ -184,7 +188,9 @@ def make(self, key): output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) sorter_name = clustering_method.replace(".", "_") recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl" - si_recording: si.BaseRecording = si.load_extractor(recording_file, base_folder=output_dir) + si_recording: si.BaseRecording = si.load_extractor( + recording_file, base_folder=output_dir + ) # Run sorting # Sorting performed in a dedicated docker environment if the sorter is not built in the spikeinterface package. @@ -241,8 +247,12 @@ def make(self, key): recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl" sorting_file = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" - si_recording: si.BaseRecording = si.load_extractor(recording_file, base_folder=output_dir) - si_sorting: si.sorters.BaseSorter = si.load_extractor(sorting_file, base_folder=output_dir) + si_recording: si.BaseRecording = si.load_extractor( + recording_file, base_folder=output_dir + ) + si_sorting: si.sorters.BaseSorter = si.load_extractor( + sorting_file, base_folder=output_dir + ) # Extract waveforms we: si.WaveformExtractor = si.extract_waveforms( @@ -257,27 +267,46 @@ def make(self, key): **params.get("SI_JOB_KWARGS", {"n_jobs": -1, "chunk_size": 30000}), ) - # Calculate QC Metrics - metrics: pd.DataFrame = si.qualitymetrics.compute_quality_metrics( - we, - metric_names=[ - "firing_rate", - "snr", - "presence_ratio", - "isi_violation", - "num_spikes", - "amplitude_cutoff", - "amplitude_median", - "sliding_rp_violation", - "rp_violation", - "drift", - ], - ) - # Add PCA based metrics. These will be added to the metrics dataframe above. + # Calculate Cluster and Waveform Metrics + + # To provide waveform_principal_component _ = si.postprocessing.compute_principal_components( waveform_extractor=we, **params.get("SI_QUALITY_METRICS_PARAMS", None) ) - metrics = si.qualitymetrics.compute_quality_metrics(waveform_extractor=we) + + # To estimate the location of each spike in the sorting output. + # The drift metrics require the `spike_locations` waveform extension. + _ = si.postprocessing.compute_spike_locations(waveform_extractor=we) + + # The `sd_ratio` metric requires the `spike_amplitudes` waveform extension. + # It is highly recommended before calculating amplitude-based quality metrics. + _ = si.postprocessing.compute_spike_amplitudes(waveform_extractor=we) + + # To compute correlograms for spike trains. + _ = si.postprocessing.compute_correlograms(we) + + metric_names = si.qualitymetrics.get_quality_metric_list() + metric_names.extend(si.qualitymetrics.get_quality_pca_metric_list()) + + # To compute commonly used cluster quality metrics. + qc_metrics = si.qualitymetrics.compute_quality_metrics( + waveform_extractor=we, + metric_names=metric_names, + ) + + # To compute commonly used waveform/template metrics. + template_metric_names = si.postprocessing.get_template_metric_names() + template_metric_names.extend(["amplitude", "duration"]) + + template_metrics = si.postprocessing.compute_template_metrics( + waveform_extractor=we, + include_multi_channel_metrics=True, + metric_names=template_metric_names, + ) + + # Save the output (metrics.csv to the output dir) + metrics = pd.DataFrame() + metrics = pd.concat([qc_metrics, template_metrics], axis=1) # Save the output (metrics.csv to the output dir) metrics_output_dir = output_dir / sorter_name / "metrics" @@ -285,9 +314,14 @@ def make(self, key): metrics.to_csv(metrics_output_dir / "metrics.csv") # Save to phy format - si.exporters.export_to_phy(waveform_extractor=we, output_folder=output_dir / sorter_name / "phy") + si.exporters.export_to_phy( + waveform_extractor=we, output_folder=output_dir / sorter_name / "phy" + ) # Generate spike interface report - si.exporters.export_report(waveform_extractor=we, output_folder=output_dir / sorter_name / "spikeinterface_report") + si.exporters.export_report( + waveform_extractor=we, + output_folder=output_dir / sorter_name / "spikeinterface_report", + ) self.insert1( { From bd3bb8e9eccb7df3f44fce7398549325f994dec8 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Wed, 8 May 2024 10:34:09 -0500 Subject: [PATCH 117/146] Update si_spike_sorting.py --- element_array_ephys/spike_sorting/si_spike_sorting.py | 1 + 1 file changed, 1 insertion(+) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index c1a906ea..1aea4ad0 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -42,6 +42,7 @@ def activate( create_tables=create_tables, add_objects=ephys.__dict__, ) + ephys.Clustering.key_source -= PreProcessing.key_source.proj() SI_SORTERS = [s.replace("_", ".") for s in si.sorters.sorter_dict.keys()] From 403d1df30c18eb63f84b200ea8a861c59d9d6ac5 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 9 May 2024 18:57:31 +0200 Subject: [PATCH 118/146] update `postprocessing` logic --- element_array_ephys/spike_sorting/si_spike_sorting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 94f12f84..4c90337e 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -286,7 +286,7 @@ def make(self, key): _ = si.postprocessing.compute_correlograms(we) metric_names = si.qualitymetrics.get_quality_metric_list() - metric_names.extend(si.qualitymetrics.get_quality_pca_metric_list()) + metric_names.extend(si.qualitymetrics.get_quality_pca_metric_list()) # To compute commonly used cluster quality metrics. qc_metrics = si.qualitymetrics.compute_quality_metrics( @@ -308,7 +308,7 @@ def make(self, key): metrics = pd.DataFrame() metrics = pd.concat([qc_metrics, template_metrics], axis=1) - # Save the output (metrics.csv to the output dir) + # Save metrics.csv to the output dir metrics_output_dir = output_dir / sorter_name / "metrics" metrics_output_dir.mkdir(parents=True, exist_ok=True) metrics.to_csv(metrics_output_dir / "metrics.csv") From c934e67ea6e5de2e30b35dbc10ab547e49917159 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 24 May 2024 11:00:06 -0500 Subject: [PATCH 119/146] feat: prototyping with the new `sorting_analyzer` --- .../spike_sorting/si_spike_sorting.py | 27 +++++++++++++++++-- setup.py | 2 +- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index ab803490..f7cb1e57 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -255,6 +255,29 @@ def make(self, key): sorting_file, base_folder=output_dir ) + # Sorting Analyzer + analyzer_output_dir = output_dir / sorter_name / "sorting_analyzer" + if analyzer_output_dir.exists(): + sorting_analyzer = si.load_sorting_analyzer(folder=analyzer_output_dir) + else: + sorting_analyzer = si.create_sorting_analyzer( + sorting=si_sorting, + recording=si_recording, + format="binary_folder", + folder=analyzer_output_dir, + sparse=True, + overwrite=True, + ) + + job_kwargs = params.get("SI_JOB_KWARGS", {"n_jobs": -1, "chunk_duration": "1s"}) + all_computable_extensions = ['random_spikes', 'waveforms', 'templates', 'noise_levels', 'amplitude_scalings', 'correlograms', 'isi_histograms', 'principal_components', 'spike_amplitudes', 'spike_locations', 'template_metrics', 'template_similarity', 'unit_locations', 'quality_metrics'] + extensions_to_compute = ['random_spikes', 'waveforms', 'templates', 'noise_levels', + 'spike_amplitudes', 'spike_locations', 'unit_locations', + 'principal_components', + 'template_metrics', 'quality_metrics'] + + sorting_analyzer.compute(extensions_to_compute, **job_kwargs) + # Extract waveforms we: si.WaveformExtractor = si.extract_waveforms( si_recording, @@ -287,7 +310,7 @@ def make(self, key): _ = si.postprocessing.compute_correlograms(we) metric_names = si.qualitymetrics.get_quality_metric_list() - metric_names.extend(si.qualitymetrics.get_quality_pca_metric_list()) + # metric_names.extend(si.qualitymetrics.get_quality_pca_metric_list()) # TODO: temporarily removed # To compute commonly used cluster quality metrics. qc_metrics = si.qualitymetrics.compute_quality_metrics( @@ -297,7 +320,7 @@ def make(self, key): # To compute commonly used waveform/template metrics. template_metric_names = si.postprocessing.get_template_metric_names() - template_metric_names.extend(["amplitude", "duration"]) + template_metric_names.extend(["amplitude", "duration"]) # TODO: does this do anything? template_metrics = si.postprocessing.compute_template_metrics( waveform_extractor=we, diff --git a/setup.py b/setup.py index 52cd38b1..e62719d8 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ "openpyxl", "plotly", "seaborn", - "spikeinterface", + "spikeinterface>=0.101.0", "scikit-image", "nbformat>=4.2.0", "pyopenephys>=1.1.6", From 3666cda077448cc40d7b7e9c219c9c489396cbd6 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 24 May 2024 14:34:05 -0500 Subject: [PATCH 120/146] feat: update ingestion to be compatible with spikeinterface 0.101+ --- element_array_ephys/ephys_no_curation.py | 209 ++++++++---------- .../spike_sorting/si_spike_sorting.py | 93 ++------ 2 files changed, 116 insertions(+), 186 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index b0a8bc26..413868da 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1037,98 +1037,69 @@ def make(self, key): # Get sorter method and create output directory. sorter_name = clustering_method.replace(".", "_") - si_waveform_dir = output_dir / sorter_name / "waveform" - si_sorting_dir = output_dir / sorter_name / "spike_sorting" + si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" - if si_waveform_dir.exists(): # Read from spikeinterface outputs - we: si.WaveformExtractor = si.load_waveforms( - si_waveform_dir, with_recording=False - ) - si_sorting: si.sorters.BaseSorter = si.load_extractor( - si_sorting_dir / "si_sorting.pkl", base_folder=output_dir - ) + if si_sorting_analyzer_dir.exists(): # Read from spikeinterface outputs + sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) + si_sorting = sorting_analyzer.sorting - unit_peak_channel: dict[int, int] = si.get_template_extremum_channel( - we, outputs="index" - ) # {unit: peak_channel_id} + # Find representative channel for each unit + unit_peak_channel: dict[int, np.ndarray] = ( + si.ChannelSparsity.from_best_channels( + sorting_analyzer, 1, peak_sign="neg" + ).unit_id_to_channel_indices + ) # {unit: peak_channel_index} + unit_peak_channel = {u: chn[0] for u, chn in unit_peak_channel.items()} spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit() # {unit: spike_count} - spikes = si_sorting.to_spike_vector() - # reorder channel2electrode_map according to recording channel ids channel2electrode_map = { chn_idx: channel2electrode_map[chn_idx] - for chn_idx in we.channel_ids_to_indices(we.channel_ids) + for chn_idx in sorting_analyzer.channel_ids_to_indices( + sorting_analyzer.channel_ids + ) } # Get unit id to quality label mapping - try: - cluster_quality_label_map = pd.read_csv( - si_sorting_dir / "sorter_output" / "cluster_KSLabel.tsv", - delimiter="\t", + cluster_quality_label_map = { + int(unit_id): ( + si_sorting.get_unit_property(unit_id, "KSLabel") + if "KSLabel" in si_sorting.get_property_keys() + else "n.a." ) - except FileNotFoundError: - cluster_quality_label_map = {} - else: - cluster_quality_label_map: dict[ - int, str - ] = cluster_quality_label_map.set_index("cluster_id")[ - "KSLabel" - ].to_dict() # {unit: quality_label} - - # Get electrode where peak unit activity is recorded - peak_electrode_ind = np.array( - [ - channel2electrode_map[unit_peak_channel[unit_id]]["electrode"] - for unit_id in si_sorting.unit_ids - ] - ) - - # Get channel depth - channel_depth_ind = np.array( - [ - we.get_probe().contact_positions[unit_peak_channel[unit_id]][1] - for unit_id in si_sorting.unit_ids - ] - ) - - # Assign electrode and depth for each spike - new_spikes = np.empty( - spikes.shape, - spikes.dtype.descr + [("electrode", " Date: Fri, 24 May 2024 14:52:45 -0500 Subject: [PATCH 121/146] format: black formatting --- element_array_ephys/ephys_no_curation.py | 10 +++++++--- .../spike_sorting/si_spike_sorting.py | 19 ++++++++++++------- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 413868da..99247e35 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1256,7 +1256,9 @@ def make(self, key): unit_peak_channel = {u: chn[0] for u, chn in unit_peak_channel.items()} # reorder channel2electrode_map according to recording channel ids - channel_indices = sorting_analyzer.channel_ids_to_indices(sorting_analyzer.channel_ids).tolist() + channel_indices = sorting_analyzer.channel_ids_to_indices( + sorting_analyzer.channel_ids + ).tolist() channel2electrode_map = { chn_idx: channel2electrode_map[chn_idx] for chn_idx in channel_indices } @@ -1500,7 +1502,9 @@ def make(self, key): if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) qc_metrics = sorting_analyzer.get_extension("quality_metrics").get_data() - template_metrics = sorting_analyzer.get_extension("template_metrics").get_data() + template_metrics = sorting_analyzer.get_extension( + "template_metrics" + ).get_data() metrics_df = pd.concat([qc_metrics, template_metrics], axis=1) metrics_df.rename( @@ -1514,7 +1518,7 @@ def make(self, key): "drift_mad": "cumulative_drift", "half_width": "halfwidth", "peak_trough_ratio": "pt_ratio", - "peak_to_valley": "duration" + "peak_to_valley": "duration", }, inplace=True, ) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 55c6efdd..33201d86 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -270,28 +270,33 @@ def make(self, key): overwrite=True, ) - job_kwargs = params["SI_POSTPROCESSING_PARAMS"].get("job_kwargs", {"n_jobs": -1, "chunk_duration": "1s"}) + job_kwargs = params["SI_POSTPROCESSING_PARAMS"].get( + "job_kwargs", {"n_jobs": -1, "chunk_duration": "1s"} + ) extensions_params = params["SI_POSTPROCESSING_PARAMS"].get("extensions", {}) # The order of extension computation is drawn from sorting_analyzer.get_computable_extensions() # each extension is parameterized by params specified in extensions_params dictionary (skip if not specified) - extensions_to_compute = {ext_name: extensions_params[ext_name] - for ext_name in sorting_analyzer.get_computable_extensions() - if ext_name in extensions_params} + extensions_to_compute = { + ext_name: extensions_params[ext_name] + for ext_name in sorting_analyzer.get_computable_extensions() + if ext_name in extensions_params + } sorting_analyzer.compute(extensions_to_compute, **job_kwargs) # Save to phy format if params["SI_POSTPROCESSING_PARAMS"].get("export_to_phy", False): si.exporters.export_to_phy( - sorting_analyzer=sorting_analyzer, output_folder=output_dir / sorter_name / "phy", - **job_kwargs + sorting_analyzer=sorting_analyzer, + output_folder=output_dir / sorter_name / "phy", + **job_kwargs, ) # Generate spike interface report if params["SI_POSTPROCESSING_PARAMS"].get("export_report", True): si.exporters.export_report( sorting_analyzer=sorting_analyzer, output_folder=output_dir / sorter_name / "spikeinterface_report", - **job_kwargs + **job_kwargs, ) self.insert1( From 07a09f6152b9632ce713287a85dedd0ad1bf8e9b Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 24 May 2024 15:28:52 -0500 Subject: [PATCH 122/146] chore: code clean up --- .../spike_sorting/si_spike_sorting.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 33201d86..a0ff2035 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -80,11 +80,9 @@ def make(self, key): sorter_name = clustering_method.replace(".", "_") for required_key in ( - "SI_SORTING_PARAMS", "SI_PREPROCESSING_METHOD", + "SI_SORTING_PARAMS", "SI_POSTPROCESSING_PARAMS", - "SI_WAVEFORM_EXTRACTION_PARAMS", - "SI_QUALITY_METRICS_PARAMS", ): if required_key not in params: raise ValueError( @@ -256,6 +254,10 @@ def make(self, key): sorting_file, base_folder=output_dir ) + job_kwargs = params["SI_POSTPROCESSING_PARAMS"].get( + "job_kwargs", {"n_jobs": -1, "chunk_duration": "1s"} + ) + # Sorting Analyzer analyzer_output_dir = output_dir / sorter_name / "sorting_analyzer" if (analyzer_output_dir / "extensions").exists(): @@ -268,14 +270,12 @@ def make(self, key): folder=analyzer_output_dir, sparse=True, overwrite=True, + **job_kwargs ) - job_kwargs = params["SI_POSTPROCESSING_PARAMS"].get( - "job_kwargs", {"n_jobs": -1, "chunk_duration": "1s"} - ) - extensions_params = params["SI_POSTPROCESSING_PARAMS"].get("extensions", {}) # The order of extension computation is drawn from sorting_analyzer.get_computable_extensions() # each extension is parameterized by params specified in extensions_params dictionary (skip if not specified) + extensions_params = params["SI_POSTPROCESSING_PARAMS"].get("extensions", {}) extensions_to_compute = { ext_name: extensions_params[ext_name] for ext_name in sorting_analyzer.get_computable_extensions() From 3fcf542d1435f4f891f2bbf93eaa3668da1986ea Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 24 May 2024 15:29:09 -0500 Subject: [PATCH 123/146] update: update requirements to install `SpikeInterface` from github (latest version) --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index e62719d8..f1ba9c90 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ "openpyxl", "plotly", "seaborn", - "spikeinterface>=0.101.0", + "spikeinterface @ git+https://github.com/SpikeInterface/spikeinterface.git", "scikit-image", "nbformat>=4.2.0", "pyopenephys>=1.1.6", From 76dfc94568bf28296da18905d0b187588bc99397 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 28 May 2024 10:32:19 -0500 Subject: [PATCH 124/146] fix: minor bug in spikes ingestion --- element_array_ephys/ephys_no_curation.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 99247e35..9222ccd2 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1048,8 +1048,8 @@ def make(self, key): si.ChannelSparsity.from_best_channels( sorting_analyzer, 1, peak_sign="neg" ).unit_id_to_channel_indices - ) # {unit: peak_channel_index} - unit_peak_channel = {u: chn[0] for u, chn in unit_peak_channel.items()} + ) + unit_peak_channel: dict[int, int] = {u: chn[0] for u, chn in unit_peak_channel.items()} spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit() # {unit: spike_count} @@ -1076,9 +1076,9 @@ def make(self, key): spikes_df = pd.DataFrame(spike_locations.spikes) units = [] - for unit_id in si_sorting.unit_ids: + for unit_idx, unit_id in enumerate(si_sorting.unit_ids): unit_id = int(unit_id) - unit_spikes_df = spikes_df[spikes_df.unit_index == unit_id] + unit_spikes_df = spikes_df[spikes_df.unit_index == unit_idx] spike_sites = np.array( [ channel2electrode_map[chn_idx]["electrode"] @@ -1087,6 +1087,9 @@ def make(self, key): ) unit_spikes_loc = spike_locations.get_data()[unit_spikes_df.index] _, spike_depths = zip(*unit_spikes_loc) # x-coordinates, y-coordinates + spike_times = si_sorting.get_unit_spike_train(unit_id, return_times=True) + + assert len(spike_times) == len(spike_sites) == len(spike_depths) units.append( { @@ -1094,9 +1097,7 @@ def make(self, key): **channel2electrode_map[unit_peak_channel[unit_id]], "unit": unit_id, "cluster_quality_label": cluster_quality_label_map[unit_id], - "spike_times": si_sorting.get_unit_spike_train( - unit_id, return_times=True - ), + "spike_times": spike_times, "spike_count": spike_count_dict[unit_id], "spike_sites": spike_sites, "spike_depths": spike_depths, From 9094754b6f23bd65a71390094ac509e06d22b34c Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 28 May 2024 10:38:59 -0500 Subject: [PATCH 125/146] update: bump version --- CHANGELOG.md | 5 +++++ element_array_ephys/version.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e45e427..5d81dcba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,11 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention. +## [0.4.0] - 2024-05-28 + ++ Add - support for SpikeInterface version >= 0.101.0 (updated API) + + ## [0.3.4] - 2024-03-22 + Add - pytest diff --git a/element_array_ephys/version.py b/element_array_ephys/version.py index 148bac24..2e6de55a 100644 --- a/element_array_ephys/version.py +++ b/element_array_ephys/version.py @@ -1,3 +1,3 @@ """Package metadata.""" -__version__ = "0.3.4" +__version__ = "0.4.0" From 51e2ced3f36fa1b69bacf69ea1fbf295c84eaf16 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 28 May 2024 13:14:00 -0500 Subject: [PATCH 126/146] feat: add `memoized_result` on spike sorting --- CHANGELOG.md | 1 + .../spike_sorting/si_spike_sorting.py | 103 ++++++++++-------- setup.py | 2 +- 3 files changed, 60 insertions(+), 46 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d81dcba..cd8bb5b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and ## [0.4.0] - 2024-05-28 + Add - support for SpikeInterface version >= 0.101.0 (updated API) ++ Add - feature for memoization of spike sorting results (prevent duplicated runs) ## [0.3.4] - 2024-03-22 diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index a0ff2035..dff74dd7 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -8,7 +8,7 @@ import pandas as pd import spikeinterface as si from element_array_ephys import probe, readers -from element_interface.utils import find_full_path +from element_interface.utils import find_full_path, memoized_result from spikeinterface import exporters, postprocessing, qualitymetrics, sorters from . import si_preprocessing @@ -192,23 +192,29 @@ def make(self, key): recording_file, base_folder=output_dir ) + sorting_params = params["SI_SORTING_PARAMS"] + sorting_output_dir = output_dir / sorter_name / "spike_sorting" + # Run sorting - # Sorting performed in a dedicated docker environment if the sorter is not built in the spikeinterface package. - si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter( - sorter_name=sorter_name, - recording=si_recording, - output_folder=output_dir / sorter_name / "spike_sorting", - remove_existing_folder=True, - verbose=True, - docker_image=sorter_name not in si.sorters.installed_sorters(), - **params.get("SI_SORTING_PARAMS", {}), + @memoized_result( + uniqueness_dict=sorting_params, + output_directory=sorting_output_dir, ) + def _run_sorter(): + # Sorting performed in a dedicated docker environment if the sorter is not built in the spikeinterface package. + si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter( + sorter_name=sorter_name, + recording=si_recording, + output_folder=sorting_output_dir, + remove_existing_folder=True, + verbose=True, + docker_image=sorter_name not in si.sorters.installed_sorters(), + **sorting_params, + ) - # Save sorting object - sorting_save_path = ( - output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" - ) - si_sorting.dump_to_pickle(sorting_save_path, relative_to=output_dir) + # Save sorting object + sorting_save_path = sorting_output_dir / "si_sorting.pkl" + si_sorting.dump_to_pickle(sorting_save_path, relative_to=output_dir) self.insert1( { @@ -254,15 +260,20 @@ def make(self, key): sorting_file, base_folder=output_dir ) - job_kwargs = params["SI_POSTPROCESSING_PARAMS"].get( + postprocessing_params = params["SI_POSTPROCESSING_PARAMS"] + + job_kwargs = postprocessing_params.get( "job_kwargs", {"n_jobs": -1, "chunk_duration": "1s"} ) - # Sorting Analyzer analyzer_output_dir = output_dir / sorter_name / "sorting_analyzer" - if (analyzer_output_dir / "extensions").exists(): - sorting_analyzer = si.load_sorting_analyzer(folder=analyzer_output_dir) - else: + + @memoized_result( + uniqueness_dict=postprocessing_params, + output_directory=analyzer_output_dir, + ) + def _sorting_analyzer_compute(): + # Sorting Analyzer sorting_analyzer = si.create_sorting_analyzer( sorting=si_sorting, recording=si_recording, @@ -273,31 +284,33 @@ def make(self, key): **job_kwargs ) - # The order of extension computation is drawn from sorting_analyzer.get_computable_extensions() - # each extension is parameterized by params specified in extensions_params dictionary (skip if not specified) - extensions_params = params["SI_POSTPROCESSING_PARAMS"].get("extensions", {}) - extensions_to_compute = { - ext_name: extensions_params[ext_name] - for ext_name in sorting_analyzer.get_computable_extensions() - if ext_name in extensions_params - } - - sorting_analyzer.compute(extensions_to_compute, **job_kwargs) - - # Save to phy format - if params["SI_POSTPROCESSING_PARAMS"].get("export_to_phy", False): - si.exporters.export_to_phy( - sorting_analyzer=sorting_analyzer, - output_folder=output_dir / sorter_name / "phy", - **job_kwargs, - ) - # Generate spike interface report - if params["SI_POSTPROCESSING_PARAMS"].get("export_report", True): - si.exporters.export_report( - sorting_analyzer=sorting_analyzer, - output_folder=output_dir / sorter_name / "spikeinterface_report", - **job_kwargs, - ) + # The order of extension computation is drawn from sorting_analyzer.get_computable_extensions() + # each extension is parameterized by params specified in extensions_params dictionary (skip if not specified) + extensions_params = postprocessing_params.get("extensions", {}) + extensions_to_compute = { + ext_name: extensions_params[ext_name] + for ext_name in sorting_analyzer.get_computable_extensions() + if ext_name in extensions_params + } + + sorting_analyzer.compute(extensions_to_compute, **job_kwargs) + + # Save to phy format + if postprocessing_params.get("export_to_phy", False): + si.exporters.export_to_phy( + sorting_analyzer=sorting_analyzer, + output_folder=analyzer_output_dir / "phy", + **job_kwargs, + ) + # Generate spike interface report + if postprocessing_params.get("export_report", True): + si.exporters.export_report( + sorting_analyzer=sorting_analyzer, + output_folder=analyzer_output_dir / "spikeinterface_report", + **job_kwargs, + ) + + _sorting_analyzer_compute() self.insert1( { diff --git a/setup.py b/setup.py index f1ba9c90..66789740 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,7 @@ "scikit-image", "nbformat>=4.2.0", "pyopenephys>=1.1.6", - "element-interface @ git+https://github.com/datajoint/element-interface.git", + "element-interface @ git+https://github.com/datajoint/element-interface.git@dev_memoized_results", "numba", ], extras_require={ From 0afb4529de262fbee6b21461e5aec58765fd0e12 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 28 May 2024 14:22:20 -0500 Subject: [PATCH 127/146] chore: minor code cleanup --- element_array_ephys/ephys_no_curation.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 9222ccd2..b49d4422 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -8,14 +8,12 @@ import datajoint as dj import numpy as np import pandas as pd -import spikeinterface as si from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory -from spikeinterface import exporters, postprocessing, qualitymetrics, sorters from . import ephys_report, probe from .readers import kilosort, openephys, spikeglx -log = dj.logger +logger = dj.logger schema = dj.schema() @@ -824,7 +822,7 @@ def infer_output_dir(cls, key, relative: bool = False, mkdir: bool = False): if mkdir: output_dir.mkdir(parents=True, exist_ok=True) - log.info(f"{output_dir} created!") + logger.info(f"{output_dir} created!") return output_dir.relative_to(processed_dir) if relative else output_dir @@ -1040,6 +1038,8 @@ def make(self, key): si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" if si_sorting_analyzer_dir.exists(): # Read from spikeinterface outputs + import spikeinterface as si + sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) si_sorting = sorting_analyzer.sorting @@ -1246,6 +1246,8 @@ def make(self, key): si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs + import spikeinterface as si + sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) # Find representative channel for each unit @@ -1501,6 +1503,8 @@ def make(self, key): si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs + import spikeinterface as si + sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) qc_metrics = sorting_analyzer.get_extension("quality_metrics").get_data() template_metrics = sorting_analyzer.get_extension( From e8f445c3b4b532b3159638e71d231e2048939a90 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 28 May 2024 16:47:22 -0500 Subject: [PATCH 128/146] fix: merge fix & formatting --- element_array_ephys/spike_sorting/si_spike_sorting.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index dff74dd7..9e14f636 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -248,7 +248,6 @@ def make(self, key): ).fetch1("clustering_method", "clustering_output_dir", "params") output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) sorter_name = clustering_method.replace(".", "_") - output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) recording_file = output_dir / sorter_name / "recording" / "si_recording.pkl" sorting_file = output_dir / sorter_name / "spike_sorting" / "si_sorting.pkl" @@ -281,7 +280,7 @@ def _sorting_analyzer_compute(): folder=analyzer_output_dir, sparse=True, overwrite=True, - **job_kwargs + **job_kwargs, ) # The order of extension computation is drawn from sorting_analyzer.get_computable_extensions() From 6155f13fd755ac76ec79fdd1594b0e96ef8d550b Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 28 May 2024 17:01:10 -0500 Subject: [PATCH 129/146] fix: calling `_run_sorter()` --- element_array_ephys/spike_sorting/si_spike_sorting.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 9e14f636..5c1d6567 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -216,6 +216,8 @@ def _run_sorter(): sorting_save_path = sorting_output_dir / "si_sorting.pkl" si_sorting.dump_to_pickle(sorting_save_path, relative_to=output_dir) + _run_sorter() + self.insert1( { **key, From f6a52d9d3f31b7ebe2853da4545551898cfa50ae Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 28 May 2024 20:07:27 -0500 Subject: [PATCH 130/146] chore: more robust channel mapping --- element_array_ephys/ephys_no_curation.py | 29 ++++++++---------------- 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index b49d4422..142f350b 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1028,9 +1028,8 @@ def make(self, key): # Get channel and electrode-site mapping electrode_query = (EphysRecording.Channel & key).proj(..., "-channel_name") - channel2electrode_map = electrode_query.fetch(as_dict=True) channel2electrode_map: dict[int, dict] = { - chn.pop("channel_idx"): chn for chn in channel2electrode_map + chn.pop("channel_idx"): chn for chn in electrode_query.fetch(as_dict=True) } # Get sorter method and create output directory. @@ -1054,12 +1053,10 @@ def make(self, key): spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit() # {unit: spike_count} - # reorder channel2electrode_map according to recording channel ids + # update channel2electrode_map to match with probe's channel index channel2electrode_map = { - chn_idx: channel2electrode_map[chn_idx] - for chn_idx in sorting_analyzer.channel_ids_to_indices( - sorting_analyzer.channel_ids - ) + idx: channel2electrode_map[int(chn_idx)] + for idx, chn_idx in enumerate(sorting_analyzer.get_probe().contact_ids) } # Get unit id to quality label mapping @@ -1239,9 +1236,8 @@ def make(self, key): # Get channel and electrode-site mapping electrode_query = (EphysRecording.Channel & key).proj(..., "-channel_name") - channel2electrode_map = electrode_query.fetch(as_dict=True) channel2electrode_map: dict[int, dict] = { - chn.pop("channel_idx"): chn for chn in channel2electrode_map + chn.pop("channel_idx"): chn for chn in electrode_query.fetch(as_dict=True) } si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" @@ -1258,12 +1254,10 @@ def make(self, key): ) # {unit: peak_channel_index} unit_peak_channel = {u: chn[0] for u, chn in unit_peak_channel.items()} - # reorder channel2electrode_map according to recording channel ids - channel_indices = sorting_analyzer.channel_ids_to_indices( - sorting_analyzer.channel_ids - ).tolist() + # update channel2electrode_map to match with probe's channel index channel2electrode_map = { - chn_idx: channel2electrode_map[chn_idx] for chn_idx in channel_indices + idx: channel2electrode_map[int(chn_idx)] + for idx, chn_idx in enumerate(sorting_analyzer.get_probe().contact_ids) } templates = sorting_analyzer.get_extension("templates") @@ -1276,12 +1270,9 @@ def yield_unit_waveforms(): unit_waveforms = templates.get_unit_template( unit_id=unit["unit"], operator="average" ) - peak_chn_idx = channel_indices.index( - unit_peak_channel[unit["unit"]] - ) unit_peak_waveform = { **unit, - "peak_electrode_waveform": unit_waveforms[:, peak_chn_idx], + "peak_electrode_waveform": unit_waveforms[:, unit_peak_channel[unit["unit"]]], } unit_electrode_waveforms = [ @@ -1290,7 +1281,7 @@ def yield_unit_waveforms(): **channel2electrode_map[chn_idx], "waveform_mean": unit_waveforms[:, chn_idx], } - for chn_idx in channel_indices + for chn_idx in channel2electrode_map ] yield unit_peak_waveform, unit_electrode_waveforms From 1ff92dd15db6ff9e8458f53ec96fdffb6b93305d Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Wed, 29 May 2024 16:09:16 -0500 Subject: [PATCH 131/146] fix: use relative path for phy output --- element_array_ephys/spike_sorting/si_spike_sorting.py | 1 + 1 file changed, 1 insertion(+) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 5c1d6567..93619303 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -301,6 +301,7 @@ def _sorting_analyzer_compute(): si.exporters.export_to_phy( sorting_analyzer=sorting_analyzer, output_folder=analyzer_output_dir / "phy", + use_relative_path=True, **job_kwargs, ) # Generate spike interface report From b45970974df001319a4ebae182bf291313f5e39a Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Wed, 29 May 2024 16:16:21 -0500 Subject: [PATCH 132/146] feat: in data ingestion, set peak_sign="both" --- element_array_ephys/ephys_no_curation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 142f350b..8eadba49 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1045,7 +1045,7 @@ def make(self, key): # Find representative channel for each unit unit_peak_channel: dict[int, np.ndarray] = ( si.ChannelSparsity.from_best_channels( - sorting_analyzer, 1, peak_sign="neg" + sorting_analyzer, 1, peak_sign="both" ).unit_id_to_channel_indices ) unit_peak_channel: dict[int, int] = {u: chn[0] for u, chn in unit_peak_channel.items()} @@ -1249,7 +1249,7 @@ def make(self, key): # Find representative channel for each unit unit_peak_channel: dict[int, np.ndarray] = ( si.ChannelSparsity.from_best_channels( - sorting_analyzer, 1, peak_sign="neg" + sorting_analyzer, 1, peak_sign="both" ).unit_id_to_channel_indices ) # {unit: peak_channel_index} unit_peak_channel = {u: chn[0] for u, chn in unit_peak_channel.items()} From 1a1b18f8a52b83298bffc8d82555ccc147151dd1 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Mon, 3 Jun 2024 13:22:49 -0500 Subject: [PATCH 133/146] feat: replace `output_folder` with `folder` when calling `run_sorter`, use default value for `peak_sign` --- element_array_ephys/ephys_no_curation.py | 21 ++++++++++++------- .../spike_sorting/si_spike_sorting.py | 2 +- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 8eadba49..891cee0f 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1045,10 +1045,13 @@ def make(self, key): # Find representative channel for each unit unit_peak_channel: dict[int, np.ndarray] = ( si.ChannelSparsity.from_best_channels( - sorting_analyzer, 1, peak_sign="both" + sorting_analyzer, + 1, ).unit_id_to_channel_indices ) - unit_peak_channel: dict[int, int] = {u: chn[0] for u, chn in unit_peak_channel.items()} + unit_peak_channel: dict[int, int] = { + u: chn[0] for u, chn in unit_peak_channel.items() + } spike_count_dict: dict[int, int] = si_sorting.count_num_spikes_per_unit() # {unit: spike_count} @@ -1084,7 +1087,9 @@ def make(self, key): ) unit_spikes_loc = spike_locations.get_data()[unit_spikes_df.index] _, spike_depths = zip(*unit_spikes_loc) # x-coordinates, y-coordinates - spike_times = si_sorting.get_unit_spike_train(unit_id, return_times=True) + spike_times = si_sorting.get_unit_spike_train( + unit_id, return_times=True + ) assert len(spike_times) == len(spike_sites) == len(spike_depths) @@ -1243,13 +1248,13 @@ def make(self, key): si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs import spikeinterface as si - + sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) # Find representative channel for each unit unit_peak_channel: dict[int, np.ndarray] = ( si.ChannelSparsity.from_best_channels( - sorting_analyzer, 1, peak_sign="both" + sorting_analyzer, 1 ).unit_id_to_channel_indices ) # {unit: peak_channel_index} unit_peak_channel = {u: chn[0] for u, chn in unit_peak_channel.items()} @@ -1272,7 +1277,9 @@ def yield_unit_waveforms(): ) unit_peak_waveform = { **unit, - "peak_electrode_waveform": unit_waveforms[:, unit_peak_channel[unit["unit"]]], + "peak_electrode_waveform": unit_waveforms[ + :, unit_peak_channel[unit["unit"]] + ], } unit_electrode_waveforms = [ @@ -1495,7 +1502,7 @@ def make(self, key): si_sorting_analyzer_dir = output_dir / sorter_name / "sorting_analyzer" if si_sorting_analyzer_dir.exists(): # read from spikeinterface outputs import spikeinterface as si - + sorting_analyzer = si.load_sorting_analyzer(folder=si_sorting_analyzer_dir) qc_metrics = sorting_analyzer.get_extension("quality_metrics").get_data() template_metrics = sorting_analyzer.get_extension( diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 93619303..57aa0ba1 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -205,7 +205,7 @@ def _run_sorter(): si_sorting: si.sorters.BaseSorter = si.sorters.run_sorter( sorter_name=sorter_name, recording=si_recording, - output_folder=sorting_output_dir, + folder=sorting_output_dir, remove_existing_folder=True, verbose=True, docker_image=sorter_name not in si.sorters.installed_sorters(), From 4e645ebd9b83f5e607e1d18188c0c3ce5f84eb4a Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Wed, 5 Jun 2024 16:15:25 -0500 Subject: [PATCH 134/146] fix: remove `job_kwargs` for sparsity calculation - memory error in linux container --- element_array_ephys/spike_sorting/si_spike_sorting.py | 1 - 1 file changed, 1 deletion(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 57aa0ba1..b93d9c10 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -282,7 +282,6 @@ def _sorting_analyzer_compute(): folder=analyzer_output_dir, sparse=True, overwrite=True, - **job_kwargs, ) # The order of extension computation is drawn from sorting_analyzer.get_computable_extensions() From 38fdfb2a5fd44f1115aa4f1660482e1639eaa3c2 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 5 Jul 2024 10:59:37 -0500 Subject: [PATCH 135/146] feat: separate `export` (phy and report) into a separate table --- .../spike_sorting/si_spike_sorting.py | 94 +++++++++++++++---- 1 file changed, 78 insertions(+), 16 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index b93d9c10..463af3df 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -239,6 +239,7 @@ class PostProcessing(dj.Imported): --- execution_time: datetime # datetime of the start of this step execution_duration: float # execution duration in hours + do_si_export=1: bool # whether to export to phy """ def make(self, key): @@ -295,22 +296,6 @@ def _sorting_analyzer_compute(): sorting_analyzer.compute(extensions_to_compute, **job_kwargs) - # Save to phy format - if postprocessing_params.get("export_to_phy", False): - si.exporters.export_to_phy( - sorting_analyzer=sorting_analyzer, - output_folder=analyzer_output_dir / "phy", - use_relative_path=True, - **job_kwargs, - ) - # Generate spike interface report - if postprocessing_params.get("export_report", True): - si.exporters.export_report( - sorting_analyzer=sorting_analyzer, - output_folder=analyzer_output_dir / "spikeinterface_report", - **job_kwargs, - ) - _sorting_analyzer_compute() self.insert1( @@ -321,6 +306,8 @@ def _sorting_analyzer_compute(): datetime.utcnow() - execution_time ).total_seconds() / 3600, + "do_si_export": postprocessing_params.get("export_to_phy", False) + or postprocessing_params.get("export_report", False), } ) @@ -328,3 +315,78 @@ def _sorting_analyzer_compute(): ephys.Clustering.insert1( {**key, "clustering_time": datetime.utcnow()}, allow_direct_insert=True ) + + +@schema +class SIExport(dj.Computed): + """A SpikeInterface export report and to Phy""" + + definition = """ + -> PostProcessing + --- + execution_time: datetime + execution_duration: float + """ + + @property + def key_source(self): + return PostProcessing & "do_si_export = 1" + + def make(self, key): + execution_time = datetime.utcnow() + + clustering_method, output_dir, params = ( + ephys.ClusteringTask * ephys.ClusteringParamSet & key + ).fetch1("clustering_method", "clustering_output_dir", "params") + output_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) + sorter_name = clustering_method.replace(".", "_") + + postprocessing_params = params["SI_POSTPROCESSING_PARAMS"] + + job_kwargs = postprocessing_params.get( + "job_kwargs", {"n_jobs": -1, "chunk_duration": "1s"} + ) + + analyzer_output_dir = output_dir / sorter_name / "sorting_analyzer" + sorting_analyzer = si.load_sorting_analyzer(folder=analyzer_output_dir) + + @memoized_result( + uniqueness_dict=postprocessing_params, + output_directory=analyzer_output_dir / "phy", + ) + def _export_to_phy(): + # Save to phy format + si.exporters.export_to_phy( + sorting_analyzer=sorting_analyzer, + output_folder=analyzer_output_dir / "phy", + use_relative_path=True, + **job_kwargs, + ) + + @memoized_result( + uniqueness_dict=postprocessing_params, + output_directory=analyzer_output_dir / "spikeinterface_report", + ) + def _export_report(): + # Generate spike interface report + si.exporters.export_report( + sorting_analyzer=sorting_analyzer, + output_folder=analyzer_output_dir / "spikeinterface_report", + **job_kwargs, + ) + + if postprocessing_params.get("export_report", False): + _export_report() + if postprocessing_params.get("export_to_phy", False): + _export_to_phy() + + self.insert1( + { + **key, + "execution_time": execution_time, + "execution_duration": ( + datetime.utcnow() - execution_time + ).total_seconds() + / 3600, + } + ) From a4a8380405673bf2c85861223afa0c9e5e481296 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 5 Jul 2024 11:00:42 -0500 Subject: [PATCH 136/146] fix: export default to `False` --- element_array_ephys/spike_sorting/si_spike_sorting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 463af3df..6f2d7b53 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -239,7 +239,7 @@ class PostProcessing(dj.Imported): --- execution_time: datetime # datetime of the start of this step execution_duration: float # execution duration in hours - do_si_export=1: bool # whether to export to phy + do_si_export=0: bool # whether to export to phy """ def make(self, key): @@ -331,7 +331,7 @@ class SIExport(dj.Computed): @property def key_source(self): return PostProcessing & "do_si_export = 1" - + def make(self, key): execution_time = datetime.utcnow() From 1f05998e25d848b6aeb73231fa90e616580cd1d8 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 5 Jul 2024 16:45:40 -0500 Subject: [PATCH 137/146] fix: `spikes` object no longer available from `ComputeSpikeLocations` (https://github.com/SpikeInterface/spikeinterface/pull/3015) --- element_array_ephys/ephys_no_curation.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys_no_curation.py index 891cee0f..5df8bad0 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys_no_curation.py @@ -1073,7 +1073,9 @@ def make(self, key): } spike_locations = sorting_analyzer.get_extension("spike_locations") - spikes_df = pd.DataFrame(spike_locations.spikes) + extremum_channel_inds = si.template_tools.get_template_extremum_channel(sorting_analyzer, outputs="index") + spikes_df = pd.DataFrame( + sorting_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds)) units = [] for unit_idx, unit_id in enumerate(si_sorting.unit_ids): From 7cd8ac8ce8eeb731f149924279bb3b0d990caa45 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 5 Jul 2024 21:26:48 -0500 Subject: [PATCH 138/146] chore: code cleanup --- element_array_ephys/spike_sorting/si_spike_sorting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 6f2d7b53..8624e073 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -9,7 +9,7 @@ import spikeinterface as si from element_array_ephys import probe, readers from element_interface.utils import find_full_path, memoized_result -from spikeinterface import exporters, postprocessing, qualitymetrics, sorters +from spikeinterface import exporters, extractors, sorters from . import si_preprocessing From c87e49332f90386acc8eb696e65f87bfd7b6ae24 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Sat, 6 Jul 2024 08:00:44 -0500 Subject: [PATCH 139/146] fix: recording_extractor_full_dict is deprecated (https://github.com/SpikeInterface/spikeinterface/pull/3153) --- .../spike_sorting/si_spike_sorting.py | 27 +++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 8624e073..7133b81c 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -111,25 +111,30 @@ def make(self, key): ) spikeglx_recording.validate_file("ap") data_dir = spikeglx_meta_filepath.parent + + si_extractor = si.extractors.neoextractors.spikeglx.SpikeGLXRecordingExtractor + stream_names, stream_ids = si.extractors.get_neo_streams( + acq_software, folder_path=data_dir + ) + si_recording: si.BaseRecording = si_extractor( + folder_path=data_dir, stream_name=stream_names[0] + ) elif acq_software == "Open Ephys": oe_probe = ephys.get_openephys_probe_data(key) assert len(oe_probe.recording_info["recording_files"]) == 1 data_dir = oe_probe.recording_info["recording_files"][0] + si_extractor = si.extractors.neoextractors.openephys.OpenEphysBinaryRecordingExtractor + + stream_names, stream_ids = si.extractors.get_neo_streams( + acq_software, folder_path=data_dir + ) + si_recording: si.BaseRecording = si_extractor( + folder_path=data_dir, stream_name=stream_names[0] + ) else: raise NotImplementedError( f"SpikeInterface processing for {acq_software} not yet implemented." ) - acq_software = acq_software.replace(" ", "").lower() - si_extractor: si.extractors.neoextractors = ( - si.extractors.extractorlist.recording_extractor_full_dict[acq_software] - ) # data extractor object - - stream_names, stream_ids = si.extractors.get_neo_streams( - acq_software, folder_path=data_dir - ) - si_recording: si.BaseRecording = si_extractor( - folder_path=data_dir, stream_name=stream_names[0] - ) # Add probe information to recording object electrodes_df = ( From 097d9bbf7694e40a839b4cebb49890d1acd325f1 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 30 Jul 2024 18:26:37 -0500 Subject: [PATCH 140/146] fix: bugfix spikeinterface extractor name --- element_array_ephys/spike_sorting/si_spike_sorting.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 7133b81c..550ae4a1 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -114,7 +114,7 @@ def make(self, key): si_extractor = si.extractors.neoextractors.spikeglx.SpikeGLXRecordingExtractor stream_names, stream_ids = si.extractors.get_neo_streams( - acq_software, folder_path=data_dir + "spikeglx", folder_path=data_dir ) si_recording: si.BaseRecording = si_extractor( folder_path=data_dir, stream_name=stream_names[0] @@ -126,7 +126,7 @@ def make(self, key): si_extractor = si.extractors.neoextractors.openephys.OpenEphysBinaryRecordingExtractor stream_names, stream_ids = si.extractors.get_neo_streams( - acq_software, folder_path=data_dir + "openephysbinary", folder_path=data_dir ) si_recording: si.BaseRecording = si_extractor( folder_path=data_dir, stream_name=stream_names[0] From b6f131b814ed9dba2e2cc38d6918df52668dd590 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Thu, 15 Aug 2024 17:24:09 -0500 Subject: [PATCH 141/146] update: element-interface `main` branch --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 66789740..f1ba9c90 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,7 @@ "scikit-image", "nbformat>=4.2.0", "pyopenephys>=1.1.6", - "element-interface @ git+https://github.com/datajoint/element-interface.git@dev_memoized_results", + "element-interface @ git+https://github.com/datajoint/element-interface.git", "numba", ], extras_require={ From ccd23fc413d126f897c20cececcc35b86cb5190f Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 10 Sep 2024 12:04:48 -0500 Subject: [PATCH 142/146] rearrange(all): major refactor of modules --- CHANGELOG.md | 9 + element_array_ephys/__init__.py | 4 +- .../{ephys_no_curation.py => ephys.py} | 13 +- element_array_ephys/ephys_acute.py | 1594 ----------------- element_array_ephys/ephys_chronic.py | 1523 ---------------- element_array_ephys/ephys_precluster.py | 1435 --------------- element_array_ephys/ephys_report.py | 14 +- element_array_ephys/export/nwb/nwb.py | 9 +- .../spike_sorting/si_spike_sorting.py | 28 +- element_array_ephys/version.py | 2 +- tests/tutorial_pipeline.py | 6 +- 11 files changed, 43 insertions(+), 4594 deletions(-) rename element_array_ephys/{ephys_no_curation.py => ephys.py} (99%) delete mode 100644 element_array_ephys/ephys_acute.py delete mode 100644 element_array_ephys/ephys_chronic.py delete mode 100644 element_array_ephys/ephys_precluster.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 7068216b..34d1a2e4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,15 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention. + +## [1.0.0] - 2024-09-10 + ++ Update - No longer support multiple variation of ephys module, keep only `ephys_no_curation` module, renamed to `ephys` ++ Update - Remove other ephys modules (e.g. `ephys_acute`, `ephys_chronic`) (moved to different branches) ++ Update - Add support for `SpikeInterface` ++ Update - Remove support for `ecephys_spike_sorting` (moved to a different branch) ++ Update - Simplify the "activate" mechanism + ## [0.4.0] - 2024-08-16 + Add - support for SpikeInterface version >= 0.101.0 (updated API) diff --git a/element_array_ephys/__init__.py b/element_array_ephys/__init__.py index 1c0c7285..079950b4 100644 --- a/element_array_ephys/__init__.py +++ b/element_array_ephys/__init__.py @@ -1 +1,3 @@ -from . import ephys_acute as ephys +from . import ephys + +ephys_no_curation = ephys # alias for backward compatibility diff --git a/element_array_ephys/ephys_no_curation.py b/element_array_ephys/ephys.py similarity index 99% rename from element_array_ephys/ephys_no_curation.py rename to element_array_ephys/ephys.py index 5df8bad0..3025d289 100644 --- a/element_array_ephys/ephys_no_curation.py +++ b/element_array_ephys/ephys.py @@ -10,7 +10,7 @@ import pandas as pd from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory -from . import ephys_report, probe +from . import probe from .readers import kilosort, openephys, spikeglx logger = dj.logger @@ -22,7 +22,6 @@ def activate( ephys_schema_name: str, - probe_schema_name: str = None, *, create_schema: bool = True, create_tables: bool = True, @@ -32,7 +31,6 @@ def activate( Args: ephys_schema_name (str): A string containing the name of the ephys schema. - probe_schema_name (str): A string containing the name of the probe schema. create_schema (bool): If True, schema will be created in the database. create_tables (bool): If True, tables related to the schema will be created in the database. linking_module (str): A string containing the module name or module containing the required dependencies to activate the schema. @@ -46,7 +44,6 @@ def activate( get_ephys_root_data_dir(): Returns absolute path for root data director(y/ies) with all electrophysiological recording sessions, as a list of string(s). get_session_direction(session_key: dict): Returns path to electrophysiology data for the a particular session as a list of strings. get_processed_data_dir(): Optional. Returns absolute path for processed data. Defaults to root directory. - """ if isinstance(linking_module, str): @@ -58,17 +55,15 @@ def activate( global _linking_module _linking_module = linking_module - # activate - probe.activate( - probe_schema_name, create_schema=create_schema, create_tables=create_tables - ) + if not probe.schema.is_activated(): + raise RuntimeError("Please activate the `probe` schema first.") + schema.activate( ephys_schema_name, create_schema=create_schema, create_tables=create_tables, add_objects=_linking_module.__dict__, ) - ephys_report.activate(f"{ephys_schema_name}_report", ephys_schema_name) # -------------- Functions required by the elements-ephys --------------- diff --git a/element_array_ephys/ephys_acute.py b/element_array_ephys/ephys_acute.py deleted file mode 100644 index c2627fc9..00000000 --- a/element_array_ephys/ephys_acute.py +++ /dev/null @@ -1,1594 +0,0 @@ -import gc -import importlib -import inspect -import pathlib -import re -from decimal import Decimal - -import datajoint as dj -import numpy as np -import pandas as pd -from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory - -from . import ephys_report, probe -from .readers import kilosort, openephys, spikeglx - -log = dj.logger - -schema = dj.schema() - -_linking_module = None - - -def activate( - ephys_schema_name: str, - probe_schema_name: str = None, - *, - create_schema: bool = True, - create_tables: bool = True, - linking_module: str = None, -): - """Activates the `ephys` and `probe` schemas. - - Args: - ephys_schema_name (str): A string containing the name of the ephys schema. - probe_schema_name (str): A string containing the name of the probe schema. - create_schema (bool): If True, schema will be created in the database. - create_tables (bool): If True, tables related to the schema will be created in the database. - linking_module (str): A string containing the module name or module containing the required dependencies to activate the schema. - - Dependencies: - Upstream tables: - Session: A parent table to ProbeInsertion - Probe: A parent table to EphysRecording. Probe information is required before electrophysiology data is imported. - - Functions: - get_ephys_root_data_dir(): Returns absolute path for root data director(y/ies) with all electrophysiological recording sessions, as a list of string(s). - get_session_direction(session_key: dict): Returns path to electrophysiology data for the a particular session as a list of strings. - get_processed_data_dir(): Optional. Returns absolute path for processed data. Defaults to root directory. - """ - - if isinstance(linking_module, str): - linking_module = importlib.import_module(linking_module) - assert inspect.ismodule( - linking_module - ), "The argument 'dependency' must be a module's name or a module" - - global _linking_module - _linking_module = linking_module - - probe.activate( - probe_schema_name, create_schema=create_schema, create_tables=create_tables - ) - schema.activate( - ephys_schema_name, - create_schema=create_schema, - create_tables=create_tables, - add_objects=_linking_module.__dict__, - ) - ephys_report.activate(f"{ephys_schema_name}_report", ephys_schema_name) - - -# -------------- Functions required by the elements-ephys --------------- - - -def get_ephys_root_data_dir() -> list: - """Fetches absolute data path to ephys data directories. - - The absolute path here is used as a reference for all downstream relative paths used in DataJoint. - - Returns: - A list of the absolute path(s) to ephys data directories. - """ - root_directories = _linking_module.get_ephys_root_data_dir() - if isinstance(root_directories, (str, pathlib.Path)): - root_directories = [root_directories] - - if hasattr(_linking_module, "get_processed_root_data_dir"): - root_directories.append(_linking_module.get_processed_root_data_dir()) - - return root_directories - - -def get_session_directory(session_key: dict) -> str: - """Retrieve the session directory with Neuropixels for the given session. - - Args: - session_key (dict): A dictionary mapping subject to an entry in the subject table, and session_datetime corresponding to a session in the database. - - Returns: - A string for the path to the session directory. - """ - return _linking_module.get_session_directory(session_key) - - -def get_processed_root_data_dir() -> str: - """Retrieve the root directory for all processed data. - - Returns: - A string for the full path to the root directory for processed data. - """ - - if hasattr(_linking_module, "get_processed_root_data_dir"): - return _linking_module.get_processed_root_data_dir() - else: - return get_ephys_root_data_dir()[0] - - -# ----------------------------- Table declarations ---------------------- - - -@schema -class AcquisitionSoftware(dj.Lookup): - """Name of software used for recording electrophysiological data. - - Attributes: - acq_software ( varchar(24) ): Acquisition software, e.g,. SpikeGLX, OpenEphys - """ - - definition = """ # Software used for recording of neuropixels probes - acq_software: varchar(24) - """ - contents = zip(["SpikeGLX", "Open Ephys"]) - - -@schema -class ProbeInsertion(dj.Manual): - """Information about probe insertion across subjects and sessions. - - Attributes: - Session (foreign key): Session primary key. - insertion_number (foreign key, str): Unique insertion number for each probe insertion for a given session. - probe.Probe (str): probe.Probe primary key. - """ - - definition = """ - # Probe insertion implanted into an animal for a given session. - -> Session - insertion_number: tinyint unsigned - --- - -> probe.Probe - """ - - @classmethod - def auto_generate_entries(cls, session_key): - """Automatically populate entries in ProbeInsertion table for a session.""" - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(session_key) - ) - # search session dir and determine acquisition software - for ephys_pattern, ephys_acq_type in ( - ("*.ap.meta", "SpikeGLX"), - ("*.oebin", "Open Ephys"), - ): - ephys_meta_filepaths = list(session_dir.rglob(ephys_pattern)) - if ephys_meta_filepaths: - acq_software = ephys_acq_type - break - else: - raise FileNotFoundError( - f"Ephys recording data not found!" - f" Neither SpikeGLX nor Open Ephys recording files found in: {session_dir}" - ) - - probe_list, probe_insertion_list = [], [] - if acq_software == "SpikeGLX": - for meta_fp_idx, meta_filepath in enumerate(ephys_meta_filepaths): - spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) - - probe_key = { - "probe_type": spikeglx_meta.probe_model, - "probe": spikeglx_meta.probe_SN, - } - if probe_key["probe"] not in [p["probe"] for p in probe_list]: - probe_list.append(probe_key) - - probe_dir = meta_filepath.parent - try: - probe_number = re.search("(imec)?\d{1}$", probe_dir.name).group() - probe_number = int(probe_number.replace("imec", "")) - except AttributeError: - probe_number = meta_fp_idx - - probe_insertion_list.append( - { - **session_key, - "probe": spikeglx_meta.probe_SN, - "insertion_number": int(probe_number), - } - ) - elif acq_software == "Open Ephys": - loaded_oe = openephys.OpenEphys(session_dir) - for probe_idx, oe_probe in enumerate(loaded_oe.probes.values()): - probe_key = { - "probe_type": oe_probe.probe_model, - "probe": oe_probe.probe_SN, - } - if probe_key["probe"] not in [p["probe"] for p in probe_list]: - probe_list.append(probe_key) - probe_insertion_list.append( - { - **session_key, - "probe": oe_probe.probe_SN, - "insertion_number": probe_idx, - } - ) - else: - raise NotImplementedError(f"Unknown acquisition software: {acq_software}") - - probe.Probe.insert(probe_list, skip_duplicates=True) - cls.insert(probe_insertion_list, skip_duplicates=True) - - -@schema -class InsertionLocation(dj.Manual): - """Stereotaxic location information for each probe insertion. - - Attributes: - ProbeInsertion (foreign key): ProbeInsertion primary key. - SkullReference (dict): SkullReference primary key. - ap_location (decimal (6, 2) ): Anterior-posterior location in micrometers. Reference is 0 with anterior values positive. - ml_location (decimal (6, 2) ): Medial-lateral location in micrometers. Reference is zero with right side values positive. - depth (decimal (6, 2) ): Manipulator depth relative to the surface of the brain at zero. Ventral is negative. - Theta (decimal (5, 2) ): elevation - rotation about the ml-axis in degrees relative to positive z-axis. - phi (decimal (5, 2) ): azimuth - rotation about the dv-axis in degrees relative to the positive x-axis. - """ - - definition = """ - # Brain Location of a given probe insertion. - -> ProbeInsertion - --- - -> SkullReference - ap_location: decimal(6, 2) # (um) anterior-posterior; ref is 0; more anterior is more positive - ml_location: decimal(6, 2) # (um) medial axis; ref is 0 ; more right is more positive - depth: decimal(6, 2) # (um) manipulator depth relative to surface of the brain (0); more ventral is more negative - theta=null: decimal(5, 2) # (deg) - elevation - rotation about the ml-axis [0, 180] - w.r.t the z+ axis - phi=null: decimal(5, 2) # (deg) - azimuth - rotation about the dv-axis [0, 360] - w.r.t the x+ axis - beta=null: decimal(5, 2) # (deg) rotation about the shank of the probe [-180, 180] - clockwise is increasing in degree - 0 is the probe-front facing anterior - """ - - -@schema -class EphysRecording(dj.Imported): - """Automated table with electrophysiology recording information for each probe inserted during an experimental session. - - Attributes: - ProbeInsertion (foreign key): ProbeInsertion primary key. - probe.ElectrodeConfig (dict): probe.ElectrodeConfig primary key. - AcquisitionSoftware (dict): AcquisitionSoftware primary key. - sampling_rate (float): sampling rate of the recording in Hertz (Hz). - recording_datetime (datetime): datetime of the recording from this probe. - recording_duration (float): duration of the entire recording from this probe in seconds. - """ - - definition = """ - # Ephys recording from a probe insertion for a given session. - -> ProbeInsertion - --- - -> probe.ElectrodeConfig - -> AcquisitionSoftware - sampling_rate: float # (Hz) - recording_datetime: datetime # datetime of the recording from this probe - recording_duration: float # (seconds) duration of the recording from this probe - """ - - class EphysFile(dj.Part): - """Paths of electrophysiology recording files for each insertion. - - Attributes: - EphysRecording (foreign key): EphysRecording primary key. - file_path (varchar(255) ): relative file path for electrophysiology recording. - """ - - definition = """ - # Paths of files of a given EphysRecording round. - -> master - file_path: varchar(255) # filepath relative to root data directory - """ - - def make(self, key): - """Populates table with electrophysiology recording information.""" - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(key) - ) - - inserted_probe_serial_number = (ProbeInsertion * probe.Probe & key).fetch1( - "probe" - ) - - # search session dir and determine acquisition software - for ephys_pattern, ephys_acq_type in ( - ("*.ap.meta", "SpikeGLX"), - ("*.oebin", "Open Ephys"), - ): - ephys_meta_filepaths = list(session_dir.rglob(ephys_pattern)) - if ephys_meta_filepaths: - acq_software = ephys_acq_type - break - else: - raise FileNotFoundError( - f"Ephys recording data not found!" - f" Neither SpikeGLX nor Open Ephys recording files found" - f" in {session_dir}" - ) - - supported_probe_types = probe.ProbeType.fetch("probe_type") - - if acq_software == "SpikeGLX": - for meta_filepath in ephys_meta_filepaths: - spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) - if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: - break - else: - raise FileNotFoundError( - "No SpikeGLX data found for probe insertion: {}".format(key) - ) - - if spikeglx_meta.probe_model in supported_probe_types: - probe_type = spikeglx_meta.probe_model - electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} - - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - - electrode_group_members = [ - probe_electrodes[(shank, shank_col, shank_row)] - for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"] - ] - else: - raise NotImplementedError( - "Processing for neuropixels probe model" - " {} not yet implemented".format(spikeglx_meta.probe_model) - ) - - self.insert1( - { - **key, - **generate_electrode_config(probe_type, electrode_group_members), - "acq_software": acq_software, - "sampling_rate": spikeglx_meta.meta["imSampRate"], - "recording_datetime": spikeglx_meta.recording_time, - "recording_duration": ( - spikeglx_meta.recording_duration - or spikeglx.retrieve_recording_duration(meta_filepath) - ), - } - ) - - root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath) - self.EphysFile.insert1( - {**key, "file_path": meta_filepath.relative_to(root_dir).as_posix()} - ) - elif acq_software == "Open Ephys": - dataset = openephys.OpenEphys(session_dir) - for serial_number, probe_data in dataset.probes.items(): - if str(serial_number) == inserted_probe_serial_number: - break - else: - raise FileNotFoundError( - "No Open Ephys data found for probe insertion: {}".format(key) - ) - - if not probe_data.ap_meta: - raise IOError( - 'No analog signals found - check "structure.oebin" file or "continuous" directory' - ) - - if probe_data.probe_model in supported_probe_types: - probe_type = probe_data.probe_model - electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} - - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - electrode_group_members = [ - probe_electrodes[channel_idx] - for channel_idx in probe_data.ap_meta["channels_indices"] - ] - else: - raise NotImplementedError( - "Processing for neuropixels" - " probe model {} not yet implemented".format(probe_data.probe_model) - ) - - self.insert1( - { - **key, - **generate_electrode_config(probe_type, electrode_group_members), - "acq_software": acq_software, - "sampling_rate": probe_data.ap_meta["sample_rate"], - "recording_datetime": probe_data.recording_info[ - "recording_datetimes" - ][0], - "recording_duration": np.sum( - probe_data.recording_info["recording_durations"] - ), - } - ) - - root_dir = find_root_directory( - get_ephys_root_data_dir(), - probe_data.recording_info["recording_files"][0], - ) - self.EphysFile.insert( - [ - {**key, "file_path": fp.relative_to(root_dir).as_posix()} - for fp in probe_data.recording_info["recording_files"] - ] - ) - # explicitly garbage collect "dataset" - # as these may have large memory footprint and may not be cleared fast enough - del probe_data, dataset - gc.collect() - else: - raise NotImplementedError( - f"Processing ephys files from" - f" acquisition software of type {acq_software} is" - f" not yet implemented" - ) - - -@schema -class LFP(dj.Imported): - """Extracts local field potentials (LFP) from an electrophysiology recording. - - Attributes: - EphysRecording (foreign key): EphysRecording primary key. - lfp_sampling_rate (float): Sampling rate for LFPs in Hz. - lfp_time_stamps (longblob): Time stamps with respect to the start of the recording. - lfp_mean (longblob): Overall mean LFP across electrodes. - """ - - definition = """ - # Acquired local field potential (LFP) from a given Ephys recording. - -> EphysRecording - --- - lfp_sampling_rate: float # (Hz) - lfp_time_stamps: longblob # (s) timestamps with respect to the start of the recording (recording_timestamp) - lfp_mean: longblob # (uV) mean of LFP across electrodes - shape (time,) - """ - - class Electrode(dj.Part): - """Saves local field potential data for each electrode. - - Attributes: - LFP (foreign key): LFP primary key. - probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key. - lfp (longblob): LFP recording at this electrode in microvolts. - """ - - definition = """ - -> master - -> probe.ElectrodeConfig.Electrode - --- - lfp: longblob # (uV) recorded lfp at this electrode - """ - - # Only store LFP for every 9th channel, due to high channel density, - # close-by channels exhibit highly similar LFP - _skip_channel_counts = 9 - - def make(self, key): - """Populates the LFP tables.""" - acq_software = (EphysRecording * ProbeInsertion & key).fetch1("acq_software") - - electrode_keys, lfp = [], [] - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - - lfp_channel_ind = spikeglx_recording.lfmeta.recording_channels[ - -1 :: -self._skip_channel_counts - ] - - # Extract LFP data at specified channels and convert to uV - lfp = spikeglx_recording.lf_timeseries[ - :, lfp_channel_ind - ] # (sample x channel) - lfp = ( - lfp * spikeglx_recording.get_channel_bit_volts("lf")[lfp_channel_ind] - ).T # (channel x sample) - - self.insert1( - dict( - key, - lfp_sampling_rate=spikeglx_recording.lfmeta.meta["imSampRate"], - lfp_time_stamps=( - np.arange(lfp.shape[1]) - / spikeglx_recording.lfmeta.meta["imSampRate"] - ), - lfp_mean=lfp.mean(axis=0), - ) - ) - - electrode_query = ( - probe.ProbeType.Electrode - * probe.ElectrodeConfig.Electrode - * EphysRecording - & key - ) - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - - for recorded_site in lfp_channel_ind: - shank, shank_col, shank_row, _ = spikeglx_recording.apmeta.shankmap[ - "data" - ][recorded_site] - electrode_keys.append(probe_electrodes[(shank, shank_col, shank_row)]) - elif acq_software == "Open Ephys": - oe_probe = get_openephys_probe_data(key) - - lfp_channel_ind = np.r_[ - len(oe_probe.lfp_meta["channels_indices"]) - - 1 : 0 : -self._skip_channel_counts - ] - - # (sample x channel) - lfp = oe_probe.lfp_timeseries[:, lfp_channel_ind] - lfp = ( - lfp * np.array(oe_probe.lfp_meta["channels_gains"])[lfp_channel_ind] - ).T # (channel x sample) - lfp_timestamps = oe_probe.lfp_timestamps - - self.insert1( - dict( - key, - lfp_sampling_rate=oe_probe.lfp_meta["sample_rate"], - lfp_time_stamps=lfp_timestamps, - lfp_mean=lfp.mean(axis=0), - ) - ) - - electrode_query = ( - probe.ProbeType.Electrode - * probe.ElectrodeConfig.Electrode - * EphysRecording - & key - ) - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - electrode_keys.extend( - probe_electrodes[channel_idx] for channel_idx in lfp_channel_ind - ) - else: - raise NotImplementedError( - f"LFP extraction from acquisition software" - f" of type {acq_software} is not yet implemented" - ) - - # single insert in loop to mitigate potential memory issue - for electrode_key, lfp_trace in zip(electrode_keys, lfp): - self.Electrode.insert1({**key, **electrode_key, "lfp": lfp_trace}) - - -# ------------ Clustering -------------- - - -@schema -class ClusteringMethod(dj.Lookup): - """Kilosort clustering method. - - Attributes: - clustering_method (foreign key, varchar(16) ): Kilosort clustering method. - clustering_methods_desc (varchar(1000) ): Additional description of the clustering method. - """ - - definition = """ - # Method for clustering - clustering_method: varchar(16) - --- - clustering_method_desc: varchar(1000) - """ - - contents = [ - ("kilosort2", "kilosort2 clustering method"), - ("kilosort2.5", "kilosort2.5 clustering method"), - ("kilosort3", "kilosort3 clustering method"), - ] - - -@schema -class ClusteringParamSet(dj.Lookup): - """Parameters to be used in clustering procedure for spike sorting. - - Attributes: - paramset_idx (foreign key): Unique ID for the clustering parameter set. - ClusteringMethod (dict): ClusteringMethod primary key. - paramset_desc (varchar(128) ): Description of the clustering parameter set. - param_set_hash (uuid): UUID hash for the parameter set. - params (longblob): Parameters for clustering with Kilosort. - """ - - definition = """ - # Parameter set to be used in a clustering procedure - paramset_idx: smallint - --- - -> ClusteringMethod - paramset_desc: varchar(128) - param_set_hash: uuid - unique index (param_set_hash) - params: longblob # dictionary of all applicable parameters - """ - - @classmethod - def insert_new_params( - cls, - clustering_method: str, - paramset_desc: str, - params: dict, - paramset_idx: int = None, - ): - """Inserts new parameters into the ClusteringParamSet table. - - Args: - clustering_method (str): name of the clustering method. - paramset_desc (str): description of the parameter set - params (dict): clustering parameters - paramset_idx (int, optional): Unique parameter set ID. Defaults to None. - """ - if paramset_idx is None: - paramset_idx = ( - dj.U().aggr(cls, n="max(paramset_idx)").fetch1("n") or 0 - ) + 1 - - param_dict = { - "clustering_method": clustering_method, - "paramset_idx": paramset_idx, - "paramset_desc": paramset_desc, - "params": params, - "param_set_hash": dict_to_uuid( - {**params, "clustering_method": clustering_method} - ), - } - param_query = cls & {"param_set_hash": param_dict["param_set_hash"]} - - if param_query: # If the specified param-set already exists - existing_paramset_idx = param_query.fetch1("paramset_idx") - if ( - existing_paramset_idx == paramset_idx - ): # If the existing set has the same paramset_idx: job done - return - else: # If not same name: human error, trying to add the same paramset with different name - raise dj.DataJointError( - f"The specified param-set already exists" - f" - with paramset_idx: {existing_paramset_idx}" - ) - else: - if {"paramset_idx": paramset_idx} in cls.proj(): - raise dj.DataJointError( - f"The specified paramset_idx {paramset_idx} already exists," - f" please pick a different one." - ) - cls.insert1(param_dict) - - -@schema -class ClusterQualityLabel(dj.Lookup): - """Quality label for each spike sorted cluster. - - Attributes: - cluster_quality_label (foreign key, varchar(100) ): Cluster quality type. - cluster_quality_description ( varchar(4000) ): Description of the cluster quality type. - """ - - definition = """ - # Quality - cluster_quality_label: varchar(100) # cluster quality type - e.g. 'good', 'MUA', 'noise', etc. - --- - cluster_quality_description: varchar(4000) - """ - contents = [ - ("good", "single unit"), - ("ok", "probably a single unit, but could be contaminated"), - ("mua", "multi-unit activity"), - ("noise", "bad unit"), - ] - - -@schema -class ClusteringTask(dj.Manual): - """A clustering task to spike sort electrophysiology datasets. - - Attributes: - EphysRecording (foreign key): EphysRecording primary key. - ClusteringParamSet (foreign key): ClusteringParamSet primary key. - clustering_output_dir ( varchar (255) ): Relative path to output clustering results. - task_mode (enum): `Trigger` computes clustering or and `load` imports existing data. - """ - - definition = """ - # Manual table for defining a clustering task ready to be run - -> EphysRecording - -> ClusteringParamSet - --- - clustering_output_dir='': varchar(255) # clustering output directory relative to the clustering root data directory - task_mode='load': enum('load', 'trigger') # 'load': load computed analysis results, 'trigger': trigger computation - """ - - @classmethod - def infer_output_dir( - cls, key: dict, relative: bool = False, mkdir: bool = False - ) -> pathlib.Path: - """Infer output directory if it is not provided. - - Args: - key (dict): ClusteringTask primary key. - - Returns: - Expected clustering_output_dir based on the following convention: - processed_dir / session_dir / probe_{insertion_number} / {clustering_method}_{paramset_idx} - e.g.: sub4/sess1/probe_2/kilosort2_0 - """ - processed_dir = pathlib.Path(get_processed_root_data_dir()) - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(key) - ) - root_dir = find_root_directory(get_ephys_root_data_dir(), session_dir) - - method = ( - (ClusteringParamSet * ClusteringMethod & key) - .fetch1("clustering_method") - .replace(".", "-") - ) - - output_dir = ( - processed_dir - / session_dir.relative_to(root_dir) - / f'probe_{key["insertion_number"]}' - / f'{method}_{key["paramset_idx"]}' - ) - - if mkdir: - output_dir.mkdir(parents=True, exist_ok=True) - log.info(f"{output_dir} created!") - - return output_dir.relative_to(processed_dir) if relative else output_dir - - @classmethod - def auto_generate_entries(cls, ephys_recording_key: dict, paramset_idx: int = 0): - """Autogenerate entries based on a particular ephys recording. - - Args: - ephys_recording_key (dict): EphysRecording primary key. - paramset_idx (int, optional): Parameter index to use for clustering task. Defaults to 0. - """ - key = {**ephys_recording_key, "paramset_idx": paramset_idx} - - processed_dir = get_processed_root_data_dir() - output_dir = ClusteringTask.infer_output_dir(key, relative=False, mkdir=True) - - try: - kilosort.Kilosort( - output_dir - ) # check if the directory is a valid Kilosort output - except FileNotFoundError: - task_mode = "trigger" - else: - task_mode = "load" - - cls.insert1( - { - **key, - "clustering_output_dir": output_dir.relative_to( - processed_dir - ).as_posix(), - "task_mode": task_mode, - } - ) - - -@schema -class Clustering(dj.Imported): - """A processing table to handle each clustering task. - - Attributes: - ClusteringTask (foreign key): ClusteringTask primary key. - clustering_time (datetime): Time when clustering results are generated. - package_version ( varchar(16) ): Package version used for a clustering analysis. - """ - - definition = """ - # Clustering Procedure - -> ClusteringTask - --- - clustering_time: datetime # time of generation of this set of clustering results - package_version='': varchar(16) - """ - - def make(self, key): - """Triggers or imports clustering analysis.""" - task_mode, output_dir = (ClusteringTask & key).fetch1( - "task_mode", "clustering_output_dir" - ) - - if not output_dir: - output_dir = ClusteringTask.infer_output_dir(key, relative=True, mkdir=True) - # update clustering_output_dir - ClusteringTask.update1( - {**key, "clustering_output_dir": output_dir.as_posix()} - ) - - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - if task_mode == "load": - kilosort.Kilosort( - kilosort_dir - ) # check if the directory is a valid Kilosort output - elif task_mode == "trigger": - acq_software, clustering_method, params = ( - ClusteringTask * EphysRecording * ClusteringParamSet & key - ).fetch1("acq_software", "clustering_method", "params") - - if "kilosort" in clustering_method: - from element_array_ephys.readers import kilosort_triggering - - # add additional probe-recording and channels details into `params` - params = {**params, **get_recording_channels_details(key)} - params["fs"] = params["sample_rate"] - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX( - spikeglx_meta_filepath.parent - ) - spikeglx_recording.validate_file("ap") - run_CatGT = ( - params.pop("run_CatGT", True) - and "_tcat." not in spikeglx_meta_filepath.stem - ) - - if clustering_method.startswith("pykilosort"): - kilosort_triggering.run_pykilosort( - continuous_file=spikeglx_recording.root_dir - / (spikeglx_recording.root_name + ".ap.bin"), - kilosort_output_directory=kilosort_dir, - channel_ind=params.pop("channel_ind"), - x_coords=params.pop("x_coords"), - y_coords=params.pop("y_coords"), - shank_ind=params.pop("shank_ind"), - connected=params.pop("connected"), - sample_rate=params.pop("sample_rate"), - params=params, - ) - else: - run_kilosort = kilosort_triggering.SGLXKilosortPipeline( - npx_input_dir=spikeglx_meta_filepath.parent, - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - run_CatGT=run_CatGT, - ) - run_kilosort.run_modules() - elif acq_software == "Open Ephys": - oe_probe = get_openephys_probe_data(key) - - assert len(oe_probe.recording_info["recording_files"]) == 1 - - # run kilosort - if clustering_method.startswith("pykilosort"): - kilosort_triggering.run_pykilosort( - continuous_file=pathlib.Path( - oe_probe.recording_info["recording_files"][0] - ) - / "continuous.dat", - kilosort_output_directory=kilosort_dir, - channel_ind=params.pop("channel_ind"), - x_coords=params.pop("x_coords"), - y_coords=params.pop("y_coords"), - shank_ind=params.pop("shank_ind"), - connected=params.pop("connected"), - sample_rate=params.pop("sample_rate"), - params=params, - ) - else: - run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( - npx_input_dir=oe_probe.recording_info["recording_files"][0], - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - ) - run_kilosort.run_modules() - else: - raise NotImplementedError( - f"Automatic triggering of {clustering_method}" - f" clustering analysis is not yet supported" - ) - - else: - raise ValueError(f"Unknown task mode: {task_mode}") - - creation_time, _, _ = kilosort.extract_clustering_info(kilosort_dir) - self.insert1({**key, "clustering_time": creation_time, "package_version": ""}) - - -@schema -class Curation(dj.Manual): - """Curation procedure table. - - Attributes: - Clustering (foreign key): Clustering primary key. - curation_id (foreign key, int): Unique curation ID. - curation_time (datetime): Time when curation results are generated. - curation_output_dir ( varchar(255) ): Output directory of the curated results. - quality_control (bool): If True, this clustering result has undergone quality control. - manual_curation (bool): If True, manual curation has been performed on this clustering result. - curation_note ( varchar(2000) ): Notes about the curation task. - """ - - definition = """ - # Manual curation procedure - -> Clustering - curation_id: int - --- - curation_time: datetime # time of generation of this set of curated clustering results - curation_output_dir: varchar(255) # output directory of the curated results, relative to root data directory - quality_control: bool # has this clustering result undergone quality control? - manual_curation: bool # has manual curation been performed on this clustering result? - curation_note='': varchar(2000) - """ - - def create1_from_clustering_task(self, key, curation_note=""): - """ - A function to create a new corresponding "Curation" for a particular - "ClusteringTask" - """ - if key not in Clustering(): - raise ValueError( - f"No corresponding entry in Clustering available" - f" for: {key}; do `Clustering.populate(key)`" - ) - - task_mode, output_dir = (ClusteringTask & key).fetch1( - "task_mode", "clustering_output_dir" - ) - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - creation_time, is_curated, is_qc = kilosort.extract_clustering_info( - kilosort_dir - ) - # Synthesize curation_id - curation_id = ( - dj.U().aggr(self & key, n="ifnull(max(curation_id)+1,1)").fetch1("n") - ) - self.insert1( - { - **key, - "curation_id": curation_id, - "curation_time": creation_time, - "curation_output_dir": output_dir, - "quality_control": is_qc, - "manual_curation": is_curated, - "curation_note": curation_note, - } - ) - - -@schema -class CuratedClustering(dj.Imported): - """Clustering results after curation. - - Attributes: - Curation (foreign key): Curation primary key. - """ - - definition = """ - # Clustering results of a curation. - -> Curation - """ - - class Unit(dj.Part): - """Single unit properties after clustering and curation. - - Attributes: - CuratedClustering (foreign key): CuratedClustering primary key. - unit (foreign key, int): Unique integer identifying a single unit. - probe.ElectrodeConfig.Electrode (dict): probe.ElectrodeConfig.Electrode primary key. - ClusteringQualityLabel (dict): CLusteringQualityLabel primary key. - spike_count (int): Number of spikes in this recording for this unit. - spike_times (longblob): Spike times of this unit, relative to start time of EphysRecording. - spike_sites (longblob): Array of electrode associated with each spike. - spike_depths (longblob): Array of depths associated with each spike, relative to each spike. - """ - - definition = """ - # Properties of a given unit from a round of clustering (and curation) - -> master - unit: int - --- - -> probe.ElectrodeConfig.Electrode # electrode with highest waveform amplitude for this unit - -> ClusterQualityLabel - spike_count: int # how many spikes in this recording for this unit - spike_times: longblob # (s) spike times of this unit, relative to the start of the EphysRecording - spike_sites : longblob # array of electrode associated with each spike - spike_depths=null : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe - """ - - def make(self, key): - """Automated population of Unit information.""" - output_dir = (Curation & key).fetch1("curation_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - kilosort_dataset = kilosort.Kilosort(kilosort_dir) - acq_software, sample_rate = (EphysRecording & key).fetch1( - "acq_software", "sampling_rate" - ) - - sample_rate = kilosort_dataset.data["params"].get("sample_rate", sample_rate) - - # ---------- Unit ---------- - # -- Remove 0-spike units - withspike_idx = [ - i - for i, u in enumerate(kilosort_dataset.data["cluster_ids"]) - if (kilosort_dataset.data["spike_clusters"] == u).any() - ] - valid_units = kilosort_dataset.data["cluster_ids"][withspike_idx] - valid_unit_labels = kilosort_dataset.data["cluster_groups"][withspike_idx] - # -- Get channel and electrode-site mapping - channel2electrodes = get_neuropixels_channel2electrode_map(key, acq_software) - - # -- Spike-times -- - # spike_times_sec_adj > spike_times_sec > spike_times - spike_time_key = ( - "spike_times_sec_adj" - if "spike_times_sec_adj" in kilosort_dataset.data - else ( - "spike_times_sec" - if "spike_times_sec" in kilosort_dataset.data - else "spike_times" - ) - ) - spike_times = kilosort_dataset.data[spike_time_key] - kilosort_dataset.extract_spike_depths() - - # -- Spike-sites and Spike-depths -- - spike_sites = np.array( - [ - channel2electrodes[s]["electrode"] - for s in kilosort_dataset.data["spike_sites"] - ] - ) - spike_depths = kilosort_dataset.data["spike_depths"] - - # -- Insert unit, label, peak-chn - units = [] - for unit, unit_lbl in zip(valid_units, valid_unit_labels): - if (kilosort_dataset.data["spike_clusters"] == unit).any(): - unit_channel, _ = kilosort_dataset.get_best_channel(unit) - unit_spike_times = ( - spike_times[kilosort_dataset.data["spike_clusters"] == unit] - / sample_rate - ) - spike_count = len(unit_spike_times) - - units.append( - { - "unit": unit, - "cluster_quality_label": unit_lbl, - **channel2electrodes[unit_channel], - "spike_times": unit_spike_times, - "spike_count": spike_count, - "spike_sites": spike_sites[ - kilosort_dataset.data["spike_clusters"] == unit - ], - "spike_depths": ( - spike_depths[ - kilosort_dataset.data["spike_clusters"] == unit - ] - if spike_depths is not None - else None - ), - } - ) - - self.insert1(key) - self.Unit.insert([{**key, **u} for u in units]) - - -@schema -class WaveformSet(dj.Imported): - """A set of spike waveforms for units out of a given CuratedClustering. - - Attributes: - CuratedClustering (foreign key): CuratedClustering primary key. - """ - - definition = """ - # A set of spike waveforms for units out of a given CuratedClustering - -> CuratedClustering - """ - - class PeakWaveform(dj.Part): - """Mean waveform across spikes for a given unit. - - Attributes: - WaveformSet (foreign key): WaveformSet primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - peak_electrode_waveform (longblob): Mean waveform for a given unit at its representative electrode. - """ - - definition = """ - # Mean waveform across spikes for a given unit at its representative electrode - -> master - -> CuratedClustering.Unit - --- - peak_electrode_waveform: longblob # (uV) mean waveform for a given unit at its representative electrode - """ - - class Waveform(dj.Part): - """Spike waveforms for a given unit. - - Attributes: - WaveformSet (foreign key): WaveformSet primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key. - waveform_mean (longblob): mean waveform across spikes of the unit in microvolts. - waveforms (longblob): waveforms of a sampling of spikes at the given electrode and unit. - """ - - definition = """ - # Spike waveforms and their mean across spikes for the given unit - -> master - -> CuratedClustering.Unit - -> probe.ElectrodeConfig.Electrode - --- - waveform_mean: longblob # (uV) mean waveform across spikes of the given unit - waveforms=null: longblob # (uV) (spike x sample) waveforms of a sampling of spikes at the given electrode for the given unit - """ - - def make(self, key): - """Populates waveform tables.""" - output_dir = (Curation & key).fetch1("curation_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - kilosort_dataset = kilosort.Kilosort(kilosort_dir) - - acq_software, probe_serial_number = ( - EphysRecording * ProbeInsertion & key - ).fetch1("acq_software", "probe") - - # -- Get channel and electrode-site mapping - recording_key = (EphysRecording & key).fetch1("KEY") - channel2electrodes = get_neuropixels_channel2electrode_map( - recording_key, acq_software - ) - - is_qc = (Curation & key).fetch1("quality_control") - - # Get all units - units = { - u["unit"]: u - for u in (CuratedClustering.Unit & key).fetch(as_dict=True, order_by="unit") - } - - if is_qc: - unit_waveforms = np.load( - kilosort_dir / "mean_waveforms.npy" - ) # unit x channel x sample - - def yield_unit_waveforms(): - for unit_no, unit_waveform in zip( - kilosort_dataset.data["cluster_ids"], unit_waveforms - ): - unit_peak_waveform = {} - unit_electrode_waveforms = [] - if unit_no in units: - for channel, channel_waveform in zip( - kilosort_dataset.data["channel_map"], unit_waveform - ): - unit_electrode_waveforms.append( - { - **units[unit_no], - **channel2electrodes[channel], - "waveform_mean": channel_waveform, - } - ) - if ( - channel2electrodes[channel]["electrode"] - == units[unit_no]["electrode"] - ): - unit_peak_waveform = { - **units[unit_no], - "peak_electrode_waveform": channel_waveform, - } - yield unit_peak_waveform, unit_electrode_waveforms - - else: - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) - neuropixels_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - elif acq_software == "Open Ephys": - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(key) - ) - openephys_dataset = openephys.OpenEphys(session_dir) - neuropixels_recording = openephys_dataset.probes[probe_serial_number] - - def yield_unit_waveforms(): - for unit_dict in units.values(): - unit_peak_waveform = {} - unit_electrode_waveforms = [] - - spikes = unit_dict["spike_times"] - waveforms = neuropixels_recording.extract_spike_waveforms( - spikes, kilosort_dataset.data["channel_map"] - ) # (sample x channel x spike) - waveforms = waveforms.transpose( - (1, 2, 0) - ) # (channel x spike x sample) - for channel, channel_waveform in zip( - kilosort_dataset.data["channel_map"], waveforms - ): - unit_electrode_waveforms.append( - { - **unit_dict, - **channel2electrodes[channel], - "waveform_mean": channel_waveform.mean(axis=0), - "waveforms": channel_waveform, - } - ) - if ( - channel2electrodes[channel]["electrode"] - == unit_dict["electrode"] - ): - unit_peak_waveform = { - **unit_dict, - "peak_electrode_waveform": channel_waveform.mean( - axis=0 - ), - } - - yield unit_peak_waveform, unit_electrode_waveforms - - # insert waveform on a per-unit basis to mitigate potential memory issue - self.insert1(key) - for unit_peak_waveform, unit_electrode_waveforms in yield_unit_waveforms(): - if unit_peak_waveform: - self.PeakWaveform.insert1(unit_peak_waveform, ignore_extra_fields=True) - if unit_electrode_waveforms: - self.Waveform.insert(unit_electrode_waveforms, ignore_extra_fields=True) - - -@schema -class QualityMetrics(dj.Imported): - """Clustering and waveform quality metrics. - - Attributes: - CuratedClustering (foreign key): CuratedClustering primary key. - """ - - definition = """ - # Clusters and waveforms metrics - -> CuratedClustering - """ - - class Cluster(dj.Part): - """Cluster metrics for a unit. - - Attributes: - QualityMetrics (foreign key): QualityMetrics primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - firing_rate (float): Firing rate of the unit. - snr (float): Signal-to-noise ratio for a unit. - presence_ratio (float): Fraction of time where spikes are present. - isi_violation (float): rate of ISI violation as a fraction of overall rate. - number_violation (int): Total ISI violations. - amplitude_cutoff (float): Estimate of miss rate based on amplitude histogram. - isolation_distance (float): Distance to nearest cluster. - l_ratio (float): Amount of empty space between a cluster and other spikes in dataset. - d_prime (float): Classification accuracy based on LDA. - nn_hit_rate (float): Fraction of neighbors for target cluster that are also in target cluster. - nn_miss_rate (float): Fraction of neighbors outside target cluster that are in the target cluster. - silhouette_core (float): Maximum change in spike depth throughout recording. - cumulative_drift (float): Cumulative change in spike depth throughout recording. - contamination_rate (float): Frequency of spikes in the refractory period. - """ - - definition = """ - # Cluster metrics for a particular unit - -> master - -> CuratedClustering.Unit - --- - firing_rate=null: float # (Hz) firing rate for a unit - snr=null: float # signal-to-noise ratio for a unit - presence_ratio=null: float # fraction of time in which spikes are present - isi_violation=null: float # rate of ISI violation as a fraction of overall rate - number_violation=null: int # total number of ISI violations - amplitude_cutoff=null: float # estimate of miss rate based on amplitude histogram - isolation_distance=null: float # distance to nearest cluster in Mahalanobis space - l_ratio=null: float # - d_prime=null: float # Classification accuracy based on LDA - nn_hit_rate=null: float # Fraction of neighbors for target cluster that are also in target cluster - nn_miss_rate=null: float # Fraction of neighbors outside target cluster that are in target cluster - silhouette_score=null: float # Standard metric for cluster overlap - max_drift=null: float # Maximum change in spike depth throughout recording - cumulative_drift=null: float # Cumulative change in spike depth throughout recording - contamination_rate=null: float # - """ - - class Waveform(dj.Part): - """Waveform metrics for a particular unit. - - Attributes: - QualityMetrics (foreign key): QualityMetrics primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - amplitude (float): Absolute difference between waveform peak and trough in microvolts. - duration (float): Time between waveform peak and trough in milliseconds. - halfwidth (float): Spike width at half max amplitude. - pt_ratio (float): Absolute amplitude of peak divided by absolute amplitude of trough relative to 0. - repolarization_slope (float): Slope of the regression line fit to first 30 microseconds from trough to peak. - recovery_slope (float): Slope of the regression line fit to first 30 microseconds from peak to tail. - spread (float): The range with amplitude over 12-percent of maximum amplitude along the probe. - velocity_above (float): inverse velocity of waveform propagation from soma to the top of the probe. - velocity_below (float): inverse velocity of waveform propagation from soma toward the bottom of the probe. - """ - - definition = """ - # Waveform metrics for a particular unit - -> master - -> CuratedClustering.Unit - --- - amplitude: float # (uV) absolute difference between waveform peak and trough - duration: float # (ms) time between waveform peak and trough - halfwidth=null: float # (ms) spike width at half max amplitude - pt_ratio=null: float # absolute amplitude of peak divided by absolute amplitude of trough relative to 0 - repolarization_slope=null: float # the repolarization slope was defined by fitting a regression line to the first 30us from trough to peak - recovery_slope=null: float # the recovery slope was defined by fitting a regression line to the first 30us from peak to tail - spread=null: float # (um) the range with amplitude above 12-percent of the maximum amplitude along the probe - velocity_above=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the top of the probe - velocity_below=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the bottom of the probe - """ - - def make(self, key): - """Populates tables with quality metrics data.""" - output_dir = (ClusteringTask & key).fetch1("clustering_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - metric_fp = kilosort_dir / "metrics.csv" - rename_dict = { - "isi_viol": "isi_violation", - "num_viol": "number_violation", - "contam_rate": "contamination_rate", - } - - if not metric_fp.exists(): - raise FileNotFoundError(f"QC metrics file not found: {metric_fp}") - - metrics_df = pd.read_csv(metric_fp) - metrics_df.set_index("cluster_id", inplace=True) - metrics_df.replace([np.inf, -np.inf], np.nan, inplace=True) - metrics_df.columns = metrics_df.columns.str.lower() - metrics_df.rename(columns=rename_dict, inplace=True) - metrics_list = [ - dict(metrics_df.loc[unit_key["unit"]], **unit_key) - for unit_key in (CuratedClustering.Unit & key).fetch("KEY") - ] - - self.insert1(key) - self.Cluster.insert(metrics_list, ignore_extra_fields=True) - self.Waveform.insert(metrics_list, ignore_extra_fields=True) - - -# ---------------- HELPER FUNCTIONS ---------------- - - -def get_spikeglx_meta_filepath(ephys_recording_key: dict) -> str: - """Get spikeGLX data filepath.""" - # attempt to retrieve from EphysRecording.EphysFile - spikeglx_meta_filepath = pathlib.Path( - ( - EphysRecording.EphysFile - & ephys_recording_key - & 'file_path LIKE "%.ap.meta"' - ).fetch1("file_path") - ) - - try: - spikeglx_meta_filepath = find_full_path( - get_ephys_root_data_dir(), spikeglx_meta_filepath - ) - except FileNotFoundError: - # if not found, search in session_dir again - if not spikeglx_meta_filepath.exists(): - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(ephys_recording_key) - ) - inserted_probe_serial_number = ( - ProbeInsertion * probe.Probe & ephys_recording_key - ).fetch1("probe") - - spikeglx_meta_filepaths = [fp for fp in session_dir.rglob("*.ap.meta")] - for meta_filepath in spikeglx_meta_filepaths: - spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) - if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: - spikeglx_meta_filepath = meta_filepath - break - else: - raise FileNotFoundError( - "No SpikeGLX data found for probe insertion: {}".format( - ephys_recording_key - ) - ) - - return spikeglx_meta_filepath - - -def get_openephys_probe_data(ephys_recording_key: dict) -> list: - """Get OpenEphys probe data from file.""" - inserted_probe_serial_number = ( - ProbeInsertion * probe.Probe & ephys_recording_key - ).fetch1("probe") - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(ephys_recording_key) - ) - loaded_oe = openephys.OpenEphys(session_dir) - probe_data = loaded_oe.probes[inserted_probe_serial_number] - - # explicitly garbage collect "loaded_oe" - # as these may have large memory footprint and may not be cleared fast enough - del loaded_oe - gc.collect() - - return probe_data - - -def get_neuropixels_channel2electrode_map( - ephys_recording_key: dict, acq_software: str -) -> dict: - """Get the channel map for neuropixels probe.""" - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key) - spikeglx_meta = spikeglx.SpikeGLXMeta(spikeglx_meta_filepath) - electrode_config_key = ( - EphysRecording * probe.ElectrodeConfig & ephys_recording_key - ).fetch1("KEY") - - electrode_query = ( - probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode - & electrode_config_key - ) - - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - - channel2electrode_map = { - recorded_site: probe_electrodes[(shank, shank_col, shank_row)] - for recorded_site, (shank, shank_col, shank_row, _) in enumerate( - spikeglx_meta.shankmap["data"] - ) - } - elif acq_software == "Open Ephys": - probe_dataset = get_openephys_probe_data(ephys_recording_key) - - electrode_query = ( - probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode * EphysRecording - & ephys_recording_key - ) - - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - channel2electrode_map = { - channel_idx: probe_electrodes[channel_idx] - for channel_idx in probe_dataset.ap_meta["channels_indices"] - } - - return channel2electrode_map - - -def generate_electrode_config(probe_type: str, electrode_keys: list) -> dict: - """Generate and insert new ElectrodeConfig - - Args: - probe_type (str): probe type (e.g. neuropixels 2.0 - SS) - electrode_keys (list): list of keys of the probe.ProbeType.Electrode table - - Returns: - dict: representing a key of the probe.ElectrodeConfig table - """ - # compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode) - electrode_config_hash = dict_to_uuid({k["electrode"]: k for k in electrode_keys}) - - electrode_list = sorted([k["electrode"] for k in electrode_keys]) - electrode_gaps = ( - [-1] - + np.where(np.diff(electrode_list) > 1)[0].tolist() - + [len(electrode_list) - 1] - ) - electrode_config_name = "; ".join( - [ - f"{electrode_list[start + 1]}-{electrode_list[end]}" - for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:]) - ] - ) - - electrode_config_key = {"electrode_config_hash": electrode_config_hash} - - # ---- make new ElectrodeConfig if needed ---- - if not probe.ElectrodeConfig & electrode_config_key: - probe.ElectrodeConfig.insert1( - { - **electrode_config_key, - "probe_type": probe_type, - "electrode_config_name": electrode_config_name, - } - ) - probe.ElectrodeConfig.Electrode.insert( - {**electrode_config_key, **electrode} for electrode in electrode_keys - ) - - return electrode_config_key - - -def get_recording_channels_details(ephys_recording_key: dict) -> np.array: - """Get details of recording channels for a given recording.""" - channels_details = {} - - acq_software, sample_rate = (EphysRecording & ephys_recording_key).fetch1( - "acq_software", "sampling_rate" - ) - - probe_type = (ProbeInsertion * probe.Probe & ephys_recording_key).fetch1( - "probe_type" - ) - channels_details["probe_type"] = { - "neuropixels 1.0 - 3A": "3A", - "neuropixels 1.0 - 3B": "NP1", - "neuropixels UHD": "NP1100", - "neuropixels 2.0 - SS": "NP21", - "neuropixels 2.0 - MS": "NP24", - }[probe_type] - - electrode_config_key = ( - probe.ElectrodeConfig * EphysRecording & ephys_recording_key - ).fetch1("KEY") - ( - channels_details["channel_ind"], - channels_details["x_coords"], - channels_details["y_coords"], - channels_details["shank_ind"], - ) = ( - probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode - & electrode_config_key - ).fetch( - "electrode", "x_coord", "y_coord", "shank" - ) - channels_details["sample_rate"] = sample_rate - channels_details["num_channels"] = len(channels_details["channel_ind"]) - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key) - spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - channels_details["uVPerBit"] = spikeglx_recording.get_channel_bit_volts("ap")[0] - channels_details["connected"] = np.array( - [v for *_, v in spikeglx_recording.apmeta.shankmap["data"]] - ) - elif acq_software == "Open Ephys": - oe_probe = get_openephys_probe_data(ephys_recording_key) - channels_details["uVPerBit"] = oe_probe.ap_meta["channels_gains"][0] - channels_details["connected"] = np.array( - [ - int(v == 1) - for c, v in oe_probe.channels_connected.items() - if c in channels_details["channel_ind"] - ] - ) - - return channels_details diff --git a/element_array_ephys/ephys_chronic.py b/element_array_ephys/ephys_chronic.py deleted file mode 100644 index 772e885f..00000000 --- a/element_array_ephys/ephys_chronic.py +++ /dev/null @@ -1,1523 +0,0 @@ -import gc -import importlib -import inspect -import pathlib -from decimal import Decimal - -import datajoint as dj -import numpy as np -import pandas as pd -from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory - -from . import ephys_report, probe -from .readers import kilosort, openephys, spikeglx - -log = dj.logger - -schema = dj.schema() - -_linking_module = None - - -def activate( - ephys_schema_name: str, - probe_schema_name: str = None, - *, - create_schema: bool = True, - create_tables: bool = True, - linking_module: str = None, -): - """Activates the `ephys` and `probe` schemas. - - Args: - ephys_schema_name (str): A string containing the name of the ephys schema. - probe_schema_name (str): A string containing the name of the probe schema. - create_schema (bool): If True, schema will be created in the database. - create_tables (bool): If True, tables related to the schema will be created in the database. - linking_module (str): A string containing the module name or module containing the required dependencies to activate the schema. - - Dependencies: - Upstream tables: - Session: A parent table to ProbeInsertion - Probe: A parent table to EphysRecording. Probe information is required before electrophysiology data is imported. - - Functions: - get_ephys_root_data_dir(): Returns absolute path for root data director(y/ies) with all electrophysiological recording sessions, as a list of string(s). - get_session_direction(session_key: dict): Returns path to electrophysiology data for the a particular session as a list of strings. - get_processed_data_dir(): Optional. Returns absolute path for processed data. Defaults to root directory. - """ - - if isinstance(linking_module, str): - linking_module = importlib.import_module(linking_module) - assert inspect.ismodule( - linking_module - ), "The argument 'dependency' must be a module's name or a module" - - global _linking_module - _linking_module = linking_module - - probe.activate( - probe_schema_name, create_schema=create_schema, create_tables=create_tables - ) - schema.activate( - ephys_schema_name, - create_schema=create_schema, - create_tables=create_tables, - add_objects=_linking_module.__dict__, - ) - ephys_report.activate(f"{ephys_schema_name}_report", ephys_schema_name) - - -# -------------- Functions required by the elements-ephys --------------- - - -def get_ephys_root_data_dir() -> list: - """Fetches absolute data path to ephys data directories. - - The absolute path here is used as a reference for all downstream relative paths used in DataJoint. - - Returns: - A list of the absolute path(s) to ephys data directories. - """ - root_directories = _linking_module.get_ephys_root_data_dir() - if isinstance(root_directories, (str, pathlib.Path)): - root_directories = [root_directories] - - if hasattr(_linking_module, "get_processed_root_data_dir"): - root_directories.append(_linking_module.get_processed_root_data_dir()) - - return root_directories - - -def get_session_directory(session_key: dict) -> str: - """Retrieve the session directory with Neuropixels for the given session. - - Args: - session_key (dict): A dictionary mapping subject to an entry in the subject table, and session_datetime corresponding to a session in the database. - - Returns: - A string for the path to the session directory. - """ - return _linking_module.get_session_directory(session_key) - - -def get_processed_root_data_dir() -> str: - """Retrieve the root directory for all processed data. - - Returns: - A string for the full path to the root directory for processed data. - """ - - if hasattr(_linking_module, "get_processed_root_data_dir"): - return _linking_module.get_processed_root_data_dir() - else: - return get_ephys_root_data_dir()[0] - - -# ----------------------------- Table declarations ---------------------- - - -@schema -class AcquisitionSoftware(dj.Lookup): - """Name of software used for recording electrophysiological data. - - Attributes: - acq_software ( varchar(24) ): Acquisition software, e.g,. SpikeGLX, OpenEphys - """ - - definition = """ # Software used for recording of neuropixels probes - acq_software: varchar(24) - """ - contents = zip(["SpikeGLX", "Open Ephys"]) - - -@schema -class ProbeInsertion(dj.Manual): - """Information about probe insertion across subjects and sessions. - - Attributes: - Session (foreign key): Session primary key. - insertion_number (foreign key, str): Unique insertion number for each probe insertion for a given session. - probe.Probe (str): probe.Probe primary key. - """ - - definition = """ - # Probe insertion chronically implanted into an animal. - -> Subject - insertion_number: tinyint unsigned - --- - -> probe.Probe - insertion_datetime=null: datetime - """ - - -@schema -class InsertionLocation(dj.Manual): - """Stereotaxic location information for each probe insertion. - - Attributes: - ProbeInsertion (foreign key): ProbeInsertion primary key. - SkullReference (dict): SkullReference primary key. - ap_location (decimal (6, 2) ): Anterior-posterior location in micrometers. Reference is 0 with anterior values positive. - ml_location (decimal (6, 2) ): Medial-lateral location in micrometers. Reference is zero with right side values positive. - depth (decimal (6, 2) ): Manipulator depth relative to the surface of the brain at zero. Ventral is negative. - Theta (decimal (5, 2) ): elevation - rotation about the ml-axis in degrees relative to positive z-axis. - phi (decimal (5, 2) ): azimuth - rotation about the dv-axis in degrees relative to the positive x-axis. - """ - - definition = """ - # Brain Location of a given probe insertion. - -> ProbeInsertion - --- - -> SkullReference - ap_location: decimal(6, 2) # (um) anterior-posterior; ref is 0; more anterior is more positive - ml_location: decimal(6, 2) # (um) medial axis; ref is 0 ; more right is more positive - depth: decimal(6, 2) # (um) manipulator depth relative to surface of the brain (0); more ventral is more negative - theta=null: decimal(5, 2) # (deg) - elevation - rotation about the ml-axis [0, 180] - w.r.t the z+ axis - phi=null: decimal(5, 2) # (deg) - azimuth - rotation about the dv-axis [0, 360] - w.r.t the x+ axis - beta=null: decimal(5, 2) # (deg) rotation about the shank of the probe [-180, 180] - clockwise is increasing in degree - 0 is the probe-front facing anterior - """ - - -@schema -class EphysRecording(dj.Imported): - """Automated table with electrophysiology recording information for each probe inserted during an experimental session. - - Attributes: - ProbeInsertion (foreign key): ProbeInsertion primary key. - probe.ElectrodeConfig (dict): probe.ElectrodeConfig primary key. - AcquisitionSoftware (dict): AcquisitionSoftware primary key. - sampling_rate (float): sampling rate of the recording in Hertz (Hz). - recording_datetime (datetime): datetime of the recording from this probe. - recording_duration (float): duration of the entire recording from this probe in seconds. - """ - - definition = """ - # Ephys recording from a probe insertion for a given session. - -> Session - -> ProbeInsertion - --- - -> probe.ElectrodeConfig - -> AcquisitionSoftware - sampling_rate: float # (Hz) - recording_datetime: datetime # datetime of the recording from this probe - recording_duration: float # (seconds) duration of the recording from this probe - """ - - class EphysFile(dj.Part): - """Paths of electrophysiology recording files for each insertion. - - Attributes: - EphysRecording (foreign key): EphysRecording primary key. - file_path (varchar(255) ): relative file path for electrophysiology recording. - """ - - definition = """ - # Paths of files of a given EphysRecording round. - -> master - file_path: varchar(255) # filepath relative to root data directory - """ - - def make(self, key): - """Populates table with electrophysiology recording information.""" - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(key) - ) - - inserted_probe_serial_number = (ProbeInsertion * probe.Probe & key).fetch1( - "probe" - ) - - # search session dir and determine acquisition software - for ephys_pattern, ephys_acq_type in ( - ("*.ap.meta", "SpikeGLX"), - ("*.oebin", "Open Ephys"), - ): - ephys_meta_filepaths = list(session_dir.rglob(ephys_pattern)) - if ephys_meta_filepaths: - acq_software = ephys_acq_type - break - else: - raise FileNotFoundError( - f"Ephys recording data not found!" - f" Neither SpikeGLX nor Open Ephys recording files found" - f" in {session_dir}" - ) - - supported_probe_types = probe.ProbeType.fetch("probe_type") - - if acq_software == "SpikeGLX": - for meta_filepath in ephys_meta_filepaths: - spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) - if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: - break - else: - raise FileNotFoundError( - f"No SpikeGLX data found for probe insertion: {key}" - + " The probe serial number does not match." - ) - - if spikeglx_meta.probe_model in supported_probe_types: - probe_type = spikeglx_meta.probe_model - electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} - - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - - electrode_group_members = [ - probe_electrodes[(shank, shank_col, shank_row)] - for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"] - ] - else: - raise NotImplementedError( - "Processing for neuropixels probe model" - " {} not yet implemented".format(spikeglx_meta.probe_model) - ) - - self.insert1( - { - **key, - **generate_electrode_config(probe_type, electrode_group_members), - "acq_software": acq_software, - "sampling_rate": spikeglx_meta.meta["imSampRate"], - "recording_datetime": spikeglx_meta.recording_time, - "recording_duration": ( - spikeglx_meta.recording_duration - or spikeglx.retrieve_recording_duration(meta_filepath) - ), - } - ) - - root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath) - self.EphysFile.insert1( - {**key, "file_path": meta_filepath.relative_to(root_dir).as_posix()} - ) - elif acq_software == "Open Ephys": - dataset = openephys.OpenEphys(session_dir) - for serial_number, probe_data in dataset.probes.items(): - if str(serial_number) == inserted_probe_serial_number: - break - else: - raise FileNotFoundError( - "No Open Ephys data found for probe insertion: {}".format(key) - ) - - if not probe_data.ap_meta: - raise IOError( - 'No analog signals found - check "structure.oebin" file or "continuous" directory' - ) - - if probe_data.probe_model in supported_probe_types: - probe_type = probe_data.probe_model - electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} - - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - electrode_group_members = [ - probe_electrodes[channel_idx] - for channel_idx in probe_data.ap_meta["channels_indices"] - ] - else: - raise NotImplementedError( - "Processing for neuropixels" - " probe model {} not yet implemented".format(probe_data.probe_model) - ) - - self.insert1( - { - **key, - **generate_electrode_config(probe_type, electrode_group_members), - "acq_software": acq_software, - "sampling_rate": probe_data.ap_meta["sample_rate"], - "recording_datetime": probe_data.recording_info[ - "recording_datetimes" - ][0], - "recording_duration": np.sum( - probe_data.recording_info["recording_durations"] - ), - } - ) - - root_dir = find_root_directory( - get_ephys_root_data_dir(), - probe_data.recording_info["recording_files"][0], - ) - self.EphysFile.insert( - [ - {**key, "file_path": fp.relative_to(root_dir).as_posix()} - for fp in probe_data.recording_info["recording_files"] - ] - ) - # explicitly garbage collect "dataset" - # as these may have large memory footprint and may not be cleared fast enough - del probe_data, dataset - gc.collect() - else: - raise NotImplementedError( - f"Processing ephys files from" - f" acquisition software of type {acq_software} is" - f" not yet implemented" - ) - - -@schema -class LFP(dj.Imported): - """Extracts local field potentials (LFP) from an electrophysiology recording. - - Attributes: - EphysRecording (foreign key): EphysRecording primary key. - lfp_sampling_rate (float): Sampling rate for LFPs in Hz. - lfp_time_stamps (longblob): Time stamps with respect to the start of the recording. - lfp_mean (longblob): Overall mean LFP across electrodes. - """ - - definition = """ - # Acquired local field potential (LFP) from a given Ephys recording. - -> EphysRecording - --- - lfp_sampling_rate: float # (Hz) - lfp_time_stamps: longblob # (s) timestamps with respect to the start of the recording (recording_timestamp) - lfp_mean: longblob # (uV) mean of LFP across electrodes - shape (time,) - """ - - class Electrode(dj.Part): - """Saves local field potential data for each electrode. - - Attributes: - LFP (foreign key): LFP primary key. - probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key. - lfp (longblob): LFP recording at this electrode in microvolts. - """ - - definition = """ - -> master - -> probe.ElectrodeConfig.Electrode - --- - lfp: longblob # (uV) recorded lfp at this electrode - """ - - # Only store LFP for every 9th channel, due to high channel density, - # close-by channels exhibit highly similar LFP - _skip_channel_counts = 9 - - def make(self, key): - """Populates the LFP tables.""" - acq_software = (EphysRecording * ProbeInsertion & key).fetch1("acq_software") - - electrode_keys, lfp = [], [] - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - - lfp_channel_ind = spikeglx_recording.lfmeta.recording_channels[ - -1 :: -self._skip_channel_counts - ] - - # Extract LFP data at specified channels and convert to uV - lfp = spikeglx_recording.lf_timeseries[ - :, lfp_channel_ind - ] # (sample x channel) - lfp = ( - lfp * spikeglx_recording.get_channel_bit_volts("lf")[lfp_channel_ind] - ).T # (channel x sample) - - self.insert1( - dict( - key, - lfp_sampling_rate=spikeglx_recording.lfmeta.meta["imSampRate"], - lfp_time_stamps=( - np.arange(lfp.shape[1]) - / spikeglx_recording.lfmeta.meta["imSampRate"] - ), - lfp_mean=lfp.mean(axis=0), - ) - ) - - electrode_query = ( - probe.ProbeType.Electrode - * probe.ElectrodeConfig.Electrode - * EphysRecording - & key - ) - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - - for recorded_site in lfp_channel_ind: - shank, shank_col, shank_row, _ = spikeglx_recording.apmeta.shankmap[ - "data" - ][recorded_site] - electrode_keys.append(probe_electrodes[(shank, shank_col, shank_row)]) - elif acq_software == "Open Ephys": - oe_probe = get_openephys_probe_data(key) - - lfp_channel_ind = np.r_[ - len(oe_probe.lfp_meta["channels_indices"]) - - 1 : 0 : -self._skip_channel_counts - ] - - # (sample x channel) - lfp = oe_probe.lfp_timeseries[:, lfp_channel_ind] - lfp = ( - lfp * np.array(oe_probe.lfp_meta["channels_gains"])[lfp_channel_ind] - ).T # (channel x sample) - lfp_timestamps = oe_probe.lfp_timestamps - - self.insert1( - dict( - key, - lfp_sampling_rate=oe_probe.lfp_meta["sample_rate"], - lfp_time_stamps=lfp_timestamps, - lfp_mean=lfp.mean(axis=0), - ) - ) - - electrode_query = ( - probe.ProbeType.Electrode - * probe.ElectrodeConfig.Electrode - * EphysRecording - & key - ) - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - electrode_keys.extend( - probe_electrodes[channel_idx] for channel_idx in lfp_channel_ind - ) - else: - raise NotImplementedError( - f"LFP extraction from acquisition software" - f" of type {acq_software} is not yet implemented" - ) - - # single insert in loop to mitigate potential memory issue - for electrode_key, lfp_trace in zip(electrode_keys, lfp): - self.Electrode.insert1({**key, **electrode_key, "lfp": lfp_trace}) - - -# ------------ Clustering -------------- - - -@schema -class ClusteringMethod(dj.Lookup): - """Kilosort clustering method. - - Attributes: - clustering_method (foreign key, varchar(16) ): Kilosort clustering method. - clustering_methods_desc (varchar(1000) ): Additional description of the clustering method. - """ - - definition = """ - # Method for clustering - clustering_method: varchar(16) - --- - clustering_method_desc: varchar(1000) - """ - - contents = [ - ("kilosort2", "kilosort2 clustering method"), - ("kilosort2.5", "kilosort2.5 clustering method"), - ("kilosort3", "kilosort3 clustering method"), - ] - - -@schema -class ClusteringParamSet(dj.Lookup): - """Parameters to be used in clustering procedure for spike sorting. - - Attributes: - paramset_idx (foreign key): Unique ID for the clustering parameter set. - ClusteringMethod (dict): ClusteringMethod primary key. - paramset_desc (varchar(128) ): Description of the clustering parameter set. - param_set_hash (uuid): UUID hash for the parameter set. - params (longblob): Parameters for clustering with Kilosort. - """ - - definition = """ - # Parameter set to be used in a clustering procedure - paramset_idx: smallint - --- - -> ClusteringMethod - paramset_desc: varchar(128) - param_set_hash: uuid - unique index (param_set_hash) - params: longblob # dictionary of all applicable parameters - """ - - @classmethod - def insert_new_params( - cls, - clustering_method: str, - paramset_desc: str, - params: dict, - paramset_idx: int = None, - ): - """Inserts new parameters into the ClusteringParamSet table. - - Args: - clustering_method (str): name of the clustering method. - paramset_desc (str): description of the parameter set - params (dict): clustering parameters - paramset_idx (int, optional): Unique parameter set ID. Defaults to None. - """ - if paramset_idx is None: - paramset_idx = ( - dj.U().aggr(cls, n="max(paramset_idx)").fetch1("n") or 0 - ) + 1 - - param_dict = { - "clustering_method": clustering_method, - "paramset_idx": paramset_idx, - "paramset_desc": paramset_desc, - "params": params, - "param_set_hash": dict_to_uuid( - {**params, "clustering_method": clustering_method} - ), - } - param_query = cls & {"param_set_hash": param_dict["param_set_hash"]} - - if param_query: # If the specified param-set already exists - existing_paramset_idx = param_query.fetch1("paramset_idx") - if ( - existing_paramset_idx == paramset_idx - ): # If the existing set has the same paramset_idx: job done - return - else: # If not same name: human error, trying to add the same paramset with different name - raise dj.DataJointError( - f"The specified param-set already exists" - f" - with paramset_idx: {existing_paramset_idx}" - ) - else: - if {"paramset_idx": paramset_idx} in cls.proj(): - raise dj.DataJointError( - f"The specified paramset_idx {paramset_idx} already exists," - f" please pick a different one." - ) - cls.insert1(param_dict) - - -@schema -class ClusterQualityLabel(dj.Lookup): - """Quality label for each spike sorted cluster. - - Attributes: - cluster_quality_label (foreign key, varchar(100) ): Cluster quality type. - cluster_quality_description (varchar(4000) ): Description of the cluster quality type. - """ - - definition = """ - # Quality - cluster_quality_label: varchar(100) # cluster quality type - e.g. 'good', 'MUA', 'noise', etc. - --- - cluster_quality_description: varchar(4000) - """ - contents = [ - ("good", "single unit"), - ("ok", "probably a single unit, but could be contaminated"), - ("mua", "multi-unit activity"), - ("noise", "bad unit"), - ] - - -@schema -class ClusteringTask(dj.Manual): - """A clustering task to spike sort electrophysiology datasets. - - Attributes: - EphysRecording (foreign key): EphysRecording primary key. - ClusteringParamSet (foreign key): ClusteringParamSet primary key. - clustering_outdir_dir (varchar (255) ): Relative path to output clustering results. - task_mode (enum): `Trigger` computes clustering or and `load` imports existing data. - """ - - definition = """ - # Manual table for defining a clustering task ready to be run - -> EphysRecording - -> ClusteringParamSet - --- - clustering_output_dir='': varchar(255) # clustering output directory relative to the clustering root data directory - task_mode='load': enum('load', 'trigger') # 'load': load computed analysis results, 'trigger': trigger computation - """ - - @classmethod - def infer_output_dir(cls, key, relative=False, mkdir=False) -> pathlib.Path: - """Infer output directory if it is not provided. - - Args: - key (dict): ClusteringTask primary key. - - Returns: - Expected clustering_output_dir based on the following convention: - processed_dir / session_dir / probe_{insertion_number} / {clustering_method}_{paramset_idx} - e.g.: sub4/sess1/probe_2/kilosort2_0 - """ - processed_dir = pathlib.Path(get_processed_root_data_dir()) - sess_dir = find_full_path(get_ephys_root_data_dir(), get_session_directory(key)) - root_dir = find_root_directory(get_ephys_root_data_dir(), sess_dir) - - method = ( - (ClusteringParamSet * ClusteringMethod & key) - .fetch1("clustering_method") - .replace(".", "-") - ) - - output_dir = ( - processed_dir - / sess_dir.relative_to(root_dir) - / f'probe_{key["insertion_number"]}' - / f'{method}_{key["paramset_idx"]}' - ) - - if mkdir: - output_dir.mkdir(parents=True, exist_ok=True) - log.info(f"{output_dir} created!") - - return output_dir.relative_to(processed_dir) if relative else output_dir - - @classmethod - def auto_generate_entries(cls, ephys_recording_key: dict, paramset_idx: int = 0): - """Autogenerate entries based on a particular ephys recording. - - Args: - ephys_recording_key (dict): EphysRecording primary key. - paramset_idx (int, optional): Parameter index to use for clustering task. Defaults to 0. - """ - key = {**ephys_recording_key, "paramset_idx": paramset_idx} - - processed_dir = get_processed_root_data_dir() - output_dir = ClusteringTask.infer_output_dir(key, relative=False, mkdir=True) - - try: - kilosort.Kilosort( - output_dir - ) # check if the directory is a valid Kilosort output - except FileNotFoundError: - task_mode = "trigger" - else: - task_mode = "load" - - cls.insert1( - { - **key, - "clustering_output_dir": output_dir.relative_to( - processed_dir - ).as_posix(), - "task_mode": task_mode, - } - ) - - -@schema -class Clustering(dj.Imported): - """A processing table to handle each clustering task. - - Attributes: - ClusteringTask (foreign key): ClusteringTask primary key. - clustering_time (datetime): Time when clustering results are generated. - package_version (varchar(16) ): Package version used for a clustering analysis. - """ - - definition = """ - # Clustering Procedure - -> ClusteringTask - --- - clustering_time: datetime # time of generation of this set of clustering results - package_version='': varchar(16) - """ - - def make(self, key): - """Triggers or imports clustering analysis.""" - task_mode, output_dir = (ClusteringTask & key).fetch1( - "task_mode", "clustering_output_dir" - ) - - if not output_dir: - output_dir = ClusteringTask.infer_output_dir(key, relative=True, mkdir=True) - # update clustering_output_dir - ClusteringTask.update1( - {**key, "clustering_output_dir": output_dir.as_posix()} - ) - - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - if task_mode == "load": - kilosort.Kilosort( - kilosort_dir - ) # check if the directory is a valid Kilosort output - elif task_mode == "trigger": - acq_software, clustering_method, params = ( - ClusteringTask * EphysRecording * ClusteringParamSet & key - ).fetch1("acq_software", "clustering_method", "params") - - if "kilosort" in clustering_method: - from element_array_ephys.readers import kilosort_triggering - - # add additional probe-recording and channels details into `params` - params = {**params, **get_recording_channels_details(key)} - params["fs"] = params["sample_rate"] - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX( - spikeglx_meta_filepath.parent - ) - spikeglx_recording.validate_file("ap") - run_CatGT = ( - params.pop("run_CatGT", True) - and "_tcat." not in spikeglx_meta_filepath.stem - ) - - if clustering_method.startswith("pykilosort"): - kilosort_triggering.run_pykilosort( - continuous_file=spikeglx_recording.root_dir - / (spikeglx_recording.root_name + ".ap.bin"), - kilosort_output_directory=kilosort_dir, - channel_ind=params.pop("channel_ind"), - x_coords=params.pop("x_coords"), - y_coords=params.pop("y_coords"), - shank_ind=params.pop("shank_ind"), - connected=params.pop("connected"), - sample_rate=params.pop("sample_rate"), - params=params, - ) - else: - run_kilosort = kilosort_triggering.SGLXKilosortPipeline( - npx_input_dir=spikeglx_meta_filepath.parent, - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - run_CatGT=run_CatGT, - ) - run_kilosort.run_modules() - elif acq_software == "Open Ephys": - oe_probe = get_openephys_probe_data(key) - - assert len(oe_probe.recording_info["recording_files"]) == 1 - - # run kilosort - if clustering_method.startswith("pykilosort"): - kilosort_triggering.run_pykilosort( - continuous_file=pathlib.Path( - oe_probe.recording_info["recording_files"][0] - ) - / "continuous.dat", - kilosort_output_directory=kilosort_dir, - channel_ind=params.pop("channel_ind"), - x_coords=params.pop("x_coords"), - y_coords=params.pop("y_coords"), - shank_ind=params.pop("shank_ind"), - connected=params.pop("connected"), - sample_rate=params.pop("sample_rate"), - params=params, - ) - else: - run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( - npx_input_dir=oe_probe.recording_info["recording_files"][0], - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - ) - run_kilosort.run_modules() - else: - raise NotImplementedError( - f"Automatic triggering of {clustering_method}" - f" clustering analysis is not yet supported" - ) - - else: - raise ValueError(f"Unknown task mode: {task_mode}") - - creation_time, _, _ = kilosort.extract_clustering_info(kilosort_dir) - self.insert1({**key, "clustering_time": creation_time, "package_version": ""}) - - -@schema -class Curation(dj.Manual): - """Curation procedure table. - - Attributes: - Clustering (foreign key): Clustering primary key. - curation_id (foreign key, int): Unique curation ID. - curation_time (datetime): Time when curation results are generated. - curation_output_dir (varchar(255) ): Output directory of the curated results. - quality_control (bool): If True, this clustering result has undergone quality control. - manual_curation (bool): If True, manual curation has been performed on this clustering result. - curation_note (varchar(2000) ): Notes about the curation task. - """ - - definition = """ - # Manual curation procedure - -> Clustering - curation_id: int - --- - curation_time: datetime # time of generation of this set of curated clustering results - curation_output_dir: varchar(255) # output directory of the curated results, relative to root data directory - quality_control: bool # has this clustering result undergone quality control? - manual_curation: bool # has manual curation been performed on this clustering result? - curation_note='': varchar(2000) - """ - - def create1_from_clustering_task(self, key, curation_note: str = ""): - """ - A function to create a new corresponding "Curation" for a particular - "ClusteringTask" - """ - if key not in Clustering(): - raise ValueError( - f"No corresponding entry in Clustering available" - f" for: {key}; do `Clustering.populate(key)`" - ) - - task_mode, output_dir = (ClusteringTask & key).fetch1( - "task_mode", "clustering_output_dir" - ) - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - creation_time, is_curated, is_qc = kilosort.extract_clustering_info( - kilosort_dir - ) - # Synthesize curation_id - curation_id = ( - dj.U().aggr(self & key, n="ifnull(max(curation_id)+1,1)").fetch1("n") - ) - self.insert1( - { - **key, - "curation_id": curation_id, - "curation_time": creation_time, - "curation_output_dir": output_dir, - "quality_control": is_qc, - "manual_curation": is_curated, - "curation_note": curation_note, - } - ) - - -@schema -class CuratedClustering(dj.Imported): - """Clustering results after curation. - - Attributes: - Curation (foreign key): Curation primary key. - """ - - definition = """ - # Clustering results of a curation. - -> Curation - """ - - class Unit(dj.Part): - """Single unit properties after clustering and curation. - - Attributes: - CuratedClustering (foreign key): CuratedClustering primary key. - unit (foreign key, int): Unique integer identifying a single unit. - probe.ElectrodeConfig.Electrode (dict): probe.ElectrodeConfig.Electrode primary key. - ClusteringQualityLabel (dict): CLusteringQualityLabel primary key. - spike_count (int): Number of spikes in this recording for this unit. - spike_times (longblob): Spike times of this unit, relative to start time of EphysRecording. - spike_sites (longblob): Array of electrode associated with each spike. - spike_depths (longblob): Array of depths associated with each spike, relative to each spike. - """ - - definition = """ - # Properties of a given unit from a round of clustering (and curation) - -> master - unit: int - --- - -> probe.ElectrodeConfig.Electrode # electrode with highest waveform amplitude for this unit - -> ClusterQualityLabel - spike_count: int # how many spikes in this recording for this unit - spike_times: longblob # (s) spike times of this unit, relative to the start of the EphysRecording - spike_sites : longblob # array of electrode associated with each spike - spike_depths=null : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe - """ - - def make(self, key): - """Automated population of Unit information.""" - output_dir = (Curation & key).fetch1("curation_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - kilosort_dataset = kilosort.Kilosort(kilosort_dir) - acq_software, sample_rate = (EphysRecording & key).fetch1( - "acq_software", "sampling_rate" - ) - - sample_rate = kilosort_dataset.data["params"].get("sample_rate", sample_rate) - - # ---------- Unit ---------- - # -- Remove 0-spike units - withspike_idx = [ - i - for i, u in enumerate(kilosort_dataset.data["cluster_ids"]) - if (kilosort_dataset.data["spike_clusters"] == u).any() - ] - valid_units = kilosort_dataset.data["cluster_ids"][withspike_idx] - valid_unit_labels = kilosort_dataset.data["cluster_groups"][withspike_idx] - # -- Get channel and electrode-site mapping - channel2electrodes = get_neuropixels_channel2electrode_map(key, acq_software) - - # -- Spike-times -- - # spike_times_sec_adj > spike_times_sec > spike_times - spike_time_key = ( - "spike_times_sec_adj" - if "spike_times_sec_adj" in kilosort_dataset.data - else ( - "spike_times_sec" - if "spike_times_sec" in kilosort_dataset.data - else "spike_times" - ) - ) - spike_times = kilosort_dataset.data[spike_time_key] - kilosort_dataset.extract_spike_depths() - - # -- Spike-sites and Spike-depths -- - spike_sites = np.array( - [ - channel2electrodes[s]["electrode"] - for s in kilosort_dataset.data["spike_sites"] - ] - ) - spike_depths = kilosort_dataset.data["spike_depths"] - - # -- Insert unit, label, peak-chn - units = [] - for unit, unit_lbl in zip(valid_units, valid_unit_labels): - if (kilosort_dataset.data["spike_clusters"] == unit).any(): - unit_channel, _ = kilosort_dataset.get_best_channel(unit) - unit_spike_times = ( - spike_times[kilosort_dataset.data["spike_clusters"] == unit] - / sample_rate - ) - spike_count = len(unit_spike_times) - - units.append( - { - "unit": unit, - "cluster_quality_label": unit_lbl, - **channel2electrodes[unit_channel], - "spike_times": unit_spike_times, - "spike_count": spike_count, - "spike_sites": spike_sites[ - kilosort_dataset.data["spike_clusters"] == unit - ], - "spike_depths": ( - spike_depths[ - kilosort_dataset.data["spike_clusters"] == unit - ] - if spike_depths is not None - else None - ), - } - ) - - self.insert1(key) - self.Unit.insert([{**key, **u} for u in units]) - - -@schema -class WaveformSet(dj.Imported): - """A set of spike waveforms for units out of a given CuratedClustering. - - Attributes: - CuratedClustering (foreign key): CuratedClustering primary key. - """ - - definition = """ - # A set of spike waveforms for units out of a given CuratedClustering - -> CuratedClustering - """ - - class PeakWaveform(dj.Part): - """Mean waveform across spikes for a given unit. - - Attributes: - WaveformSet (foreign key): WaveformSet primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - peak_electrode_waveform (longblob): Mean waveform for a given unit at its representative electrode. - """ - - definition = """ - # Mean waveform across spikes for a given unit at its representative electrode - -> master - -> CuratedClustering.Unit - --- - peak_electrode_waveform: longblob # (uV) mean waveform for a given unit at its representative electrode - """ - - class Waveform(dj.Part): - """Spike waveforms for a given unit. - - Attributes: - WaveformSet (foreign key): WaveformSet primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key. - waveform_mean (longblob): mean waveform across spikes of the unit in microvolts. - waveforms (longblob): waveforms of a sampling of spikes at the given electrode and unit. - """ - - definition = """ - # Spike waveforms and their mean across spikes for the given unit - -> master - -> CuratedClustering.Unit - -> probe.ElectrodeConfig.Electrode - --- - waveform_mean: longblob # (uV) mean waveform across spikes of the given unit - waveforms=null: longblob # (uV) (spike x sample) waveforms of a sampling of spikes at the given electrode for the given unit - """ - - def make(self, key): - """Populates waveform tables.""" - output_dir = (Curation & key).fetch1("curation_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - kilosort_dataset = kilosort.Kilosort(kilosort_dir) - - acq_software, probe_serial_number = ( - EphysRecording * ProbeInsertion & key - ).fetch1("acq_software", "probe") - - # -- Get channel and electrode-site mapping - recording_key = (EphysRecording & key).fetch1("KEY") - channel2electrodes = get_neuropixels_channel2electrode_map( - recording_key, acq_software - ) - - is_qc = (Curation & key).fetch1("quality_control") - - # Get all units - units = { - u["unit"]: u - for u in (CuratedClustering.Unit & key).fetch(as_dict=True, order_by="unit") - } - - if is_qc: - unit_waveforms = np.load( - kilosort_dir / "mean_waveforms.npy" - ) # unit x channel x sample - - def yield_unit_waveforms(): - for unit_no, unit_waveform in zip( - kilosort_dataset.data["cluster_ids"], unit_waveforms - ): - unit_peak_waveform = {} - unit_electrode_waveforms = [] - if unit_no in units: - for channel, channel_waveform in zip( - kilosort_dataset.data["channel_map"], unit_waveform - ): - unit_electrode_waveforms.append( - { - **units[unit_no], - **channel2electrodes[channel], - "waveform_mean": channel_waveform, - } - ) - if ( - channel2electrodes[channel]["electrode"] - == units[unit_no]["electrode"] - ): - unit_peak_waveform = { - **units[unit_no], - "peak_electrode_waveform": channel_waveform, - } - yield unit_peak_waveform, unit_electrode_waveforms - - else: - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) - neuropixels_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - elif acq_software == "Open Ephys": - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(key) - ) - openephys_dataset = openephys.OpenEphys(session_dir) - neuropixels_recording = openephys_dataset.probes[probe_serial_number] - - def yield_unit_waveforms(): - for unit_dict in units.values(): - unit_peak_waveform = {} - unit_electrode_waveforms = [] - - spikes = unit_dict["spike_times"] - waveforms = neuropixels_recording.extract_spike_waveforms( - spikes, kilosort_dataset.data["channel_map"] - ) # (sample x channel x spike) - waveforms = waveforms.transpose( - (1, 2, 0) - ) # (channel x spike x sample) - for channel, channel_waveform in zip( - kilosort_dataset.data["channel_map"], waveforms - ): - unit_electrode_waveforms.append( - { - **unit_dict, - **channel2electrodes[channel], - "waveform_mean": channel_waveform.mean(axis=0), - "waveforms": channel_waveform, - } - ) - if ( - channel2electrodes[channel]["electrode"] - == unit_dict["electrode"] - ): - unit_peak_waveform = { - **unit_dict, - "peak_electrode_waveform": channel_waveform.mean( - axis=0 - ), - } - - yield unit_peak_waveform, unit_electrode_waveforms - - # insert waveform on a per-unit basis to mitigate potential memory issue - self.insert1(key) - for unit_peak_waveform, unit_electrode_waveforms in yield_unit_waveforms(): - if unit_peak_waveform: - self.PeakWaveform.insert1(unit_peak_waveform, ignore_extra_fields=True) - if unit_electrode_waveforms: - self.Waveform.insert(unit_electrode_waveforms, ignore_extra_fields=True) - - -@schema -class QualityMetrics(dj.Imported): - """Clustering and waveform quality metrics. - - Attributes: - CuratedClustering (foreign key): CuratedClustering primary key. - """ - - definition = """ - # Clusters and waveforms metrics - -> CuratedClustering - """ - - class Cluster(dj.Part): - """Cluster metrics for a unit. - - Attributes: - QualityMetrics (foreign key): QualityMetrics primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - firing_rate (float): Firing rate of the unit. - snr (float): Signal-to-noise ratio for a unit. - presence_ratio (float): Fraction of time where spikes are present. - isi_violation (float): rate of ISI violation as a fraction of overall rate. - number_violation (int): Total ISI violations. - amplitude_cutoff (float): Estimate of miss rate based on amplitude histogram. - isolation_distance (float): Distance to nearest cluster. - l_ratio (float): Amount of empty space between a cluster and other spikes in dataset. - d_prime (float): Classification accuracy based on LDA. - nn_hit_rate (float): Fraction of neighbors for target cluster that are also in target cluster. - nn_miss_rate (float): Fraction of neighbors outside target cluster that are in the target cluster. - silhouette_core (float): Maximum change in spike depth throughout recording. - cumulative_drift (float): Cumulative change in spike depth throughout recording. - contamination_rate (float): Frequency of spikes in the refractory period. - """ - - definition = """ - # Cluster metrics for a particular unit - -> master - -> CuratedClustering.Unit - --- - firing_rate=null: float # (Hz) firing rate for a unit - snr=null: float # signal-to-noise ratio for a unit - presence_ratio=null: float # fraction of time in which spikes are present - isi_violation=null: float # rate of ISI violation as a fraction of overall rate - number_violation=null: int # total number of ISI violations - amplitude_cutoff=null: float # estimate of miss rate based on amplitude histogram - isolation_distance=null: float # distance to nearest cluster in Mahalanobis space - l_ratio=null: float # - d_prime=null: float # Classification accuracy based on LDA - nn_hit_rate=null: float # Fraction of neighbors for target cluster that are also in target cluster - nn_miss_rate=null: float # Fraction of neighbors outside target cluster that are in target cluster - silhouette_score=null: float # Standard metric for cluster overlap - max_drift=null: float # Maximum change in spike depth throughout recording - cumulative_drift=null: float # Cumulative change in spike depth throughout recording - contamination_rate=null: float # - """ - - class Waveform(dj.Part): - """Waveform metrics for a particular unit. - - Attributes: - QualityMetrics (foreign key): QualityMetrics primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - amplitude (float): Absolute difference between waveform peak and trough in microvolts. - duration (float): Time between waveform peak and trough in milliseconds. - halfwidth (float): Spike width at half max amplitude. - pt_ratio (float): Absolute amplitude of peak divided by absolute amplitude of trough relative to 0. - repolarization_slope (float): Slope of the regression line fit to first 30 microseconds from trough to peak. - recovery_slope (float): Slope of the regression line fit to first 30 microseconds from peak to tail. - spread (float): The range with amplitude over 12-percent of maximum amplitude along the probe. - velocity_above (float): inverse velocity of waveform propagation from soma to the top of the probe. - velocity_below (float): inverse velocity of waveform propagation from soma toward the bottom of the probe. - """ - - definition = """ - # Waveform metrics for a particular unit - -> master - -> CuratedClustering.Unit - --- - amplitude: float # (uV) absolute difference between waveform peak and trough - duration: float # (ms) time between waveform peak and trough - halfwidth=null: float # (ms) spike width at half max amplitude - pt_ratio=null: float # absolute amplitude of peak divided by absolute amplitude of trough relative to 0 - repolarization_slope=null: float # the repolarization slope was defined by fitting a regression line to the first 30us from trough to peak - recovery_slope=null: float # the recovery slope was defined by fitting a regression line to the first 30us from peak to tail - spread=null: float # (um) the range with amplitude above 12-percent of the maximum amplitude along the probe - velocity_above=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the top of the probe - velocity_below=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the bottom of the probe - """ - - def make(self, key): - """Populates tables with quality metrics data.""" - output_dir = (ClusteringTask & key).fetch1("clustering_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - metric_fp = kilosort_dir / "metrics.csv" - rename_dict = { - "isi_viol": "isi_violation", - "num_viol": "number_violation", - "contam_rate": "contamination_rate", - } - - if not metric_fp.exists(): - raise FileNotFoundError(f"QC metrics file not found: {metric_fp}") - - metrics_df = pd.read_csv(metric_fp) - metrics_df.set_index("cluster_id", inplace=True) - metrics_df.replace([np.inf, -np.inf], np.nan, inplace=True) - metrics_df.columns = metrics_df.columns.str.lower() - metrics_df.rename(columns=rename_dict, inplace=True) - metrics_list = [ - dict(metrics_df.loc[unit_key["unit"]], **unit_key) - for unit_key in (CuratedClustering.Unit & key).fetch("KEY") - ] - - self.insert1(key) - self.Cluster.insert(metrics_list, ignore_extra_fields=True) - self.Waveform.insert(metrics_list, ignore_extra_fields=True) - - -# ---------------- HELPER FUNCTIONS ---------------- - - -def get_spikeglx_meta_filepath(ephys_recording_key: dict) -> str: - """Get spikeGLX data filepath.""" - # attempt to retrieve from EphysRecording.EphysFile - spikeglx_meta_filepath = pathlib.Path( - ( - EphysRecording.EphysFile - & ephys_recording_key - & 'file_path LIKE "%.ap.meta"' - ).fetch1("file_path") - ) - - try: - spikeglx_meta_filepath = find_full_path( - get_ephys_root_data_dir(), spikeglx_meta_filepath - ) - except FileNotFoundError: - # if not found, search in session_dir again - if not spikeglx_meta_filepath.exists(): - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(ephys_recording_key) - ) - inserted_probe_serial_number = ( - ProbeInsertion * probe.Probe & ephys_recording_key - ).fetch1("probe") - - spikeglx_meta_filepaths = [fp for fp in session_dir.rglob("*.ap.meta")] - for meta_filepath in spikeglx_meta_filepaths: - spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) - if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: - spikeglx_meta_filepath = meta_filepath - break - else: - raise FileNotFoundError( - "No SpikeGLX data found for probe insertion: {}".format( - ephys_recording_key - ) - ) - - return spikeglx_meta_filepath - - -def get_openephys_probe_data(ephys_recording_key: dict) -> list: - """Get OpenEphys probe data from file.""" - inserted_probe_serial_number = ( - ProbeInsertion * probe.Probe & ephys_recording_key - ).fetch1("probe") - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(ephys_recording_key) - ) - loaded_oe = openephys.OpenEphys(session_dir) - probe_data = loaded_oe.probes[inserted_probe_serial_number] - - # explicitly garbage collect "loaded_oe" - # as these may have large memory footprint and may not be cleared fast enough - del loaded_oe - gc.collect() - - return probe_data - - -def get_neuropixels_channel2electrode_map( - ephys_recording_key: dict, acq_software: str -) -> dict: - """Get the channel map for neuropixels probe.""" - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key) - spikeglx_meta = spikeglx.SpikeGLXMeta(spikeglx_meta_filepath) - electrode_config_key = ( - EphysRecording * probe.ElectrodeConfig & ephys_recording_key - ).fetch1("KEY") - - electrode_query = ( - probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode - & electrode_config_key - ) - - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - - channel2electrode_map = { - recorded_site: probe_electrodes[(shank, shank_col, shank_row)] - for recorded_site, (shank, shank_col, shank_row, _) in enumerate( - spikeglx_meta.shankmap["data"] - ) - } - elif acq_software == "Open Ephys": - probe_dataset = get_openephys_probe_data(ephys_recording_key) - - electrode_query = ( - probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode * EphysRecording - & ephys_recording_key - ) - - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - channel2electrode_map = { - channel_idx: probe_electrodes[channel_idx] - for channel_idx in probe_dataset.ap_meta["channels_indices"] - } - - return channel2electrode_map - - -def generate_electrode_config(probe_type: str, electrode_keys: list) -> dict: - """Generate and insert new ElectrodeConfig - - Args: - probe_type (str): probe type (e.g. neuropixels 2.0 - SS) - electrode_keys (list): list of keys of the probe.ProbeType.Electrode table - - Returns: - dict: representing a key of the probe.ElectrodeConfig table - """ - # compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode) - electrode_config_hash = dict_to_uuid({k["electrode"]: k for k in electrode_keys}) - - electrode_list = sorted([k["electrode"] for k in electrode_keys]) - electrode_gaps = ( - [-1] - + np.where(np.diff(electrode_list) > 1)[0].tolist() - + [len(electrode_list) - 1] - ) - electrode_config_name = "; ".join( - [ - f"{electrode_list[start + 1]}-{electrode_list[end]}" - for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:]) - ] - ) - - electrode_config_key = {"electrode_config_hash": electrode_config_hash} - - # ---- make new ElectrodeConfig if needed ---- - if not probe.ElectrodeConfig & electrode_config_key: - probe.ElectrodeConfig.insert1( - { - **electrode_config_key, - "probe_type": probe_type, - "electrode_config_name": electrode_config_name, - } - ) - probe.ElectrodeConfig.Electrode.insert( - {**electrode_config_key, **electrode} for electrode in electrode_keys - ) - - return electrode_config_key - - -def get_recording_channels_details(ephys_recording_key: dict) -> np.array: - """Get details of recording channels for a given recording.""" - channels_details = {} - - acq_software, sample_rate = (EphysRecording & ephys_recording_key).fetch1( - "acq_software", "sampling_rate" - ) - - probe_type = (ProbeInsertion * probe.Probe & ephys_recording_key).fetch1( - "probe_type" - ) - channels_details["probe_type"] = { - "neuropixels 1.0 - 3A": "3A", - "neuropixels 1.0 - 3B": "NP1", - "neuropixels UHD": "NP1100", - "neuropixels 2.0 - SS": "NP21", - "neuropixels 2.0 - MS": "NP24", - }[probe_type] - - electrode_config_key = ( - probe.ElectrodeConfig * EphysRecording & ephys_recording_key - ).fetch1("KEY") - ( - channels_details["channel_ind"], - channels_details["x_coords"], - channels_details["y_coords"], - channels_details["shank_ind"], - ) = ( - probe.ElectrodeConfig.Electrode * probe.ProbeType.Electrode - & electrode_config_key - ).fetch( - "electrode", "x_coord", "y_coord", "shank" - ) - channels_details["sample_rate"] = sample_rate - channels_details["num_channels"] = len(channels_details["channel_ind"]) - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key) - spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - channels_details["uVPerBit"] = spikeglx_recording.get_channel_bit_volts("ap")[0] - channels_details["connected"] = np.array( - [v for *_, v in spikeglx_recording.apmeta.shankmap["data"]] - ) - elif acq_software == "Open Ephys": - oe_probe = get_openephys_probe_data(ephys_recording_key) - channels_details["uVPerBit"] = oe_probe.ap_meta["channels_gains"][0] - channels_details["connected"] = np.array( - [ - int(v == 1) - for c, v in oe_probe.channels_connected.items() - if c in channels_details["channel_ind"] - ] - ) - - return channels_details diff --git a/element_array_ephys/ephys_precluster.py b/element_array_ephys/ephys_precluster.py deleted file mode 100644 index 4d52c610..00000000 --- a/element_array_ephys/ephys_precluster.py +++ /dev/null @@ -1,1435 +0,0 @@ -import importlib -import inspect -import re - -import datajoint as dj -import numpy as np -import pandas as pd -from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory - -from . import ephys_report, probe -from .readers import kilosort, openephys, spikeglx - -schema = dj.schema() - -_linking_module = None - - -def activate( - ephys_schema_name: str, - probe_schema_name: str = None, - *, - create_schema: bool = True, - create_tables: bool = True, - linking_module: str = None, -): - """Activates the `ephys` and `probe` schemas. - - Args: - ephys_schema_name (str): A string containing the name of the ephys schema. - probe_schema_name (str): A string containing the name of the probe schema. - create_schema (bool): If True, schema will be created in the database. - create_tables (bool): If True, tables related to the schema will be created in the database. - linking_module (str): A string containing the module name or module containing the required dependencies to activate the schema. - - Dependencies: - Upstream tables: - Session: A parent table to ProbeInsertion - Probe: A parent table to EphysRecording. Probe information is required before electrophysiology data is imported. - - Functions: - get_ephys_root_data_dir(): Returns absolute path for root data director(y/ies) with all electrophysiological recording sessions, as a list of string(s). - get_session_direction(session_key: dict): Returns path to electrophysiology data for the a particular session as a list of strings. - """ - - if isinstance(linking_module, str): - linking_module = importlib.import_module(linking_module) - assert inspect.ismodule( - linking_module - ), "The argument 'dependency' must be a module's name or a module" - - global _linking_module - _linking_module = linking_module - - probe.activate( - probe_schema_name, create_schema=create_schema, create_tables=create_tables - ) - schema.activate( - ephys_schema_name, - create_schema=create_schema, - create_tables=create_tables, - add_objects=_linking_module.__dict__, - ) - ephys_report.activate(f"{ephys_schema_name}_report", ephys_schema_name) - - -# -------------- Functions required by the elements-ephys --------------- - - -def get_ephys_root_data_dir() -> list: - """Fetches absolute data path to ephys data directories. - - The absolute path here is used as a reference for all downstream relative paths used in DataJoint. - - Returns: - A list of the absolute path(s) to ephys data directories. - """ - return _linking_module.get_ephys_root_data_dir() - - -def get_session_directory(session_key: dict) -> str: - """Retrieve the session directory with Neuropixels for the given session. - - Args: - session_key (dict): A dictionary mapping subject to an entry in the subject table, and session_datetime corresponding to a session in the database. - - Returns: - A string for the path to the session directory. - """ - return _linking_module.get_session_directory(session_key) - - -# ----------------------------- Table declarations ---------------------- - - -@schema -class AcquisitionSoftware(dj.Lookup): - """Name of software used for recording electrophysiological data. - - Attributes: - acq_software ( varchar(24) ): Acquisition software, e.g,. SpikeGLX, OpenEphys - """ - - definition = """ # Name of software used for recording of neuropixels probes - SpikeGLX or Open Ephys - acq_software: varchar(24) - """ - contents = zip(["SpikeGLX", "Open Ephys"]) - - -@schema -class ProbeInsertion(dj.Manual): - """Information about probe insertion across subjects and sessions. - - Attributes: - Session (foreign key): Session primary key. - insertion_number (foreign key, str): Unique insertion number for each probe insertion for a given session. - probe.Probe (str): probe.Probe primary key. - """ - - definition = """ - # Probe insertion implanted into an animal for a given session. - -> Session - insertion_number: tinyint unsigned - --- - -> probe.Probe - """ - - -@schema -class InsertionLocation(dj.Manual): - """Stereotaxic location information for each probe insertion. - - Attributes: - ProbeInsertion (foreign key): ProbeInsertion primary key. - SkullReference (dict): SkullReference primary key. - ap_location (decimal (6, 2) ): Anterior-posterior location in micrometers. Reference is 0 with anterior values positive. - ml_location (decimal (6, 2) ): Medial-lateral location in micrometers. Reference is zero with right side values positive. - depth (decimal (6, 2) ): Manipulator depth relative to the surface of the brain at zero. Ventral is negative. - Theta (decimal (5, 2) ): elevation - rotation about the ml-axis in degrees relative to positive z-axis. - phi (decimal (5, 2) ): azimuth - rotation about the dv-axis in degrees relative to the positive x-axis - - """ - - definition = """ - # Brain Location of a given probe insertion. - -> ProbeInsertion - --- - -> SkullReference - ap_location: decimal(6, 2) # (um) anterior-posterior; ref is 0; more anterior is more positive - ml_location: decimal(6, 2) # (um) medial axis; ref is 0 ; more right is more positive - depth: decimal(6, 2) # (um) manipulator depth relative to surface of the brain (0); more ventral is more negative - theta=null: decimal(5, 2) # (deg) - elevation - rotation about the ml-axis [0, 180] - w.r.t the z+ axis - phi=null: decimal(5, 2) # (deg) - azimuth - rotation about the dv-axis [0, 360] - w.r.t the x+ axis - beta=null: decimal(5, 2) # (deg) rotation about the shank of the probe [-180, 180] - clockwise is increasing in degree - 0 is the probe-front facing anterior - """ - - -@schema -class EphysRecording(dj.Imported): - """Automated table with electrophysiology recording information for each probe inserted during an experimental session. - - Attributes: - ProbeInsertion (foreign key): ProbeInsertion primary key. - probe.ElectrodeConfig (dict): probe.ElectrodeConfig primary key. - AcquisitionSoftware (dict): AcquisitionSoftware primary key. - sampling_rate (float): sampling rate of the recording in Hertz (Hz). - recording_datetime (datetime): datetime of the recording from this probe. - recording_duration (float): duration of the entire recording from this probe in seconds. - """ - - definition = """ - # Ephys recording from a probe insertion for a given session. - -> ProbeInsertion - --- - -> probe.ElectrodeConfig - -> AcquisitionSoftware - sampling_rate: float # (Hz) - recording_datetime: datetime # datetime of the recording from this probe - recording_duration: float # (seconds) duration of the recording from this probe - """ - - class EphysFile(dj.Part): - """Paths of electrophysiology recording files for each insertion. - - Attributes: - EphysRecording (foreign key): EphysRecording primary key. - file_path (varchar(255) ): relative file path for electrophysiology recording. - """ - - definition = """ - # Paths of files of a given EphysRecording round. - -> master - file_path: varchar(255) # filepath relative to root data directory - """ - - def make(self, key): - """Populates table with electrophysiology recording information.""" - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(key) - ) - - inserted_probe_serial_number = (ProbeInsertion * probe.Probe & key).fetch1( - "probe" - ) - - # search session dir and determine acquisition software - for ephys_pattern, ephys_acq_type in ( - ("*.ap.meta", "SpikeGLX"), - ("*.oebin", "Open Ephys"), - ): - ephys_meta_filepaths = [fp for fp in session_dir.rglob(ephys_pattern)] - if ephys_meta_filepaths: - acq_software = ephys_acq_type - break - else: - raise FileNotFoundError( - f"Ephys recording data not found!" - f" Neither SpikeGLX nor Open Ephys recording files found" - f" in {session_dir}" - ) - - if acq_software == "SpikeGLX": - for meta_filepath in ephys_meta_filepaths: - spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) - if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: - break - else: - raise FileNotFoundError( - "No SpikeGLX data found for probe insertion: {}".format(key) - ) - - if re.search("(1.0|2.0)", spikeglx_meta.probe_model): - probe_type = spikeglx_meta.probe_model - electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} - - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - - electrode_group_members = [ - probe_electrodes[(shank, shank_col, shank_row)] - for shank, shank_col, shank_row, _ in spikeglx_meta.shankmap["data"] - ] - else: - raise NotImplementedError( - "Processing for neuropixels probe model" - " {} not yet implemented".format(spikeglx_meta.probe_model) - ) - - self.insert1( - { - **key, - **generate_electrode_config(probe_type, electrode_group_members), - "acq_software": acq_software, - "sampling_rate": spikeglx_meta.meta["imSampRate"], - "recording_datetime": spikeglx_meta.recording_time, - "recording_duration": ( - spikeglx_meta.recording_duration - or spikeglx.retrieve_recording_duration(meta_filepath) - ), - } - ) - - root_dir = find_root_directory(get_ephys_root_data_dir(), meta_filepath) - self.EphysFile.insert1( - {**key, "file_path": meta_filepath.relative_to(root_dir).as_posix()} - ) - elif acq_software == "Open Ephys": - dataset = openephys.OpenEphys(session_dir) - for serial_number, probe_data in dataset.probes.items(): - if str(serial_number) == inserted_probe_serial_number: - break - else: - raise FileNotFoundError( - "No Open Ephys data found for probe insertion: {}".format(key) - ) - - if re.search("(1.0|2.0)", probe_data.probe_model): - probe_type = probe_data.probe_model - electrode_query = probe.ProbeType.Electrode & {"probe_type": probe_type} - - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - electrode_group_members = [ - probe_electrodes[channel_idx] - for channel_idx in probe_data.ap_meta["channels_ids"] - ] - else: - raise NotImplementedError( - "Processing for neuropixels" - " probe model {} not yet implemented".format(probe_data.probe_model) - ) - - self.insert1( - { - **key, - **generate_electrode_config(probe_type, electrode_group_members), - "acq_software": acq_software, - "sampling_rate": probe_data.ap_meta["sample_rate"], - "recording_datetime": probe_data.recording_info[ - "recording_datetimes" - ][0], - "recording_duration": np.sum( - probe_data.recording_info["recording_durations"] - ), - } - ) - - root_dir = find_root_directory( - get_ephys_root_data_dir(), - probe_data.recording_info["recording_files"][0], - ) - self.EphysFile.insert( - [ - {**key, "file_path": fp.relative_to(root_dir).as_posix()} - for fp in probe_data.recording_info["recording_files"] - ] - ) - else: - raise NotImplementedError( - f"Processing ephys files from" - f" acquisition software of type {acq_software} is" - f" not yet implemented" - ) - - -@schema -class PreClusterMethod(dj.Lookup): - """Pre-clustering method - - Attributes: - precluster_method (foreign key, varchar(16) ): Pre-clustering method for the dataset. - precluster_method_desc(varchar(1000) ): Pre-clustering method description. - """ - - definition = """ - # Method for pre-clustering - precluster_method: varchar(16) - --- - precluster_method_desc: varchar(1000) - """ - - contents = [("catgt", "Time shift, Common average referencing, Zeroing")] - - -@schema -class PreClusterParamSet(dj.Lookup): - """Parameters for the pre-clustering method. - - Attributes: - paramset_idx (foreign key): Unique parameter set ID. - PreClusterMethod (dict): PreClusterMethod query for this dataset. - paramset_desc (varchar(128) ): Description for the pre-clustering parameter set. - param_set_hash (uuid): Unique hash for parameter set. - params (longblob): All parameters for the pre-clustering method. - """ - - definition = """ - # Parameter set to be used in a clustering procedure - paramset_idx: smallint - --- - -> PreClusterMethod - paramset_desc: varchar(128) - param_set_hash: uuid - unique index (param_set_hash) - params: longblob # dictionary of all applicable parameters - """ - - @classmethod - def insert_new_params( - cls, precluster_method: str, paramset_idx: int, paramset_desc: str, params: dict - ): - param_dict = { - "precluster_method": precluster_method, - "paramset_idx": paramset_idx, - "paramset_desc": paramset_desc, - "params": params, - "param_set_hash": dict_to_uuid(params), - } - param_query = cls & {"param_set_hash": param_dict["param_set_hash"]} - - if param_query: # If the specified param-set already exists - existing_paramset_idx = param_query.fetch1("paramset_idx") - if ( - existing_paramset_idx == paramset_idx - ): # If the existing set has the same paramset_idx: job done - return - else: # If not same name: human error, trying to add the same paramset with different name - raise dj.DataJointError( - "The specified param-set" - " already exists - paramset_idx: {}".format(existing_paramset_idx) - ) - else: - cls.insert1(param_dict) - - -@schema -class PreClusterParamSteps(dj.Manual): - """Ordered list of parameter sets that will be run. - - Attributes: - precluster_param_steps_id (foreign key): Unique ID for the pre-clustering parameter sets to be run. - precluster_param_steps_name (varchar(32) ): User-friendly name for the parameter steps. - precluster_param_steps_desc (varchar(128) ): Description of the parameter steps. - """ - - definition = """ - # Ordered list of paramset_idx that are to be run - # When pre-clustering is not performed, do not create an entry in `Step` Part table - precluster_param_steps_id: smallint - --- - precluster_param_steps_name: varchar(32) - precluster_param_steps_desc: varchar(128) - """ - - class Step(dj.Part): - """Define the order of operations for parameter sets. - - Attributes: - PreClusterParamSteps (foreign key): PreClusterParamSteps primary key. - step_number (foreign key, smallint): Order of operations. - PreClusterParamSet (dict): PreClusterParamSet to be used in pre-clustering. - """ - - definition = """ - -> master - step_number: smallint # Order of operations - --- - -> PreClusterParamSet - """ - - -@schema -class PreClusterTask(dj.Manual): - """Defines a pre-clustering task ready to be run. - - Attributes: - EphysRecording (foreign key): EphysRecording primary key. - PreclusterParamSteps (foreign key): PreClusterParam Steps primary key. - precluster_output_dir (varchar(255) ): relative path to directory for storing results of pre-clustering. - task_mode (enum ): `none` (no pre-clustering), `load` results from file, or `trigger` automated pre-clustering. - """ - - definition = """ - # Manual table for defining a clustering task ready to be run - -> EphysRecording - -> PreClusterParamSteps - --- - precluster_output_dir='': varchar(255) # pre-clustering output directory relative to the root data directory - task_mode='none': enum('none','load', 'trigger') # 'none': no pre-clustering analysis - # 'load': load analysis results - # 'trigger': trigger computation - """ - - -@schema -class PreCluster(dj.Imported): - """ - A processing table to handle each PreClusterTask: - - Attributes: - PreClusterTask (foreign key): PreClusterTask primary key. - precluster_time (datetime): Time of generation of this set of pre-clustering results. - package_version (varchar(16) ): Package version used for performing pre-clustering. - """ - - definition = """ - -> PreClusterTask - --- - precluster_time: datetime # time of generation of this set of pre-clustering results - package_version='': varchar(16) - """ - - def make(self, key): - """Populate pre-clustering tables.""" - task_mode, output_dir = (PreClusterTask & key).fetch1( - "task_mode", "precluster_output_dir" - ) - precluster_output_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - if task_mode == "none": - if len((PreClusterParamSteps.Step & key).fetch()) > 0: - raise ValueError( - "There are entries in the PreClusterParamSteps.Step " - "table and task_mode=none" - ) - creation_time = (EphysRecording & key).fetch1("recording_datetime") - elif task_mode == "load": - acq_software = (EphysRecording & key).fetch1("acq_software") - inserted_probe_serial_number = (ProbeInsertion * probe.Probe & key).fetch1( - "probe" - ) - - if acq_software == "SpikeGLX": - for meta_filepath in precluster_output_dir.rglob("*.ap.meta"): - spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) - - if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: - creation_time = spikeglx_meta.recording_time - break - else: - raise FileNotFoundError( - "No SpikeGLX data found for probe insertion: {}".format(key) - ) - else: - raise NotImplementedError( - f"Pre-clustering analysis of {acq_software}" "is not yet supported." - ) - elif task_mode == "trigger": - raise NotImplementedError( - "Automatic triggering of" - " pre-clustering analysis is not yet supported." - ) - else: - raise ValueError(f"Unknown task mode: {task_mode}") - - self.insert1({**key, "precluster_time": creation_time, "package_version": ""}) - - -@schema -class LFP(dj.Imported): - """Extracts local field potentials (LFP) from an electrophysiology recording. - - Attributes: - EphysRecording (foreign key): EphysRecording primary key. - lfp_sampling_rate (float): Sampling rate for LFPs in Hz. - lfp_time_stamps (longblob): Time stamps with respect to the start of the recording. - lfp_mean (longblob): Overall mean LFP across electrodes. - """ - - definition = """ - # Acquired local field potential (LFP) from a given Ephys recording. - -> PreCluster - --- - lfp_sampling_rate: float # (Hz) - lfp_time_stamps: longblob # (s) timestamps with respect to the start of the recording (recording_timestamp) - lfp_mean: longblob # (uV) mean of LFP across electrodes - shape (time,) - """ - - class Electrode(dj.Part): - """Saves local field potential data for each electrode. - - Attributes: - LFP (foreign key): LFP primary key. - probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key. - lfp (longblob): LFP recording at this electrode in microvolts. - """ - - definition = """ - -> master - -> probe.ElectrodeConfig.Electrode - --- - lfp: longblob # (uV) recorded lfp at this electrode - """ - - # Only store LFP for every 9th channel, due to high channel density, - # close-by channels exhibit highly similar LFP - _skip_channel_counts = 9 - - def make(self, key): - """Populates the LFP tables.""" - acq_software, probe_sn = (EphysRecording * ProbeInsertion & key).fetch1( - "acq_software", "probe" - ) - - electrode_keys, lfp = [], [] - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - - lfp_channel_ind = spikeglx_recording.lfmeta.recording_channels[ - -1 :: -self._skip_channel_counts - ] - - # Extract LFP data at specified channels and convert to uV - lfp = spikeglx_recording.lf_timeseries[ - :, lfp_channel_ind - ] # (sample x channel) - lfp = ( - lfp * spikeglx_recording.get_channel_bit_volts("lf")[lfp_channel_ind] - ).T # (channel x sample) - - self.insert1( - dict( - key, - lfp_sampling_rate=spikeglx_recording.lfmeta.meta["imSampRate"], - lfp_time_stamps=( - np.arange(lfp.shape[1]) - / spikeglx_recording.lfmeta.meta["imSampRate"] - ), - lfp_mean=lfp.mean(axis=0), - ) - ) - - electrode_query = ( - probe.ProbeType.Electrode - * probe.ElectrodeConfig.Electrode - * EphysRecording - & key - ) - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - - for recorded_site in lfp_channel_ind: - shank, shank_col, shank_row, _ = spikeglx_recording.apmeta.shankmap[ - "data" - ][recorded_site] - electrode_keys.append(probe_electrodes[(shank, shank_col, shank_row)]) - elif acq_software == "Open Ephys": - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(key) - ) - - loaded_oe = openephys.OpenEphys(session_dir) - oe_probe = loaded_oe.probes[probe_sn] - - lfp_channel_ind = np.arange(len(oe_probe.lfp_meta["channels_ids"]))[ - -1 :: -self._skip_channel_counts - ] - - lfp = oe_probe.lfp_timeseries[:, lfp_channel_ind] # (sample x channel) - lfp = ( - lfp * np.array(oe_probe.lfp_meta["channels_gains"])[lfp_channel_ind] - ).T # (channel x sample) - lfp_timestamps = oe_probe.lfp_timestamps - - self.insert1( - dict( - key, - lfp_sampling_rate=oe_probe.lfp_meta["sample_rate"], - lfp_time_stamps=lfp_timestamps, - lfp_mean=lfp.mean(axis=0), - ) - ) - - electrode_query = ( - probe.ProbeType.Electrode - * probe.ElectrodeConfig.Electrode - * EphysRecording - & key - ) - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - for channel_idx in np.array(oe_probe.lfp_meta["channels_ids"])[ - lfp_channel_ind - ]: - electrode_keys.append(probe_electrodes[channel_idx]) - else: - raise NotImplementedError( - f"LFP extraction from acquisition software" - f" of type {acq_software} is not yet implemented" - ) - - # single insert in loop to mitigate potential memory issue - for electrode_key, lfp_trace in zip(electrode_keys, lfp): - self.Electrode.insert1({**key, **electrode_key, "lfp": lfp_trace}) - - -# ------------ Clustering -------------- - - -@schema -class ClusteringMethod(dj.Lookup): - """Kilosort clustering method. - - Attributes: - clustering_method (foreign key, varchar(16) ): Kilosort clustering method. - clustering_methods_desc (varchar(1000) ): Additional description of the clustering method. - """ - - definition = """ - # Method for clustering - clustering_method: varchar(16) - --- - clustering_method_desc: varchar(1000) - """ - - contents = [ - ("kilosort", "kilosort clustering method"), - ("kilosort2", "kilosort2 clustering method"), - ] - - -@schema -class ClusteringParamSet(dj.Lookup): - """Parameters to be used in clustering procedure for spike sorting. - - Attributes: - paramset_idx (foreign key): Unique ID for the clustering parameter set. - ClusteringMethod (dict): ClusteringMethod primary key. - paramset_desc (varchar(128) ): Description of the clustering parameter set. - param_set_hash (uuid): UUID hash for the parameter set. - params (longblob): Paramset, dictionary of all applicable parameters. - """ - - definition = """ - # Parameter set to be used in a clustering procedure - paramset_idx: smallint - --- - -> ClusteringMethod - paramset_desc: varchar(128) - param_set_hash: uuid - unique index (param_set_hash) - params: longblob # dictionary of all applicable parameters - """ - - @classmethod - def insert_new_params( - cls, processing_method: str, paramset_idx: int, paramset_desc: str, params: dict - ): - """Inserts new parameters into the ClusteringParamSet table. - - Args: - processing_method (str): name of the clustering method. - paramset_desc (str): description of the parameter set - params (dict): clustering parameters - paramset_idx (int, optional): Unique parameter set ID. Defaults to None. - """ - param_dict = { - "clustering_method": processing_method, - "paramset_idx": paramset_idx, - "paramset_desc": paramset_desc, - "params": params, - "param_set_hash": dict_to_uuid(params), - } - param_query = cls & {"param_set_hash": param_dict["param_set_hash"]} - - if param_query: # If the specified param-set already exists - existing_paramset_idx = param_query.fetch1("paramset_idx") - if ( - existing_paramset_idx == paramset_idx - ): # If the existing set has the same paramset_idx: job done - return - else: # If not same name: human error, trying to add the same paramset with different name - raise dj.DataJointError( - "The specified param-set" - " already exists - paramset_idx: {}".format(existing_paramset_idx) - ) - else: - cls.insert1(param_dict) - - -@schema -class ClusterQualityLabel(dj.Lookup): - """Quality label for each spike sorted cluster. - - Attributes: - cluster_quality_label (foreign key, varchar(100) ): Cluster quality type. - cluster_quality_description (varchar(4000) ): Description of the cluster quality type. - """ - - definition = """ - # Quality - cluster_quality_label: varchar(100) - --- - cluster_quality_description: varchar(4000) - """ - contents = [ - ("good", "single unit"), - ("ok", "probably a single unit, but could be contaminated"), - ("mua", "multi-unit activity"), - ("noise", "bad unit"), - ] - - -@schema -class ClusteringTask(dj.Manual): - """A clustering task to spike sort electrophysiology datasets. - - Attributes: - EphysRecording (foreign key): EphysRecording primary key. - ClusteringParamSet (foreign key): ClusteringParamSet primary key. - clustering_outdir_dir (varchar (255) ): Relative path to output clustering results. - task_mode (enum): `Trigger` computes clustering or and `load` imports existing data. - """ - - definition = """ - # Manual table for defining a clustering task ready to be run - -> PreCluster - -> ClusteringParamSet - --- - clustering_output_dir: varchar(255) # clustering output directory relative to the clustering root data directory - task_mode='load': enum('load', 'trigger') # 'load': load computed analysis results, 'trigger': trigger computation - """ - - -@schema -class Clustering(dj.Imported): - """A processing table to handle each clustering task. - - Attributes: - ClusteringTask (foreign key): ClusteringTask primary key. - clustering_time (datetime): Time when clustering results are generated. - package_version (varchar(16) ): Package version used for a clustering analysis. - """ - - definition = """ - # Clustering Procedure - -> ClusteringTask - --- - clustering_time: datetime # time of generation of this set of clustering results - package_version='': varchar(16) - """ - - def make(self, key): - """Triggers or imports clustering analysis.""" - task_mode, output_dir = (ClusteringTask & key).fetch1( - "task_mode", "clustering_output_dir" - ) - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - if task_mode == "load": - _ = kilosort.Kilosort( - kilosort_dir - ) # check if the directory is a valid Kilosort output - creation_time, _, _ = kilosort.extract_clustering_info(kilosort_dir) - elif task_mode == "trigger": - raise NotImplementedError( - "Automatic triggering of" " clustering analysis is not yet supported" - ) - else: - raise ValueError(f"Unknown task mode: {task_mode}") - - self.insert1({**key, "clustering_time": creation_time, "package_version": ""}) - - -@schema -class Curation(dj.Manual): - """Curation procedure table. - - Attributes: - Clustering (foreign key): Clustering primary key. - curation_id (foreign key, int): Unique curation ID. - curation_time (datetime): Time when curation results are generated. - curation_output_dir (varchar(255) ): Output directory of the curated results. - quality_control (bool): If True, this clustering result has undergone quality control. - manual_curation (bool): If True, manual curation has been performed on this clustering result. - curation_note (varchar(2000) ): Notes about the curation task. - """ - - definition = """ - # Manual curation procedure - -> Clustering - curation_id: int - --- - curation_time: datetime # time of generation of this set of curated clustering results - curation_output_dir: varchar(255) # output directory of the curated results, relative to root data directory - quality_control: bool # has this clustering result undergone quality control? - manual_curation: bool # has manual curation been performed on this clustering result? - curation_note='': varchar(2000) - """ - - def create1_from_clustering_task(self, key, curation_note: str = ""): - """ - A function to create a new corresponding "Curation" for a particular - "ClusteringTask" - """ - if key not in Clustering(): - raise ValueError( - f"No corresponding entry in Clustering available" - f" for: {key}; do `Clustering.populate(key)`" - ) - - task_mode, output_dir = (ClusteringTask & key).fetch1( - "task_mode", "clustering_output_dir" - ) - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - creation_time, is_curated, is_qc = kilosort.extract_clustering_info( - kilosort_dir - ) - # Synthesize curation_id - curation_id = ( - dj.U().aggr(self & key, n="ifnull(max(curation_id)+1,1)").fetch1("n") - ) - self.insert1( - { - **key, - "curation_id": curation_id, - "curation_time": creation_time, - "curation_output_dir": output_dir, - "quality_control": is_qc, - "manual_curation": is_curated, - "curation_note": curation_note, - } - ) - - -@schema -class CuratedClustering(dj.Imported): - """Clustering results after curation. - - Attributes: - Curation (foreign key): Curation primary key. - """ - - definition = """ - # Clustering results of a curation. - -> Curation - """ - - class Unit(dj.Part): - """Single unit properties after clustering and curation. - - Attributes: - CuratedClustering (foreign key): CuratedClustering primary key. - unit (foreign key, int): Unique integer identifying a single unit. - probe.ElectrodeConfig.Electrode (dict): probe.ElectrodeConfig.Electrode primary key. - ClusteringQualityLabel (dict): CLusteringQualityLabel primary key. - spike_count (int): Number of spikes in this recording for this unit. - spike_times (longblob): Spike times of this unit, relative to start time of EphysRecording. - spike_sites (longblob): Array of electrode associated with each spike. - spike_depths (longblob): Array of depths associated with each spike, relative to each spike. - """ - - definition = """ - # Properties of a given unit from a round of clustering (and curation) - -> master - unit: int - --- - -> probe.ElectrodeConfig.Electrode # electrode with highest waveform amplitude for this unit - -> ClusterQualityLabel - spike_count: int # how many spikes in this recording for this unit - spike_times: longblob # (s) spike times of this unit, relative to the start of the EphysRecording - spike_sites : longblob # array of electrode associated with each spike - spike_depths=null : longblob # (um) array of depths associated with each spike, relative to the (0, 0) of the probe - """ - - def make(self, key): - """Automated population of Unit information.""" - output_dir = (Curation & key).fetch1("curation_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - kilosort_dataset = kilosort.Kilosort(kilosort_dir) - acq_software = (EphysRecording & key).fetch1("acq_software") - - # ---------- Unit ---------- - # -- Remove 0-spike units - withspike_idx = [ - i - for i, u in enumerate(kilosort_dataset.data["cluster_ids"]) - if (kilosort_dataset.data["spike_clusters"] == u).any() - ] - valid_units = kilosort_dataset.data["cluster_ids"][withspike_idx] - valid_unit_labels = kilosort_dataset.data["cluster_groups"][withspike_idx] - # -- Get channel and electrode-site mapping - channel2electrodes = get_neuropixels_channel2electrode_map(key, acq_software) - - # -- Spike-times -- - # spike_times_sec_adj > spike_times_sec > spike_times - spike_time_key = ( - "spike_times_sec_adj" - if "spike_times_sec_adj" in kilosort_dataset.data - else ( - "spike_times_sec" - if "spike_times_sec" in kilosort_dataset.data - else "spike_times" - ) - ) - spike_times = kilosort_dataset.data[spike_time_key] - kilosort_dataset.extract_spike_depths() - - # -- Spike-sites and Spike-depths -- - spike_sites = np.array( - [ - channel2electrodes[s]["electrode"] - for s in kilosort_dataset.data["spike_sites"] - ] - ) - spike_depths = kilosort_dataset.data["spike_depths"] - - # -- Insert unit, label, peak-chn - units = [] - for unit, unit_lbl in zip(valid_units, valid_unit_labels): - if (kilosort_dataset.data["spike_clusters"] == unit).any(): - unit_channel, _ = kilosort_dataset.get_best_channel(unit) - unit_spike_times = ( - spike_times[kilosort_dataset.data["spike_clusters"] == unit] - / kilosort_dataset.data["params"]["sample_rate"] - ) - spike_count = len(unit_spike_times) - - units.append( - { - "unit": unit, - "cluster_quality_label": unit_lbl, - **channel2electrodes[unit_channel], - "spike_times": unit_spike_times, - "spike_count": spike_count, - "spike_sites": spike_sites[ - kilosort_dataset.data["spike_clusters"] == unit - ], - "spike_depths": ( - spike_depths[ - kilosort_dataset.data["spike_clusters"] == unit - ] - if spike_depths is not None - else None - ), - } - ) - - self.insert1(key) - self.Unit.insert([{**key, **u} for u in units]) - - -@schema -class WaveformSet(dj.Imported): - """A set of spike waveforms for units out of a given CuratedClustering. - - Attributes: - CuratedClustering (foreign key): CuratedClustering primary key. - """ - - definition = """ - # A set of spike waveforms for units out of a given CuratedClustering - -> CuratedClustering - """ - - class PeakWaveform(dj.Part): - """Mean waveform across spikes for a given unit. - - Attributes: - WaveformSet (foreign key): WaveformSet primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - peak_electrode_waveform (longblob): Mean waveform for a given unit at its representative electrode. - """ - - definition = """ - # Mean waveform across spikes for a given unit at its representative electrode - -> master - -> CuratedClustering.Unit - --- - peak_electrode_waveform: longblob # (uV) mean waveform for a given unit at its representative electrode - """ - - class Waveform(dj.Part): - """Spike waveforms for a given unit. - - Attributes: - WaveformSet (foreign key): WaveformSet primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - probe.ElectrodeConfig.Electrode (foreign key): probe.ElectrodeConfig.Electrode primary key. - waveform_mean (longblob): mean waveform across spikes of the unit in microvolts. - waveforms (longblob): waveforms of a sampling of spikes at the given electrode and unit. - """ - - definition = """ - # Spike waveforms and their mean across spikes for the given unit - -> master - -> CuratedClustering.Unit - -> probe.ElectrodeConfig.Electrode - --- - waveform_mean: longblob # (uV) mean waveform across spikes of the given unit - waveforms=null: longblob # (uV) (spike x sample) waveforms of a sampling of spikes at the given electrode for the given unit - """ - - def make(self, key): - """Populates waveform tables.""" - output_dir = (Curation & key).fetch1("curation_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - kilosort_dataset = kilosort.Kilosort(kilosort_dir) - - acq_software, probe_serial_number = ( - EphysRecording * ProbeInsertion & key - ).fetch1("acq_software", "probe") - - # -- Get channel and electrode-site mapping - recording_key = (EphysRecording & key).fetch1("KEY") - channel2electrodes = get_neuropixels_channel2electrode_map( - recording_key, acq_software - ) - - is_qc = (Curation & key).fetch1("quality_control") - - # Get all units - units = { - u["unit"]: u - for u in (CuratedClustering.Unit & key).fetch(as_dict=True, order_by="unit") - } - - if is_qc: - unit_waveforms = np.load( - kilosort_dir / "mean_waveforms.npy" - ) # unit x channel x sample - - def yield_unit_waveforms(): - for unit_no, unit_waveform in zip( - kilosort_dataset.data["cluster_ids"], unit_waveforms - ): - unit_peak_waveform = {} - unit_electrode_waveforms = [] - if unit_no in units: - for channel, channel_waveform in zip( - kilosort_dataset.data["channel_map"], unit_waveform - ): - unit_electrode_waveforms.append( - { - **units[unit_no], - **channel2electrodes[channel], - "waveform_mean": channel_waveform, - } - ) - if ( - channel2electrodes[channel]["electrode"] - == units[unit_no]["electrode"] - ): - unit_peak_waveform = { - **units[unit_no], - "peak_electrode_waveform": channel_waveform, - } - yield unit_peak_waveform, unit_electrode_waveforms - - else: - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(key) - neuropixels_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - elif acq_software == "Open Ephys": - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(key) - ) - openephys_dataset = openephys.OpenEphys(session_dir) - neuropixels_recording = openephys_dataset.probes[probe_serial_number] - - def yield_unit_waveforms(): - for unit_dict in units.values(): - unit_peak_waveform = {} - unit_electrode_waveforms = [] - - spikes = unit_dict["spike_times"] - waveforms = neuropixels_recording.extract_spike_waveforms( - spikes, kilosort_dataset.data["channel_map"] - ) # (sample x channel x spike) - waveforms = waveforms.transpose( - (1, 2, 0) - ) # (channel x spike x sample) - for channel, channel_waveform in zip( - kilosort_dataset.data["channel_map"], waveforms - ): - unit_electrode_waveforms.append( - { - **unit_dict, - **channel2electrodes[channel], - "waveform_mean": channel_waveform.mean(axis=0), - "waveforms": channel_waveform, - } - ) - if ( - channel2electrodes[channel]["electrode"] - == unit_dict["electrode"] - ): - unit_peak_waveform = { - **unit_dict, - "peak_electrode_waveform": channel_waveform.mean( - axis=0 - ), - } - - yield unit_peak_waveform, unit_electrode_waveforms - - # insert waveform on a per-unit basis to mitigate potential memory issue - self.insert1(key) - for unit_peak_waveform, unit_electrode_waveforms in yield_unit_waveforms(): - self.PeakWaveform.insert1(unit_peak_waveform, ignore_extra_fields=True) - self.Waveform.insert(unit_electrode_waveforms, ignore_extra_fields=True) - - -@schema -class QualityMetrics(dj.Imported): - """Clustering and waveform quality metrics. - - Attributes: - CuratedClustering (foreign key): CuratedClustering primary key. - """ - - definition = """ - # Clusters and waveforms metrics - -> CuratedClustering - """ - - class Cluster(dj.Part): - """Cluster metrics for a unit. - - Attributes: - QualityMetrics (foreign key): QualityMetrics primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - firing_rate (float): Firing rate of the unit. - snr (float): Signal-to-noise ratio for a unit. - presence_ratio (float): Fraction of time where spikes are present. - isi_violation (float): rate of ISI violation as a fraction of overall rate. - number_violation (int): Total ISI violations. - amplitude_cutoff (float): Estimate of miss rate based on amplitude histogram. - isolation_distance (float): Distance to nearest cluster. - l_ratio (float): Amount of empty space between a cluster and other spikes in dataset. - d_prime (float): Classification accuracy based on LDA. - nn_hit_rate (float): Fraction of neighbors for target cluster that are also in target cluster. - nn_miss_rate (float): Fraction of neighbors outside target cluster that are in the target cluster. - silhouette_core (float): Maximum change in spike depth throughout recording. - cumulative_drift (float): Cumulative change in spike depth throughout recording. - contamination_rate (float): Frequency of spikes in the refractory period. - """ - - definition = """ - # Cluster metrics for a particular unit - -> master - -> CuratedClustering.Unit - --- - firing_rate=null: float # (Hz) firing rate for a unit - snr=null: float # signal-to-noise ratio for a unit - presence_ratio=null: float # fraction of time in which spikes are present - isi_violation=null: float # rate of ISI violation as a fraction of overall rate - number_violation=null: int # total number of ISI violations - amplitude_cutoff=null: float # estimate of miss rate based on amplitude histogram - isolation_distance=null: float # distance to nearest cluster in Mahalanobis space - l_ratio=null: float # - d_prime=null: float # Classification accuracy based on LDA - nn_hit_rate=null: float # Fraction of neighbors for target cluster that are also in target cluster - nn_miss_rate=null: float # Fraction of neighbors outside target cluster that are in target cluster - silhouette_score=null: float # Standard metric for cluster overlap - max_drift=null: float # Maximum change in spike depth throughout recording - cumulative_drift=null: float # Cumulative change in spike depth throughout recording - contamination_rate=null: float # - """ - - class Waveform(dj.Part): - """Waveform metrics for a particular unit. - - Attributes: - QualityMetrics (foreign key): QualityMetrics primary key. - CuratedClustering.Unit (foreign key): CuratedClustering.Unit primary key. - amplitude (float): Absolute difference between waveform peak and trough in microvolts. - duration (float): Time between waveform peak and trough in milliseconds. - halfwidth (float): Spike width at half max amplitude. - pt_ratio (float): Absolute amplitude of peak divided by absolute amplitude of trough relative to 0. - repolarization_slope (float): Slope of the regression line fit to first 30 microseconds from trough to peak. - recovery_slope (float): Slope of the regression line fit to first 30 microseconds from peak to tail. - spread (float): The range with amplitude over 12-percent of maximum amplitude along the probe. - velocity_above (float): inverse velocity of waveform propagation from soma to the top of the probe. - velocity_below (float): inverse velocity of waveform propagation from soma toward the bottom of the probe. - """ - - definition = """ - # Waveform metrics for a particular unit - -> master - -> CuratedClustering.Unit - --- - amplitude: float # (uV) absolute difference between waveform peak and trough - duration: float # (ms) time between waveform peak and trough - halfwidth=null: float # (ms) spike width at half max amplitude - pt_ratio=null: float # absolute amplitude of peak divided by absolute amplitude of trough relative to 0 - repolarization_slope=null: float # the repolarization slope was defined by fitting a regression line to the first 30us from trough to peak - recovery_slope=null: float # the recovery slope was defined by fitting a regression line to the first 30us from peak to tail - spread=null: float # (um) the range with amplitude above 12-percent of the maximum amplitude along the probe - velocity_above=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the top of the probe - velocity_below=null: float # (s/m) inverse velocity of waveform propagation from the soma toward the bottom of the probe - """ - - def make(self, key): - """Populates tables with quality metrics data.""" - output_dir = (ClusteringTask & key).fetch1("clustering_output_dir") - kilosort_dir = find_full_path(get_ephys_root_data_dir(), output_dir) - - metric_fp = kilosort_dir / "metrics.csv" - rename_dict = { - "isi_viol": "isi_violation", - "num_viol": "number_violation", - "contam_rate": "contamination_rate", - } - - if not metric_fp.exists(): - raise FileNotFoundError(f"QC metrics file not found: {metric_fp}") - - metrics_df = pd.read_csv(metric_fp) - metrics_df.set_index("cluster_id", inplace=True) - metrics_df.replace([np.inf, -np.inf], np.nan, inplace=True) - metrics_df.columns = metrics_df.columns.str.lower() - metrics_df.rename(columns=rename_dict, inplace=True) - metrics_list = [ - dict(metrics_df.loc[unit_key["unit"]], **unit_key) - for unit_key in (CuratedClustering.Unit & key).fetch("KEY") - ] - - self.insert1(key) - self.Cluster.insert(metrics_list, ignore_extra_fields=True) - self.Waveform.insert(metrics_list, ignore_extra_fields=True) - - -# ---------------- HELPER FUNCTIONS ---------------- - - -def get_spikeglx_meta_filepath(ephys_recording_key: dict) -> str: - """Get spikeGLX data filepath.""" - # attempt to retrieve from EphysRecording.EphysFile - spikeglx_meta_filepath = ( - EphysRecording.EphysFile & ephys_recording_key & 'file_path LIKE "%.ap.meta"' - ).fetch1("file_path") - - try: - spikeglx_meta_filepath = find_full_path( - get_ephys_root_data_dir(), spikeglx_meta_filepath - ) - except FileNotFoundError: - # if not found, search in session_dir again - if not spikeglx_meta_filepath.exists(): - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(ephys_recording_key) - ) - inserted_probe_serial_number = ( - ProbeInsertion * probe.Probe & ephys_recording_key - ).fetch1("probe") - - spikeglx_meta_filepaths = [fp for fp in session_dir.rglob("*.ap.meta")] - for meta_filepath in spikeglx_meta_filepaths: - spikeglx_meta = spikeglx.SpikeGLXMeta(meta_filepath) - if str(spikeglx_meta.probe_SN) == inserted_probe_serial_number: - spikeglx_meta_filepath = meta_filepath - break - else: - raise FileNotFoundError( - "No SpikeGLX data found for probe insertion: {}".format( - ephys_recording_key - ) - ) - - return spikeglx_meta_filepath - - -def get_neuropixels_channel2electrode_map( - ephys_recording_key: dict, acq_software: str -) -> dict: - """Get the channel map for neuropixels probe.""" - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = get_spikeglx_meta_filepath(ephys_recording_key) - spikeglx_meta = spikeglx.SpikeGLXMeta(spikeglx_meta_filepath) - electrode_config_key = ( - EphysRecording * probe.ElectrodeConfig & ephys_recording_key - ).fetch1("KEY") - - electrode_query = ( - probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode - & electrode_config_key - ) - - probe_electrodes = { - (shank, shank_col, shank_row): key - for key, shank, shank_col, shank_row in zip( - *electrode_query.fetch("KEY", "shank", "shank_col", "shank_row") - ) - } - - channel2electrode_map = { - recorded_site: probe_electrodes[(shank, shank_col, shank_row)] - for recorded_site, (shank, shank_col, shank_row, _) in enumerate( - spikeglx_meta.shankmap["data"] - ) - } - elif acq_software == "Open Ephys": - session_dir = find_full_path( - get_ephys_root_data_dir(), get_session_directory(ephys_recording_key) - ) - openephys_dataset = openephys.OpenEphys(session_dir) - probe_serial_number = (ProbeInsertion & ephys_recording_key).fetch1("probe") - probe_dataset = openephys_dataset.probes[probe_serial_number] - - electrode_query = ( - probe.ProbeType.Electrode * probe.ElectrodeConfig.Electrode * EphysRecording - & ephys_recording_key - ) - - probe_electrodes = { - key["electrode"]: key for key in electrode_query.fetch("KEY") - } - - channel2electrode_map = { - channel_idx: probe_electrodes[channel_idx] - for channel_idx in probe_dataset.ap_meta["channels_ids"] - } - - return channel2electrode_map - - -def generate_electrode_config(probe_type: str, electrode_keys: list) -> dict: - """Generate and insert new ElectrodeConfig - - Args: - probe_type (str): probe type (e.g. neuropixels 2.0 - SS) - electrode_keys (list): list of keys of the probe.ProbeType.Electrode table - - Returns: - dict: representing a key of the probe.ElectrodeConfig table - """ - # compute hash for the electrode config (hash of dict of all ElectrodeConfig.Electrode) - electrode_config_hash = dict_to_uuid({k["electrode"]: k for k in electrode_keys}) - - electrode_list = sorted([k["electrode"] for k in electrode_keys]) - electrode_gaps = ( - [-1] - + np.where(np.diff(electrode_list) > 1)[0].tolist() - + [len(electrode_list) - 1] - ) - electrode_config_name = "; ".join( - [ - f"{electrode_list[start + 1]}-{electrode_list[end]}" - for start, end in zip(electrode_gaps[:-1], electrode_gaps[1:]) - ] - ) - - electrode_config_key = {"electrode_config_hash": electrode_config_hash} - - # ---- make new ElectrodeConfig if needed ---- - if not probe.ElectrodeConfig & electrode_config_key: - probe.ElectrodeConfig.insert1( - { - **electrode_config_key, - "probe_type": probe_type, - "electrode_config_name": electrode_config_name, - } - ) - probe.ElectrodeConfig.Electrode.insert( - {**electrode_config_key, **electrode} for electrode in electrode_keys - ) - - return electrode_config_key diff --git a/element_array_ephys/ephys_report.py b/element_array_ephys/ephys_report.py index 48bcf613..c962d33d 100644 --- a/element_array_ephys/ephys_report.py +++ b/element_array_ephys/ephys_report.py @@ -7,26 +7,24 @@ import datajoint as dj from element_interface.utils import dict_to_uuid -from . import probe +from . import probe, ephys schema = dj.schema() -ephys = None - -def activate(schema_name, ephys_schema_name, *, create_schema=True, create_tables=True): +def activate(schema_name, *, create_schema=True, create_tables=True): """Activate the current schema. Args: schema_name (str): schema name on the database server to activate the `ephys_report` schema. - ephys_schema_name (str): schema name of the activated ephys element for which - this ephys_report schema will be downstream from. create_schema (bool, optional): If True (default), create schema in the database if it does not yet exist. create_tables (bool, optional): If True (default), create tables in the database if they do not yet exist. """ + if not probe.schema.is_activated(): + raise RuntimeError("Please activate the `probe` schema first.") + if not ephys.schema.is_activated(): + raise RuntimeError("Please activate the `ephys` schema first.") - global ephys - ephys = dj.create_virtual_module("ephys", ephys_schema_name) schema.activate( schema_name, create_schema=create_schema, diff --git a/element_array_ephys/export/nwb/nwb.py b/element_array_ephys/export/nwb/nwb.py index a45eb754..8d7da8f5 100644 --- a/element_array_ephys/export/nwb/nwb.py +++ b/element_array_ephys/export/nwb/nwb.py @@ -17,14 +17,7 @@ from spikeinterface import extractors from tqdm import tqdm -from ... import ephys_no_curation as ephys -from ... import probe - -ephys_mode = os.getenv("EPHYS_MODE", dj.config["custom"].get("ephys_mode", "acute")) -if ephys_mode != "no-curation": - raise NotImplementedError( - "This export function is designed for the no_curation " + "schema" - ) +from ... import probe, ephys class DecimalEncoder(json.JSONEncoder): diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 550ae4a1..547fd8ce 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -1,5 +1,7 @@ """ -The following DataJoint pipeline implements the sequence of steps in the spike-sorting routine featured in the "spikeinterface" pipeline. Spikeinterface was developed by Alessio Buccino, Samuel Garcia, Cole Hurwitz, Jeremy Magland, and Matthias Hennig (https://github.com/SpikeInterface) +The following DataJoint pipeline implements the sequence of steps in the spike-sorting routine featured in the "spikeinterface" pipeline. +Spikeinterface was developed by Alessio Buccino, Samuel Garcia, Cole Hurwitz, Jeremy Magland, and Matthias Hennig (https://github.com/SpikeInterface) +If you use this pipeline, please cite SpikeInterface and the relevant sorter(s) used in your publication (see https://github.com/SpikeInterface for additional details for citation). """ from datetime import datetime @@ -7,7 +9,7 @@ import datajoint as dj import pandas as pd import spikeinterface as si -from element_array_ephys import probe, readers +from element_array_ephys import probe, ephys, readers from element_interface.utils import find_full_path, memoized_result from spikeinterface import exporters, extractors, sorters @@ -17,25 +19,25 @@ schema = dj.schema() -ephys = None - def activate( schema_name, *, - ephys_module, create_schema=True, create_tables=True, ): + """Activate the current schema. + + Args: + schema_name (str): schema name on the database server to activate the `si_spike_sorting` schema. + create_schema (bool, optional): If True (default), create schema in the database if it does not yet exist. + create_tables (bool, optional): If True (default), create tables in the database if they do not yet exist. """ - activate(schema_name, *, create_schema=True, create_tables=True, activated_ephys=None) - :param schema_name: schema name on the database server to activate the `spike_sorting` schema - :param ephys_module: the activated ephys element for which this `spike_sorting` schema will be downstream from - :param create_schema: when True (default), create schema in the database if it does not yet exist. - :param create_tables: when True (default), create tables in the database if they do not yet exist. - """ - global ephys - ephys = ephys_module + if not probe.schema.is_activated(): + raise RuntimeError("Please activate the `probe` schema first.") + if not ephys.schema.is_activated(): + raise RuntimeError("Please activate the `ephys` schema first.") + schema.activate( schema_name, create_schema=create_schema, diff --git a/element_array_ephys/version.py b/element_array_ephys/version.py index 2e6de55a..19ba4c76 100644 --- a/element_array_ephys/version.py +++ b/element_array_ephys/version.py @@ -1,3 +1,3 @@ """Package metadata.""" -__version__ = "0.4.0" +__version__ = "1.0.0" diff --git a/tests/tutorial_pipeline.py b/tests/tutorial_pipeline.py index 74b27ddc..1b27027d 100644 --- a/tests/tutorial_pipeline.py +++ b/tests/tutorial_pipeline.py @@ -3,7 +3,7 @@ import datajoint as dj from element_animal import subject from element_animal.subject import Subject -from element_array_ephys import probe, ephys_no_curation as ephys, ephys_report +from element_array_ephys import probe, ephys, ephys_report from element_lab import lab from element_lab.lab import Lab, Location, Project, Protocol, Source, User from element_lab.lab import Device as Equipment @@ -62,7 +62,9 @@ def get_session_directory(session_key): return pathlib.Path(session_directory) -ephys.activate(db_prefix + "ephys", db_prefix + "probe", linking_module=__name__) +probe.activate(db_prefix + "probe") +ephys.activate(db_prefix + "ephys", linking_module=__name__) +ephys_report.activate(db_prefix + "ephys_report") probe.create_neuropixels_probe_types() From 0eef1cbaec2494b7dec7a5af2e8d9d62986280cb Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 10 Sep 2024 15:07:34 -0500 Subject: [PATCH 143/146] rearrange: remove the `ecephys_spike_sorting` flow --- element_array_ephys/ephys.py | 2 +- .../spike_sorting/ecephys_spike_sorting.py | 317 ------------------ .../kilosort_triggering.py | 0 3 files changed, 1 insertion(+), 318 deletions(-) delete mode 100644 element_array_ephys/spike_sorting/ecephys_spike_sorting.py rename element_array_ephys/{readers => spike_sorting}/kilosort_triggering.py (100%) diff --git a/element_array_ephys/ephys.py b/element_array_ephys/ephys.py index 3025d289..f17527c1 100644 --- a/element_array_ephys/ephys.py +++ b/element_array_ephys/ephys.py @@ -897,7 +897,7 @@ def make(self, key): ).fetch1("acq_software", "clustering_method", "params") if "kilosort" in clustering_method: - from element_array_ephys.readers import kilosort_triggering + from .spike_sorting import kilosort_triggering # add additional probe-recording and channels details into `params` params = {**params, **get_recording_channels_details(key)} diff --git a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py b/element_array_ephys/spike_sorting/ecephys_spike_sorting.py deleted file mode 100644 index 3a43c384..00000000 --- a/element_array_ephys/spike_sorting/ecephys_spike_sorting.py +++ /dev/null @@ -1,317 +0,0 @@ -""" -The following DataJoint pipeline implements the sequence of steps in the spike-sorting routine featured in the -"ecephys_spike_sorting" pipeline. -The "ecephys_spike_sorting" was originally developed by the Allen Institute (https://github.com/AllenInstitute/ecephys_spike_sorting) for Neuropixels data acquired with Open Ephys acquisition system. -Then forked by Jennifer Colonell from the Janelia Research Campus (https://github.com/jenniferColonell/ecephys_spike_sorting) to support SpikeGLX acquisition system. - -At DataJoint, we fork from Jennifer's fork and implemented a version that supports both Open Ephys and Spike GLX. -https://github.com/datajoint-company/ecephys_spike_sorting - -The follow pipeline features three tables: -1. KilosortPreProcessing - for preprocessing steps (no GPU required) - - median_subtraction for Open Ephys - - or the CatGT step for SpikeGLX -2. KilosortClustering - kilosort (MATLAB) - requires GPU - - supports kilosort 2.0, 2.5 or 3.0 (https://github.com/MouseLand/Kilosort.git) -3. KilosortPostProcessing - for postprocessing steps (no GPU required) - - kilosort_postprocessing - - noise_templates - - mean_waveforms - - quality_metrics -""" - - -import datajoint as dj -from decimal import Decimal -import json -from datetime import datetime, timedelta - -from element_interface.utils import find_full_path -from element_array_ephys.readers import ( - spikeglx, - kilosort_triggering, -) - -log = dj.logger - -schema = dj.schema() - -ephys = None - -_supported_kilosort_versions = [ - "kilosort2", - "kilosort2.5", - "kilosort3", -] - - -def activate( - schema_name, - *, - ephys_module, - create_schema=True, - create_tables=True, -): - """ - activate(schema_name, *, create_schema=True, create_tables=True, activated_ephys=None) - :param schema_name: schema name on the database server to activate the `spike_sorting` schema - :param ephys_module: the activated ephys element for which this `spike_sorting` schema will be downstream from - :param create_schema: when True (default), create schema in the database if it does not yet exist. - :param create_tables: when True (default), create tables in the database if they do not yet exist. - """ - global ephys - ephys = ephys_module - schema.activate( - schema_name, - create_schema=create_schema, - create_tables=create_tables, - add_objects=ephys.__dict__, - ) - - -@schema -class KilosortPreProcessing(dj.Imported): - """A processing table to handle each clustering task.""" - - definition = """ - -> ephys.ClusteringTask - --- - params: longblob # finalized parameterset for this run - execution_time: datetime # datetime of the start of this step - execution_duration: float # (hour) execution duration - """ - - @property - def key_source(self): - return ( - ephys.ClusteringTask * ephys.ClusteringParamSet - & {"task_mode": "trigger"} - & 'clustering_method in ("kilosort2", "kilosort2.5", "kilosort3")' - ) - ephys.Clustering - - def make(self, key): - """Triggers or imports clustering analysis.""" - execution_time = datetime.utcnow() - - task_mode, output_dir = (ephys.ClusteringTask & key).fetch1( - "task_mode", "clustering_output_dir" - ) - - assert task_mode == "trigger", 'Supporting "trigger" task_mode only' - - if not output_dir: - output_dir = ephys.ClusteringTask.infer_output_dir( - key, relative=True, mkdir=True - ) - # update clustering_output_dir - ephys.ClusteringTask.update1( - {**key, "clustering_output_dir": output_dir.as_posix()} - ) - - kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - - acq_software, clustering_method, params = ( - ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key - ).fetch1("acq_software", "clustering_method", "params") - - assert ( - clustering_method in _supported_kilosort_versions - ), f'Clustering_method "{clustering_method}" is not supported' - - # add additional probe-recording and channels details into `params` - params = {**params, **ephys.get_recording_channels_details(key)} - params["fs"] = params["sample_rate"] - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - spikeglx_recording.validate_file("ap") - run_CatGT = ( - params.get("run_CatGT", True) - and "_tcat." not in spikeglx_meta_filepath.stem - ) - - run_kilosort = kilosort_triggering.SGLXKilosortPipeline( - npx_input_dir=spikeglx_meta_filepath.parent, - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - run_CatGT=run_CatGT, - ) - run_kilosort.run_CatGT() - elif acq_software == "Open Ephys": - oe_probe = ephys.get_openephys_probe_data(key) - - assert len(oe_probe.recording_info["recording_files"]) == 1 - - # run kilosort - run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( - npx_input_dir=oe_probe.recording_info["recording_files"][0], - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - ) - run_kilosort._modules = ["depth_estimation", "median_subtraction"] - run_kilosort.run_modules() - - self.insert1( - { - **key, - "params": params, - "execution_time": execution_time, - "execution_duration": ( - datetime.utcnow() - execution_time - ).total_seconds() - / 3600, - } - ) - - -@schema -class KilosortClustering(dj.Imported): - """A processing table to handle each clustering task.""" - - definition = """ - -> KilosortPreProcessing - --- - execution_time: datetime # datetime of the start of this step - execution_duration: float # (hour) execution duration - """ - - def make(self, key): - execution_time = datetime.utcnow() - - output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") - kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - - acq_software, clustering_method = ( - ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key - ).fetch1("acq_software", "clustering_method") - - params = (KilosortPreProcessing & key).fetch1("params") - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - spikeglx_recording.validate_file("ap") - - run_kilosort = kilosort_triggering.SGLXKilosortPipeline( - npx_input_dir=spikeglx_meta_filepath.parent, - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - run_CatGT=True, - ) - run_kilosort._modules = ["kilosort_helper"] - run_kilosort._CatGT_finished = True - run_kilosort.run_modules() - elif acq_software == "Open Ephys": - oe_probe = ephys.get_openephys_probe_data(key) - - assert len(oe_probe.recording_info["recording_files"]) == 1 - - # run kilosort - run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( - npx_input_dir=oe_probe.recording_info["recording_files"][0], - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - ) - run_kilosort._modules = ["kilosort_helper"] - run_kilosort.run_modules() - - self.insert1( - { - **key, - "execution_time": execution_time, - "execution_duration": ( - datetime.utcnow() - execution_time - ).total_seconds() - / 3600, - } - ) - - -@schema -class KilosortPostProcessing(dj.Imported): - """A processing table to handle each clustering task.""" - - definition = """ - -> KilosortClustering - --- - modules_status: longblob # dictionary of summary status for all modules - execution_time: datetime # datetime of the start of this step - execution_duration: float # (hour) execution duration - """ - - def make(self, key): - execution_time = datetime.utcnow() - - output_dir = (ephys.ClusteringTask & key).fetch1("clustering_output_dir") - kilosort_dir = find_full_path(ephys.get_ephys_root_data_dir(), output_dir) - - acq_software, clustering_method = ( - ephys.ClusteringTask * ephys.EphysRecording * ephys.ClusteringParamSet & key - ).fetch1("acq_software", "clustering_method") - - params = (KilosortPreProcessing & key).fetch1("params") - - if acq_software == "SpikeGLX": - spikeglx_meta_filepath = ephys.get_spikeglx_meta_filepath(key) - spikeglx_recording = spikeglx.SpikeGLX(spikeglx_meta_filepath.parent) - spikeglx_recording.validate_file("ap") - - run_kilosort = kilosort_triggering.SGLXKilosortPipeline( - npx_input_dir=spikeglx_meta_filepath.parent, - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - run_CatGT=True, - ) - run_kilosort._modules = [ - "kilosort_postprocessing", - "noise_templates", - "mean_waveforms", - "quality_metrics", - ] - run_kilosort._CatGT_finished = True - run_kilosort.run_modules() - elif acq_software == "Open Ephys": - oe_probe = ephys.get_openephys_probe_data(key) - - assert len(oe_probe.recording_info["recording_files"]) == 1 - - # run kilosort - run_kilosort = kilosort_triggering.OpenEphysKilosortPipeline( - npx_input_dir=oe_probe.recording_info["recording_files"][0], - ks_output_dir=kilosort_dir, - params=params, - KS2ver=f'{Decimal(clustering_method.replace("kilosort", "")):.1f}', - ) - run_kilosort._modules = [ - "kilosort_postprocessing", - "noise_templates", - "mean_waveforms", - "quality_metrics", - ] - run_kilosort.run_modules() - - with open(run_kilosort._modules_input_hash_fp) as f: - modules_status = json.load(f) - - self.insert1( - { - **key, - "modules_status": modules_status, - "execution_time": execution_time, - "execution_duration": ( - datetime.utcnow() - execution_time - ).total_seconds() - / 3600, - } - ) - - # all finished, insert this `key` into ephys.Clustering - ephys.Clustering.insert1( - {**key, "clustering_time": datetime.utcnow()}, allow_direct_insert=True - ) diff --git a/element_array_ephys/readers/kilosort_triggering.py b/element_array_ephys/spike_sorting/kilosort_triggering.py similarity index 100% rename from element_array_ephys/readers/kilosort_triggering.py rename to element_array_ephys/spike_sorting/kilosort_triggering.py From c2bd5adb07096a04c7afc28418f1f996533fdf8b Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Tue, 10 Sep 2024 15:23:04 -0500 Subject: [PATCH 144/146] chore: clean up diagrams --- ...n.svg => attached_array_ephys_element.svg} | 0 images/attached_array_ephys_element_acute.svg | 451 --------------- .../attached_array_ephys_element_chronic.svg | 456 --------------- ...ttached_array_ephys_element_precluster.svg | 535 ------------------ 4 files changed, 1442 deletions(-) rename images/{attached_array_ephys_element_no_curation.svg => attached_array_ephys_element.svg} (100%) delete mode 100644 images/attached_array_ephys_element_acute.svg delete mode 100644 images/attached_array_ephys_element_chronic.svg delete mode 100644 images/attached_array_ephys_element_precluster.svg diff --git a/images/attached_array_ephys_element_no_curation.svg b/images/attached_array_ephys_element.svg similarity index 100% rename from images/attached_array_ephys_element_no_curation.svg rename to images/attached_array_ephys_element.svg diff --git a/images/attached_array_ephys_element_acute.svg b/images/attached_array_ephys_element_acute.svg deleted file mode 100644 index 5b2bc265..00000000 --- a/images/attached_array_ephys_element_acute.svg +++ /dev/null @@ -1,451 +0,0 @@ - - - - - -ephys.ProbeInsertion - - -ephys.ProbeInsertion - - - - - -ephys.InsertionLocation - - -ephys.InsertionLocation - - - - - -ephys.ProbeInsertion->ephys.InsertionLocation - - - - -ephys.EphysRecording - - -ephys.EphysRecording - - - - - -ephys.ProbeInsertion->ephys.EphysRecording - - - - -ephys.QualityMetrics - - -ephys.QualityMetrics - - - - - -ephys.QualityMetrics.Cluster - - -ephys.QualityMetrics.Cluster - - - - - -ephys.QualityMetrics->ephys.QualityMetrics.Cluster - - - - -ephys.QualityMetrics.Waveform - - -ephys.QualityMetrics.Waveform - - - - - -ephys.QualityMetrics->ephys.QualityMetrics.Waveform - - - - -probe.ElectrodeConfig - - -probe.ElectrodeConfig - - - - - -probe.ElectrodeConfig.Electrode - - -probe.ElectrodeConfig.Electrode - - - - - -probe.ElectrodeConfig->probe.ElectrodeConfig.Electrode - - - - -probe.ElectrodeConfig->ephys.EphysRecording - - - - -ephys.AcquisitionSoftware - - -ephys.AcquisitionSoftware - - - - - -ephys.AcquisitionSoftware->ephys.EphysRecording - - - - -SkullReference - - -SkullReference - - - - - -SkullReference->ephys.InsertionLocation - - - - -ephys.ClusteringParamSet - - -ephys.ClusteringParamSet - - - - - -ephys.ClusteringTask - - -ephys.ClusteringTask - - - - - -ephys.ClusteringParamSet->ephys.ClusteringTask - - - - -ephys.LFP.Electrode - - -ephys.LFP.Electrode - - - - - -ephys.ClusterQualityLabel - - -ephys.ClusterQualityLabel - - - - - -ephys.CuratedClustering.Unit - - -ephys.CuratedClustering.Unit - - - - - -ephys.ClusterQualityLabel->ephys.CuratedClustering.Unit - - - - -ephys.WaveformSet.Waveform - - -ephys.WaveformSet.Waveform - - - - - -ephys.Clustering - - -ephys.Clustering - - - - - -ephys.ClusteringTask->ephys.Clustering - - - - -probe.ProbeType - - -probe.ProbeType - - - - - -probe.ProbeType->probe.ElectrodeConfig - - - - -probe.Probe - - -probe.Probe - - - - - -probe.ProbeType->probe.Probe - - - - -probe.ProbeType.Electrode - - -probe.ProbeType.Electrode - - - - - -probe.ProbeType->probe.ProbeType.Electrode - - - - -ephys.Curation - - -ephys.Curation - - - - - -ephys.Clustering->ephys.Curation - - - - -ephys.LFP - - -ephys.LFP - - - - - -ephys.LFP->ephys.LFP.Electrode - - - - -probe.Probe->ephys.ProbeInsertion - - - - -ephys.CuratedClustering - - -ephys.CuratedClustering - - - - - -ephys.CuratedClustering->ephys.QualityMetrics - - - - -ephys.WaveformSet - - -ephys.WaveformSet - - - - - -ephys.CuratedClustering->ephys.WaveformSet - - - - -ephys.CuratedClustering->ephys.CuratedClustering.Unit - - - - -subject.Subject - - -subject.Subject - - - - - -session.Session - - -session.Session - - - - - -subject.Subject->session.Session - - - - -probe.ElectrodeConfig.Electrode->ephys.LFP.Electrode - - - - -probe.ElectrodeConfig.Electrode->ephys.WaveformSet.Waveform - - - - -probe.ElectrodeConfig.Electrode->ephys.CuratedClustering.Unit - - - - -ephys.Curation->ephys.CuratedClustering - - - - -ephys.ClusteringMethod - - -ephys.ClusteringMethod - - - - - -ephys.ClusteringMethod->ephys.ClusteringParamSet - - - - -ephys.WaveformSet.PeakWaveform - - -ephys.WaveformSet.PeakWaveform - - - - - -session.Session->ephys.ProbeInsertion - - - - -ephys.EphysRecording.EphysFile - - -ephys.EphysRecording.EphysFile - - - - - -ephys.WaveformSet->ephys.WaveformSet.Waveform - - - - -ephys.WaveformSet->ephys.WaveformSet.PeakWaveform - - - - -ephys.CuratedClustering.Unit->ephys.WaveformSet.Waveform - - - - -ephys.CuratedClustering.Unit->ephys.QualityMetrics.Cluster - - - - -ephys.CuratedClustering.Unit->ephys.QualityMetrics.Waveform - - - - -ephys.CuratedClustering.Unit->ephys.WaveformSet.PeakWaveform - - - - -ephys.EphysRecording->ephys.ClusteringTask - - - - -ephys.EphysRecording->ephys.LFP - - - - -ephys.EphysRecording->ephys.EphysRecording.EphysFile - - - - -probe.ProbeType.Electrode->probe.ElectrodeConfig.Electrode - - - - \ No newline at end of file diff --git a/images/attached_array_ephys_element_chronic.svg b/images/attached_array_ephys_element_chronic.svg deleted file mode 100644 index 808a2f17..00000000 --- a/images/attached_array_ephys_element_chronic.svg +++ /dev/null @@ -1,456 +0,0 @@ - - - - - -ephys.Curation - - -ephys.Curation - - - - - -ephys.CuratedClustering - - -ephys.CuratedClustering - - - - - -ephys.Curation->ephys.CuratedClustering - - - - -ephys.AcquisitionSoftware - - -ephys.AcquisitionSoftware - - - - - -ephys.EphysRecording - - -ephys.EphysRecording - - - - - -ephys.AcquisitionSoftware->ephys.EphysRecording - - - - -ephys.ProbeInsertion - - -ephys.ProbeInsertion - - - - - -ephys.ProbeInsertion->ephys.EphysRecording - - - - -ephys.InsertionLocation - - -ephys.InsertionLocation - - - - - -ephys.ProbeInsertion->ephys.InsertionLocation - - - - -subject.Subject - - -subject.Subject - - - - - -subject.Subject->ephys.ProbeInsertion - - - - -session.Session - - -session.Session - - - - - -subject.Subject->session.Session - - - - -ephys.WaveformSet.PeakWaveform - - -ephys.WaveformSet.PeakWaveform - - - - - -ephys.EphysRecording.EphysFile - - -ephys.EphysRecording.EphysFile - - - - - -ephys.EphysRecording->ephys.EphysRecording.EphysFile - - - - -ephys.ClusteringTask - - -ephys.ClusteringTask - - - - - -ephys.EphysRecording->ephys.ClusteringTask - - - - -ephys.LFP - - -ephys.LFP - - - - - -ephys.EphysRecording->ephys.LFP - - - - -probe.Probe - - -probe.Probe - - - - - -probe.Probe->ephys.ProbeInsertion - - - - -ephys.QualityMetrics - - -ephys.QualityMetrics - - - - - -ephys.QualityMetrics.Waveform - - -ephys.QualityMetrics.Waveform - - - - - -ephys.QualityMetrics->ephys.QualityMetrics.Waveform - - - - -ephys.QualityMetrics.Cluster - - -ephys.QualityMetrics.Cluster - - - - - -ephys.QualityMetrics->ephys.QualityMetrics.Cluster - - - - -ephys.ClusteringParamSet - - -ephys.ClusteringParamSet - - - - - -ephys.ClusteringParamSet->ephys.ClusteringTask - - - - -ephys.WaveformSet.Waveform - - -ephys.WaveformSet.Waveform - - - - - -probe.ProbeType - - -probe.ProbeType - - - - - -probe.ProbeType->probe.Probe - - - - -probe.ElectrodeConfig - - -probe.ElectrodeConfig - - - - - -probe.ProbeType->probe.ElectrodeConfig - - - - -probe.ProbeType.Electrode - - -probe.ProbeType.Electrode - - - - - -probe.ProbeType->probe.ProbeType.Electrode - - - - -ephys.Clustering - - -ephys.Clustering - - - - - -ephys.ClusteringTask->ephys.Clustering - - - - -ephys.LFP.Electrode - - -ephys.LFP.Electrode - - - - - -ephys.LFP->ephys.LFP.Electrode - - - - -session.Session->ephys.EphysRecording - - - - -ephys.Clustering->ephys.Curation - - - - -probe.ElectrodeConfig.Electrode - - -probe.ElectrodeConfig.Electrode - - - - - -probe.ElectrodeConfig.Electrode->ephys.WaveformSet.Waveform - - - - -probe.ElectrodeConfig.Electrode->ephys.LFP.Electrode - - - - -ephys.CuratedClustering.Unit - - -ephys.CuratedClustering.Unit - - - - - -probe.ElectrodeConfig.Electrode->ephys.CuratedClustering.Unit - - - - -ephys.WaveformSet - - -ephys.WaveformSet - - - - - -ephys.WaveformSet->ephys.WaveformSet.PeakWaveform - - - - -ephys.WaveformSet->ephys.WaveformSet.Waveform - - - - -probe.ElectrodeConfig->ephys.EphysRecording - - - - -probe.ElectrodeConfig->probe.ElectrodeConfig.Electrode - - - - -probe.ProbeType.Electrode->probe.ElectrodeConfig.Electrode - - - - -ephys.CuratedClustering.Unit->ephys.WaveformSet.PeakWaveform - - - - -ephys.CuratedClustering.Unit->ephys.WaveformSet.Waveform - - - - -ephys.CuratedClustering.Unit->ephys.QualityMetrics.Waveform - - - - -ephys.CuratedClustering.Unit->ephys.QualityMetrics.Cluster - - - - -ephys.ClusteringMethod - - -ephys.ClusteringMethod - - - - - -ephys.ClusteringMethod->ephys.ClusteringParamSet - - - - -ephys.CuratedClustering->ephys.QualityMetrics - - - - -ephys.CuratedClustering->ephys.WaveformSet - - - - -ephys.CuratedClustering->ephys.CuratedClustering.Unit - - - - -ephys.ClusterQualityLabel - - -ephys.ClusterQualityLabel - - - - - -ephys.ClusterQualityLabel->ephys.CuratedClustering.Unit - - - - -SkullReference - - -SkullReference - - - - - -SkullReference->ephys.InsertionLocation - - - - \ No newline at end of file diff --git a/images/attached_array_ephys_element_precluster.svg b/images/attached_array_ephys_element_precluster.svg deleted file mode 100644 index 7d854d2e..00000000 --- a/images/attached_array_ephys_element_precluster.svg +++ /dev/null @@ -1,535 +0,0 @@ - - - - - -ephys.AcquisitionSoftware - - -ephys.AcquisitionSoftware - - - - - -ephys.EphysRecording - - -ephys.EphysRecording - - - - - -ephys.AcquisitionSoftware->ephys.EphysRecording - - - - -ephys.QualityMetrics.Waveform - - -ephys.QualityMetrics.Waveform - - - - - -ephys.PreClusterTask - - -ephys.PreClusterTask - - - - - -ephys.EphysRecording->ephys.PreClusterTask - - - - -ephys.EphysRecording.EphysFile - - -ephys.EphysRecording.EphysFile - - - - - -ephys.EphysRecording->ephys.EphysRecording.EphysFile - - - - -ephys.PreCluster - - -ephys.PreCluster - - - - - -ephys.PreClusterTask->ephys.PreCluster - - - - -probe.ProbeType.Electrode - - -probe.ProbeType.Electrode - - - - - -probe.ElectrodeConfig.Electrode - - -probe.ElectrodeConfig.Electrode - - - - - -probe.ProbeType.Electrode->probe.ElectrodeConfig.Electrode - - - - -ephys.LFP - - -ephys.LFP - - - - - -ephys.PreCluster->ephys.LFP - - - - -ephys.ClusteringTask - - -ephys.ClusteringTask - - - - - -ephys.PreCluster->ephys.ClusteringTask - - - - -ephys.LFP.Electrode - - -ephys.LFP.Electrode - - - - - -probe.ElectrodeConfig.Electrode->ephys.LFP.Electrode - - - - -ephys.CuratedClustering.Unit - - -ephys.CuratedClustering.Unit - - - - - -probe.ElectrodeConfig.Electrode->ephys.CuratedClustering.Unit - - - - -ephys.WaveformSet.Waveform - - -ephys.WaveformSet.Waveform - - - - - -probe.ElectrodeConfig.Electrode->ephys.WaveformSet.Waveform - - - - -ephys.Curation - - -ephys.Curation - - - - - -ephys.CuratedClustering - - -ephys.CuratedClustering - - - - - -ephys.Curation->ephys.CuratedClustering - - - - -probe.ElectrodeConfig - - -probe.ElectrodeConfig - - - - - -probe.ElectrodeConfig->ephys.EphysRecording - - - - -probe.ElectrodeConfig->probe.ElectrodeConfig.Electrode - - - - -ephys.QualityMetrics - - -ephys.QualityMetrics - - - - - -ephys.CuratedClustering->ephys.QualityMetrics - - - - -ephys.WaveformSet - - -ephys.WaveformSet - - - - - -ephys.CuratedClustering->ephys.WaveformSet - - - - -ephys.CuratedClustering->ephys.CuratedClustering.Unit - - - - -ephys.InsertionLocation - - -ephys.InsertionLocation - - - - - -SkullReference - - -SkullReference - - - - - -SkullReference->ephys.InsertionLocation - - - - -ephys.QualityMetrics->ephys.QualityMetrics.Waveform - - - - -ephys.QualityMetrics.Cluster - - -ephys.QualityMetrics.Cluster - - - - - -ephys.QualityMetrics->ephys.QualityMetrics.Cluster - - - - -ephys.PreClusterParamSteps.Step - - -ephys.PreClusterParamSteps.Step - - - - - -ephys.ClusterQualityLabel - - -ephys.ClusterQualityLabel - - - - - -ephys.ClusterQualityLabel->ephys.CuratedClustering.Unit - - - - -session.Session - - -session.Session - - - - - -ephys.ProbeInsertion - - -ephys.ProbeInsertion - - - - - -session.Session->ephys.ProbeInsertion - - - - -ephys.ClusteringMethod - - -ephys.ClusteringMethod - - - - - -ephys.ClusteringParamSet - - -ephys.ClusteringParamSet - - - - - -ephys.ClusteringMethod->ephys.ClusteringParamSet - - - - -ephys.WaveformSet.PeakWaveform - - -ephys.WaveformSet.PeakWaveform - - - - - -ephys.WaveformSet->ephys.WaveformSet.PeakWaveform - - - - -ephys.WaveformSet->ephys.WaveformSet.Waveform - - - - -subject.Subject - - -subject.Subject - - - - - -subject.Subject->session.Session - - - - -ephys.LFP->ephys.LFP.Electrode - - - - -ephys.CuratedClustering.Unit->ephys.QualityMetrics.Waveform - - - - -ephys.CuratedClustering.Unit->ephys.QualityMetrics.Cluster - - - - -ephys.CuratedClustering.Unit->ephys.WaveformSet.PeakWaveform - - - - -ephys.CuratedClustering.Unit->ephys.WaveformSet.Waveform - - - - -ephys.Clustering - - -ephys.Clustering - - - - - -ephys.ClusteringTask->ephys.Clustering - - - - -probe.Probe - - -probe.Probe - - - - - -probe.Probe->ephys.ProbeInsertion - - - - -ephys.PreClusterMethod - - -ephys.PreClusterMethod - - - - - -ephys.PreClusterParamSet - - -ephys.PreClusterParamSet - - - - - -ephys.PreClusterMethod->ephys.PreClusterParamSet - - - - -ephys.ClusteringParamSet->ephys.ClusteringTask - - - - -probe.ProbeType - - -probe.ProbeType - - - - - -probe.ProbeType->probe.ProbeType.Electrode - - - - -probe.ProbeType->probe.ElectrodeConfig - - - - -probe.ProbeType->probe.Probe - - - - -ephys.ProbeInsertion->ephys.EphysRecording - - - - -ephys.ProbeInsertion->ephys.InsertionLocation - - - - -ephys.PreClusterParamSteps - - -ephys.PreClusterParamSteps - - - - - -ephys.PreClusterParamSteps->ephys.PreClusterTask - - - - -ephys.PreClusterParamSteps->ephys.PreClusterParamSteps.Step - - - - -ephys.Clustering->ephys.Curation - - - - -ephys.PreClusterParamSet->ephys.PreClusterParamSteps.Step - - - - \ No newline at end of file From 497110816058ae0655cac8f9414b4622905f78a6 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Thu, 19 Sep 2024 13:29:45 -0500 Subject: [PATCH 145/146] fix: use tempfile.TemporaryDirectory --- element_array_ephys/ephys_report.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/element_array_ephys/ephys_report.py b/element_array_ephys/ephys_report.py index c962d33d..0c6836a0 100644 --- a/element_array_ephys/ephys_report.py +++ b/element_array_ephys/ephys_report.py @@ -2,6 +2,7 @@ import datetime import pathlib +import tempfile from uuid import UUID import datajoint as dj @@ -53,7 +54,7 @@ class ProbeLevelReport(dj.Computed): def make(self, key): from .plotting.probe_level import plot_driftmap - save_dir = _make_save_dir() + save_dir = tempfile.TemporaryDirectory() units = ephys.CuratedClustering.Unit & key & "cluster_quality_label='good'" @@ -88,13 +89,15 @@ def make(self, key): fig_dict = _save_figs( figs=(fig,), fig_names=("drift_map_plot",), - save_dir=save_dir, + save_dir=save_dir.name, fig_prefix=fig_prefix, extension=".png", ) self.insert1({**key, **fig_dict, "shank": shank_no}) + save_dir.cleanup() + @schema class UnitLevelReport(dj.Computed): @@ -266,17 +269,10 @@ def make(self, key): ) -def _make_save_dir(root_dir: pathlib.Path = None) -> pathlib.Path: - if root_dir is None: - root_dir = pathlib.Path().absolute() - save_dir = root_dir / "temp_ephys_figures" - save_dir.mkdir(parents=True, exist_ok=True) - return save_dir - - def _save_figs( figs, fig_names, save_dir, fig_prefix, extension=".png" ) -> dict[str, pathlib.Path]: + save_dir = pathlib.Path(save_dir) fig_dict = {} for fig, fig_name in zip(figs, fig_names): fig_filepath = save_dir / (fig_prefix + "_" + fig_name + extension) From 63df4cda7d5ab97d1d52195acf0d7031fe2496f6 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Thu, 19 Sep 2024 16:27:13 -0500 Subject: [PATCH 146/146] format: black --- element_array_ephys/ephys.py | 9 +++++++-- element_array_ephys/spike_sorting/si_spike_sorting.py | 8 ++++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/element_array_ephys/ephys.py b/element_array_ephys/ephys.py index 02e1366e..ad9bb8d7 100644 --- a/element_array_ephys/ephys.py +++ b/element_array_ephys/ephys.py @@ -1068,9 +1068,14 @@ def make(self, key): } spike_locations = sorting_analyzer.get_extension("spike_locations") - extremum_channel_inds = si.template_tools.get_template_extremum_channel(sorting_analyzer, outputs="index") + extremum_channel_inds = si.template_tools.get_template_extremum_channel( + sorting_analyzer, outputs="index" + ) spikes_df = pd.DataFrame( - sorting_analyzer.sorting.to_spike_vector(extremum_channel_inds=extremum_channel_inds)) + sorting_analyzer.sorting.to_spike_vector( + extremum_channel_inds=extremum_channel_inds + ) + ) units = [] for unit_idx, unit_id in enumerate(si_sorting.unit_ids): diff --git a/element_array_ephys/spike_sorting/si_spike_sorting.py b/element_array_ephys/spike_sorting/si_spike_sorting.py index 547fd8ce..e2f011e1 100644 --- a/element_array_ephys/spike_sorting/si_spike_sorting.py +++ b/element_array_ephys/spike_sorting/si_spike_sorting.py @@ -114,7 +114,9 @@ def make(self, key): spikeglx_recording.validate_file("ap") data_dir = spikeglx_meta_filepath.parent - si_extractor = si.extractors.neoextractors.spikeglx.SpikeGLXRecordingExtractor + si_extractor = ( + si.extractors.neoextractors.spikeglx.SpikeGLXRecordingExtractor + ) stream_names, stream_ids = si.extractors.get_neo_streams( "spikeglx", folder_path=data_dir ) @@ -125,7 +127,9 @@ def make(self, key): oe_probe = ephys.get_openephys_probe_data(key) assert len(oe_probe.recording_info["recording_files"]) == 1 data_dir = oe_probe.recording_info["recording_files"][0] - si_extractor = si.extractors.neoextractors.openephys.OpenEphysBinaryRecordingExtractor + si_extractor = ( + si.extractors.neoextractors.openephys.OpenEphysBinaryRecordingExtractor + ) stream_names, stream_ids = si.extractors.get_neo_streams( "openephysbinary", folder_path=data_dir