From 353da6a50f672c379b8ce4ac78079c27b7abfeed Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Fri, 15 Nov 2024 17:01:48 +0100 Subject: [PATCH 1/6] refactor: improve interface and logic for compute_brain_mask --- junifer/data/masks/_masks.py | 63 ++++++++++++++------------ junifer/data/masks/tests/test_masks.py | 1 - 2 files changed, 35 insertions(+), 29 deletions(-) diff --git a/junifer/data/masks/_masks.py b/junifer/data/masks/_masks.py index 779aa3eab..4c7748c1b 100644 --- a/junifer/data/masks/_masks.py +++ b/junifer/data/masks/_masks.py @@ -44,23 +44,24 @@ def compute_brain_mask( target_data: dict[str, Any], - extra_input: Optional[dict[str, Any]] = None, + warp_data: Optional[dict[str, Any]] = None, mask_type: str = "brain", threshold: float = 0.5, ) -> "Nifti1Image": """Compute the whole-brain, grey-matter or white-matter mask. This mask is calculated using the template space and resolution as found - in the ``target_data``. + in the ``target_data``. If target space is native, then the template is + warped to native and then thresholded. Parameters ---------- target_data : dict The corresponding item of the data object for which mask will be loaded. - extra_input : dict, optional - The other fields in the data object. Useful for accessing other data - types (default None). + warp_data : dict or None, optional + The warp data item of the data object. Needs to be provided if + ``target_data`` is in native space (default None). mask_type : {"brain", "gm", "wm"}, optional Type of mask to be computed: @@ -81,7 +82,7 @@ def compute_brain_mask( ------ ValueError If ``mask_type`` is invalid or - if ``extra_input`` is None when ``target_data``'s space is native. + if ``warp_data`` is None when ``target_data``'s space is native. """ logger.debug(f"Computing {mask_type} mask") @@ -90,39 +91,45 @@ def compute_brain_mask( raise_error(f"Unknown mask type: {mask_type}") # Check pre-requirements for space manipulation - target_space = target_data["space"] - # Set target standard space to target space - target_std_space = target_space - # Extra data type requirement check if target space is native - if target_space == "native": - # Check for extra inputs - if extra_input is None: - raise_error( - "No extra input provided, requires `Warp` " - "data type to infer target template space." - ) - # Set target standard space to warp file space source - for entry in extra_input["Warp"]: - if entry["dst"] == "native": - target_std_space = entry["src"] + if target_data["space"] == "native": + # Warp data check + if warp_data is None: + raise_error("No `warp_data` provided") + # Set space to fetch template using + target_std_space = warp_data["src"] + else: + # Set space to fetch template using + target_std_space = target_data["space"] # Fetch template in closest resolution template = get_template( space=target_std_space, target_data=target_data, - extra_input=extra_input, + extra_input=None, template_type=mask_type, ) + + # Resample and warp template if target space is native + if target_data["space"] == "native": + resampled_template = ANTsMaskWarper().warp( + mask_name=f"template_{target_std_space}_for_compute_brain_mask", + # use template here + mask_img=template, + src=target_std_space, + dst="native", + target_data=target_data, + warp_data=warp_data, + ) # Resample template to target image - target_img = target_data["data"] - resampled_template = resample_to_img( - source_img=template, target_img=target_img - ) + else: + resampled_template = resample_to_img( + source_img=template, target_img=target_data["data"] + ) - # Threshold and get mask + # Threshold resampled template and get mask mask = (get_data(resampled_template) >= threshold).astype("int8") - return new_img_like(target_img, mask) # type: ignore + return new_img_like(target_data["data"], mask) # type: ignore class MaskRegistry(BasePipelineDataRegistry, metaclass=Singleton): diff --git a/junifer/data/masks/tests/test_masks.py b/junifer/data/masks/tests/test_masks.py index 65e0e8a43..a51380062 100644 --- a/junifer/data/masks/tests/test_masks.py +++ b/junifer/data/masks/tests/test_masks.py @@ -64,7 +64,6 @@ def test_compute_brain_mask(mask_type: str, threshold: float) -> None: element_data = DefaultDataReader().fit_transform(dg["sub-01"]) mask = compute_brain_mask( target_data=element_data["BOLD"], - extra_input=None, mask_type=mask_type, ) assert isinstance(mask, nib.nifti1.Nifti1Image) From c9183a99715a9439e38430085267047759b8ac01 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Fri, 15 Nov 2024 17:02:24 +0100 Subject: [PATCH 2/6] refactor: improve space warping logic for MaskRegistry.get --- junifer/data/masks/_masks.py | 123 +++++++++++++++++++---------------- 1 file changed, 67 insertions(+), 56 deletions(-) diff --git a/junifer/data/masks/_masks.py b/junifer/data/masks/_masks.py index 4c7748c1b..d8927eb6c 100644 --- a/junifer/data/masks/_masks.py +++ b/junifer/data/masks/_masks.py @@ -393,6 +393,9 @@ def get( # noqa: C901 # Set target standard space to warp file space source target_std_space = warper_spec["src"] else: + # Set warper_spec so that compute_brain_mask does not fail when + # target space is non-native + warper_spec = None # Set target standard space to target space target_std_space = target_space @@ -405,31 +408,33 @@ def get( # noqa: C901 masks = [masks] # Check that masks passed as dicts have only one key - invalid_elements = [ + invalid_mask_specs = [ x for x in masks if isinstance(x, dict) and len(x) != 1 ] - if len(invalid_elements) > 0: + if invalid_mask_specs: raise_error( "Each of the masks dictionary must have only one key, " "the name of the mask. The following dictionaries are " - f"invalid: {invalid_elements}" + f"invalid: {invalid_mask_specs}" ) - # Check params for the intersection function + # Store params for nilearn.masking.intersect_mask() intersect_params = {} - true_masks = [] + # Store all mask specs for further operations + mask_specs = [] for t_mask in masks: if isinstance(t_mask, dict): + # Get params to pass to nilearn.masking.intersect_mask() if "threshold" in t_mask: intersect_params["threshold"] = t_mask["threshold"] continue - elif "connected" in t_mask: + if "connected" in t_mask: intersect_params["connected"] = t_mask["connected"] continue - # All the other elements are masks - true_masks.append(t_mask) + # Add mask spec + mask_specs.append(t_mask) - if len(true_masks) == 0: + if not mask_specs: raise_error("No mask was passed. At least one mask is required.") # Get the nested mask data type for the input data type @@ -437,7 +442,7 @@ def get( # noqa: C901 # Get all the masks all_masks = [] - for t_mask in true_masks: + for t_mask in mask_specs: if isinstance(t_mask, dict): mask_name = next(iter(t_mask.keys())) mask_params = t_mask[mask_name] @@ -461,20 +466,26 @@ def get( # noqa: C901 mask_object, _, mask_space = self.load( mask_name, path_only=False, resolution=resolution ) - # Replace mask space with target space if mask's space is - # inherit - if mask_space == "inherit": - mask_space = target_std_space - # If mask is callable like from nilearn + # If mask is callable like from nilearn; space will be inherit + # so no check for that if callable(mask_object): if mask_params is None: mask_params = {} # From nilearn - if mask_name != "compute_brain_mask": + if mask_name in [ + "compute_epi_mask", + "compute_background_mask", + ]: mask_img = mask_object(target_img, **mask_params) - # Not from nilearn + # custom compute_brain_mask + elif mask_name == "compute_brain_mask": + mask_img = mask_object( + target_data, warper_spec, **mask_params + ) + # custom registered; arm kept for clarity else: - mask_img = mask_object(target_data, **mask_params) + mask_img = mask_object(target_img, **mask_params) + # Mask is a Nifti1Image else: # Mask params provided @@ -484,23 +495,43 @@ def get( # noqa: C901 "Cannot pass callable params to a non-callable " "mask." ) - # Resample mask to target image - mask_img = resample_to_img( - source_img=mask_object, - target_img=target_img, - interpolation="nearest", - copy=True, - ) - # Convert mask space if required - if mask_space != target_std_space: - mask_img = ANTsMaskWarper().warp( - mask_name=mask_name, - mask_img=mask_img, - src=mask_space, - dst=target_std_space, - target_data=target_data, - warp_data=None, - ) + # Resample and warp mask if target data is native + if target_space == "native": + mask_name = f"{mask_name}_to_native" + # extra_input check done earlier and warper_spec exists + if warper_spec["warper"] == "fsl": + mask_img = FSLMaskWarper().warp( + mask_name=mask_name, + mask_img=mask_object, + target_data=target_data, + warp_data=warper_spec, + ) + elif warper_spec["warper"] == "ants": + mask_img = ANTsMaskWarper().warp( + mask_name=mask_name, + mask_img=mask_object, + src="", + dst="native", + target_data=target_data, + warp_data=warper_spec, + ) + else: + # Resample and warp mask + if mask_space != target_std_space: + mask_img = ANTsMaskWarper().warp( + mask_name=mask_name, + mask_img=mask_object, + src=mask_space, + dst=target_std_space, + target_data=target_data, + warp_data=warper_spec, + ) + # Resample mask to target image + else: + mask_img = resample_to_img( + source_img=mask_object, + target_img=target_data["data"], + ) all_masks.append(mask_img) @@ -510,7 +541,7 @@ def get( # noqa: C901 mask_img = intersect_masks(all_masks, **intersect_params) # Single mask else: - if len(intersect_params) > 0: + if intersect_params: # Yes, I'm this strict! raise_error( "Cannot pass parameters to the intersection function " @@ -518,26 +549,6 @@ def get( # noqa: C901 ) mask_img = all_masks[0] - # Warp mask if target data is native - if target_space == "native": - # extra_input check done earlier and warper_spec exists - if warper_spec["warper"] == "fsl": - mask_img = FSLMaskWarper().warp( - mask_name="native", - mask_img=mask_img, - target_data=target_data, - warp_data=warper_spec, - ) - elif warper_spec["warper"] == "ants": - mask_img = ANTsMaskWarper().warp( - mask_name="native", - mask_img=mask_img, - src="", - dst="native", - target_data=target_data, - warp_data=warper_spec, - ) - return mask_img From 2993346de5b385e5512615881c671e42d596bd96 Mon Sep 17 00:00:00 2001 From: Synchon Mandal Date: Mon, 18 Nov 2024 14:49:18 +0100 Subject: [PATCH 3/6] refactor: fix mask logic --- junifer/data/masks/_masks.py | 43 ++++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/junifer/data/masks/_masks.py b/junifer/data/masks/_masks.py index d8927eb6c..3eb27be67 100644 --- a/junifer/data/masks/_masks.py +++ b/junifer/data/masks/_masks.py @@ -495,6 +495,28 @@ def get( # noqa: C901 "Cannot pass callable params to a non-callable " "mask." ) + + # Resample and warp mask to standard space + if mask_space != target_std_space: + mask_img = ANTsMaskWarper().warp( + mask_name=mask_name, + mask_img=mask_object, + src=mask_space, + dst=target_std_space, + target_data=target_data, + warp_data=warper_spec, + ) + + else: + # Resample mask to target image; no further warping + if target_space != "native": + mask_img = resample_to_img( + source_img=mask_object, + target_img=target_data["data"], + ) + # Set mask_img in case no warping happens before this + else: + mask_img = mask_object # Resample and warp mask if target data is native if target_space == "native": mask_name = f"{mask_name}_to_native" @@ -502,36 +524,19 @@ def get( # noqa: C901 if warper_spec["warper"] == "fsl": mask_img = FSLMaskWarper().warp( mask_name=mask_name, - mask_img=mask_object, + mask_img=mask_img, target_data=target_data, warp_data=warper_spec, ) elif warper_spec["warper"] == "ants": mask_img = ANTsMaskWarper().warp( mask_name=mask_name, - mask_img=mask_object, + mask_img=mask_img, src="", dst="native", target_data=target_data, warp_data=warper_spec, ) - else: - # Resample and warp mask - if mask_space != target_std_space: - mask_img = ANTsMaskWarper().warp( - mask_name=mask_name, - mask_img=mask_object, - src=mask_space, - dst=target_std_space, - target_data=target_data, - warp_data=warper_spec, - ) - # Resample mask to target image - else: - mask_img = resample_to_img( - source_img=mask_object, - target_img=target_data["data"], - ) all_masks.append(mask_img) From 8a59908b6725e83ced25298e96031459cdfb8661 Mon Sep 17 00:00:00 2001 From: Fede Raimondo Date: Mon, 25 Nov 2024 11:23:10 +0100 Subject: [PATCH 4/6] Fix docstring + add debug logging --- junifer/data/masks/_ants_mask_warper.py | 11 ++++--- junifer/data/masks/_masks.py | 33 ++++++++++++++++++++ junifer/data/parcellations/_parcellations.py | 19 ++++++++++- 3 files changed, 57 insertions(+), 6 deletions(-) diff --git a/junifer/data/masks/_ants_mask_warper.py b/junifer/data/masks/_ants_mask_warper.py index 8e6a9b7f8..64556d256 100644 --- a/junifer/data/masks/_ants_mask_warper.py +++ b/junifer/data/masks/_ants_mask_warper.py @@ -46,18 +46,19 @@ def warp( The mask image to transform. src : str The data type or template space to warp from. - It should be empty string if ``dst="T1w"``. + It should be empty string if ``dst="native"``. dst : str The data type or template space to warp to. - `"T1w"` is the only allowed data type and it uses the resampled T1w - found in ``target_data.reference_path``. The ``"reference_path"`` - key is added when :class:`.SpaceWarper` is used. + `"native"` is the only allowed data type and it uses the resampled + T1w found in ``target_data.reference_path``. The + ``"reference_path"`` key is added when :class:`.SpaceWarper` is + used. target_data : dict The corresponding item of the data object to which the mask will be applied. warp_data : dict or None The warp data item of the data object. The value is unused if - ``dst!="T1w"``. + ``dst!="native"``. Returns ------- diff --git a/junifer/data/masks/_masks.py b/junifer/data/masks/_masks.py index 3eb27be67..bb7d743bd 100644 --- a/junifer/data/masks/_masks.py +++ b/junifer/data/masks/_masks.py @@ -376,6 +376,8 @@ def get( # noqa: C901 """ # Check pre-requirements for space manipulation target_space = target_data["space"] + logger.debug(f"Getting masks: {masks} in {target_space} space") + # Extra data type requirement check if target space is native if target_space == "native": # Check for extra inputs @@ -392,6 +394,9 @@ def get( # noqa: C901 ) # Set target standard space to warp file space source target_std_space = warper_spec["src"] + logger.debug( + f"Target space is native. Will warp from {target_std_space}" + ) else: # Set warper_spec so that compute_brain_mask does not fail when # target space is non-native @@ -453,22 +458,40 @@ def get( # noqa: C901 # If mask is being inherited from the datagrabber or a # preprocessor, check that it's accessible if mask_name == "inherit": + logger.debug("Using inherited mask.") if inherited_mask_item is None: raise_error( "Cannot inherit mask from the target data. Either the " "DataGrabber or a Preprocessor does not provide " "`mask` for the target data type." ) + logger.debug( + f"Inherited mask is in {inherited_mask_item['space']} " + "space." + ) mask_img = inherited_mask_item["data"] + + if inherited_mask_item["space"] != target_space: + raise_error( + "Inherited mask space does not match target space." + ) + logger.debug("Resampling inherited mask to target image.") + # Resample inherited mask to target image + mask_img = resample_to_img( + source_img=mask_img, + target_img=target_data["data"], + ) # Starting with new mask else: # Load mask + logger.debug(f"Loading parcellation {t_mask}.") mask_object, _, mask_space = self.load( mask_name, path_only=False, resolution=resolution ) # If mask is callable like from nilearn; space will be inherit # so no check for that if callable(mask_object): + logger.debug("Computing mask (callable).") if mask_params is None: mask_params = {} # From nilearn @@ -498,6 +521,10 @@ def get( # noqa: C901 # Resample and warp mask to standard space if mask_space != target_std_space: + logger.debug( + f"Warping {t_mask} to {target_std_space} space " + "using ants." + ) mask_img = ANTsMaskWarper().warp( mask_name=mask_name, mask_img=mask_object, @@ -509,6 +536,7 @@ def get( # noqa: C901 else: # Resample mask to target image; no further warping + logger.debug(f"Resampling {t_mask} to target image.") if target_space != "native": mask_img = resample_to_img( source_img=mask_object, @@ -519,6 +547,10 @@ def get( # noqa: C901 mask_img = mask_object # Resample and warp mask if target data is native if target_space == "native": + logger.debug( + "Warping mask to native space using " + f"{warper_spec['warper']}." + ) mask_name = f"{mask_name}_to_native" # extra_input check done earlier and warper_spec exists if warper_spec["warper"] == "fsl": @@ -543,6 +575,7 @@ def get( # noqa: C901 # Multiple masks, need intersection / union if len(all_masks) > 1: # Intersect / union of masks + logger.debug("Intersecting masks.") mask_img = intersect_masks(all_masks, **intersect_params) # Single mask else: diff --git a/junifer/data/parcellations/_parcellations.py b/junifer/data/parcellations/_parcellations.py index dfe79add3..c339dac12 100644 --- a/junifer/data/parcellations/_parcellations.py +++ b/junifer/data/parcellations/_parcellations.py @@ -400,6 +400,7 @@ def get( """ # Check pre-requirements for space manipulation target_space = target_data["space"] + logger.debug(f"Getting {parcellations} in{target_space} space.") # Extra data type requirement check if target space is native if target_space == "native": # Check for extra inputs @@ -416,6 +417,9 @@ def get( ) # Set target standard space to warp file space source target_std_space = warper_spec["src"] + logger.debug( + f"Target space is native. Will warp from {target_std_space}" + ) else: # Set target standard space to target space target_std_space = target_space @@ -433,6 +437,7 @@ def get( all_labels = [] for name in parcellations: # Load parcellation + logger.debug(f"Loading parcellation {name}") img, labels, _, space = self.load( name=name, resolution=resolution, @@ -441,6 +446,9 @@ def get( # Convert parcellation spaces if required; # cannot be "native" due to earlier check if space != target_std_space: + logger.debug( + f"Warping {name} to {target_std_space} space using ants." + ) raw_img = ANTsParcellationWarper().warp( parcellation_name=name, parcellation_img=img, @@ -452,6 +460,7 @@ def get( # Remove extra dimension added by ANTs img = image.math_img("np.squeeze(img)", img=raw_img) + logger.debug(f"Resampling {name} to target image.") # Resample parcellation to target image img_to_merge = image.resample_to_img( source_img=img, @@ -469,6 +478,7 @@ def get( labels = all_labels[0] # Parcellations are already transformed to target standard space else: + logger.debug("Merging parcellations.") resampled_parcellation_img, labels = merge_parcellations( parcellations_list=all_parcellations, parcellations_names=parcellations, @@ -477,6 +487,10 @@ def get( # Warp parcellation if target space is native if target_space == "native": + logger.debug( + "Warping parcellation to native space using " + f"{warper_spec['warper']}." + ) # extra_input check done earlier and warper_spec exists if warper_spec["warper"] == "fsl": resampled_parcellation_img = FSLParcellationWarper().warp( @@ -1194,7 +1208,10 @@ def _retrieve_aicha( # Load labels labels = pd.read_csv( - parcellation_lname, sep="\t", header=None, skiprows=[0] # type: ignore + parcellation_lname, + sep="\t", + header=None, + skiprows=[0], # type: ignore )[0].to_list() return parcellation_fname, labels From 8a6d35e0b358909532e8a240ce853cfbaf9e2e8d Mon Sep 17 00:00:00 2001 From: Fede Raimondo Date: Mon, 25 Nov 2024 15:38:44 +0100 Subject: [PATCH 5/6] Convert reference to a datatype --- junifer/data/coordinates/_fsl_coordinates_warper.py | 2 +- junifer/data/masks/_ants_mask_warper.py | 12 ++++++++---- junifer/data/masks/_fsl_mask_warper.py | 2 +- .../data/parcellations/_ants_parcellation_warper.py | 11 ++++++++--- .../data/parcellations/_fsl_parcellation_warper.py | 2 +- junifer/datagrabber/pattern_validation_mixin.py | 2 ++ junifer/preprocess/warping/_ants_warper.py | 4 ++-- junifer/preprocess/warping/_fsl_warper.py | 4 ++-- 8 files changed, 25 insertions(+), 14 deletions(-) diff --git a/junifer/data/coordinates/_fsl_coordinates_warper.py b/junifer/data/coordinates/_fsl_coordinates_warper.py index b12519e32..910f2e95a 100644 --- a/junifer/data/coordinates/_fsl_coordinates_warper.py +++ b/junifer/data/coordinates/_fsl_coordinates_warper.py @@ -69,7 +69,7 @@ def warp( f"{pretransform_coordinates_path.resolve()}", "| img2imgcoord -mm", f"-src {target_data['path'].resolve()}", - f"-dest {target_data['reference_path'].resolve()}", + f"-dest {target_data['reference']["path"].resolve()}", f"-warp {warp_data['path'].resolve()}", f"> {transformed_coords_path.resolve()};", f"sed -i 1d {transformed_coords_path.resolve()}", diff --git a/junifer/data/masks/_ants_mask_warper.py b/junifer/data/masks/_ants_mask_warper.py index 64556d256..3a739e59e 100644 --- a/junifer/data/masks/_ants_mask_warper.py +++ b/junifer/data/masks/_ants_mask_warper.py @@ -50,9 +50,9 @@ def warp( dst : str The data type or template space to warp to. `"native"` is the only allowed data type and it uses the resampled - T1w found in ``target_data.reference_path``. The - ``"reference_path"`` key is added when :class:`.SpaceWarper` is - used. + T1w found in ``target_data.reference``. The + ``"reference"`` key is added when :class:`.SpaceWarper` is + used or if the data is provided native space. target_data : dict The corresponding item of the data object to which the mask will be applied. @@ -88,6 +88,10 @@ def warp( # Warp data check if warp_data is None: raise_error("No `warp_data` provided") + if "reference" not in target_data: + raise_error("No `reference` provided") + if "path" not in target_data["reference"]: + raise_error("No `path` provided in `reference`") logger.debug("Using ANTs for mask transformation") @@ -105,7 +109,7 @@ def warp( "-n 'GenericLabel[NearestNeighbor]'", f"-i {prewarp_mask_path.resolve()}", # use resampled reference - f"-r {target_data['reference_path'].resolve()}", + f"-r {target_data['reference']["path"].resolve()}", f"-t {warp_data['path'].resolve()}", f"-o {warped_mask_path.resolve()}", ] diff --git a/junifer/data/masks/_fsl_mask_warper.py b/junifer/data/masks/_fsl_mask_warper.py index cfbcce684..3be873b43 100644 --- a/junifer/data/masks/_fsl_mask_warper.py +++ b/junifer/data/masks/_fsl_mask_warper.py @@ -74,7 +74,7 @@ def warp( "--interp=nn", f"-i {prewarp_mask_path.resolve()}", # use resampled reference - f"-r {target_data['reference_path'].resolve()}", + f"-r {target_data['reference']["path"].resolve()}", f"-w {warp_data['path'].resolve()}", f"-o {warped_mask_path.resolve()}", ] diff --git a/junifer/data/parcellations/_ants_parcellation_warper.py b/junifer/data/parcellations/_ants_parcellation_warper.py index 755183821..ed2907a6e 100644 --- a/junifer/data/parcellations/_ants_parcellation_warper.py +++ b/junifer/data/parcellations/_ants_parcellation_warper.py @@ -50,8 +50,9 @@ def warp( dst : str The data type or template space to warp to. `"T1w"` is the only allowed data type and it uses the resampled T1w - found in ``target_data.reference_path``. The ``"reference_path"`` - key is added when :class:`.SpaceWarper` is used. + found in ``target_data.reference``. The ``"reference"`` + key is added if the :class:`.SpaceWarper` is used or if the + data is provided in native space. target_data : dict The corresponding item of the data object to which the parcellation will be applied. @@ -87,6 +88,10 @@ def warp( # Warp data check if warp_data is None: raise_error("No `warp_data` provided") + if "reference" not in target_data: + raise_error("No `reference` provided") + if "path" not in target_data["reference"]: + raise_error("No `path` provided in `reference`") logger.debug("Using ANTs for parcellation transformation") @@ -108,7 +113,7 @@ def warp( "-n 'GenericLabel[NearestNeighbor]'", f"-i {prewarp_parcellation_path.resolve()}", # use resampled reference - f"-r {target_data['reference_path'].resolve()}", + f"-r {target_data['reference']["path"].resolve()}", f"-t {warp_data['path'].resolve()}", f"-o {warped_parcellation_path.resolve()}", ] diff --git a/junifer/data/parcellations/_fsl_parcellation_warper.py b/junifer/data/parcellations/_fsl_parcellation_warper.py index d991458d5..78c5f4b48 100644 --- a/junifer/data/parcellations/_fsl_parcellation_warper.py +++ b/junifer/data/parcellations/_fsl_parcellation_warper.py @@ -78,7 +78,7 @@ def warp( "--interp=nn", f"-i {prewarp_parcellation_path.resolve()}", # use resampled reference - f"-r {target_data['reference_path'].resolve()}", + f"-r {target_data['reference']["path"].resolve()}", f"-w {warp_data['path'].resolve()}", f"-o {warped_parcellation_path.resolve()}", ] diff --git a/junifer/datagrabber/pattern_validation_mixin.py b/junifer/datagrabber/pattern_validation_mixin.py index 6acf8fab1..24d5cf743 100644 --- a/junifer/datagrabber/pattern_validation_mixin.py +++ b/junifer/datagrabber/pattern_validation_mixin.py @@ -33,6 +33,8 @@ "mandatory": ["pattern", "format"], "optional": ["mappings"], }, + "reference": {"mandatory": ["pattern"], "optional": []}, + "prewarp_space" : {"mandatory": [], "optional": []}, }, }, "Warp": { diff --git a/junifer/preprocess/warping/_ants_warper.py b/junifer/preprocess/warping/_ants_warper.py index 789a2495e..f1bbbe2b5 100644 --- a/junifer/preprocess/warping/_ants_warper.py +++ b/junifer/preprocess/warping/_ants_warper.py @@ -59,7 +59,7 @@ def preprocess( ------- dict The ``input`` dictionary with modified ``data`` and ``space`` key - values and new ``reference_path`` key whose value points to the + values and new ``reference`` key whose value points to the reference file used for warping. Raises @@ -129,7 +129,7 @@ def preprocess( # Load nifti input["data"] = nib.load(apply_transforms_out_path) # Save resampled reference path - input["reference_path"] = resample_image_out_path + input["reference"] = {"path": resample_image_out_path} # Keep pre-warp space for further operations input["prewarp_space"] = input["space"] # Use reference input's space as warped input's space diff --git a/junifer/preprocess/warping/_fsl_warper.py b/junifer/preprocess/warping/_fsl_warper.py index 3f25cd92f..d504c5694 100644 --- a/junifer/preprocess/warping/_fsl_warper.py +++ b/junifer/preprocess/warping/_fsl_warper.py @@ -55,7 +55,7 @@ def preprocess( ------- dict The ``input`` dictionary with modified ``data`` and ``space`` key - values and new ``reference_path`` key whose value points to the + values and new ``reference`` key whose value points to the reference file used for warping. Raises @@ -116,7 +116,7 @@ def preprocess( # Load nifti input["data"] = nib.load(applywarp_out_path) # Save resampled reference path - input["reference_path"] = flirt_out_path + input["reference"] = {"path": flirt_out_path} # Keep pre-warp space for further operations input["prewarp_space"] = input["space"] # Use reference input's space as warped input's space From 0a84c2e1a0bafa18c8f077a16b397ec480b34289 Mon Sep 17 00:00:00 2001 From: Fede Raimondo Date: Mon, 25 Nov 2024 18:59:56 +0100 Subject: [PATCH 6/6] Compute brain gm/wm mask on subject's VBM data --- junifer/data/masks/_masks.py | 69 ++++++++++++++++++++++++++---------- 1 file changed, 51 insertions(+), 18 deletions(-) diff --git a/junifer/data/masks/_masks.py b/junifer/data/masks/_masks.py index bb7d743bd..e75f86cbe 100644 --- a/junifer/data/masks/_masks.py +++ b/junifer/data/masks/_masks.py @@ -47,6 +47,8 @@ def compute_brain_mask( warp_data: Optional[dict[str, Any]] = None, mask_type: str = "brain", threshold: float = 0.5, + source: str = "template", + extra_input: Optional[dict[str, Any]] = None, ) -> "Nifti1Image": """Compute the whole-brain, grey-matter or white-matter mask. @@ -72,6 +74,10 @@ def compute_brain_mask( (default "brain"). threshold : float, optional The value under which the template is cut off (default 0.5). + source : {"subject", "template"}, optional + The source of the mask. If "subject", the mask is computed from the + subject's data (VBM_GM or VBM_WM). If "template", the mask is computed + from the template data (default "template"). Returns ------- @@ -90,6 +96,12 @@ def compute_brain_mask( if mask_type not in ["brain", "gm", "wm"]: raise_error(f"Unknown mask type: {mask_type}") + if source not in ["subject", "template"]: + raise_error(f"Unknown mask source: {source}") + + if source == "subject" and mask_type not in ["gm", "wm"]: + raise_error(f"Unknown mask type: {mask_type} for subject space") + # Check pre-requirements for space manipulation if target_data["space"] == "native": # Warp data check @@ -101,25 +113,43 @@ def compute_brain_mask( # Set space to fetch template using target_std_space = target_data["space"] - # Fetch template in closest resolution - template = get_template( - space=target_std_space, - target_data=target_data, - extra_input=None, - template_type=mask_type, - ) - - # Resample and warp template if target space is native - if target_data["space"] == "native": - resampled_template = ANTsMaskWarper().warp( - mask_name=f"template_{target_std_space}_for_compute_brain_mask", - # use template here - mask_img=template, - src=target_std_space, - dst="native", + if source == "subject": + key = f"VBM_{mask_type.upper()}" + if key not in extra_input: + raise_error( + f"Cannot compute {mask_type} from subject's data. " + f"Missing {key} in extra input." + ) + template = extra_input[key]["data"] + template_space = extra_input[key]["space"] + else: + # Fetch template in closest resolution + template = get_template( + space=target_std_space, target_data=target_data, - warp_data=warp_data, + extra_input=None, + template_type=mask_type, ) + template_space = target_std_space + # Resample and warp template if target space is native + if target_data["space"] == "native" and template_space != "native": + if warp_data["warper"] == "fsl": + resampled_template = FSLMaskWarper().warp( + mask_name=f"template_{target_std_space}_for_compute_brain_mask", + mask_img=template, + target_data=target_data, + warp_data=warp_data, + ) + elif warp_data["warper"] == "ants": + resampled_template = ANTsMaskWarper().warp( + mask_name=f"template_{target_std_space}_for_compute_brain_mask", + # use template here + mask_img=template, + src=target_std_space, + dst="native", + target_data=target_data, + warp_data=warp_data, + ) # Resample template to target image else: resampled_template = resample_to_img( @@ -503,7 +533,10 @@ def get( # noqa: C901 # custom compute_brain_mask elif mask_name == "compute_brain_mask": mask_img = mask_object( - target_data, warper_spec, **mask_params + target_data, + warper_spec, + extra_input=extra_input, + **mask_params, ) # custom registered; arm kept for clarity else: