From 6d5cde9beb36702110de4b9098b7d727ac9cd8a1 Mon Sep 17 00:00:00 2001 From: Fede Raimondo Date: Mon, 25 Nov 2024 18:59:56 +0100 Subject: [PATCH] Compute brain gm/wm mask on subject's VBM data --- junifer/data/masks/_masks.py | 69 ++++++++++++++++++++++++++---------- 1 file changed, 51 insertions(+), 18 deletions(-) diff --git a/junifer/data/masks/_masks.py b/junifer/data/masks/_masks.py index 050166b60..2c79ecaba 100644 --- a/junifer/data/masks/_masks.py +++ b/junifer/data/masks/_masks.py @@ -47,6 +47,8 @@ def compute_brain_mask( warp_data: Optional[dict[str, Any]] = None, mask_type: str = "brain", threshold: float = 0.5, + source: str = "template", + extra_input: Optional[dict[str, Any]] = None, ) -> "Nifti1Image": """Compute the whole-brain, grey-matter or white-matter mask. @@ -72,6 +74,10 @@ def compute_brain_mask( (default "brain"). threshold : float, optional The value under which the template is cut off (default 0.5). + source : {"subject", "template"}, optional + The source of the mask. If "subject", the mask is computed from the + subject's data (VBM_GM or VBM_WM). If "template", the mask is computed + from the template data (default "template"). Returns ------- @@ -90,6 +96,12 @@ def compute_brain_mask( if mask_type not in ["brain", "gm", "wm"]: raise_error(f"Unknown mask type: {mask_type}") + if source not in ["subject", "template"]: + raise_error(f"Unknown mask source: {source}") + + if source == "subject" and mask_type not in ["gm", "wm"]: + raise_error(f"Unknown mask type: {mask_type} for subject space") + # Check pre-requirements for space manipulation if target_data["space"] == "native": # Warp data check @@ -101,25 +113,43 @@ def compute_brain_mask( # Set space to fetch template using target_std_space = target_data["space"] - # Fetch template in closest resolution - template = get_template( - space=target_std_space, - target_data=target_data, - extra_input=None, - template_type=mask_type, - ) - - # Resample and warp template if target space is native - if target_data["space"] == "native": - resampled_template = ANTsMaskWarper().warp( - mask_name=f"template_{target_std_space}_for_compute_brain_mask", - # use template here - mask_img=template, - src=target_std_space, - dst="native", + if source == "subject": + key = f"VBM_{mask_type.upper()}" + if key not in extra_input: + raise_error( + f"Cannot compute {mask_type} from subject's data. " + f"Missing {key} in extra input." + ) + template = extra_input[key]["data"] + template_space = extra_input[key]["space"] + else: + # Fetch template in closest resolution + template = get_template( + space=target_std_space, target_data=target_data, - warp_data=warp_data, + extra_input=None, + template_type=mask_type, ) + template_space = target_std_space + # Resample and warp template if target space is native + if target_data["space"] == "native" and template_space != "native": + if warp_data["warper"] == "fsl": + resampled_template = FSLMaskWarper().warp( + mask_name=f"template_{target_std_space}_for_compute_brain_mask", + mask_img=template, + target_data=target_data, + warp_data=warp_data, + ) + elif warp_data["warper"] == "ants": + resampled_template = ANTsMaskWarper().warp( + mask_name=f"template_{target_std_space}_for_compute_brain_mask", + # use template here + mask_img=template, + src=target_std_space, + dst="native", + target_data=target_data, + warp_data=warp_data, + ) # Resample template to target image else: resampled_template = resample_to_img( @@ -503,7 +533,10 @@ def get( # noqa: C901 # custom compute_brain_mask elif mask_name == "compute_brain_mask": mask_img = mask_object( - target_data, warper_spec, **mask_params + target_data, + warper_spec, + extra_input=extra_input, + **mask_params, ) # custom registered; arm kept for clarity else: