Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify files regarding lead-field matrix #263

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 223 additions & 25 deletions examples/example-0.8-leadfield-matrix.ipynb

Large diffs are not rendered by default.

208 changes: 203 additions & 5 deletions neurolib/utils/leadfield.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import numpy as np
import matplotlib.pyplot as plt
import typing
import pandas as pd

import nibabel as nib
import mne
Expand All @@ -15,8 +17,10 @@
class LeadfieldGenerator:

"""
Authors: Mohammad Orabe <[email protected]>
Zixuan liu <[email protected]>
Authors:
Zixuan liu <[email protected]>
Mohammad Orabe <[email protected]>


A class to compute the lead-field matrix and perform related operations.
The default loaded data is the template data 'fsaverage'.
Expand Down Expand Up @@ -318,7 +322,7 @@ def __get_backprojection(

return back_proj_rounded

def __filter_for_regions(self, label_strings: list[str], regions: list[str]) -> list[bool]:
def __filter_for_regions(self, label_strings: typing.List[str], regions: typing.List[str]) -> typing.List[bool]:
"""
Create a list of bools indicating if the label_strings are in the regions list.
This function can be used if one is only interested in a subset of regions defined by an atlas.
Expand Down Expand Up @@ -349,7 +353,7 @@ def __get_labels_of_points(
xml_file: dict,
atlas="aal2_cortical",
cortex_parts="only_cortical_parts",
) -> tuple[list[bool], np.ndarray, list[str]]:
) -> typing.Tuple[typing.List[bool], np.ndarray, typing.List[str]]:
"""
Gives labels of regions the points fall into.

Expand Down Expand Up @@ -443,7 +447,7 @@ def __get_labels_of_points(

def __downsample_leadfield_matrix(
self, leadfield: np.ndarray, label_codes: np.ndarray
) -> tuple[np.ndarray, np.ndarray]:
) -> typing.Tuple[np.ndarray, np.ndarray]:
"""
Downsample the leadfield matrix by computing the average across all dipoles falling within specific regions. This process assumes a one-to-one correspondence between source positions and dipoles, as commonly found in a surface source space where the dipoles' orientations are aligned with the surface normals.

Expand Down Expand Up @@ -611,3 +615,197 @@ def check_atlas_missing_regions(self, atlas_xml_path, unique_labels):
missed_region_indices = np.array([i + 1 for i, e in enumerate(label_numbers) if e in subset])
print("missed region indices:", missed_region_indices)
print("=====================================================")

def view_all_region_names(self,atlas_xml_path):
"""
take a view of all the region names in the atlas.

Parameters:
==========
atlas_xml_path (str): Path to the XML file containing label information.

Returns:
=======
None

"""
xml_file = self.__create_label_lut(atlas_xml_path)
label_numbers = np.array(list(map(int, xml_file.keys())))[:-1] # Convert the keys to integers
empty_set = []
all_region_labels = np.setdiff1d(label_numbers, empty_set)
#print("all region labels:", all_region_labels)

all_region_labels_str = all_region_labels.astype(str)
all_region_values = list(xml_file[label] for label in all_region_labels_str if label in xml_file)
print("all region names:", all_region_values)
print("=====================================================")

def find_region_corresponding_index(self,atlas_xml_path, region_name):
"""
find the index of given region name of the atlas.

Parameters:
==========
atlas_xml_path (str): Path to the XML file containing label information.
region_name (str): The name of the region of the atlas.

Returns:
=======
region_index (np.ndarray): The index of given region name in the atlas.

"""
xml_file = self.__create_label_lut(atlas_xml_path)
label_numbers = np.array(list(map(int, xml_file.keys())))[:-1] # Convert the keys to integers
empty_set = []
all_region_labels = np.setdiff1d(label_numbers, empty_set)
#print("all region labels:", all_region_labels)

all_region_labels_str = all_region_labels.astype(str)
all_region_values = list(xml_file[label] for label in all_region_labels_str if label in xml_file)
#print("all region names:", all_region_values)

all_subset = set(all_region_labels)
all_region_index = np.array([i+1 for i, e in enumerate(label_numbers) if e in all_subset])
#print("all region index:", all_region_index)

region_index = all_region_values.index(region_name)

print("Index for %s:" %region_name, region_index)
print("=====================================================")

return region_index

def simulated_source_data(
self,
leadfield_downsampled,
timepoints_number=1000,
frequency_parameter=(5,10),
time_parameter=(0, 1)
):
"""
Generate simulated source data.

Parameters:
==========
leadfield_downsampled (np.ndarray): Channels x Regions leadfield matrix.
timepoints (np.ndarray): Number of timepoints of generated data.
frequency_parameter (np.ndarray): The parameter of random frequencies for each dipole.
time_paremter (np.ndarray): The total time of generated data.

Returns:
=======
simulated_source_data (np.ndarray): The generated source data with the dimension regions x timepoints number
time (np.ndarray): The total timepoints of generated data.

"""

n_dipoles = leadfield_downsampled.shape[1] # Number of dipoles
n_timepoints = timepoints_number # Number of time points in the simulated data
frequencies = np.random.uniform(frequency_parameter[0], frequency_parameter[1], n_dipoles) # Random frequencies for each dipole
time = np.linspace(time_parameter[0], time_parameter[1], n_timepoints) # 1 second of data

# Create source time-series: [n_dipoles x n_timepoints]
simulated_source_data = np.array([np.sin(2 * np.pi * f * time) for f in frequencies])

return simulated_source_data, time

def plot_eeg_data(self, eeg_data, time, title, offset_per_channel):
"""
Plot the calculated EEG data.

Parameters:
==========
eeg_data (np.ndarray): The calculated EEG data based on source data and the lead field matrix.
time (np.ndarray): The total timepoints of generated data
title (str): The title of the simulated EEG plot.
offset_per_channel (np.ndarray): The offset of the simulated EEG plot.
csv_file_title (str): Title name of the csv file
path_to_save_csv (str): The path to save the csv file.

Returns:
=======
None

"""
channel_offsets = np.arange(eeg_data.shape[0]) * offset_per_channel
for i, channel_data in enumerate(eeg_data):
plt.plot(time, channel_data + channel_offsets[i], label=f'Channel {i}')
plt.yticks(channel_offsets, [f'Channel {i}' for i in range(eeg_data.shape[0])])
plt.title(title)
plt.xlabel("Time (s)")
plt.ylabel("Channels")
plt.tight_layout()
plt.show() # Explicitly display the plot

def simulated_eeg_data(
self,
simulated_source_data,
leadfield_downsampled,
time,
visualization=True,
plot_title="Simulated EEG Data",
plot_offset=None,
plot_size=(8, 16),
csv_file_name="simulated_eeg_data.csv",
folder_to_save_csv="examples/data/AAL2_atlas_data"
):
"""
Calculate simulated EEG data based on generated simulated source data, generate the plot and csv file.

Parameters:
==========
simulated_source_data (np.ndarray): The generated source data with the dimension regions x timepoints
leadfield_downsampled (np.ndarray): Channels x Regions leadfield matrix.
time (np.ndarray): The total timepoints of generated data
plot_title (str): The title of the simulated EEG plot.
plot_offset (np.ndarray): The offset of the simulated EEG plot.
plot_size (np.ndarray): The size of the simulated EEG plot.
csv_file_name (str): Saved name of the csv file.
folder_to_save_csv (str): The folder to save the csv file.

Returns:
=======
simulated_eeg_data (np.ndarray): The calculated EEG data with the dimension channels x timepoints
"""
# Simulate EEG data: [n_sensors x n_timepoints]
simulated_eeg_data = np.dot(leadfield_downsampled, simulated_source_data)

# Plot EEG data
if visualization == True:
# Define an offset between each channel's plot
if plot_offset is None:
offset_per_channel = np.max(np.abs(simulated_eeg_data)) * 1.5
else:
offset_per_channel = plot_offset

plt.figure(figsize=plot_size)
self.plot_eeg_data(simulated_eeg_data, time, plot_title, offset_per_channel)

# List to store individual dataframes for each channel
dfs = []
# Loop through each channel and create individual dataframes
for i in range(simulated_eeg_data.shape[0]):
df_channel = pd.DataFrame({
f'Channel_{i+1}': simulated_eeg_data[i, :],
})
dfs.append(df_channel)

# Concatenate individual dataframes along columns axis
df_pairwise = pd.concat(dfs, axis=1)

# Save to CSV
if folder_to_save_csv is not None:

folder_path = folder_to_save_csv
if not os.path.exists(folder_path):
os.makedirs(folder_path)

path_to_save_csv = os.path.join(folder_path, csv_file_name)

df_pairwise.to_csv(path_to_save_csv, index=False)

print(f"The simulated EEG data is saved as a csv file at {path_to_save_csv}")
print("=====================================================")

return simulated_eeg_data

Loading