Skip to content

Commit

Permalink
On the fly rescaling (GPU) (#64)
Browse files Browse the repository at this point in the history
* Remove local variables display from pretty print

* On the fly rescaling

* Read pixel size from header by default

* Fix voxel size from header loading

* add rescaling option to CLI

* remove print statements

* Move rescaling to torch GPU

* remove hard-coded GPU requirement

* Fix test time augmentation with rescaling SWInferer

* read device from model by default
  • Loading branch information
LorenzLamm authored May 23, 2024
1 parent 1317d16 commit 92ec8ca
Show file tree
Hide file tree
Showing 5 changed files with 279 additions and 15 deletions.
1 change: 1 addition & 0 deletions src/membrain_seg/segmentation/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def list_commands(self, ctx: Context):
add_completion=False,
no_args_is_help=True,
rich_markup_mode="rich",
pretty_exceptions_show_locals=False
)
OPTION_PROMPT_KWARGS = {"prompt": True, "prompt_required": True}
PKWARGS = OPTION_PROMPT_KWARGS
Expand Down
16 changes: 16 additions & 0 deletions src/membrain_seg/segmentation/cli/segment_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,19 @@ def segment(
out_folder: str = Option( # noqa: B008
"./predictions", help="Path to the folder where segmentations should be stored."
),
rescale_patches: bool = Option( # noqa: B008
False, help="Should patches be rescaled on-the-fly during inference?"
),
in_pixel_size: float = Option( # noqa: B008
None,
help="Pixel size of the input tomogram in Angstrom. \
(default: 10 Angstrom)",
),
out_pixel_size: float = Option( # noqa: B008
10.,
help="Pixel size of the output segmentation in Angstrom. \
(default: 10 Angstrom; should normally stay at 10 Angstrom)",
),
store_probabilities: bool = Option( # noqa: B008
False, help="Should probability maps be output in addition to segmentations?"
),
Expand Down Expand Up @@ -66,6 +79,9 @@ def segment(
tomogram_path=tomogram_path,
ckpt_path=ckpt_path,
out_folder=out_folder,
rescale_patches=rescale_patches,
in_pixel_size=in_pixel_size,
out_pixel_size=out_pixel_size,
store_probabilities=store_probabilities,
store_connected_components=store_connected_components,
connected_component_thres=connected_component_thres,
Expand Down
109 changes: 109 additions & 0 deletions src/membrain_seg/segmentation/networks/inference_unet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from typing import Tuple

import torch
import torch.nn.functional as F

from membrain_seg.segmentation.networks.unet import SemanticSegmentationUnet
from membrain_seg.tomo_preprocessing.matching_utils.px_matching_utils import (
fourier_cropping_torch,
fourier_extend_torch,
)


def rescale_tensor(
sample: torch.Tensor, target_size: tuple, mode="trilinear"
) -> torch.Tensor:
"""
Rescales the input tensor by given factors using interpolation.
Parameters
----------
sample : torch.Tensor
The input data as a torch tensor.
target_size : tuple
The target size of the rescaled tensor.
mode : str, optional
The mode of interpolation ('nearest', 'linear', 'bilinear',
'bicubic', or 'trilinear'). Default is 'trilinear'.
Returns
-------
torch.Tensor
The rescaled tensor.
"""
# Add batch and channel dimensions
sample = sample.unsqueeze(0).unsqueeze(0)

# Apply interpolation
rescaled_sample = F.interpolate(
sample, size=target_size, mode=mode, align_corners=False
)

return rescaled_sample.squeeze(0).squeeze(0)


class PreprocessedSemanticSegmentationUnet(SemanticSegmentationUnet):
"""U-Net with rescaling preprocessing.
This class extends the SemanticSegmentationUnet class by adding
preprocessing and postprocessing steps. The preprocessing step
rescales the input to the target shape, and the postprocessing
step rescales the output to the original shape.
All of this is done on the GPU if available.
"""

def __init__(
self,
*args,
rescale_patches: bool = False, # Should patches be rescaled?
target_shape: Tuple[int, int, int] = (160, 160, 160),
**kwargs,
):
super().__init__(*args, **kwargs)
# Store the preprocessing parameters
self.rescale_patches = rescale_patches
self.target_shape = target_shape

def preprocess(self, x):
"""Preprocess the input to the network.
In this case, we rescale the input to the target shape.
"""
rescaled_samples = []
for sample in x:
sample = sample[0] # only use the first channel
if self.rescale_patches:
if sample.shape[0] > self.target_shape[0]:
sample = fourier_cropping_torch(
data=sample, new_shape=self.target_shape, device=self.device
)
elif sample.shape[0] < self.target_shape[0]:
sample = fourier_extend_torch(
data=sample, new_shape=self.target_shape, device=self.device
)
rescaled_samples.append(sample.unsqueeze(0))
rescaled_samples = torch.stack(rescaled_samples, dim=0)
return rescaled_samples

def postprocess(self, x, orig_shape):
"""Postprocess the output of the network.
In this case, we rescale the output to the original shape.
"""
rescaled_samples = []
for sample in x:
sample = sample[0] # only use first channel
if self.rescale_patches:
sample = rescale_tensor(sample, orig_shape, mode="trilinear")
rescaled_samples.append(sample.unsqueeze(0))
rescaled_samples = torch.stack(rescaled_samples, dim=0)
return rescaled_samples

def forward(self, x):
"""Forward pass through the network."""
orig_shape = x.shape[2:]
preprocessed_x = self.preprocess(x)
predicted = super().forward(preprocessed_x)
postprocessed_predicted = self.postprocess(predicted[0], orig_shape)
# Return list to be compatible with deep supervision outputs
return [postprocessed_predicted]
69 changes: 54 additions & 15 deletions src/membrain_seg/segmentation/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
import torch
from monai.inferers import SlidingWindowInferer

from membrain_seg.segmentation.networks.unet import SemanticSegmentationUnet
from membrain_seg.segmentation.networks.inference_unet import (
PreprocessedSemanticSegmentationUnet,
)
from membrain_seg.tomo_preprocessing.matching_utils.px_matching_utils import (
determine_output_shape,
)

from .dataloading.data_utils import (
load_data_for_inference,
Expand All @@ -16,6 +21,9 @@ def segment(
tomogram_path,
ckpt_path,
out_folder,
rescale_patches=False,
in_pixel_size=None,
out_pixel_size=10.0,
store_probabilities=False,
sw_roi_size=160,
store_connected_components=False,
Expand All @@ -40,6 +48,12 @@ def segment(
Path to the trained model checkpoint file.
out_folder : str
Path to the folder where the output segmentations should be stored.
rescale_patches : bool, optional
If True, rescale the patches to the output pixel size (default is False).
in_pixel_size : float, optional
Pixel size of the input tomogram in Angstrom (default is None).
out_pixel_size : float, optional
Pixel size of the output segmentation in Angstrom (default is 10.0).
store_probabilities : bool, optional
If True, store the predicted probabilities along with the segmentations
(default is False).
Expand Down Expand Up @@ -78,10 +92,13 @@ def segment(
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize the model and load trained weights from checkpoint
pl_model = SemanticSegmentationUnet.load_from_checkpoint(
pl_model = PreprocessedSemanticSegmentationUnet.load_from_checkpoint(
model_checkpoint, map_location=device, strict=False
)
pl_model.to(device)
if sw_roi_size % 32 != 0:
raise OSError("Sliding window size must be multiple of 32°!")
pl_model.target_shape = (sw_roi_size, sw_roi_size, sw_roi_size)

# Preprocess the new data
new_data_path = tomogram_path
Expand All @@ -91,12 +108,34 @@ def segment(
)
new_data = new_data.to(torch.float32)

if rescale_patches:
# Rescale patches if necessary
if in_pixel_size is None:
in_pixel_size = voxel_size.x
if in_pixel_size == 0.0:
raise ValueError(
"Input pixel size is 0.0. Please specify the pixel size manually."
)
if in_pixel_size == 1.0:
print(
"WARNING: Input pixel size is 1.0. Looks like a corrupt header.",
"Please specify the pixel size manually.",
)
pl_model.rescale_patches = in_pixel_size != out_pixel_size

# Determine the sliding window size according to the input and output pixel size
sw_roi_size = determine_output_shape(
# switch in and out pixel size to get SW shape
pixel_size_in=out_pixel_size,
pixel_size_out=in_pixel_size,
orig_shape=(sw_roi_size, sw_roi_size, sw_roi_size),
)
sw_roi_size = sw_roi_size[0]

# Put the model into evaluation mode
pl_model.eval()

# Perform sliding window inference on the new data
if sw_roi_size % 32 != 0:
raise OSError("Sliding window size must be multiple of 32°!")
roi_size = (sw_roi_size, sw_roi_size, sw_roi_size)
sw_batch_size = 1
inferer = SlidingWindowInferer(
Expand All @@ -110,20 +149,20 @@ def segment(

# Perform test time augmentation (8-fold mirroring)
predictions = torch.zeros_like(new_data)
print("Performing 8-fold test-time augmentation.")
if test_time_augmentation:
print(
"Performing 8-fold test-time augmentation.",
"I.e. the following bar will run 8 times.",
)
for m in range(8 if test_time_augmentation else 1):
with torch.no_grad():
with torch.cuda.amp.autocast():
predictions += (
get_mirrored_img(
inferer(
get_mirrored_img(new_data.clone(), m).to(device), pl_model
)[0],
m,
)
.detach()
.cpu()
)
mirrored_input = get_mirrored_img(new_data.clone(), m).to(device)
mirrored_pred = inferer(mirrored_input, pl_model)
if not isinstance(mirrored_pred, list):
mirrored_pred = [mirrored_pred]
correct_pred = get_mirrored_img(mirrored_pred[0], m)
predictions += correct_pred.detach().cpu()
if test_time_augmentation:
predictions /= 8.0

Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,109 @@
from typing import Tuple, Union

import numpy as np
import torch
import torch.fft
from scipy.fft import fftn, ifftn
from scipy.ndimage import distance_transform_edt


def fourier_cropping_torch(
data: torch.Tensor, new_shape: tuple, device: torch.device = None
) -> torch.Tensor:
"""
Fourier cropping adapted for PyTorch and GPU, without smoothing functionality.
Parameters
----------
data : torch.Tensor
The input data as a 3D torch tensor on GPU.
new_shape : tuple
The target shape for the cropped data as a tuple (x, y, z).
device : torch.device, optional
The device to use for the computation. If None, the device is set to "cuda" if
available; otherwise, it is set to "cpu".
Returns
-------
torch.Tensor
The resized data as a 3D torch tensor.
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"

data = data.to(device)

# Calculate the FFT of the input data
data_fft = torch.fft.fftn(data)
data_fft = torch.fft.fftshift(data_fft)

# Calculate the cropping indices
original_shape = torch.tensor(data.shape, device=device)
new_shape = torch.tensor(new_shape, device=device)
start_indices = (original_shape - new_shape) // 2
end_indices = start_indices + new_shape

# Crop the filtered FFT data
cropped_fft = data_fft[
start_indices[0] : end_indices[0],
start_indices[1] : end_indices[1],
start_indices[2] : end_indices[2],
]

unshifted_cropped_fft = torch.fft.ifftshift(cropped_fft)

# Calculate the inverse FFT of the cropped data
resized_data = torch.real(torch.fft.ifftn(unshifted_cropped_fft))

return resized_data


def fourier_extend_torch(
data: torch.Tensor, new_shape: tuple, device: torch.device = None
) -> torch.Tensor:
"""
Fourier padding adapted for PyTorch and GPU, without smoothing functionality.
Parameters
----------
data : torch.Tensor
The input data as a 3D torch tensor on GPU.
new_shape : tuple
The target shape for the extended data as a tuple (x, y, z).
device : torch.device, optional
The device to use for the computation. If None, the device is set to "cuda" if
available; otherwise, it is set to "cpu".
Returns
-------
torch.Tensor
The resized data as a 3D torch tensor.
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"

data = data.to(device)

data_fft = torch.fft.fftn(data)
data_fft = torch.fft.fftshift(data_fft)

padding = [
(new_dim - old_dim) // 2 for old_dim, new_dim in zip(data.shape, new_shape)
]
padded_fft = torch.nn.functional.pad(
data_fft,
pad=[pad for pair in zip(padding, padding) for pad in pair],
mode="constant",
)

unshifted_padded_fft = torch.fft.ifftshift(padded_fft)

# Calculate the inverse FFT of the cropped data
resized_data = torch.real(torch.fft.ifftn(unshifted_padded_fft))

return resized_data


def smooth_cosine_dropoff(mask: np.ndarray, decay_width: float) -> np.ndarray:
"""
Apply a smooth cosine drop-off to a given mask.
Expand Down

0 comments on commit 92ec8ca

Please sign in to comment.