Skip to content

Commit

Permalink
Add sliding window size option for segmentation
Browse files Browse the repository at this point in the history
  • Loading branch information
LorenzLamm committed Jul 4, 2023
1 parent 1969338 commit b7df4ea
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 7 deletions.
6 changes: 6 additions & 0 deletions src/membrain_seg/cli/segment_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ def segment(
store_probabilities: bool = Option( # noqa: B008
False, help="Should probability maps be output in addition to segmentations?"
),
sliding_window_size: int = Option( # noqa: B008
160,
help="Sliding window size used for inference. Smaller values than 160 \
consume less GPU, but also lead to worse segmentation results!",
),
):
"""Segment tomograms using a trained model.
Expand All @@ -33,4 +38,5 @@ def segment(
ckpt_path=ckpt_path,
out_folder=out_folder,
store_probabilities=store_probabilities,
sw_roi_size=sliding_window_size,
)
36 changes: 29 additions & 7 deletions src/membrain_seg/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from .dataloading.memseg_augmentation import get_mirrored_img, get_prediction_transforms


def segment(tomogram_path, ckpt_path, out_folder, store_probabilities=False):
def segment(
tomogram_path, ckpt_path, out_folder, store_probabilities=False, sw_roi_size=160
):
"""
Segment tomograms using a trained model.
Expand All @@ -33,6 +35,10 @@ def segment(tomogram_path, ckpt_path, out_folder, store_probabilities=False):
store_probabilities : bool, optional
If True, store the predicted probabilities along with the segmentations
(default is False).
sw_roi_size: int, optional
Sliding window size used for inference. Smaller values than 160 consume less
GPU, but also lead to worse segmentation results!
Must be a multiple of 32.
Returns
-------
Expand All @@ -58,26 +64,42 @@ def segment(tomogram_path, ckpt_path, out_folder, store_probabilities=False):
# Preprocess the new data
new_data_path = tomogram_path
transforms = get_prediction_transforms()
new_data = load_data_for_inference(new_data_path, transforms, device)
new_data = load_data_for_inference(
new_data_path, transforms, device=torch.device("cpu")
)
new_data = new_data.to(torch.float32)

# Put the model into evaluation mode
pl_model.eval()

# Perform sliding window inference on the new data
roi_size = (160, 160, 160)
sw_batch_size = 2
if sw_roi_size % 32 != 0:
raise OSError("Sliding window size must be multiple of 32°!")
roi_size = (sw_roi_size, sw_roi_size, sw_roi_size)
sw_batch_size = 1
inferer = SlidingWindowInferer(
roi_size, sw_batch_size, overlap=0.5, progress=True, mode="gaussian"
roi_size,
sw_batch_size,
overlap=0.5,
progress=True,
mode="gaussian",
device=torch.device("cpu"),
)

# Perform test time augmentation (8-fold mirroring)
predictions = torch.zeros_like(new_data)
print("Performing 8-fold test-time augmentation.")
for m in range(8):
with torch.no_grad():
predictions += get_mirrored_img(
inferer(get_mirrored_img(new_data.clone(), m), pl_model)[0], m
predictions += (
get_mirrored_img(
inferer(get_mirrored_img(new_data.clone(), m).to(device), pl_model)[
0
],
m,
)
.detach()
.cpu()
)
predictions /= 8.0

Expand Down

0 comments on commit b7df4ea

Please sign in to comment.