diff --git a/examples/data/AAL2_atlas_data/AAL2.nii b/examples/data/AAL2_atlas_data/AAL2.nii new file mode 100755 index 00000000..48aab949 Binary files /dev/null and b/examples/data/AAL2_atlas_data/AAL2.nii differ diff --git a/examples/data/AAL2_atlas_data/AAL2.xml b/examples/data/AAL2_atlas_data/AAL2.xml new file mode 100755 index 00000000..05853e26 --- /dev/null +++ b/examples/data/AAL2_atlas_data/AAL2.xml @@ -0,0 +1,137 @@ + + +
+ aal2 + 1.0 + ROI_MNI_V5 + + + MNI + Label + + AAL2.nii + +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
diff --git a/neurolib/utils/leadfield.py b/neurolib/utils/leadfield.py new file mode 100644 index 00000000..d2ed463b --- /dev/null +++ b/neurolib/utils/leadfield.py @@ -0,0 +1,613 @@ +import os +import numpy as np +import matplotlib.pyplot as plt + +import nibabel as nib +import mne +from mne.datasets import eegbci +from mne.datasets import fetch_fsaverage + +import logging +from xml.etree import ElementTree +from neurolib.utils.atlases import AutomatedAnatomicalParcellation2 + + +class LeadfieldGenerator: + + """ + Authors: Mohammad Orabe + Zixuan liu + + A class to compute the lead-field matrix and perform related operations. + The default loaded data is the template data 'fsaverage'. + To establish an AAL2 atlas source space, the average dipole value within each atlas annotation is computed, a process referred to as downsampling. + The initial step is to generate the surface source model. + The downsampling process need NIfTI file and XML file of the AAL2 atlas. + + + Parameters: + ========== + fs_dir (str): Path to the downloaded 'fsaverage' directory, set as default data. + subject (str): The name of the subject. + subjects_dir (str): Path to the directory containing the subject data. + trans (str): Path to the coregistration transformation file. + atlas_nii (str): Path to the NIfTI file of the atlas. + atlas_xml (str): Path to the XML file of the atlas. + + Attributes: + ========== + raw (mne.io.Raw): The raw EEG data. + + + Methods: + ======= + load_data(subject, subjects_dir): + Load subject data and its directory, 'fsaverage' is set as default. For user-specific data, a coregistration transformation file needed to be generated. + + load_transformation file(trans): + load the transformation file of the subject, 'fsaverage' has default transformation file. + + build_BEM(subject, conductivity, subjects_dir) + Construct BEM for the given subject head model. + + generate_surface_source_space(subject, spacing, add_dist): + Generate the overall surface source model. + + EEG_coregistration(subject, configuration, src, trans, visualization): + Align the selected EEG configuration with the subject and visualization. + + calculate_general_forward_solution(raw, trans, src, bem, eeg, mindist, n_jobs): + Compute the general forward solution based on given subject, BEM, and EEG configuration. + + downsample_leadfield_matrix(leadfield, label_codes, atlas_nii, atlas_xml): + Downsample the lead-field matrix according to AAL2 atlas based on general forward solution. + + check_atlas_missing_regions(): + Check for missing regions in the atlas based on label codes. + """ + + def __init__(self, subject): + self.subject = subject + self.fs_dir = None + self.subjects_dir = None + self.trans = None + + def load_data(self, subjects_dir=None, subject="fsaverage"): + """ + Load subject data. + + Parameters: + ========== + subject (str): The name of the subject, default set as 'fsaverage'. + subjects_dir (str): The directory of the subject. + + """ + if subject == "fsaverage": + # Download the template data 'fsaverage' + self.fs_dir = fetch_fsaverage(verbose=True) + self.subjects_dir = os.path.dirname(self.fs_dir) + print("Load template data 'fsaverage'") + else: + self.subjects_dir = subjects_dir + # Generate transformation file, detail see https://mne.tools/stable/generated/mne.gui.coregistration.html#mne.gui.coregistration + mne.gui.coregistration(subject=subject, subjects_dir=subjects_dir) + + # (raw_fname,) = eegbci.load_data(subject=1, runs=[6]) + # raw = mne.io.read_raw_edf(raw_fname, preload=True) + + def load_transformation_file(self, trans_path, subject="fsaverage"): + """ + Load transformation file. + + Parameters: + ========== + trans_path (str): The directory of the transformation file + + """ + # Load the generated transformation file + if subject == "fsaverage": + self.trans = os.path.join(self.subjects_dir, self.subject, "bem", "fsaverage-trans.fif") + print("Load default transformation file of 'fsaverage'") + else: + self.trans = trans_path + + def build_BEM( + self, + conductivity=(0.3, 0.006, 0.3), + visualization=True, + brain_surfaces="white", + orientation="coronal", + slices=[50, 100, 150, 200], + ): + """ + Create the Boundary Element Model (BEM) solution for the given subject using on the linear collocation approach. + + Parameters: + ========== + subject (ndarray | str): Subject identifier. + subjects_dir (str): Subject directory path. + fs_dir (str): FreeSurfer directory path. + conductivity : array of int, shape (3,) or (1,). The conductivities to use for each shell. Should be a single element for a one-layer model, or three elements for a three-layer model. Defaults to ``[0.3, 0.006, 0.3]``. The MNE-C default for a single-layer model would be ``[0.3]``. + + Returns: + ======= + mne.bem.ConductorModel: BEM of the given head model. + plot_bem_kwargs: Image information of the given mri data + + """ + + model = mne.make_bem_model( + subject=self.subject, + ico=4, + conductivity=conductivity, + subjects_dir=self.subjects_dir, + ) + bem = mne.make_bem_solution(model) + + # Visualization of the BEM + plot_bem_kwargs = dict( + subject=self.subject, + subjects_dir=self.subjects_dir, + brain_surfaces=brain_surfaces, + orientation=orientation, + slices=slices, + ) + + if visualization == True: + mne.viz.plot_bem(**plot_bem_kwargs) + + return bem, plot_bem_kwargs + + def generate_surface_source_space(self, plot_bem_kwargs, spacing="ico4", add_dist="patch", visualization=True): + """ + Generate the overall surface source model. + + Parameters: + ========== + subject (ndarray | str): Subject identifier. + subjects_dir (str): Subject directory path. + spacing (str) : The spacing to use. Can be 'ico#' for a recursively subdivided icosahedron, 'oct#' for a recursively subdivided octahedron, 'all' for all points, or an integer to use approximate distance-based spacing (in mm). + add_dist (bool | str): Add distance and patch information to the source space. + + Returns: + ======= + src (mne.SourceSpaces): Surface source space object. + + """ + + if self.subject == "fsaverage": + src = os.path.join(self.fs_dir, "bem", "fsaverage-ico-5-src.fif") + else: + src = mne.setup_source_space( + subject=self.subject, + spacing=spacing, + add_dist=add_dist, + subjects_dir=self.subjects_dir, + ) + + if visualization == True: + mne.viz.plot_bem(src=src, **plot_bem_kwargs) + + return src + + def EEG_coregistration(self, src, configuration="standard_1020", visualization=True): + """ + Align the selected EEG configuration with the subject and visualization. + + Parameters: + ========== + src (mne.SourceSpaces): Source space object. + trans (str): Path to the transformation file. + configuration (str): Type of EEG electrode layout, defaults to 'standard_1020'. + + Returns: + ======= + raw (mne.io.Raw): Raw data coregistrated with EEG. + """ + + # Load the EEGBCI data + (raw_fname,) = eegbci.load_data(subject=1, runs=[6]) + raw = mne.io.read_raw_edf(raw_fname, preload=True) + + # Clean channel names to be able to use a standard 1020 montage + new_names = dict( + (ch_name, ch_name.rstrip(".").upper().replace("Z", "z").replace("FP", "Fp")) for ch_name in raw.ch_names + ) + raw.rename_channels(new_names) + + # Read and set the EEG electrode locations, which are already in fsaverage's space (MNI space) for standard_1020: + montage = mne.channels.make_standard_montage(configuration) + raw.set_montage(montage) + raw.set_eeg_reference(projection=True) # needed for inverse modeling + + # Check that the locations of EEG electrodes is correct with respect to MRI + if visualization == True: + mne.viz.plot_alignment( + raw.info, + src=src, + eeg=["original", "projected"], + trans=self.trans, + show_axes=True, + mri_fiducials=True, + dig="fiducials", + ) + + return raw + + def calculate_general_forward_solution(self, raw, src, bem, eeg=True, mindist=5.0): + """ + Calculate the general forward solution + + Parameters: + ========== + raw (mne.io.Raw): Raw data coregistrated with EEG. + src (mne.SourceSpaces): Surface source space object. + trans (str): Path to the transformation file. + bem (mne.bem.ConductorModel): BEM of the given head model. + + Returns: + ======= + fwd: The general forward solution. + + """ + + # Computer the general forward solution + fwd = mne.make_forward_solution( + raw.info, + trans=self.trans, + src=src, + bem=bem, + eeg=eeg, + mindist=mindist, + n_jobs=None, + ) + print("The general forward solution:", fwd) + print("=====================================================") + + return fwd + + def __create_label_lut(self, path: str) -> dict: + """ + Create a lookup table that contains "anatomical acronyms" corresponding to the encodings of the regions + specified by the used anatomical atlas. Adds an empty label for code "0" if not specified otherwise by atlas. + + Parameters: + ========== + path (str): Path to the XML file containing label information. + + Returns: + ======= + dict: Dictionary with keys being the integer codes of regions and the values being anatomical acronyms. + + """ + # Look up the codes ("index") and the names of the regions defined by the atlas. + tree = ElementTree.parse(path) + root = tree.getroot() + label_lut = {} + for region in root.find("data").findall("label"): + label_lut[region.find("index").text] = region.find("name").text + + if "0 " not in label_lut.keys(): + label_lut["0"] = "" + return label_lut + + def __get_backprojection( + self, point_expanded: np.ndarray, affine: np.ndarray, affine_inverse: np.ndarray + ) -> np.ndarray: + """ + Transform MNI-mm-point into 'voxel-coordinate'. + + Parameters: + ========== + point_expanded (np.ndarray): First three elements are the 3D point in MNI-coordinate space (mm), + last element being a 1 for the offset in transformations. `point_expanded` must have the shape of 4x1. + affine (np.ndarray): Projects voxel-numbers to MNI coordinate space (mm). `affine` must have the shape of 4x4. + affine_inverse (np.ndarray): Back projection from MNI space. `affine_inverse` must have the shape of 4x4. + + Returns: + ======= + np.ndarray: The point projected back into "voxel-number-space", last element 1. Will return the shape of 4x1. + + """ + + # project the point from mni to voxel + back_proj = affine_inverse @ point_expanded + + # Round to voxel resolution, multiplication with elements inverse is equivalent to division with elements of the affine here. + back_proj_rounded = np.round(np.diag(affine_inverse) * back_proj, 0) * np.diag(affine) + + return back_proj_rounded + + def __filter_for_regions(self, label_strings: list[str], regions: list[str]) -> list[bool]: + """ + Create a list of bools indicating if the label_strings are in the regions list. + This function can be used if one is only interested in a subset of regions defined by an atlas. + + Parameters: + ========== + label_strings (list[str]): List of labels that dipoles got assigned to. + regions (list[str]): List of strings that are the acronyms for the regions of interest. + + Returns: + ======= + list[bool]: List of bools indicating if each label_string is in the regions list. + + """ + # Remark: then outside this function the label codes and label-strings can be set to nan or 0 for dipoles that are not of interest such that downsampling works smoothly. + + regions_set = set(regions) + + # Use list comprehension for faster filtering + in_regions = [label in regions_set for label in label_strings] + + return in_regions + + def __get_labels_of_points( + self, + points: np.ndarray, + nii_file: nib.Nifti1Image, + xml_file: dict, + atlas="aal2_cortical", + cortex_parts="only_cortical_parts", + ) -> tuple[list[bool], np.ndarray, list[str]]: + """ + Gives labels of regions the points fall into. + + Parameters: + ========== + points (np.ndarray): Nx3 array of points defined in MNI space (mm). + nii_file (nibabel.Nifti1Image): NIfTI file representing the anatomical atlas. + xml_file (dict): Dictionary containing "anatomical acronyms" corresponding to the encodings of the regions. + atlas (str): Specification of the anatomical atlas. Currently only "aal2_cortical" is supported and is set as default. + cortex_parts (str): Specification of cortex parts, defaults to "only_cortical_parts". + + Returns: + ======= + tuple[list[bool], np.ndarray, list[str]]: Tuple containing: + - List of boolean values indicating if a valid assignment within the space defined by the atlas was found for each point. + - Array of the assigned label codes for each point. + - List of strings representing the "anatomical acronyms" of the assigned labels. + + """ + n_points = points.shape[0] + label_codes = np.zeros( + n_points + ) # Remark: or expand points-array by one dimension and fill label-codes in there? + label_strings = [None] * n_points + points_found = [None] * n_points + + points_expanded = np.ones((n_points, 4)) # Expand by a column with ones only to allow for transformations + points_expanded[:, 0:3] = points # with affines. + + if not points.shape[1] == 3: + raise ValueError + + # Load atlas (integer encoded volume and string-labels). + atlas_img = nii_file + atlas_labels_lut = xml_file + + affine = atlas_img.affine # Transformation from voxel- to mni-space. + affine_inverse = np.linalg.inv(affine) # Transformation mni- to "voxel"-space. + + # Get voxel codes + codes = atlas_img.get_fdata() + for point_idx, point in enumerate(points_expanded): + back_proj = self.__get_backprojection(point, affine, affine_inverse) + + try: + label_codes[point_idx] = codes[int(back_proj[0]), int(back_proj[1]), int(back_proj[2])] + + except IndexError: + label_codes[point_idx] = np.NAN + + if np.isnan(label_codes[point_idx]): + points_found[point_idx] = False + label_strings[point_idx] = "invalid" + else: + points_found[point_idx] = True + label_strings[point_idx] = atlas_labels_lut[ + str(int(label_codes[point_idx])) + ] # ToDo: clean up type- conversions. + if sum(points_found) < n_points: + logging.error( + f"The atlas does not specify valid labels for all the given points.\n" + f"Total number of points: (%s) out of which (%s) were validly assigned." % (n_points, sum(points_found)) + ) + + if atlas == "aal2_cortical": + aal_2 = AutomatedAnatomicalParcellation2() + regions = [] + + # Select cortex part + full_cortex = aal_2.cortex + aal_2.subcortical + only_cortical_parts = aal_2.cortex + subcortical_parts = aal_2.subcortical + + if cortex_parts == "full_cortex": + cortex_parts = full_cortex + if cortex_parts == "only_cortical_parts": + cortex_parts = only_cortical_parts + if cortex_parts == "subcortical_parts": + cortex_parts = subcortical_parts + + for r in cortex_parts: + regions.append(aal_2.aal2[r + 1]) + in_regions = self.__filter_for_regions(label_strings, regions) + + for idx_point in range(len(points_found)): + if not in_regions[idx_point]: + label_codes[idx_point] = 0 + label_strings[idx_point] = "" + + return points_found, label_codes, label_strings + + def __downsample_leadfield_matrix( + self, leadfield: np.ndarray, label_codes: np.ndarray + ) -> tuple[np.ndarray, np.ndarray]: + """ + Downsample the leadfield matrix by computing the average across all dipoles falling within specific regions. This process assumes a one-to-one correspondence between source positions and dipoles, as commonly found in a surface source space where the dipoles' orientations are aligned with the surface normals. + + Parameters: + ========== + leadfield (np.ndarray): Leadfield matrix. Channels x Dipoles. + label_codes (np.ndarray): 1D array of region-labels assigned to the source locations. + + Returns: + ======= + tuple[np.ndarray, np.ndarray]: Tuple containing: + - Array that contains the label-codes of any region that at least one dipole was assigned to. + - Channels x Regions leadfield matrix. The order of rows (channels) is unchanged compared to the input "leadfield", + but the columns are sorted according to the "unique_labels" array. + + """ + leadfield_orig_shape = leadfield.shape + n_channels = leadfield_orig_shape[0] + + if leadfield_orig_shape[1] != label_codes.size: + raise ValueError( + "The lead field matrix does not have the expected number of columns. \n" + "Number of columns differs from labels (equal number dipoles)." + ) + + unique_labels = np.unique(label_codes) + unique_labels = np.delete(unique_labels, np.where(np.isnan(unique_labels))[0]) # Delete NAN if present. + # NAN would indicate point that doesn't fall into space covered by atlas. + unique_labels = np.delete( + unique_labels, np.where(unique_labels == 0)[0] + ) # Delete 0 if present. "0" in AAL2 is non-brain-tissue, eg. CSF. + + downsampled_leadfield = np.zeros((n_channels, unique_labels.size)) + + for label_idx, label in enumerate(unique_labels): # iterate through regions + indices_label = np.where(label_codes == label)[0] + + downsampled_leadfield[:, label_idx] = np.mean(leadfield[:, indices_label], axis=1) + + return unique_labels, downsampled_leadfield + + def compute_downsampled_leadfield( + self, + fwd, + atlas_nii_path, + atlas_xml_path, + atlas="aal2_cortical", + cortex_parts="only_cortical_parts", + path_to_save=None, + ): + """ + Compute the leadfield matrix. + + Parameters: + ========== + raw (mne.io.Raw): Raw data object. + trans (str): Path to the transformation file. + src (mne.SourceSpaces): Source space object. + bem (mne.bem.ConductorModel): BEM object. + subject (str): Subject identifier. + atlas_nii_path (str): Path to the NIfTI file of the atlas. + atlas_xml_path (str): Path to the XML file of the atlas. + atlas (str): Specification of the anatomical atlas, defaults to "aal2_cortical". + cortex_parts (str): Specification of cortex parts, defaults to "only_cortical_parts". + path_to_save (str): Path to save the leadfield matrix as a binary file in NumPy .npy format, defaults to None. + + Returns: + ======= + tuple[np.ndarray, mne.Forward, np.ndarray]: Tuple containing: + - Channels x Regions leadfield matrix. + - Forward solution object. + - Array that contains the label-codes of any region that at least one dipole was assigned to. + + """ + # Calculate the general forward solution + + # Downsample the forward solution to achieve lead-field matrix + ## With the forward solution that being calculated above, compute the average dipole value of the dipoles in each AAL atlas to acquire the lead-field matrix. + + fwd_fixed = mne.convert_forward_solution(fwd, surf_ori=True, force_fixed=True, use_cps=True) + + leadfield_fixed = fwd_fixed["sol"]["data"] + + atlas_nii_file = nib.load(atlas_nii_path) + + atlas_xml_file = self.__create_label_lut(atlas_xml_path) + + lh = fwd_fixed["src"][0] + dip_pos_lh = np.vstack(lh["rr"][lh["vertno"]]) + rh = fwd_fixed["src"][1] + dip_pos_rh = np.vstack(rh["rr"][rh["vertno"]]) + + dip_pos = np.vstack((dip_pos_lh, dip_pos_rh)) + + trans_info = mne.read_trans(self.trans) + + dip_pos_mni = mne.head_to_mni(dip_pos, subject=self.subject, mri_head_t=trans_info) + + points_found, label_codes, label_strings = self.__get_labels_of_points( + dip_pos_mni, + atlas_nii_file, + atlas_xml_file, + atlas=atlas, + cortex_parts=cortex_parts, + ) + + unique_labels, leadfield_downsampled = self.__downsample_leadfield_matrix(leadfield_fixed, label_codes) + + print("Lead-field matrix's size : %d sensors x %d dipoles" % leadfield_downsampled.shape) + print("=====================================================") + + print("Downsampled lead-field matrix:", leadfield_downsampled) + print("=====================================================") + # Export the leadfield matrix an array to a binary file in NumPy .npy format. + if path_to_save is not None: + np.save( + os.path.join(path_to_save, "leadfield_downsampled"), + leadfield_downsampled, + ) + print(f"The leadfiled matrix is saved as a binary file in NumPy .npy format at {path_to_save}") + print("=====================================================") + + return leadfield_downsampled, unique_labels + + def check_atlas_missing_regions(self, atlas_xml_path, unique_labels): + """ + Investigate the missing regions of the atlas. + + Parameters: + ========== + atlas_xml_path (str): Path to the XML file containing label information. + unique_labels (np.ndarray): Array containing the label-codes of any region that at least one dipole was assigned to. + + Returns: + ======= + None + + """ + + aal_2 = AutomatedAnatomicalParcellation2() + full_cortex = aal_2.cortex + aal_2.subcortical + total_region_quantity = np.array(full_cortex).shape[0] + missed_region_quantity = np.array(full_cortex).shape[0] - np.array(unique_labels).shape[0] + + print("total region quantity:", total_region_quantity) + print("missed region quantity: ", missed_region_quantity) + print("=====================================================") + + atlas_xml_file = self.__create_label_lut(atlas_xml_path) + + label_numbers = np.array(list(map(int, atlas_xml_file.keys())))[:-1] # Convert the keys to integers + missed_region_labels = np.setdiff1d(label_numbers, unique_labels) + print("missed region labels:", missed_region_labels) + print("=====================================================") + + missed_region_labels_str = missed_region_labels.astype(str) + # missed_region_labels_str = np.core.defchararray.add(missed_region_labels.astype(str), '') + missed_region_values = list( + atlas_xml_file[label] for label in missed_region_labels_str if label in atlas_xml_file + ) + print("missed region names:", missed_region_values) + print("=====================================================") + + subset = set(missed_region_labels) + missed_region_indices = np.array([i + 1 for i, e in enumerate(label_numbers) if e in subset]) + print("missed region indices:", missed_region_indices) + print("=====================================================")