Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor gscontrol module #1086

Merged
merged 15 commits into from
Aug 13, 2024
2 changes: 1 addition & 1 deletion docs/approach.rst
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ presented at MRITogether 2022 for a hands-on tutorial.
Removal of spatially diffuse noise (optional)
*********************************************

:func:`tedana.gscontrol.gscontrol_raw`, :func:`tedana.gscontrol.gscontrol_mmix`
:func:`tedana.gscontrol.gscontrol_raw`, :func:`tedana.gscontrol.minimum_image_regression`

Due to the constraints of ICA, TEDICA is able to identify and remove spatially
localized noise components, but it cannot identify components that are spread
Expand Down
280 changes: 174 additions & 106 deletions tedana/gscontrol.py

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions tedana/metrics/dependence.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,20 @@ def calculate_betas(
betas : (M [x E] x C) array_like
Unstandardized parameter estimates
"""
if len(data.shape) == 2:
if data.ndim == 2:
data_optcom = data
assert data_optcom.shape[1] == mixing.shape[0]
# mean-center optimally-combined data
data_optcom_dm = data_optcom - data_optcom.mean(axis=-1, keepdims=True)
# betas are the result of a normal OLS fit of the mixing matrix
# against the mean-center data
# betas are from a normal OLS fit of the mixing matrix against the mean-centered data
betas = get_coeffs(data_optcom_dm, mixing)
return betas

else:
betas = np.zeros([data.shape[0], data.shape[1], mixing.shape[1]])
for n_echo in range(data.shape[1]):
betas[:, n_echo, :] = get_coeffs(data[:, n_echo, :], mixing)
return betas

return betas


def calculate_psc(
Expand Down
33 changes: 17 additions & 16 deletions tedana/reporting/static_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,23 +164,24 @@ def carpet_plot(
)
)

mir_denoised_img = io_generator.get_name("ICA accepted mir denoised img")
fig, ax = plt.subplots(figsize=(14, 7))
plotting.plot_carpet(
mir_denoised_img,
mask_img,
figure=fig,
axes=ax,
title="High-Kappa Data (Post-MIR)",
)
fig.tight_layout()
fig.savefig(
os.path.join(
io_generator.out_dir,
"figures",
f"{io_generator.prefix}carpet_accepted_mir.svg",
if io_generator.verbose:
mir_denoised_img = io_generator.get_name("ICA accepted mir denoised img")
fig, ax = plt.subplots(figsize=(14, 7))
plotting.plot_carpet(
mir_denoised_img,
mask_img,
figure=fig,
axes=ax,
title="High-Kappa Data (Post-MIR)",
)
fig.tight_layout()
fig.savefig(
os.path.join(
io_generator.out_dir,
"figures",
f"{io_generator.prefix}carpet_accepted_mir.svg",
)
)
)


def plot_component(
Expand Down
3 changes: 1 addition & 2 deletions tedana/tests/data/reclassify_debug_out.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@ sub-testymctestface_references.bib
sub-testymctestface_report.txt
sub-testymctestface_betas_OC.nii.gz
sub-testymctestface_betas_hik_OC.nii.gz
sub-testymctestface_betas_hik_OC_MIR.nii.gz
sub-testymctestface_dataset_description.json
sub-testymctestface_dn_ts_OC.nii.gz
sub-testymctestface_dn_ts_OC_MIR.nii.gz
sub-testymctestface_feats_OC2.nii.gz
sub-testymctestface_hik_ts_OC_MIR.nii.gz
sub-testymctestface_ica_components.nii.gz
sub-testymctestface_ica_cross_component_metrics.json
sub-testymctestface_ica_decision_tree.json
Expand All @@ -22,3 +20,4 @@ sub-testymctestface_ica_orth_mixing.tsv
sub-testymctestface_ica_status_table.tsv
sub-testymctestface_registry.json
sub-testymctestface_sphis_hik.nii.gz
sub-testymctestface_confounds.tsv
33 changes: 21 additions & 12 deletions tedana/tests/test_gscontrol.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@


def test_break_gscontrol_raw():
"""
Ensure that gscontrol_raw fails when input data do not have the right.

shapes.
"""
"""Ensure that gscontrol_raw fails when input data do not have the right shapes."""
n_samples, n_echos, n_vols = 10000, 4, 100
catd = np.empty((n_samples, n_echos, n_vols))
optcom = np.empty((n_samples, n_vols))
Expand All @@ -27,28 +23,41 @@ def test_break_gscontrol_raw():
catd = np.empty((n_samples + 1, n_echos, n_vols))
with pytest.raises(ValueError) as e_info:
gsc.gscontrol_raw(
catd=catd, optcom=optcom, n_echos=n_echos, io_generator=io_generator, dtrank=4
data_cat=catd,
data_optcom=optcom,
n_echos=n_echos,
io_generator=io_generator,
dtrank=4,
)
assert str(e_info.value) == (
f"First dimensions of catd ({catd.shape[0]}) and optcom ({optcom.shape[0]}) do not match"
f"First dimensions of data_cat ({catd.shape[0]}) and data_optcom ({optcom.shape[0]}) "
"do not match"
)

catd = np.empty((n_samples, n_echos + 1, n_vols))
with pytest.raises(ValueError) as e_info:
gsc.gscontrol_raw(
catd=catd, optcom=optcom, n_echos=n_echos, io_generator=io_generator, dtrank=4
data_cat=catd,
data_optcom=optcom,
n_echos=n_echos,
io_generator=io_generator,
dtrank=4,
)
assert str(e_info.value) == (
f"Second dimension of catd ({catd.shape[1]}) does not match n_echos ({n_echos})"
f"Second dimension of data_cat ({catd.shape[1]}) does not match n_echos ({n_echos})"
)

catd = np.empty((n_samples, n_echos, n_vols))
optcom = np.empty((n_samples, n_vols + 1))
with pytest.raises(ValueError) as e_info:
gsc.gscontrol_raw(
catd=catd, optcom=optcom, n_echos=n_echos, io_generator=io_generator, dtrank=4
data_cat=catd,
data_optcom=optcom,
n_echos=n_echos,
io_generator=io_generator,
dtrank=4,
)
assert str(e_info.value) == (
f"Third dimension of catd ({catd.shape[2]}) does not match "
f"second dimension of optcom ({optcom.shape[1]})"
f"Third dimension of data_cat ({catd.shape[2]}) does not match "
f"second dimension of data_optcom ({optcom.shape[1]})"
)
9 changes: 8 additions & 1 deletion tedana/workflows/ica_reclassify.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,14 @@ def ica_reclassify_workflow(

if mir:
io_generator.overwrite = True
gsc.minimum_image_regression(data_oc, mmix, mask_denoise, comptable, io_generator)
gsc.minimum_image_regression(
data_optcom=data_oc,
mixing=mmix,
mask=mask_denoise,
comptable=comptable,
classification_tags=selector.classification_tags,
io_generator=io_generator,
)
io_generator.overwrite = False

# Write out BIDS-compatible description file
Expand Down
16 changes: 14 additions & 2 deletions tedana/workflows/tedana.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,12 @@ def tedana_workflow(

if "gsr" in gscontrol:
# regress out global signal
catd, data_oc = gsc.gscontrol_raw(catd, data_oc, n_echos, io_generator)
catd, data_oc = gsc.gscontrol_raw(
data_cat=catd,
data_optcom=data_oc,
n_echos=n_echos,
io_generator=io_generator,
)

fout = io_generator.save_file(data_oc, "combined img")
LGR.info(f"Writing optimally combined data set: {fout}")
Expand Down Expand Up @@ -886,7 +891,14 @@ def tedana_workflow(
)

if "mir" in gscontrol:
gsc.minimum_image_regression(data_oc, mmix, mask_denoise, comptable, io_generator)
gsc.minimum_image_regression(
data_optcom=data_oc,
mixing=mmix,
mask=mask_denoise,
comptable=comptable,
classification_tags=selector.classification_tags,
io_generator=io_generator,
)

if verbose:
io.writeresults_echoes(catd, mmix, mask_denoise, comptable, io_generator)
Expand Down