From b7df4ea071a207fb41ad11a85f5a6becb2e15d42 Mon Sep 17 00:00:00 2001 From: LorenzLamm Date: Wed, 5 Jul 2023 00:05:22 +0200 Subject: [PATCH] Add sliding window size option for segmentation --- src/membrain_seg/cli/segment_cli.py | 6 +++++ src/membrain_seg/segment.py | 36 +++++++++++++++++++++++------ 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/src/membrain_seg/cli/segment_cli.py b/src/membrain_seg/cli/segment_cli.py index bd39a85..57ac4a7 100644 --- a/src/membrain_seg/cli/segment_cli.py +++ b/src/membrain_seg/cli/segment_cli.py @@ -19,6 +19,11 @@ def segment( store_probabilities: bool = Option( # noqa: B008 False, help="Should probability maps be output in addition to segmentations?" ), + sliding_window_size: int = Option( # noqa: B008 + 160, + help="Sliding window size used for inference. Smaller values than 160 \ + consume less GPU, but also lead to worse segmentation results!", + ), ): """Segment tomograms using a trained model. @@ -33,4 +38,5 @@ def segment( ckpt_path=ckpt_path, out_folder=out_folder, store_probabilities=store_probabilities, + sw_roi_size=sliding_window_size, ) diff --git a/src/membrain_seg/segment.py b/src/membrain_seg/segment.py index cbe02f0..3fef8c1 100644 --- a/src/membrain_seg/segment.py +++ b/src/membrain_seg/segment.py @@ -12,7 +12,9 @@ from .dataloading.memseg_augmentation import get_mirrored_img, get_prediction_transforms -def segment(tomogram_path, ckpt_path, out_folder, store_probabilities=False): +def segment( + tomogram_path, ckpt_path, out_folder, store_probabilities=False, sw_roi_size=160 +): """ Segment tomograms using a trained model. @@ -33,6 +35,10 @@ def segment(tomogram_path, ckpt_path, out_folder, store_probabilities=False): store_probabilities : bool, optional If True, store the predicted probabilities along with the segmentations (default is False). + sw_roi_size: int, optional + Sliding window size used for inference. Smaller values than 160 consume less + GPU, but also lead to worse segmentation results! + Must be a multiple of 32. Returns ------- @@ -58,17 +64,26 @@ def segment(tomogram_path, ckpt_path, out_folder, store_probabilities=False): # Preprocess the new data new_data_path = tomogram_path transforms = get_prediction_transforms() - new_data = load_data_for_inference(new_data_path, transforms, device) + new_data = load_data_for_inference( + new_data_path, transforms, device=torch.device("cpu") + ) new_data = new_data.to(torch.float32) # Put the model into evaluation mode pl_model.eval() # Perform sliding window inference on the new data - roi_size = (160, 160, 160) - sw_batch_size = 2 + if sw_roi_size % 32 != 0: + raise OSError("Sliding window size must be multiple of 32°!") + roi_size = (sw_roi_size, sw_roi_size, sw_roi_size) + sw_batch_size = 1 inferer = SlidingWindowInferer( - roi_size, sw_batch_size, overlap=0.5, progress=True, mode="gaussian" + roi_size, + sw_batch_size, + overlap=0.5, + progress=True, + mode="gaussian", + device=torch.device("cpu"), ) # Perform test time augmentation (8-fold mirroring) @@ -76,8 +91,15 @@ def segment(tomogram_path, ckpt_path, out_folder, store_probabilities=False): print("Performing 8-fold test-time augmentation.") for m in range(8): with torch.no_grad(): - predictions += get_mirrored_img( - inferer(get_mirrored_img(new_data.clone(), m), pl_model)[0], m + predictions += ( + get_mirrored_img( + inferer(get_mirrored_img(new_data.clone(), m).to(device), pl_model)[ + 0 + ], + m, + ) + .detach() + .cpu() ) predictions /= 8.0