diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index 7cb9a7b..6c34fc5 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -1,7 +1,7 @@ -FROM python:3.9-slim@sha256:5f0192a4f58a6ce99f732fe05e3b3d00f12ae62e183886bca3ebe3d202686c7f +FROM python:3.10-slim@sha256:5f0192a4f58a6ce99f732fe05e3b3d00f12ae62e183886bca3ebe3d202686c7f ENV PATH /usr/local/bin:$PATH -ENV PYTHON_VERSION 3.9.17 +ENV PYTHON_VERSION 3.10.13 RUN \ adduser --system --disabled-password --shell /bin/bash vscode && \ @@ -32,11 +32,13 @@ COPY ./ /tmp/element-miniscope/ RUN \ # pipeline dependencies apt-get install gcc g++ ffmpeg libsm6 libxext6 -y && \ - pip install numpy Cython && \ - pip install --no-cache-dir -e /tmp/element-miniscope[elements,caiman_requirements,caiman] && \ - caimanmanager.py install && \ + pip install --no-cache-dir -e /tmp/element-miniscope[elements] +RUN cd ./tmp && git clone https://github.com/datajoint/CaImAn.git && cd ./CaImAn && \ + pip install -r requirements.txt && pip install -e . && \ + caimanmanager install --inplace && cd ~ && \ # clean up rm -rf /tmp/element-miniscope && \ + rm -rf /tmp/CaImAn && \ apt-get clean ENV DJ_HOST fakeservices.datajoint.io diff --git a/element_miniscope/miniscope.py b/element_miniscope/miniscope.py index d5c34ba..560b7fc 100644 --- a/element_miniscope/miniscope.py +++ b/element_miniscope/miniscope.py @@ -1,21 +1,24 @@ import csv +import cv2 import importlib import inspect import json import pathlib from datetime import datetime +from typing import Union -import cv2 import datajoint as dj import numpy as np +import pandas as pd from element_interface.utils import dict_to_uuid, find_full_path, find_root_directory from . import miniscope_report +logger = dj.logger + schema = dj.Schema() _linking_module = None -logger = dj.logger def activate( @@ -91,6 +94,25 @@ def get_miniscope_root_data_dir() -> list: return root_directories +def get_processed_root_data_dir() -> Union[str, pathlib.Path]: + """Retrieve the root directory for all processed data. + + All data paths and directories in DataJoint Elements are recommended to be stored as + relative paths (posix format), with respect to some user-configured "root" + directory, which varies from machine to machine (e.g. different mounted drive + locations). + + Returns: + dir (str| pathlib.Path): Absolute path of the processed miniscope root data + directory. + """ + + if hasattr(_linking_module, "get_processed_root_data_dir"): + return _linking_module.get_processed_root_data_dir() + else: + return get_miniscope_root_data_dir()[0] + + def get_session_directory(session_key: dict) -> str: """Pulls session directory information from database. @@ -110,13 +132,15 @@ def get_session_directory(session_key: dict) -> str: class AcquisitionSoftware(dj.Lookup): """Software used for miniscope acquisition. + Required to define a miniscope recording. + Attributes: - acq_software (varchar(24) ): Name of the miniscope acquisition software.""" + acq_software (str): Name of the miniscope acquisition software.""" definition = """ acq_software: varchar(24) """ - contents = zip(["Miniscope-DAQ-V3", "Miniscope-DAQ-V4"]) + contents = zip(["Miniscope-DAQ-V3", "Miniscope-DAQ-V4", "Inscopix"]) @schema @@ -135,21 +159,21 @@ class Channel(dj.Lookup): @schema class Recording(dj.Manual): - """Table for discrete recording sessions with the miniscope. + """Recording defined by a measurement done using a scanner and an acquisition software. Attributes: - Session (foreign key): Session primary key. - recording_id (foreign key, int): Unique recording ID. - Device: Lookup table for miniscope device information. - AcquisitionSoftware: Lookup table for miniscope acquisition software. - recording_notes (varchar(4095) ): notes about the recording session. + Session (foreign key): A primary key from Session. + recording_id (int): Unique recording ID. + Device (foreign key, optional): A primary key from Device. + AcquisitionSoftware (foreign key): A primary key from AcquisitionSoftware. + recording_notes (str, optional): notes about the recording session. """ definition = """ -> Session recording_id: int --- - -> Device + -> [nullable] Device -> AcquisitionSoftware recording_notes='' : varchar(4095) # free-notes """ @@ -160,8 +184,8 @@ class RecordingLocation(dj.Manual): """Brain location where the miniscope recording is acquired. Attributes: - Recording (foreign key): Recording primary key. - Anatomical Location: Select the anatomical region where recording was acquired. + Recording (foreign key): A primary key from Recording. + AnatomicalLocation (foreign key): A primary key from AnatomicalLocation. """ definition = """ @@ -174,10 +198,10 @@ class RecordingLocation(dj.Manual): @schema class RecordingInfo(dj.Imported): - """Automated table with recording metadata. + """Information about the recording extracted from the recorded files. Attributes: - Recording (foreign key): Recording primary key. + Recording (foreign key): A primary key from Recording. nchannels (tinyint): Number of recording channels. nframes (int): Number of recorded frames. px_height (smallint): Height in pixels. @@ -188,7 +212,6 @@ class RecordingInfo(dj.Imported): gain (float): Recording gain. spatial_downsample (tinyint): Amount of downsampling applied. led_power (float): LED power used for the recording. - time_stamps (longblob): Time stamps for each frame. recording_datetime (datetime): Datetime of the recording. recording_duration (float): Total recording duration (seconds). """ @@ -199,19 +222,42 @@ class RecordingInfo(dj.Imported): --- nchannels : tinyint # number of channels nframes : int # number of recorded frames - px_height=null : smallint # height in pixels - px_width=null : smallint # width in pixels - um_height=null : float # height in microns - um_width=null : float # width in microns + ndepths=1 : tinyint # number of depths + px_height : smallint # height in pixels + px_width : smallint # width in pixels fps : float # (Hz) frames per second - gain=null : float # recording gain - spatial_downsample=1 : tinyint # e.g. 1, 2, 4, 8. 1 for no downsampling - led_power : float # LED power used in the given recording - time_stamps : longblob # time stamps of each frame recording_datetime=null : datetime # datetime of the recording recording_duration=null : float # (seconds) duration of the recording """ + class Config(dj.Part): + """Recording metadata and configuration. + + Attributes: + Recording (foreign key): A primary key from RecordingInfo. + config (longblob): Recording metadata and configuration. + """ + + definition = """ + -> master + --- + config: longblob # recording metadata and configuration + """ + + class Timestamps(dj.Part): + """Recording timestamps for each frame. + + Attributes: + Recording (foreign key): A primary key from RecordingInfo. + timestamps (longblob): Recording timestamps for each frame. + """ + + definition = """ + -> master + --- + timestamps: longblob + """ + class File(dj.Part): """File path to recording file relative to root data directory. @@ -239,10 +285,11 @@ def make(self, key): get_miniscope_root_data_dir(), recording_directory ) - recording_filepaths = [ - file_path.as_posix() for file_path in recording_path.glob("*.avi") - ] - + recording_filepaths = ( + [file_path.as_posix() for file_path in recording_path.glob("*.avi")] + if acq_software != "Inscopix" + else [file_path.as_posix() for file_path in recording_path.rglob("*.avi")] + ) if not recording_filepaths: raise FileNotFoundError(f"No .avi files found in " f"{recording_directory}") @@ -270,14 +317,15 @@ def make(self, key): fps = video.get(cv2.CAP_PROP_FPS) elif acq_software == "Miniscope-DAQ-V4": - recording_metadata = list(recording_path.glob("metaData.json"))[0] - recording_timestamps = list(recording_path.glob("timeStamps.csv"))[0] - - if not recording_metadata.exists(): + try: + recording_metadata = next(recording_path.glob("metaData.json")) + except StopIteration: raise FileNotFoundError( f"No .json file found in " f"{recording_directory}" ) - if not recording_timestamps.exists(): + try: + recording_timestamps = next(recording_path.glob("timeStamps.csv")) + except StopIteration: raise FileNotFoundError( f"No timestamp (*.csv) file found in " f"{recording_directory}" ) @@ -293,12 +341,21 @@ def make(self, key): px_height = metadata["ROI"]["height"] px_width = metadata["ROI"]["width"] fps = int(metadata["frameRate"].replace("FPS", "")) - gain = metadata["gain"] - spatial_downsample = 1 # Assumes no spatial downsampling - led_power = metadata["led0"] - time_stamps = np.array( - [list(map(int, time_stamps[i])) for i in range(1, len(time_stamps))] - ) + time_stamps = np.array(time_stamps[1:], dtype=float)[:, 0] + + elif acq_software == "Inscopix": + inscopix_metadata = next(recording_path.glob("session.json")) + timestamps_file = next(recording_path.glob("*/*timestamps.csv")) + metadata = json.load(open(inscopix_metadata)) + recording_timestamps = pd.read_csv(timestamps_file) + + nchannels = len(metadata["manual"]["mScope"]["ledMaxPower"]) + nframes = len(recording_timestamps) + fps = metadata["microscope"]["fps"]["fps"] + time_stamps = (recording_timestamps[" time (ms)"] / 1000).values + px_height = metadata["microscope"]["fov"]["height"] + px_width = metadata["microscope"]["fov"]["width"] + else: raise NotImplementedError( f"Loading routine not implemented for {acq_software}" @@ -314,10 +371,6 @@ def make(self, key): px_height=px_height, px_width=px_width, fps=fps, - gain=gain, - spatial_downsample=spatial_downsample, - led_power=led_power, - time_stamps=time_stamps, recording_duration=nframes / fps, ) ) @@ -337,25 +390,32 @@ def make(self, key): ] ) + if acq_software == "Inscopix" or acq_software == "Miniscope-DAQ-V4": + self.Timestamps.insert1(dict(**key, timestamps=time_stamps)) + self.Config.insert1( + dict( + **key, + config=metadata, + ) + ) + # Trigger a processing routine ------------------------------------------------- @schema class ProcessingMethod(dj.Lookup): - """Method or analysis software to process miniscope acquisition. + """Package used for processing of miniscope data (e.g. CaImAn, etc.). Attributes: - processing_method (foreign key, varchar16): Recording processing method (e.g. CaImAn). - processing_method_desc (varchar(1000) ): Additional information about the processing method. + processing_method (str): Processing method. + processing_method_desc (str): Processing method description. """ - definition = """ - # Method, package, analysis software used for processing of miniscope data - # (e.g. CaImAn, etc.) + definition = """# Package used for processing of calcium imaging data (e.g. Suite2p, CaImAn, etc.). processing_method: varchar(16) --- - processing_method_desc='': varchar(1000) + processing_method_desc: varchar(1000) """ contents = [("caiman", "caiman analysis suite")] @@ -363,25 +423,27 @@ class ProcessingMethod(dj.Lookup): @schema class ProcessingParamSet(dj.Lookup): - """Parameters of the processing method. + """Parameter set used for the processing of miniscope recordings., + including both the analysis suite and its respective input parameters. + + A hash of the parameters of the analysis suite is also stored in order + to avoid duplicated entries. Attributes: - paramset_idx (foreign key, smallint): Unique parameter set ID. - ProcessingMethod (varchar(16) ): ProcessingMethod from the lookup table. - paramset_desc (varchar(128) ): Description of the parameter set. - paramset_set_hash (uuid): UUID hash for parameter set. - params (longblob): Dictionary of all parameters for the processing method. + paramset_idx (int): Unique parameter set ID. + ProcessingMethod (foreign key): A primary key from ProcessingMethod. + paramset_desc (str): Parameter set description. + paramset_set_hash (uuid): A universally unique identifier for the parameter set. + params (longblob): Parameter set, a dictionary of all applicable parameters to the analysis suite. """ - definition = """ - # Parameter set used for processing of miniscope data - paramset_idx: smallint + definition = """# Processing Parameter set + paramset_idx: smallint # Unique parameter set ID. --- -> ProcessingMethod - paramset_desc: varchar(128) - param_set_hash: uuid - unique index (param_set_hash) - params: longblob # dictionary of all applicable parameters + paramset_desc: varchar(1280) # Parameter set description + param_set_hash: uuid # A universally unique identifier for the parameter set unique index (param_set_hash) + params: longblob # Parameter set, a dictionary of all applicable parameters to the analysis suite. """ @classmethod @@ -391,7 +453,6 @@ def insert_new_params( paramset_idx: int, paramset_desc: str, params: dict, - processing_method_desc: str = "", ): """Insert new parameter set. @@ -407,7 +468,11 @@ def insert_new_params( """ ProcessingMethod.insert1( - {"processing_method": processing_method}, skip_duplicates=True + { + "processing_method": processing_method, + "processing_method_desc": "caiman analysis", + }, + skip_duplicates=True, ) param_dict = { "processing_method": processing_method, @@ -447,17 +512,23 @@ class MaskType(dj.Lookup): @schema class ProcessingTask(dj.Manual): - """Table marking manual or automatic processing task. + """A pairing of processing params and recordings to be loaded or triggered. + + This table defines a miniscope recording processing task for a combination of a + `Recording` and a `ProcessingParamSet` entries, including all the inputs (recording, method, + method's parameters). The task defined here is then run in the downstream table + `Processing`. This table supports definitions of both loading of pre-generated results + and the triggering of new analysis for all supported analysis methods. Attributes: - RecordingInfo (foreign key): Recording info primary key. - ProcessingParamSet (foreign key): Processing param set primary key. - processing_output_dir (varchar(255) ): relative output data directory for processed files. - task_mode (enum): `Load` existing results or `trigger` new processing task. + RecordingInfo (foreign key): Primary key from RecordingInfo. + ProcessingParamSet (foreign key): Primary key from ProcessingParamSet. + processing_output_dir (str): Output directory of the processed scan relative to the root data directory. + task_mode (str): One of 'load' (load computed analysis results) or 'trigger' + (trigger computation). """ - definition = """ - # Manual table marking a processing task to be triggered or manually processed + definition = """# Manual table for defining a processing task ready to be run -> RecordingInfo -> ProcessingParamSet --- @@ -466,30 +537,143 @@ class ProcessingTask(dj.Manual): # 'trigger': trigger procedure """ + @classmethod + def infer_output_dir(cls, key, relative=False, mkdir=False): + """Infer an output directory for an entry in ProcessingTask table. + + Args: + key (dict): Primary key from the ProcessingTask table. + relative (bool): If True, processing_output_dir is returned relative to + imaging_root_dir. Default False. + mkdir (bool): If True, create the processing_output_dir directory. + Default True. + + Returns: + dir (str): A default output directory for the processed results (processed_output_dir + in ProcessingTask) based on the following convention: + processed_dir / scan_dir / {processing_method}_{paramset_idx} + e.g.: sub4/sess1/scan0/suite2p_0 + """ + acq_software = (Recording & key).fetch1("acq_software") + recording_dir = find_full_path( + get_miniscope_root_data_dir(), + get_session_directory(key)[0], + ) + root_dir = find_root_directory(get_miniscope_root_data_dir(), recording_dir) + + method = ( + (ProcessingParamSet & key).fetch1("processing_method").replace(".", "-") + ) + + processed_dir = pathlib.Path(get_processed_root_data_dir()) + output_dir = ( + processed_dir + / recording_dir.relative_to(root_dir) + / f'{method}_{key["paramset_idx"]}' + ) + + if mkdir: + output_dir.mkdir(parents=True, exist_ok=True) + + return output_dir.relative_to(processed_dir) if relative else output_dir + + @classmethod + def generate(cls, recording_key, paramset_idx=0): + """Generate a ProcessingTask for a Recording using an parameter ProcessingParamSet + + Generate an entry in the ProcessingTask table for a particular recording using an + existing parameter set from the ProcessingParamSet table. + + Args: + recording_key (dict): Primary key from Recording. + paramset_idx (int): Unique parameter set ID. + """ + key = {**recording_key, "paramset_idx": paramset_idx} + + processed_dir = get_processed_root_data_dir() + output_dir = cls.infer_output_dir(key, relative=False, mkdir=True) + + method = (ProcessingParamSet & {"paramset_idx": paramset_idx}).fetch1( + "processing_method" + ) + + try: + if method == "caiman": + from element_interface import caiman_loader + + caiman_loader.CaImAn(output_dir) + else: + raise NotImplementedError( + "Unknown/unimplemented method: {}".format(method) + ) + except FileNotFoundError: + task_mode = "trigger" + else: + task_mode = "load" + + cls.insert1( + { + **key, + "processing_output_dir": output_dir.relative_to( + processed_dir + ).as_posix(), + "task_mode": task_mode, + } + ) + + auto_generate_entries = generate + @schema class Processing(dj.Computed): - """Automatic table that beings the miniscope processing pipeline. + """Perform the computation of an entry (task) defined in the ProcessingTask table. + The computation is performed only on the recordings with RecordingInfo inserted. + Attributes: - ProcessingTask (foreign key): Processing task primary key. - processing_time (datetime): Generates time of the processed results. - package_version (varchar(16) ): Package version information. + ProcessingTask (foreign key): Primary key from ProcessingTask. + processing_time (datetime): Process completion datetime. + package_version (str, optional): Version of the analysis package used in processing the data. """ definition = """ -> ProcessingTask --- - processing_time : datetime # generation time of processed, segmented results + processing_time : datetime # generation time of processed results package_version='' : varchar(16) """ + @property + def key_source(self): + return ProcessingTask & RecordingInfo + def make(self, key): - """Triggers processing and populates Processing table.""" - task_mode = (ProcessingTask & key).fetch1("task_mode") + """ + Execute the miniscope analysis defined by the ProcessingTask. + - task_mode: 'load', confirm that the results are already computed. + - task_mode: 'trigger' runs the analysis. + """ + task_mode, output_dir = (ProcessingTask & key).fetch1( + "task_mode", "processing_output_dir" + ) - output_dir = (ProcessingTask & key).fetch1("processing_output_dir") - output_dir = find_full_path(get_miniscope_root_data_dir(), output_dir) + if not output_dir: + output_dir = ProcessingTask.infer_output_dir(key, relative=True, mkdir=True) + # update processing_output_dir + ProcessingTask.update1( + {**key, "processing_output_dir": output_dir.as_posix()} + ) + try: + output_dir = find_full_path( + get_miniscope_root_data_dir(), output_dir + ).as_posix() + except FileNotFoundError as e: + if task_mode == "trigger": + processed_dir = pathlib.Path(get_processed_root_data_dir()) + output_dir = processed_dir / output_dir + output_dir.mkdir(parents=True, exist_ok=True) + else: + raise e if task_mode == "load": method, loaded_result = get_loader_result(key, ProcessingTask) @@ -569,11 +753,11 @@ class Curation(dj.Manual): Attributes: Processing (foreign key): Processing primary key. - curation_id (foreign key, int): Unique curation ID. + curation_id (int): Unique curation ID. curation_time (datetime): Time of generation of curated results. - curation_output_dir (varchar(255) ): Output directory for curated results. + curation_output_dir (str): Output directory for curated results. manual_curation (bool): If True, manual curation has been performed. - curation_note (varchar(2000) ): Optional description of the curation procedure. + curation_note (str, optional): Optional description of the curation procedure. """ definition = """ @@ -643,100 +827,115 @@ class MotionCorrection(dj.Imported): """ class RigidMotionCorrection(dj.Part): - """Automated table with ridge motion correction data. + """Details of rigid motion correction performed on the imaging data. Attributes: - MotionCorrection (foreign key): MotionCorrection primary key. - outlier_frames (longblob): Mask with true for frames with outlier shifts. - y_shifts (longblob): y motion correction shifts, pixels. - x_shifts (longblob): x motion correction shifts, pixels. - y_std (float): Standard deviation of y shifts across all frames, pixels. - x_std (float): Standard deviation of x shifts across all frames, pixels. + MotionCorrection (foreign key): Primary key from MotionCorrection. + outlier_frames (longblob): Mask with true for frames with outlier shifts + (already corrected). + y_shifts (longblob): y motion correction shifts (pixels). + x_shifts (longblob): x motion correction shifts (pixels). + z_shifts (longblob, optional): z motion correction shifts (z-drift, pixels). + y_std (float): standard deviation of y shifts across all frames (pixels). + x_std (float): standard deviation of x shifts across all frames (pixels). + z_std (float, optional): standard deviation of z shifts across all frames + (pixels). """ - definition = """ + definition = """# Details of rigid motion correction performed on the imaging data -> master --- - outlier_frames=null : longblob # mask with true for frames with outlier shifts - # (already corrected) + outlier_frames=null : longblob # mask with true for frames with outlier shifts (already corrected) y_shifts : longblob # (pixels) y motion correction shifts x_shifts : longblob # (pixels) x motion correction shifts - y_std : float # (pixels) standard deviation of - # y shifts across all frames - x_std : float # (pixels) standard deviation of - # x shifts across all frames + z_shifts=null : longblob # (pixels) z motion correction shifts (z-drift) + y_std : float # (pixels) standard deviation of y shifts across all frames + x_std : float # (pixels) standard deviation of x shifts across all frames + z_std=null : float # (pixels) standard deviation of z shifts across all frames """ class NonRigidMotionCorrection(dj.Part): - """Automated table with piece-wise rigid motion correction data. + """Piece-wise rigid motion correction - tile the FOV into multiple 3D + blocks/patches. Attributes: - MotionCorrection (foreign key): MotionCorrection primary key. - outlier_frames (longblob): Mask with true for frames with outlier shifts (already corrected). - block_height (int): Height in pixels. - block_width (int): Width in pixels. + MotionCorrection (foreign key): Primary key from MotionCorrection. + outlier_frames (longblob, null): Mask with true for frames with outlier + shifts (already corrected). + block_height (int): Block height in pixels. + block_width (int): Block width in pixels. + block_depth (int): Block depth in pixels. block_count_y (int): Number of blocks tiled in the y direction. block_count_x (int): Number of blocks tiled in the x direction. + block_count_z (int): Number of blocks tiled in the z direction. """ - definition = """ + definition = """# Details of non-rigid motion correction performed on the imaging data -> master --- - outlier_frames=null : longblob # mask with true for frames with - # outlier shifts (already corrected) - block_height : int # (pixels) - block_width : int # (pixels) - block_count_y : int # number of blocks tiled in the - # y direction - block_count_x : int # number of blocks tiled in the - # x direction + outlier_frames=null : longblob # mask with true for frames with outlier shifts (already corrected) + block_height : int # (pixels) + block_width : int # (pixels) + block_depth : int # (pixels) + block_count_y : int # number of blocks tiled in the y direction + block_count_x : int # number of blocks tiled in the x direction + block_count_z : int # number of blocks tiled in the z direction """ class Block(dj.Part): - """Automated table with data for blocks used in non-rigid motion correction. + """FOV-tiled blocks used for non-rigid motion correction. Attributes: - master.NonRigidMotionCorrection (foreign key): NonRigidMotionCorrection primary key. - block_id (foreign key, int): Unique ID for each block. - block_y (longblob): y_start and y_end of this block in pixels. - block_x (longblob): x_start and x_end of this block in pixels. - y_shifts (longblob): y motion correction shifts for every frame in pixels. - x_shifts (longblob): x motion correction shifts for every frame in pixels. - y_std (float): standard deviation of y shifts across all frames in pixels. - x_std (float): standard deviation of x shifts across all frames in pixels. + NonRigidMotionCorrection (foreign key): Primary key from + NonRigidMotionCorrection. + block_id (int): Unique block ID. + block_y (longblob): y_start and y_end in pixels for this block + block_x (longblob): x_start and x_end in pixels for this block + block_z (longblob): z_start and z_end in pixels for this block + y_shifts (longblob): y motion correction shifts for every frame in pixels + x_shifts (longblob): x motion correction shifts for every frame in pixels + z_shift=null (longblob, optional): x motion correction shifts for every frame + in pixels + y_std (float): standard deviation of y shifts across all frames in pixels + x_std (float): standard deviation of x shifts across all frames in pixels + z_std=null (float, optional): standard deviation of z shifts across all frames + in pixels """ - definition = """ # FOV-tiled blocks used for non-rigid motion correction + definition = """# FOV-tiled blocks used for non-rigid motion correction -> master.NonRigidMotionCorrection - block_id : int + block_id : int --- - block_y : longblob # (y_start, y_end) in pixel of this block - block_x : longblob # (x_start, x_end) in pixel of this block - y_shifts : longblob # (pixels) y motion correction shifts for every frame - x_shifts : longblob # (pixels) x motion correction shifts for every frame - y_std : float # (pixels) standard deviation of y shifts across all frames - x_std : float # (pixels) standard deviation of x shifts across all frames + block_y : longblob # (y_start, y_end) in pixel of this block + block_x : longblob # (x_start, x_end) in pixel of this block + block_z : longblob # (z_start, z_end) in pixel of this block + y_shifts : longblob # (pixels) y motion correction shifts for every frame + x_shifts : longblob # (pixels) x motion correction shifts for every frame + z_shifts=null : longblob # (pixels) z motion correction shifts for every frame + y_std : float # (pixels) standard deviation of y shifts across all frames + x_std : float # (pixels) standard deviation of x shifts across all frames + z_std=null : float # (pixels) standard deviation of z shifts across all frames """ class Summary(dj.Part): - """A summary image for each field and channel after motion correction. + """Summary images for each field and channel after corrections. Attributes: - MotionCorrection (foreign key): MotionCorrection primary key. - ref_image (longblob): Image used as the alignment template. + MotionCorrection (foreign key): Primary key from MotionCorrection. + ref_image (longblob): Image used as alignment template. average_image (longblob): Mean of registered frames. - correlation_image (longblob): Correlation map computed during cell detection. - max_proj_image (longblob): Maximum of registered frames. + correlation_image (longblob, optional): Correlation map (computed during + cell detection). + max_proj_image (longblob, optional): Max of registered frames. """ - definition = """ # summary images for each field and channel after corrections + definition = """# Summary images for each field and channel after corrections -> master --- - ref_image=null : longblob # image used as alignment template - average_image : longblob # mean of registered frames - correlation_image=null : longblob # correlation map - # (computed during cell detection) - max_proj_image=null : longblob # max of registered frames + ref_image : longblob # image used as alignment template + average_image : longblob # mean of registered frames + correlation_image=null : longblob # correlation map (computed during cell detection) + max_proj_image=null : longblob # max of registered frames """ def make(self, key): @@ -744,104 +943,38 @@ def make(self, key): method, loaded_result = get_loader_result(key, ProcessingTask) if method == "caiman": - loaded_caiman = loaded_result + caiman_dataset = loaded_result self.insert1( - {**key, "motion_correct_channel": loaded_caiman.alignment_channel} + {**key, "motion_correct_channel": caiman_dataset.alignment_channel} ) # -- rigid motion correction -- - if not loaded_caiman.params.motion["pw_rigid"]: - rigid_correction = { - **key, - "x_shifts": loaded_caiman.motion_correction["shifts_rig"][:, 0], - "y_shifts": loaded_caiman.motion_correction["shifts_rig"][:, 1], - "x_std": np.nanstd( - loaded_caiman.motion_correction["shifts_rig"][:, 0] - ), - "y_std": np.nanstd( - loaded_caiman.motion_correction["shifts_rig"][:, 1] - ), - "outlier_frames": None, - } - - self.RigidMotionCorrection.insert1(rigid_correction) - - # -- non-rigid motion correction -- - else: - nonrigid_correction = { - **key, - "block_height": ( - loaded_caiman.params.motion["strides"][0] - + loaded_caiman.params.motion["overlaps"][0] - ), - "block_width": ( - loaded_caiman.params.motion["strides"][1] - + loaded_caiman.params.motion["overlaps"][1] - ), - "block_count_x": len( - set(loaded_caiman.motion_correction["coord_shifts_els"][:, 0]) - ), - "block_count_y": len( - set(loaded_caiman.motion_correction["coord_shifts_els"][:, 2]) - ), - "outlier_frames": None, - } - - nonrigid_blocks = [] - for b_id in range( - len(loaded_caiman.motion_correction["x_shifts_els"][0, :]) - ): - nonrigid_blocks.append( - { - **key, - "block_id": b_id, - "block_x": np.arange( - *loaded_caiman.motion_correction["coord_shifts_els"][ - b_id, 0:2 - ] - ), - "block_y": np.arange( - *loaded_caiman.motion_correction["coord_shifts_els"][ - b_id, 2:4 - ] - ), - "x_shifts": loaded_caiman.motion_correction["x_shifts_els"][ - :, b_id - ], - "y_shifts": loaded_caiman.motion_correction["y_shifts_els"][ - :, b_id - ], - "x_std": np.nanstd( - loaded_caiman.motion_correction["x_shifts_els"][:, b_id] - ), - "y_std": np.nanstd( - loaded_caiman.motion_correction["y_shifts_els"][:, b_id] - ), - } - ) - + if caiman_dataset.is_pw_rigid: + # -- non-rigid motion correction -- + ( + nonrigid_correction, + nonrigid_blocks, + ) = caiman_dataset.extract_pw_rigid_mc() + nonrigid_correction.update(**key) + nonrigid_blocks.update(**key) self.NonRigidMotionCorrection.insert1(nonrigid_correction) self.Block.insert(nonrigid_blocks) + else: + # -- rigid motion correction -- + rigid_correction = caiman_dataset.extract_rigid_mc() + rigid_correction.update(**key) + self.RigidMotionCorrection.insert1(rigid_correction) # -- summary images -- summary_images = { **key, - "ref_image": loaded_caiman.motion_correction["reference_image"][...][ - np.newaxis, ... - ], - "average_image": loaded_caiman.motion_correction["average_image"][...][ - np.newaxis, ... - ], - "correlation_image": loaded_caiman.motion_correction[ - "correlation_image" - ][...][np.newaxis, ...], - "max_proj_image": loaded_caiman.motion_correction["max_image"][...][ - np.newaxis, ... - ], + "ref_image": caiman_dataset.ref_image, + "average_image": caiman_dataset.mean_image, + "correlation_image": caiman_dataset.correlation_map, + "max_proj_image": caiman_dataset.max_proj_image, } - - self.Summary.insert1(summary_images) + self.Summary.insert(summary_images) else: raise NotImplementedError("Unknown/unimplemented method: {}".format(method)) @@ -863,31 +996,36 @@ class Segmentation(dj.Computed): """ class Mask(dj.Part): - """Image masks produced during segmentation. + """Details of the masks identified from the Segmentation procedure. Attributes: - Segmentation (foreign key): Segmentation primary key. - mask (smallint): Unique ID for each mask. - channel.proj(segmentation_channel='channel') (query): Channel to be used for segmentation. - mask_npix (int): Number of pixels in the mask. - mask_center_x (int): Center x coordinate in pixels. - mask_center_y (int): Center y coordinate in pixels. - mask_xpix (longblob): x coordinates of the mask in pixels. - mask_ypix (longblob): y coordinates of the mask in pixels. - mask_weights (longblob): weights of the mask at the indices above. + Segmentation (foreign key): Primary key from Segmentation. + mask (int): Unique mask ID. + Channel.proj(segmentation_channel='channel') (foreign key): Channel + used for segmentation. + mask_npix (int): Number of pixels in ROIs. + mask_center_x (int): Center x coordinate in pixel. + mask_center_y (int): Center y coordinate in pixel. + mask_center_z (int): Center z coordinate in pixel. + mask_xpix (longblob): X coordinates in pixels. + mask_ypix (longblob): Y coordinates in pixels. + mask_zpix (longblob): Z coordinates in pixels. + mask_weights (longblob): Weights of the mask at the indices above. """ definition = """ # A mask produced by segmentation. -> master - mask : smallint + mask : smallint --- -> Channel.proj(segmentation_channel='channel') # channel used for segmentation - mask_npix : int # number of pixels in this mask - mask_center_x=null : int # (pixels) center x coordinate - mask_center_y=null : int # (pixels) center y coordinate - mask_xpix=null : longblob # (pixels) x coordinates - mask_ypix=null : longblob # (pixels) y coordinates - mask_weights : longblob # weights of the mask at the indices above + mask_npix : int # number of pixels in ROIs + mask_center_x : int # center x coordinate in pixel + mask_center_y : int # center y coordinate in pixel + mask_center_z=null : int # center z coordinate in pixel + mask_xpix : longblob # x coordinates in pixels + mask_ypix : longblob # y coordinates in pixels + mask_zpix=null : longblob # z coordinates in pixels + mask_weights : longblob # weights of the mask at the indices above """ def make(self, key): @@ -895,53 +1033,50 @@ def make(self, key): method, loaded_result = get_loader_result(key, Curation) if method == "caiman": - loaded_caiman = loaded_result + caiman_dataset = loaded_result - # infer `segmentation_channel` from `params` if available, - # else from caiman loader + # infer "segmentation_channel" - from params if available, else from caiman loader params = (ProcessingParamSet * ProcessingTask & key).fetch1("params") segmentation_channel = params.get( - "segmentation_channel", loaded_caiman.segmentation_channel + "segmentation_channel", caiman_dataset.segmentation_channel ) masks, cells = [], [] - for mask in loaded_caiman.masks: - # Sample data had _id key, not mask. Permitting both - mask_id = mask.get("mask", mask["mask_id"]) + for mask in caiman_dataset.masks: masks.append( { **key, "segmentation_channel": segmentation_channel, - "mask": mask_id, + "mask": mask["mask_id"], "mask_npix": mask["mask_npix"], "mask_center_x": mask["mask_center_x"], "mask_center_y": mask["mask_center_y"], + "mask_center_z": mask["mask_center_z"], "mask_xpix": mask["mask_xpix"], "mask_ypix": mask["mask_ypix"], + "mask_zpix": mask["mask_zpix"], "mask_weights": mask["mask_weights"], } ) - - if loaded_caiman.cnmf.estimates.idx_components is not None: - if mask_id in loaded_caiman.cnmf.estimates.idx_components: - cells.append( - { - **key, - "mask_classification_method": "caiman_default_classifier", - "mask": mask_id, - "mask_type": "soma", - } - ) - - if not all([all(m.values()) for m in masks]): - logger.warning("Could not load all pixel values for at least one mask") + if mask["accepted"]: + cells.append( + { + **key, + "mask_classification_method": "caiman_default_classifier", + "mask": mask["mask_id"], + "mask_type": "soma", + } + ) self.insert1(key) self.Mask.insert(masks, ignore_extra_fields=True) if cells: MaskClassification.insert1( - {**key, "mask_classification_method": "caiman_default_classifier"}, + { + **key, + "mask_classification_method": "caiman_default_classifier", + }, allow_direct_insert=True, ) MaskClassification.MaskType.insert( @@ -1048,27 +1183,27 @@ def make(self, key): method, loaded_result = get_loader_result(key, Curation) if method == "caiman": - loaded_caiman = loaded_result + caiman_dataset = loaded_result - # infer `segmentation_channel` from `params` if available, - # else from caiman loader + # infer "segmentation_channel" - from params if available, else from caiman loader params = (ProcessingParamSet * ProcessingTask & key).fetch1("params") segmentation_channel = params.get( - "segmentation_channel", loaded_caiman.segmentation_channel + "segmentation_channel", caiman_dataset.segmentation_channel ) - self.insert1(key) - self.Trace.insert( - [ + fluo_traces = [] + for mask in caiman_dataset.masks: + fluo_traces.append( { **key, - "mask": mask.get("mask", mask["mask_id"]), - "fluorescence_channel": segmentation_channel, + "mask": mask["mask_id"], + "fluo_channel": segmentation_channel, "fluorescence": mask["inferred_trace"], } - for mask in loaded_caiman.masks - ] - ) + ) + + self.insert1(key) + self.Trace.insert(fluo_traces) else: raise NotImplementedError("Unknown/unimplemented method: {}".format(method)) @@ -1138,31 +1273,32 @@ def make(self, key): method, loaded_result = get_loader_result(key, Curation) if method == "caiman": - loaded_caiman = loaded_result - - if key["extraction_method"] in ("caiman_deconvolution", "caiman_dff"): - attr_mapper = {"caiman_deconvolution": "spikes", "caiman_dff": "dff"} + caiman_dataset = loaded_result + + if key["extraction_method"] in ( + "caiman_deconvolution", + "caiman_dff", + ): + attr_mapper = { + "caiman_deconvolution": "spikes", + "caiman_dff": "dff", + } - # infer `segmentation_channel` from `params` if available, - # else from caiman loader + # infer "segmentation_channel" - from params if available, else from caiman loader params = (ProcessingParamSet * ProcessingTask & key).fetch1("params") segmentation_channel = params.get( - "segmentation_channel", loaded_caiman.segmentation_channel + "segmentation_channel", caiman_dataset.segmentation_channel ) self.insert1(key) self.Trace.insert( - [ - { - **key, - "mask": mask.get("mask", mask["mask_id"]), - "fluorescence_channel": segmentation_channel, - "activity_trace": mask[ - attr_mapper[key["extraction_method"]] - ], - } - for mask in loaded_caiman.masks - ] + dict( + key, + mask=mask["mask_id"], + fluo_channel=segmentation_channel, + activity_trace=mask[attr_mapper[key["extraction_method"]]], + ) + for mask in caiman_dataset.masks ) else: diff --git a/install_caiman.py b/install_caiman.py new file mode 100644 index 0000000..91d5d44 --- /dev/null +++ b/install_caiman.py @@ -0,0 +1,39 @@ +import os +import subprocess +import sys + +def run_command(command): + """Run a system command and ensure it completes successfully.""" + result = subprocess.run(command, shell=True) + if result.returncode != 0: + print(f"Command failed with return code {result.returncode}: {command}") + sys.exit(result.returncode) + +def main(env_name="element-miniscope-env"): + conda_executable = 'conda' + mamba_executable = 'mamba' + + # Step 1: Create the Conda Environment + print(f"Creating conda environment: {env_name}") + run_command(f"{conda_executable} create -n {env_name} -y") + + # Step 2: Activate the Environment + print(f"Activating conda environment: {env_name}") + run_command(f"{conda_executable} activate base") + run_command(f"{conda_executable} activate C:/Users/kusha/miniconda3/envs/{env_name}") + + # Step 3: Install CaImAn and its dependencies + print("Installing CaImAn and its dependencies") + run_command(f"{conda_executable} install -c conda-forge mamba -y") + run_command(f"{mamba_executable} install -c conda-forge python==3.10 -y") + run_command(f"{mamba_executable} install -c conda-forge caiman -y") + run_command("pip install keras==2.15.0") + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Install CaImAn and element-miniscope dependencies.") + parser.add_argument('-e', '--env', type=str, default="element-miniscope-env", help="Name of the conda environment to create and use. Default is 'element-miniscope-env'.") + args = parser.parse_args() + + main(args.env) diff --git a/setup.py b/setup.py index 6a27897..6db2c22 100644 --- a/setup.py +++ b/setup.py @@ -12,14 +12,6 @@ with open(path.join(here, pkg_name, "version.py")) as f: exec(f.read()) -with urllib.request.urlopen( - "https://raw.githubusercontent.com/flatironinstitute/CaImAn/master/requirements.txt" -) as f: - caiman_requirements = f.read().decode("UTF-8").split("\n") - -caiman_requirements.remove("") -caiman_requirements.append("future") - setup( name=pkg_name.replace("_", "-"), version=__version__, # noqa: F821 @@ -39,11 +31,9 @@ "ipywidgets", "plotly", "opencv-python", - "element-interface @ git+https://github.com/datajoint/element-interface.git", + "element-interface @ git+https://github.com/datajoint/element-interface.git@staging", ], extras_require={ - "caiman_requirements": [caiman_requirements], - "caiman": ["caiman @ git+https://github.com/datajoint/CaImAn.git"], "elements": [ "element-animal @ git+https://github.com/datajoint/element-animal.git", "element-event @ git+https://github.com/datajoint/element-event.git",