Skip to content

Update a Subject's affine matrix during RandomAffine transformation #1267

@Jesse-Phitidis

Description

@Jesse-Phitidis

🚀 Feature

When applying the RandomAffine transformation to a Subject, the affine matrix should be updated.

Motivation

While the affine matrix is often not used in the training of a model, there are certain cases in which it might be. For example, if we want to train a model which is conditional on the voxel spacing, then we must track the voxel spacing through any data augmentation which we apply (of course, non-linear transforms do not allow this, but affine transforms do). The simplest and most versatile way to achieve this is to update the affine matrix after each spatial augmentation transformation (like is already done in the Resample preprocessing transform).

Pitch

Improve the RandomAffine class so as to update the affine matrix of the subject under any randomly sampled combination of scaling, rotation, and translation.

Alternatives

Here is my attempt. As far as I can tell, it does the job. I have named my classes, which inherit from RandomAffine and Affine as MyRandomAffine and MyAffine.

The important changes are in the MyAffine class, which is called in the apply_transform method of the MyRandomAffine transform. Specifically, the new method get_new_affine_matrix, called at the end of the apply_transform method of MyAffine. In short, the get_new_affine_matrix works like so:

  1. Get the randomly sampled scaling, rotation and translation parameters (which were used in the SimpleITK transformations already by this point)
  2. The SimpleITK transformations map the original voxel indices to new voxel indices. We need to find the point about which this mapping occurs, and set it to be the origin in voxels. We create the offset_voxels matrix for this.
  3. We then calculate the forward transformation that was done by the SimpleITK transforms. The first transformation applied is the rotation and translation, then the scaling happens afterwards. Notice how we pre and post multiply the transformation matrices by the resest_voxels and offset_voxels matrices respectively (strictly I think this may not be necessary for the rotation and translation transform), because we want to perform these transformations about the same point that SimpleITK did, but we then want to recover our original coordinates, since they are what the original affine matrix was based upon.
  4. We get the backward transformation as the inverse of the forward transformation and pre multiply it by the original affine. This now gives us a mapping from the output voxel indices to world space (i.e. this is our new affine matrix).

Here is the code:

import numpy as np
import torchio as tio
import torch

from torchio import Subject
from torchio.constants import TYPE, INTENSITY
from torchio.transforms.augmentation.spatial.random_affine import RandomAffine, Affine, get_borders_mean
from torchio.data.io import nib_to_sitk
from numbers import Number


def get_pixdim_from_affine(affine: np.array):
    rot = affine[:-1,:-1]
    return np.sqrt(np.sum(rot**2, axis=0))


class MyRandomAffine(RandomAffine):
    
    def __init__(self, update_affine: bool = False, **kwargs):
        super().__init__(**kwargs)
        self.update_affine = update_affine
    
    def apply_transform(self, subject: Subject) -> Subject:
            
        scaling_params, rotation_params, translation_params = self.get_params(
            self.scales,
            self.degrees,
            self.translation,
            self.isotropic,
        )
        
        arguments = {
            'scales': scaling_params.tolist(),
            'degrees': rotation_params.tolist(),
            'translation': translation_params.tolist(),
            'center': self.center,
            'default_pad_value': self.default_pad_value,
            'image_interpolation': self.image_interpolation,
            'label_interpolation': self.label_interpolation,
            'check_shape': self.check_shape,
        }
        transform = MyAffine(update_affine=self.update_affine, **self.add_include_exclude(arguments))
        transformed = transform(subject)
        assert isinstance(transformed, Subject)
        return transformed
    
    
class MyAffine(Affine):
    
    def __init__(self, update_affine: bool = False, **kwargs):
        super().__init__(**kwargs)
        self.update_affine = update_affine
    
    def apply_transform(self, subject: Subject) -> Subject:
        if self.check_shape:
            subject.check_consistent_spatial_shape()
        default_value: float
        for image in self.get_images(subject):
            transform = self.get_affine_transform(image)
            transformed_tensors = []
            for tensor in image.data:
                sitk_image = nib_to_sitk(
                    tensor[np.newaxis],
                    image.affine,
                    force_3d=True,
                )
                if image[TYPE] != INTENSITY:
                    interpolation = self.label_interpolation
                    default_value = 0
                else:
                    interpolation = self.image_interpolation
                    if self.default_pad_value == 'minimum':
                        default_value = tensor.min().item()
                    elif self.default_pad_value == 'mean':
                        default_value = get_borders_mean(
                            sitk_image,
                            filter_otsu=False,
                        )
                    elif self.default_pad_value == 'otsu':
                        default_value = get_borders_mean(
                            sitk_image,
                            filter_otsu=True,
                        )
                    else:
                        assert isinstance(self.default_pad_value, Number)
                        default_value = float(self.default_pad_value)
                transformed_tensor = self.apply_affine_transform(
                    sitk_image,
                    transform,
                    interpolation,
                    default_value,
                )
                transformed_tensors.append(transformed_tensor)
            image.set_data(torch.stack(transformed_tensors))
            if self.update_affine:
                new_affine = self.get_new_affine_matrix(image)
                image.affine = new_affine
        return subject
    
    def get_new_affine_matrix(self, image: tio.Image) -> np.ndarray:
        # get the scaling, rotation and translation parameters
        scaling = np.asarray(self.scales).copy()
        rotation = np.asarray(self.degrees).copy()
        translation = np.asarray(self.translation).copy()
        # get the original affine matrix
        original_affine = image.affine
        # get matrix to offset voxel indices so that the voxel origin ([0,0,0] in voxel space) 
        # is at the location about which the transformation is applied
        if self.center == "image":
            voxel_origin = np.array(image.spatial_shape) / 2
        elif self.center == "origin":
            voxel_origin = (np.linalg.inv(original_affine) @ np.array([[0, 0, 0, 1]]).T).flatten()[:3]
        offset_voxels = np.eye(4)
        offset_voxels[:3,-1] = -voxel_origin
        reset_voxels = np.linalg.inv(offset_voxels)
        # forward transform of voxels
        rot = self.get_rotation_matrix(rotation)
        trans = translation / get_pixdim_from_affine(original_affine) # convert translation to voxel units
        rot_trans_mat = np.hstack([rot, trans.reshape(-1,1)])
        rot_trans_mat = np.vstack([rot_trans_mat, [0, 0, 0, 1]])
        rot_trans_mat = reset_voxels @ rot_trans_mat @ offset_voxels
        scale = np.eye(4)
        scale[:3,:3] = np.diag(scaling)
        scale_mat = reset_voxels @ scale @ offset_voxels
        forward_voxel_transform = scale_mat @ rot_trans_mat
        # inverse transform of voxels
        backward_voxel_transform = np.linalg.inv(forward_voxel_transform)
        new_affine = original_affine @ backward_voxel_transform
        return new_affine   
    
    @staticmethod
    def get_rotation_matrix(deg):
        x_deg, y_deg, z_deg = deg
        x_mat = np.array([[1, 0, 0], [0, np.cos(np.deg2rad(x_deg)), -np.sin(np.deg2rad(x_deg))], [0, np.sin(np.deg2rad(x_deg)), np.cos(np.deg2rad(x_deg))]])
        y_mat = np.array([[np.cos(np.deg2rad(y_deg)), 0, np.sin(np.deg2rad(y_deg))], [0, 1, 0], [-np.sin(np.deg2rad(y_deg)), 0, np.cos(np.deg2rad(y_deg))]])
        z_mat = np.array([[np.cos(np.deg2rad(z_deg)), -np.sin(np.deg2rad(z_deg)), 0], [np.sin(np.deg2rad(z_deg)), np.cos(np.deg2rad(z_deg)), 0], [0, 0, 1]]) 
        return x_mat @ y_mat @ z_mat

Here is some code to visually confirm that it works:

import matplotlib.pyplot as plt

subject_in = tio.datasets.Colin27()

# Play around with these params
transform = MyRandomAffine(scales=(2,2,1,1,1,1), degrees=(0,0,0,0,47,47), translation=(8,8,-30,-30,0,0), update_affine=True)

# Do the transform
subject_out = transform(subject_in)

# Get original and new affines
original_affine = subject_in["t1"].affine
new_affine = subject_out["t1"].affine
print("Original affine")
print(np.round(original_affine,4), "\n")
print("New affine")
print(np.round(new_affine,4))

# Set point of interest in voxel indices of original image
index1 = np.array([[63, 75, 90, 1]]).T
world = original_affine @ index1 # original index to world

# Get index in transformed image
index2 = np.linalg.inv(new_affine) @ world # world to new index


############################## Plot ##############################

fig, ax = plt.subplots(1,2, figsize=(20,10))

### Image in

original_origin_index = np.linalg.inv(original_affine) @ np.array([[0,0,0,1]]).T

# Image data
image1 = subject_in["t1"]["data"][0, :, :, index1[2][0]]
ax[0].imshow(image1,cmap="gray")
# Point of interest
ax[0].scatter(index1[1][0], index1[0][0], s=20, c='red', marker='o')
ax[0].text(index1[1][0] + 15, index1[0][0] - 15, f"{index1[0][0]}, {index1[1][0]}", fontsize=12, color='red')
# Origin
ax[0].scatter(original_origin_index[1][0], original_origin_index[0][0], s=20, c='green', marker='o')
ax[0].text(original_origin_index[1][0] + 15, original_origin_index[0][0] - 15, f"{original_origin_index[0][0]}, {original_origin_index[1][0]}", fontsize=12, color='green')
# Formatting
ax[0].minorticks_on()
ax[0].xaxis.set_minor_locator(plt.MultipleLocator(5))
ax[0].yaxis.set_minor_locator(plt.MultipleLocator(5))
ax[0].grid(which='both', color='blue', linestyle='-', linewidth=0.5)
ax[0].grid(which='minor', color='cyan', linestyle='-', linewidth=0.5, alpha=0.25)

### Image out
new_origin_index = np.linalg.inv(new_affine) @ np.array([[0,0,0,1]]).T

# Image data
image2 = subject_out["t1"]["data"][0, :, :, int(index2[2][0])]
ax[1].imshow(image2 ,cmap="gray")
# Point of interest
ax[1].scatter(index2[1][0], index2[0][0], s=20, c='red', marker='o')
ax[1].text(index2[1][0] + 15, index2[0][0] - 15, f"{index2[0][0]:.2f}, {index2[1][0]:.2f}", fontsize=12, color='red')
# Origin
ax[1].scatter(new_origin_index[1][0], new_origin_index[0][0], s=20, c='green', marker='o')
ax[1].text(new_origin_index[1][0] + 15, new_origin_index[0][0] - 15, f"{new_origin_index[0][0]:.2f}, {new_origin_index[1][0]:.2f}", fontsize=12, color='green')
# Formatting
ax[1].minorticks_on()
ax[1].xaxis.set_minor_locator(plt.MultipleLocator(5))
ax[1].yaxis.set_minor_locator(plt.MultipleLocator(5))
ax[1].grid(which='both', color='blue', linestyle='-', linewidth=0.5)
ax[1].grid(which='minor', color='cyan', linestyle='-', linewidth=0.5, alpha=0.25)

plt.show()

Which outputs this (point of interest in red and origin in green):

Image

Additional context

Let me just say that I absolutely love the TorchIO library and use it in my pipelines wherever possible!

I think another feature missing from RandomAffine transform is the option to input scales as target voxel spacing, instead of scaling factors. This would be very easy to implement, just a couple of lines of code. I didn't think this was worth raising a separate issue, but I thought I would mention it since it is something that I use for my pipeline.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions