Skip to content

Commit

Permalink
Compute brain gm/wm mask on subject's VBM data
Browse files Browse the repository at this point in the history
  • Loading branch information
fraimondo committed Nov 25, 2024
1 parent deb8306 commit 6d5cde9
Showing 1 changed file with 51 additions and 18 deletions.
69 changes: 51 additions & 18 deletions junifer/data/masks/_masks.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def compute_brain_mask(
warp_data: Optional[dict[str, Any]] = None,
mask_type: str = "brain",
threshold: float = 0.5,
source: str = "template",
extra_input: Optional[dict[str, Any]] = None,
) -> "Nifti1Image":
"""Compute the whole-brain, grey-matter or white-matter mask.
Expand All @@ -72,6 +74,10 @@ def compute_brain_mask(
(default "brain").
threshold : float, optional
The value under which the template is cut off (default 0.5).
source : {"subject", "template"}, optional
The source of the mask. If "subject", the mask is computed from the
subject's data (VBM_GM or VBM_WM). If "template", the mask is computed
from the template data (default "template").
Returns
-------
Expand All @@ -90,6 +96,12 @@ def compute_brain_mask(
if mask_type not in ["brain", "gm", "wm"]:
raise_error(f"Unknown mask type: {mask_type}")

if source not in ["subject", "template"]:
raise_error(f"Unknown mask source: {source}")

if source == "subject" and mask_type not in ["gm", "wm"]:
raise_error(f"Unknown mask type: {mask_type} for subject space")

# Check pre-requirements for space manipulation
if target_data["space"] == "native":
# Warp data check
Expand All @@ -101,25 +113,43 @@ def compute_brain_mask(
# Set space to fetch template using
target_std_space = target_data["space"]

# Fetch template in closest resolution
template = get_template(
space=target_std_space,
target_data=target_data,
extra_input=None,
template_type=mask_type,
)

# Resample and warp template if target space is native
if target_data["space"] == "native":
resampled_template = ANTsMaskWarper().warp(
mask_name=f"template_{target_std_space}_for_compute_brain_mask",
# use template here
mask_img=template,
src=target_std_space,
dst="native",
if source == "subject":
key = f"VBM_{mask_type.upper()}"
if key not in extra_input:
raise_error(
f"Cannot compute {mask_type} from subject's data. "
f"Missing {key} in extra input."
)
template = extra_input[key]["data"]
template_space = extra_input[key]["space"]
else:
# Fetch template in closest resolution
template = get_template(
space=target_std_space,
target_data=target_data,
warp_data=warp_data,
extra_input=None,
template_type=mask_type,
)
template_space = target_std_space
# Resample and warp template if target space is native
if target_data["space"] == "native" and template_space != "native":
if warp_data["warper"] == "fsl":
resampled_template = FSLMaskWarper().warp(
mask_name=f"template_{target_std_space}_for_compute_brain_mask",
mask_img=template,
target_data=target_data,
warp_data=warp_data,
)
elif warp_data["warper"] == "ants":
resampled_template = ANTsMaskWarper().warp(
mask_name=f"template_{target_std_space}_for_compute_brain_mask",
# use template here
mask_img=template,
src=target_std_space,
dst="native",
target_data=target_data,
warp_data=warp_data,
)
# Resample template to target image
else:
resampled_template = resample_to_img(
Expand Down Expand Up @@ -503,7 +533,10 @@ def get( # noqa: C901
# custom compute_brain_mask
elif mask_name == "compute_brain_mask":
mask_img = mask_object(
target_data, warper_spec, **mask_params
target_data,
warper_spec,
extra_input=extra_input,
**mask_params,
)
# custom registered; arm kept for clarity
else:
Expand Down

0 comments on commit 6d5cde9

Please sign in to comment.