Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor SCORE/SCRUB and BASIL #253

Draft
wants to merge 33 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
67cebd9
Make --basil and --scorescrub mutually exclusive.
tsalo Mar 29, 2023
994e491
Merge remote-tracking branch 'upstream/main' into basil-mutex
tsalo Apr 4, 2023
d5b03ef
Rename CBF computation workflows.
tsalo Apr 4, 2023
8acc889
Clean up SCORE code.
tsalo Apr 4, 2023
337274d
Fix misc. bugs.
tsalo Apr 4, 2023
adb7ace
Rename SCORE/SCRUB outputs.
tsalo Apr 4, 2023
535814a
Patch connections.
tsalo Apr 4, 2023
0cd690a
Patch more connections.
tsalo Apr 4, 2023
292b0c0
Rename and document BASIL inputs and outputs.
tsalo Apr 4, 2023
3e35cf1
Drop mutual exclusion for this PR.
tsalo Apr 4, 2023
4602ef0
Update parser.py
tsalo Apr 4, 2023
28caafa
Split off QC interfaces.
tsalo Apr 4, 2023
b24b572
Merge remote-tracking branch 'upstream/main' into refactor-score-scrub
tsalo Apr 4, 2023
051d125
Start refactoring _scrub_csf.
tsalo Apr 4, 2023
f44257c
Fix bug I introduced in SCORE.
tsalo Apr 4, 2023
7588d34
Update .gitignore
tsalo Apr 5, 2023
984bcbc
Simplify _get_wfun_weight.
tsalo Apr 5, 2023
6fb6272
Continue, not break.
tsalo Apr 6, 2023
5bfd929
Merge remote-tracking branch 'upstream/main' into refactor-score-scrub
tsalo May 25, 2023
66ee1fc
Reimplement refactor of SCRUB/SCORE code.
tsalo May 25, 2023
bb4c600
Rename the functions.
tsalo May 25, 2023
ed3a569
Update cbf.py
tsalo May 25, 2023
658e1d8
Fix connection.
tsalo May 25, 2023
5316259
Update cbf.py
tsalo May 25, 2023
643e352
Update.
tsalo May 25, 2023
6e1456b
Update cbf.py
tsalo May 26, 2023
4fa33ba
Rename wavelet_function to cost_function.
tsalo May 26, 2023
4108625
Keep working on SCRUB.
tsalo May 26, 2023
144f783
Merge remote-tracking branch 'upstream/main' into refactor-score-scrub
tsalo Jun 1, 2023
8cbcc00
Merge remote-tracking branch 'upstream/main' into refactor-score-scrub
tsalo Dec 12, 2023
d099a5e
Address style issues.
tsalo Dec 12, 2023
ac2356b
Remove unused import.
tsalo Dec 12, 2023
a8e0c07
Merge remote-tracking branch 'upstream/main' into refactor-score-scrub
tsalo Jan 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 74 additions & 56 deletions aslprep/interfaces/cbf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
pcasl_or_pasl,
)
from aslprep.utils.cbf import (
_getcbfscore,
_scrubcbf,
_score_cbf,
_scrub_cbf,
estimate_cbf_pcasl_multipld,
estimate_t1,
)
Expand Down Expand Up @@ -313,12 +313,10 @@ class _ComputeCBFInputSpec(BaseInterfaceInputSpec):
),
)
metadata = traits.Dict(
exists=True,
mandatory=True,
desc="Metadata for the raw CBF file, taken from the raw ASL data's sidecar JSON file.",
)
m0_scale = traits.Float(
exists=True,
mandatory=True,
desc="Relative scale between ASL and M0.",
)
Expand Down Expand Up @@ -611,31 +609,37 @@ def _run_interface(self, runtime):


class _ScoreAndScrubCBFInputSpec(BaseInterfaceInputSpec):
cbf_ts = File(exists=True, mandatory=True, desc="Computed CBF from ComputeCBF.")
mask = File(exists=True, mandatory=True, desc="mask")
cbf_ts = File(exists=True, mandatory=True, desc="computed CBF from ComputeCBF")
gm_tpm = File(exists=True, mandatory=True, desc="Gray matter tissue probability map.")
wm_tpm = File(exists=True, mandatory=True, desc="White matter tissue probability map.")
csf_tpm = File(exists=True, mandatory=True, desc="CSF tissue probability map.")
csf_tpm = File(exists=True, mandatory=True, desc="Cerebrospinal fluid tissue probability map.")
brain_mask = File(exists=True, mandatory=True, desc="Brain mask.")
tpm_threshold = traits.Float(
default_value=0.7,
usedefault=True,
mandatory=False,
desc="Tissue probability threshold for binarizing GM, WM, and CSF masks.",
)
wavelet_function = traits.Str(
default_value="huber",
usedefault=True,
desc="Threshold for tissue probability maps.",
)
cost_function = traits.Str(
mandatory=False,
default_value="huber",
option=["bisquare", "andrews", "cauchy", "fair", "logistics", "ols", "talwar", "welsch"],
desc="Wavelet function",
desc="Wavelet function used in SCRUB.",
usedefault=True,
)


class _ScoreAndScrubCBFOutputSpec(TraitedSpec):
cbf_ts_score = File(exists=False, mandatory=False, desc="score timeseries data")
mean_cbf_score = File(exists=False, mandatory=False, desc="average score")
mean_cbf_scrub = File(exists=False, mandatory=False, desc="average scrub")
score_outlier_index = File(exists=False, mandatory=False, desc="index of volume remove ")
cbf_ts_score = File(
exists=True,
desc="CBF time series after removing outlier volumes flagged by SCORE algorithm.",
)
score_outlier_index = File(exists=True, desc="Index of removed volumes, in a CSV file.")
mean_cbf_score = File(
exists=True,
desc="Mean CBF image calculated from SCORE-censored CBF time series.",
)
mean_cbf_scrub = File(exists=True, desc="average scrub")


class ScoreAndScrubCBF(SimpleInterface):
Expand All @@ -657,38 +661,41 @@ class ScoreAndScrubCBF(SimpleInterface):
output_spec = _ScoreAndScrubCBFOutputSpec

def _run_interface(self, runtime):
cbf_ts = nb.load(self.inputs.cbf_ts).get_fdata()
mask = nb.load(self.inputs.mask).get_fdata()
greym = nb.load(self.inputs.gm_tpm).get_fdata()
whitem = nb.load(self.inputs.wm_tpm).get_fdata()
csf = nb.load(self.inputs.csf_tpm).get_fdata()
if cbf_ts.ndim > 3:
cbf_scorets, index_score = _getcbfscore(
cbfts=cbf_ts,
wm=whitem,
gm=greym,
csf=csf,
mask=mask,
cbf_img = nb.load(self.inputs.cbf_ts)
cbf_arr = cbf_img.get_fdata()

if cbf_arr.ndim > 3:
mask_arr = nb.load(self.inputs.brain_mask).get_fdata()
gm_tpm_arr = nb.load(self.inputs.gm_tpm).get_fdata()
wm_tpm_arr = nb.load(self.inputs.wm_tpm).get_fdata()
csf_tpm_arr = nb.load(self.inputs.csf_tpm).get_fdata()

cbf_ts_score, score_outliers_idx = _score_cbf(
cbf_ts=cbf_arr,
gm=gm_tpm_arr,
wm=wm_tpm_arr,
csf=csf_tpm_arr,
mask=mask_arr,
thresh=self.inputs.tpm_threshold,
)
cbfscrub = _scrubcbf(
cbf_ts=cbf_scorets,
gm=greym,
wm=whitem,
csf=csf,
mask=mask,
wfun=self.inputs.wavelet_function,
mean_cbf_score = np.mean(cbf_ts_score, axis=3)
mean_cbf_scrub = _scrub_cbf(
cbf_ts=cbf_ts_score,
gm=gm_tpm_arr,
wm=wm_tpm_arr,
csf=csf_tpm_arr,
mask=mask_arr,
cost_function=self.inputs.cost_function,
thresh=self.inputs.tpm_threshold,
)
mean_cbf_score = np.mean(cbf_scorets, axis=3)
else:
config.loggers.interface.warning(
f"CBF time series is only {cbf_ts.ndim}D. Skipping SCORE and SCRUB."
f"CBF time series is only {cbf_arr.ndim}D. Skipping SCORE and SCRUB."
)
cbf_scorets = cbf_ts
index_score = np.array([0])
cbfscrub = cbf_ts
mean_cbf_score = cbf_ts
cbf_ts_score = cbf_arr
mean_cbf_score = cbf_arr
score_outliers_idx = np.array([0])
mean_cbf_scrub = cbf_arr

self._results["cbf_ts_score"] = fname_presuffix(
self.inputs.cbf_ts,
Expand All @@ -711,25 +718,25 @@ def _run_interface(self, runtime):
newpath=runtime.cwd,
use_ext=False,
)
samplecbf = nb.load(self.inputs.mask)

nb.Nifti1Image(
dataobj=cbf_scorets,
affine=samplecbf.affine,
header=samplecbf.header,
dataobj=cbf_ts_score,
affine=cbf_img.affine,
header=cbf_img.header,
).to_filename(self._results["cbf_ts_score"])
nb.Nifti1Image(
dataobj=mean_cbf_score,
affine=samplecbf.affine,
header=samplecbf.header,
affine=cbf_img.affine,
header=cbf_img.header,
).to_filename(self._results["mean_cbf_score"])
np.savetxt(self._results["score_outlier_index"], score_outliers_idx, delimiter=",")
nb.Nifti1Image(
dataobj=cbfscrub,
affine=samplecbf.affine,
header=samplecbf.header,
dataobj=mean_cbf_scrub,
affine=cbf_img.affine,
header=cbf_img.header,
).to_filename(self._results["mean_cbf_scrub"])

score_outlier_df = pd.DataFrame(columns=["score_outlier_index"], data=index_score)
score_outlier_df = pd.DataFrame(columns=["score_outlier_index"], data=score_outliers_idx)
score_outlier_df.to_csv(self._results["score_outlier_index"], sep="\t", index=False)

return runtime
Expand Down Expand Up @@ -810,6 +817,7 @@ class _BASILCBFInputSpec(FSLCommandInputSpec):
mandatory=False,
argstr="--pvcorr",
default_value=True,
usedefault=True,
)
gm_tpm = File(
exists=True,
Expand All @@ -834,13 +842,22 @@ class _BASILCBFInputSpec(FSLCommandInputSpec):


class _BASILCBFOutputSpec(TraitedSpec):
mean_cbf_basil = File(exists=True, desc="cbf with spatial correction")
mean_cbf_gm_basil = File(exists=True, desc="cbf with spatial correction")
mean_cbf_basil = File(exists=True, desc="CBF map in absolute units (ml/100g/min).")
mean_cbf_gm_basil = File(
exists=True,
desc=(
"CBF map with gray matter partial volume correction. "
"This means that the map contains 'pure' GM perfusion estimates, in absolute units."
),
)
mean_cbf_wm_basil = File(
exists=True,
desc="cbf with spatial partial volume white matter correction",
desc=(
"CBF map with white matter partial volume correction. "
"This means that the map contains 'pure' WM perfusion estimates, in absolute units."
),
)
att_basil = File(exists=True, desc="arterial transit time")
att_basil = File(exists=True, desc="Arterial transit time map.")


class BASILCBF(FSLCommand):
Expand All @@ -866,6 +883,7 @@ def _run_interface(self, runtime):
def _gen_outfilename(self, suffix):
if isdefined(self.inputs.deltam):
out_file = self._gen_fname(self.inputs.deltam, suffix=suffix)

return os.path.abspath(out_file)

def _list_outputs(self):
Expand Down
4 changes: 2 additions & 2 deletions aslprep/interfaces/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@ def _run_interface(self, runtime):

class _CBFSummaryPlotInputSpec(BaseInterfaceInputSpec):
cbf = File(exists=True, mandatory=True, desc="")
label = traits.Str(exists=True, mandatory=True, desc="label")
vmax = traits.Int(exists=True, default_value=90, mandatory=True, desc="max value of asl")
label = traits.Str(mandatory=True, desc="label")
vmax = traits.Int(default_value=90, mandatory=True, desc="max value of asl")
ref_vol = File(exists=True, mandatory=True, desc="")


Expand Down
Loading