-
Notifications
You must be signed in to change notification settings - Fork 82
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
modified example leadfield matrix notebook and leadfield function scr…
…ipt to fix some Python version incompatibility, also adding some new features.
- Loading branch information
1 parent
efc51a4
commit ffdac07
Showing
2 changed files
with
426 additions
and
30 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
import os | ||
import numpy as np | ||
import matplotlib.pyplot as plt | ||
import typing | ||
import pandas as pd | ||
|
||
import nibabel as nib | ||
import mne | ||
|
@@ -15,8 +17,10 @@ | |
class LeadfieldGenerator: | ||
|
||
""" | ||
Authors: Mohammad Orabe <[email protected]> | ||
Zixuan liu <[email protected]> | ||
Authors: | ||
Zixuan liu <[email protected]> | ||
Mohammad Orabe <[email protected]> | ||
A class to compute the lead-field matrix and perform related operations. | ||
The default loaded data is the template data 'fsaverage'. | ||
|
@@ -318,7 +322,7 @@ def __get_backprojection( | |
|
||
return back_proj_rounded | ||
|
||
def __filter_for_regions(self, label_strings: list[str], regions: list[str]) -> list[bool]: | ||
def __filter_for_regions(self, label_strings: typing.List[str], regions: typing.List[str]) -> typing.List[bool]: | ||
""" | ||
Create a list of bools indicating if the label_strings are in the regions list. | ||
This function can be used if one is only interested in a subset of regions defined by an atlas. | ||
|
@@ -349,7 +353,7 @@ def __get_labels_of_points( | |
xml_file: dict, | ||
atlas="aal2_cortical", | ||
cortex_parts="only_cortical_parts", | ||
) -> tuple[list[bool], np.ndarray, list[str]]: | ||
) -> typing.Tuple[typing.List[bool], np.ndarray, typing.List[str]]: | ||
""" | ||
Gives labels of regions the points fall into. | ||
|
@@ -443,7 +447,7 @@ def __get_labels_of_points( | |
|
||
def __downsample_leadfield_matrix( | ||
self, leadfield: np.ndarray, label_codes: np.ndarray | ||
) -> tuple[np.ndarray, np.ndarray]: | ||
) -> typing.Tuple[np.ndarray, np.ndarray]: | ||
""" | ||
Downsample the leadfield matrix by computing the average across all dipoles falling within specific regions. This process assumes a one-to-one correspondence between source positions and dipoles, as commonly found in a surface source space where the dipoles' orientations are aligned with the surface normals. | ||
|
@@ -611,3 +615,197 @@ def check_atlas_missing_regions(self, atlas_xml_path, unique_labels): | |
missed_region_indices = np.array([i + 1 for i, e in enumerate(label_numbers) if e in subset]) | ||
print("missed region indices:", missed_region_indices) | ||
print("=====================================================") | ||
|
||
def view_all_region_names(self,atlas_xml_path): | ||
""" | ||
take a view of all the region names in the atlas. | ||
Parameters: | ||
========== | ||
atlas_xml_path (str): Path to the XML file containing label information. | ||
Returns: | ||
======= | ||
None | ||
""" | ||
xml_file = self.__create_label_lut(atlas_xml_path) | ||
label_numbers = np.array(list(map(int, xml_file.keys())))[:-1] # Convert the keys to integers | ||
empty_set = [] | ||
all_region_labels = np.setdiff1d(label_numbers, empty_set) | ||
#print("all region labels:", all_region_labels) | ||
|
||
all_region_labels_str = all_region_labels.astype(str) | ||
all_region_values = list(xml_file[label] for label in all_region_labels_str if label in xml_file) | ||
print("all region names:", all_region_values) | ||
print("=====================================================") | ||
|
||
def find_region_corresponding_index(self,atlas_xml_path, region_name): | ||
""" | ||
find the index of given region name of the atlas. | ||
Parameters: | ||
========== | ||
atlas_xml_path (str): Path to the XML file containing label information. | ||
region_name (str): The name of the region of the atlas. | ||
Returns: | ||
======= | ||
region_index (np.ndarray): The index of given region name in the atlas. | ||
""" | ||
xml_file = self.__create_label_lut(atlas_xml_path) | ||
label_numbers = np.array(list(map(int, xml_file.keys())))[:-1] # Convert the keys to integers | ||
empty_set = [] | ||
all_region_labels = np.setdiff1d(label_numbers, empty_set) | ||
#print("all region labels:", all_region_labels) | ||
|
||
all_region_labels_str = all_region_labels.astype(str) | ||
all_region_values = list(xml_file[label] for label in all_region_labels_str if label in xml_file) | ||
#print("all region names:", all_region_values) | ||
|
||
all_subset = set(all_region_labels) | ||
all_region_index = np.array([i+1 for i, e in enumerate(label_numbers) if e in all_subset]) | ||
#print("all region index:", all_region_index) | ||
|
||
region_index = all_region_values.index(region_name) | ||
|
||
print("Index for %s:" %region_name, region_index) | ||
print("=====================================================") | ||
|
||
return region_index | ||
|
||
def simulated_source_data( | ||
self, | ||
leadfield_downsampled, | ||
timepoints_number=1000, | ||
frequency_parameter=(5,10), | ||
time_parameter=(0, 1) | ||
): | ||
""" | ||
Generate simulated source data. | ||
Parameters: | ||
========== | ||
leadfield_downsampled (np.ndarray): Channels x Regions leadfield matrix. | ||
timepoints (np.ndarray): Number of timepoints of generated data. | ||
frequency_parameter (np.ndarray): The parameter of random frequencies for each dipole. | ||
time_paremter (np.ndarray): The total time of generated data. | ||
Returns: | ||
======= | ||
simulated_source_data (np.ndarray): The generated source data with the dimension regions x timepoints number | ||
time (np.ndarray): The total timepoints of generated data. | ||
""" | ||
|
||
n_dipoles = leadfield_downsampled.shape[1] # Number of dipoles | ||
n_timepoints = timepoints_number # Number of time points in the simulated data | ||
frequencies = np.random.uniform(frequency_parameter[0], frequency_parameter[1], n_dipoles) # Random frequencies for each dipole | ||
time = np.linspace(time_parameter[0], time_parameter[1], n_timepoints) # 1 second of data | ||
|
||
# Create source time-series: [n_dipoles x n_timepoints] | ||
simulated_source_data = np.array([np.sin(2 * np.pi * f * time) for f in frequencies]) | ||
|
||
return simulated_source_data, time | ||
|
||
def plot_eeg_data(self, eeg_data, time, title, offset_per_channel): | ||
""" | ||
Plot the calculated EEG data. | ||
Parameters: | ||
========== | ||
eeg_data (np.ndarray): The calculated EEG data based on source data and the lead field matrix. | ||
time (np.ndarray): The total timepoints of generated data | ||
title (str): The title of the simulated EEG plot. | ||
offset_per_channel (np.ndarray): The offset of the simulated EEG plot. | ||
csv_file_title (str): Title name of the csv file | ||
path_to_save_csv (str): The path to save the csv file. | ||
Returns: | ||
======= | ||
None | ||
""" | ||
channel_offsets = np.arange(eeg_data.shape[0]) * offset_per_channel | ||
for i, channel_data in enumerate(eeg_data): | ||
plt.plot(time, channel_data + channel_offsets[i], label=f'Channel {i}') | ||
plt.yticks(channel_offsets, [f'Channel {i}' for i in range(eeg_data.shape[0])]) | ||
plt.title(title) | ||
plt.xlabel("Time (s)") | ||
plt.ylabel("Channels") | ||
plt.tight_layout() | ||
plt.show() # Explicitly display the plot | ||
|
||
def simulated_eeg_data( | ||
self, | ||
simulated_source_data, | ||
leadfield_downsampled, | ||
time, | ||
visualization=True, | ||
plot_title="Simulated EEG Data", | ||
plot_offset=None, | ||
plot_size=(8, 16), | ||
csv_file_name="simulated_eeg_data.csv", | ||
folder_to_save_csv="examples/data/AAL2_atlas_data" | ||
): | ||
""" | ||
Calculate simulated EEG data based on generated simulated source data, generate the plot and csv file. | ||
Parameters: | ||
========== | ||
simulated_source_data (np.ndarray): The generated source data with the dimension regions x timepoints | ||
leadfield_downsampled (np.ndarray): Channels x Regions leadfield matrix. | ||
time (np.ndarray): The total timepoints of generated data | ||
plot_title (str): The title of the simulated EEG plot. | ||
plot_offset (np.ndarray): The offset of the simulated EEG plot. | ||
plot_size (np.ndarray): The size of the simulated EEG plot. | ||
csv_file_name (str): Saved name of the csv file. | ||
folder_to_save_csv (str): The folder to save the csv file. | ||
Returns: | ||
======= | ||
simulated_eeg_data (np.ndarray): The calculated EEG data with the dimension channels x timepoints | ||
""" | ||
# Simulate EEG data: [n_sensors x n_timepoints] | ||
simulated_eeg_data = np.dot(leadfield_downsampled, simulated_source_data) | ||
|
||
# Plot EEG data | ||
if visualization == True: | ||
# Define an offset between each channel's plot | ||
if plot_offset is None: | ||
offset_per_channel = np.max(np.abs(simulated_eeg_data)) * 1.5 | ||
else: | ||
offset_per_channel = plot_offset | ||
|
||
plt.figure(figsize=plot_size) | ||
self.plot_eeg_data(simulated_eeg_data, time, plot_title, offset_per_channel) | ||
|
||
# List to store individual dataframes for each channel | ||
dfs = [] | ||
# Loop through each channel and create individual dataframes | ||
for i in range(simulated_eeg_data.shape[0]): | ||
df_channel = pd.DataFrame({ | ||
f'Channel_{i+1}': simulated_eeg_data[i, :], | ||
}) | ||
dfs.append(df_channel) | ||
|
||
# Concatenate individual dataframes along columns axis | ||
df_pairwise = pd.concat(dfs, axis=1) | ||
|
||
# Save to CSV | ||
if folder_to_save_csv is not None: | ||
|
||
folder_path = folder_to_save_csv | ||
if not os.path.exists(folder_path): | ||
os.makedirs(folder_path) | ||
|
||
path_to_save_csv = os.path.join(folder_path, csv_file_name) | ||
|
||
df_pairwise.to_csv(path_to_save_csv, index=False) | ||
|
||
print(f"The simulated EEG data is saved as a csv file at {path_to_save_csv}") | ||
print("=====================================================") | ||
|
||
return simulated_eeg_data | ||
|