-
Notifications
You must be signed in to change notification settings - Fork 250
Description
🚀 Feature
When applying the RandomAffine
transformation to a Subject
, the affine matrix should be updated.
Motivation
While the affine matrix is often not used in the training of a model, there are certain cases in which it might be. For example, if we want to train a model which is conditional on the voxel spacing, then we must track the voxel spacing through any data augmentation which we apply (of course, non-linear transforms do not allow this, but affine transforms do). The simplest and most versatile way to achieve this is to update the affine matrix after each spatial augmentation transformation (like is already done in the Resample
preprocessing transform).
Pitch
Improve the RandomAffine
class so as to update the affine matrix of the subject under any randomly sampled combination of scaling, rotation, and translation.
Alternatives
Here is my attempt. As far as I can tell, it does the job. I have named my classes, which inherit from RandomAffine
and Affine
as MyRandomAffine
and MyAffine
.
The important changes are in the MyAffine
class, which is called in the apply_transform
method of the MyRandomAffine
transform. Specifically, the new method get_new_affine_matrix
, called at the end of the apply_transform
method of MyAffine
. In short, the get_new_affine_matrix
works like so:
- Get the randomly sampled scaling, rotation and translation parameters (which were used in the SimpleITK transformations already by this point)
- The SimpleITK transformations map the original voxel indices to new voxel indices. We need to find the point about which this mapping occurs, and set it to be the origin in voxels. We create the
offset_voxels
matrix for this. - We then calculate the forward transformation that was done by the SimpleITK transforms. The first transformation applied is the rotation and translation, then the scaling happens afterwards. Notice how we pre and post multiply the transformation matrices by the
resest_voxels
andoffset_voxels
matrices respectively (strictly I think this may not be necessary for the rotation and translation transform), because we want to perform these transformations about the same point that SimpleITK did, but we then want to recover our original coordinates, since they are what the original affine matrix was based upon. - We get the backward transformation as the inverse of the forward transformation and pre multiply it by the original affine. This now gives us a mapping from the output voxel indices to world space (i.e. this is our new affine matrix).
Here is the code:
import numpy as np
import torchio as tio
import torch
from torchio import Subject
from torchio.constants import TYPE, INTENSITY
from torchio.transforms.augmentation.spatial.random_affine import RandomAffine, Affine, get_borders_mean
from torchio.data.io import nib_to_sitk
from numbers import Number
def get_pixdim_from_affine(affine: np.array):
rot = affine[:-1,:-1]
return np.sqrt(np.sum(rot**2, axis=0))
class MyRandomAffine(RandomAffine):
def __init__(self, update_affine: bool = False, **kwargs):
super().__init__(**kwargs)
self.update_affine = update_affine
def apply_transform(self, subject: Subject) -> Subject:
scaling_params, rotation_params, translation_params = self.get_params(
self.scales,
self.degrees,
self.translation,
self.isotropic,
)
arguments = {
'scales': scaling_params.tolist(),
'degrees': rotation_params.tolist(),
'translation': translation_params.tolist(),
'center': self.center,
'default_pad_value': self.default_pad_value,
'image_interpolation': self.image_interpolation,
'label_interpolation': self.label_interpolation,
'check_shape': self.check_shape,
}
transform = MyAffine(update_affine=self.update_affine, **self.add_include_exclude(arguments))
transformed = transform(subject)
assert isinstance(transformed, Subject)
return transformed
class MyAffine(Affine):
def __init__(self, update_affine: bool = False, **kwargs):
super().__init__(**kwargs)
self.update_affine = update_affine
def apply_transform(self, subject: Subject) -> Subject:
if self.check_shape:
subject.check_consistent_spatial_shape()
default_value: float
for image in self.get_images(subject):
transform = self.get_affine_transform(image)
transformed_tensors = []
for tensor in image.data:
sitk_image = nib_to_sitk(
tensor[np.newaxis],
image.affine,
force_3d=True,
)
if image[TYPE] != INTENSITY:
interpolation = self.label_interpolation
default_value = 0
else:
interpolation = self.image_interpolation
if self.default_pad_value == 'minimum':
default_value = tensor.min().item()
elif self.default_pad_value == 'mean':
default_value = get_borders_mean(
sitk_image,
filter_otsu=False,
)
elif self.default_pad_value == 'otsu':
default_value = get_borders_mean(
sitk_image,
filter_otsu=True,
)
else:
assert isinstance(self.default_pad_value, Number)
default_value = float(self.default_pad_value)
transformed_tensor = self.apply_affine_transform(
sitk_image,
transform,
interpolation,
default_value,
)
transformed_tensors.append(transformed_tensor)
image.set_data(torch.stack(transformed_tensors))
if self.update_affine:
new_affine = self.get_new_affine_matrix(image)
image.affine = new_affine
return subject
def get_new_affine_matrix(self, image: tio.Image) -> np.ndarray:
# get the scaling, rotation and translation parameters
scaling = np.asarray(self.scales).copy()
rotation = np.asarray(self.degrees).copy()
translation = np.asarray(self.translation).copy()
# get the original affine matrix
original_affine = image.affine
# get matrix to offset voxel indices so that the voxel origin ([0,0,0] in voxel space)
# is at the location about which the transformation is applied
if self.center == "image":
voxel_origin = np.array(image.spatial_shape) / 2
elif self.center == "origin":
voxel_origin = (np.linalg.inv(original_affine) @ np.array([[0, 0, 0, 1]]).T).flatten()[:3]
offset_voxels = np.eye(4)
offset_voxels[:3,-1] = -voxel_origin
reset_voxels = np.linalg.inv(offset_voxels)
# forward transform of voxels
rot = self.get_rotation_matrix(rotation)
trans = translation / get_pixdim_from_affine(original_affine) # convert translation to voxel units
rot_trans_mat = np.hstack([rot, trans.reshape(-1,1)])
rot_trans_mat = np.vstack([rot_trans_mat, [0, 0, 0, 1]])
rot_trans_mat = reset_voxels @ rot_trans_mat @ offset_voxels
scale = np.eye(4)
scale[:3,:3] = np.diag(scaling)
scale_mat = reset_voxels @ scale @ offset_voxels
forward_voxel_transform = scale_mat @ rot_trans_mat
# inverse transform of voxels
backward_voxel_transform = np.linalg.inv(forward_voxel_transform)
new_affine = original_affine @ backward_voxel_transform
return new_affine
@staticmethod
def get_rotation_matrix(deg):
x_deg, y_deg, z_deg = deg
x_mat = np.array([[1, 0, 0], [0, np.cos(np.deg2rad(x_deg)), -np.sin(np.deg2rad(x_deg))], [0, np.sin(np.deg2rad(x_deg)), np.cos(np.deg2rad(x_deg))]])
y_mat = np.array([[np.cos(np.deg2rad(y_deg)), 0, np.sin(np.deg2rad(y_deg))], [0, 1, 0], [-np.sin(np.deg2rad(y_deg)), 0, np.cos(np.deg2rad(y_deg))]])
z_mat = np.array([[np.cos(np.deg2rad(z_deg)), -np.sin(np.deg2rad(z_deg)), 0], [np.sin(np.deg2rad(z_deg)), np.cos(np.deg2rad(z_deg)), 0], [0, 0, 1]])
return x_mat @ y_mat @ z_mat
Here is some code to visually confirm that it works:
import matplotlib.pyplot as plt
subject_in = tio.datasets.Colin27()
# Play around with these params
transform = MyRandomAffine(scales=(2,2,1,1,1,1), degrees=(0,0,0,0,47,47), translation=(8,8,-30,-30,0,0), update_affine=True)
# Do the transform
subject_out = transform(subject_in)
# Get original and new affines
original_affine = subject_in["t1"].affine
new_affine = subject_out["t1"].affine
print("Original affine")
print(np.round(original_affine,4), "\n")
print("New affine")
print(np.round(new_affine,4))
# Set point of interest in voxel indices of original image
index1 = np.array([[63, 75, 90, 1]]).T
world = original_affine @ index1 # original index to world
# Get index in transformed image
index2 = np.linalg.inv(new_affine) @ world # world to new index
############################## Plot ##############################
fig, ax = plt.subplots(1,2, figsize=(20,10))
### Image in
original_origin_index = np.linalg.inv(original_affine) @ np.array([[0,0,0,1]]).T
# Image data
image1 = subject_in["t1"]["data"][0, :, :, index1[2][0]]
ax[0].imshow(image1,cmap="gray")
# Point of interest
ax[0].scatter(index1[1][0], index1[0][0], s=20, c='red', marker='o')
ax[0].text(index1[1][0] + 15, index1[0][0] - 15, f"{index1[0][0]}, {index1[1][0]}", fontsize=12, color='red')
# Origin
ax[0].scatter(original_origin_index[1][0], original_origin_index[0][0], s=20, c='green', marker='o')
ax[0].text(original_origin_index[1][0] + 15, original_origin_index[0][0] - 15, f"{original_origin_index[0][0]}, {original_origin_index[1][0]}", fontsize=12, color='green')
# Formatting
ax[0].minorticks_on()
ax[0].xaxis.set_minor_locator(plt.MultipleLocator(5))
ax[0].yaxis.set_minor_locator(plt.MultipleLocator(5))
ax[0].grid(which='both', color='blue', linestyle='-', linewidth=0.5)
ax[0].grid(which='minor', color='cyan', linestyle='-', linewidth=0.5, alpha=0.25)
### Image out
new_origin_index = np.linalg.inv(new_affine) @ np.array([[0,0,0,1]]).T
# Image data
image2 = subject_out["t1"]["data"][0, :, :, int(index2[2][0])]
ax[1].imshow(image2 ,cmap="gray")
# Point of interest
ax[1].scatter(index2[1][0], index2[0][0], s=20, c='red', marker='o')
ax[1].text(index2[1][0] + 15, index2[0][0] - 15, f"{index2[0][0]:.2f}, {index2[1][0]:.2f}", fontsize=12, color='red')
# Origin
ax[1].scatter(new_origin_index[1][0], new_origin_index[0][0], s=20, c='green', marker='o')
ax[1].text(new_origin_index[1][0] + 15, new_origin_index[0][0] - 15, f"{new_origin_index[0][0]:.2f}, {new_origin_index[1][0]:.2f}", fontsize=12, color='green')
# Formatting
ax[1].minorticks_on()
ax[1].xaxis.set_minor_locator(plt.MultipleLocator(5))
ax[1].yaxis.set_minor_locator(plt.MultipleLocator(5))
ax[1].grid(which='both', color='blue', linestyle='-', linewidth=0.5)
ax[1].grid(which='minor', color='cyan', linestyle='-', linewidth=0.5, alpha=0.25)
plt.show()
Which outputs this (point of interest in red and origin in green):
Additional context
Let me just say that I absolutely love the TorchIO library and use it in my pipelines wherever possible!
I think another feature missing from RandomAffine
transform is the option to input scales as target voxel spacing, instead of scaling factors. This would be very easy to implement, just a couple of lines of code. I didn't think this was worth raising a separate issue, but I thought I would mention it since it is something that I use for my pipeline.