Skip to content

Add CropOrPadAtCenter transform class #1233

@themantalope

Description

@themantalope

🚀 Feature
A crop or pad function that allows the user to crop/pad an image but specify where the center of the new image should be based on the input image. For example, let's say I have an organ centered at location [45,62,101] in a volume that is [368, 512, 128] in size and I want my new image to be [128,128,128] in size, and the center of the new image (i.e. [64,64,64]) to map to the point [45,62,101] in the original image.

Motivation

Developing datasets with tumors, need to crop and pad around the tumors.

Pitch

Described above. Proposed code below.

Alternatives

N/A

Additional context
Here is my proposed code. I've tested it locally but probably need more robust testing.

import torchio as tio

class CropOrPadAtCenter(tio.Transform):
    
    def __init__(self, center , target_shape, image_or_mask_name=None, **kwargs):
        super().__init__(**kwargs)
        self.target_shape = target_shape
        self.center = center
        self.image_or_mask_name = image_or_mask_name
        

    def apply_transform(self, subject):
        # if image or mask name is None, use the first image as the base image to work with
        if self.image_or_mask_name is None:
            images = list(subject.keys())
            base_image = images[0]
            self.image_or_mask_name = base_image
        
        image = subject[self.image_or_mask_name]
        non_channel_im_shape = image.shape[-3:]
        
        # first assert that the center is in the image
        assert all([c >= 0 and c < s for c, s in zip(self.center, image.shape[-3:])])
        # next determine if we need to pad. if the bounds of the target shape are outside the image, we need to pad
        
        # compute how many pixels we need to pad for each dimension 
        pad = []
        for c, s, t in zip(self.center, non_channel_im_shape, self.target_shape):
            
            lower = 0
            upper = 0
            
            if c - t//2 < 0:
                lower = abs(c - t//2)
            
            if c + t//2 > s:
                upper = c + t//2 - s

            pad.extend([lower, upper])
        # pad the image
        p = tuple(pad)
        pad_xform = tio.Pad(p)
        subject = pad_xform(subject)

        # now crop the image
        # the crop function expects the start and dim_size - end of the crop (weird)
        image = subject[self.image_or_mask_name]
        non_channel_im_shape = image.shape[-3:]
        lower_bound_pads = [p for i, p in enumerate(pad) if i % 2 == 0]
        new_center = [c + l for c, l in zip(self.center, lower_bound_pads)]
        self.center = new_center
        crop = []
        width = self.target_shape[0]//2
        height = self.target_shape[1]//2
        depth = self.target_shape[2]//2
        im_width, im_height, im_depth = non_channel_im_shape
        for d, s, c in zip([width, height, depth], [im_width, im_height, im_depth], [self.center[0], self.center[1], self.center[2]]):
            start = c - d
            end = s - (c+d)
            crop.extend([start, end])

        ct = tuple(crop)
        crop_xform = tio.Crop(ct)
        subject = crop_xform(subject)
        return subject

        

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions