diff --git a/docs/Usage/Preprocessing.md b/docs/Usage/Preprocessing.md index 6686716..084d3c8 100644 --- a/docs/Usage/Preprocessing.md +++ b/docs/Usage/Preprocessing.md @@ -51,11 +51,11 @@ tomo_preprocessing --help - **match_pixel_size**: Tomogram rescaling to specified pixel size. Example: -`tomo_preprocessing match_pixel_size --input_tomogram --output_path --pixel_size_out 10.0 --pixel_size_in ` +`tomo_preprocessing match_pixel_size --input-tomogram --output-path --pixel-size-out 10.0 --pixel-size-in ` - **match_seg_to_tomo**: Segmentation rescaling to fit to target tomogram's shape. Example: -`tomo_preprocessing match_seg_to_tomo --seg_path --orig_tomo_path --output_path ` +`tomo_preprocessing match_seg_to_tomo --seg-path --orig-tomo-path --output-path ` - **extract_spectrum**: Extracts the radially averaged amplitude spectrum from the input tomogram. Example: -`tomo_preprocessing extract_spectrum --input_path --output_path ` +`tomo_preprocessing extract_spectrum --input-path --output-path ` - **match_spectrum**: Match amplitude of Fourier spectrum from input tomogram to target spectrum. Example: `tomo_preprocessing match_spectrum --input --target --output ` @@ -64,18 +64,18 @@ tomo_preprocessing --help Pixel size matching is recommended when your tomogram pixel sizes differs strongly from the training pixel size range (roughly 10-14Å). You can perform it using the command ```shell -tomo_preprocessing match_pixel_size --input_tomogram --output_path --pixel_size_out 10.0 --pixel_size_in +tomo_preprocessing match_pixel_size --input-tomogram --output-path --pixel-size-out 10.0 --pixel-size-in ``` after adjusting the paths to your respective tomograms. -Afterwards, you can perform MemBrain's segmentation on the rescaled tomogram (i.e. the one specified in `--output_path`). +Afterwards, you can perform MemBrain's segmentation on the rescaled tomogram (i.e. the one specified in `--output-path`). Now, this new segmentation does not have the same shape as the original non-pixel-size-matched tomogram. To rescale the new segmentation to the original tomogram again, you can use ```shell -tomo_preprocessing match_seg_to_tomo --seg_path --orig_tomo_path --output_path +tomo_preprocessing match_seg_to_tomo --seg_path --orig-tomo-path --output-path ``` -where the `--seg_path`is the segmentation created by MemBrain and the `--orig_tomo_path`is the original tomogram before rescaling to the new pixel size. +where the `--seg-path`is the segmentation created by MemBrain and the `--orig-tomo-path`is the original tomogram before rescaling to the new pixel size. The output of this function will be MemBrain's segmentation, but matched to the pixel size of the original tomogram. @@ -84,7 +84,7 @@ Fourier amplitude matching is performed in two steps: 1. Extraction of the target Fourier spectrum: ```shell -tomo_preprocessing extract_spectrum --input_path --output_path +tomo_preprocessing extract_spectrum --input-path --output-path ``` This extracts the radially averaged Fourier spectrum and stores it into a .tsv file. 2. Matching of the input tomogram to the extracted spectrum: diff --git a/docs/Usage/Segmentation.md b/docs/Usage/Segmentation.md index 32c5f37..993c57a 100644 --- a/docs/Usage/Segmentation.md +++ b/docs/Usage/Segmentation.md @@ -83,6 +83,10 @@ You can also compute the connected components [after you have segmented your tom **--connected-component-thres**: Threshold for connected components. Components smaller than this will be removed from the segmentation. [default: None] +**--test-time-augmentation / --no-test-time-augmentation**: Should 8-fold test time augmentation be used? If activated (default), segmentations tendo be slightly better, but runtime is increased. + +**--segmentation-threshold**: Set a custom threshold for thresholding your membrane scoremap to increase / decrease segmented membranes (default: 0.0). + **--sliding-window-size** INTEGER Sliding window size used for inference. Smaller values than 160 consume less GPU, but also lead to worse segmentation results! [default: 160] **--help** Show this message and exit. @@ -102,10 +106,17 @@ If you have segmented your tomograms already, but would still like to extract th ```shell membrain components --segmentation-path --connected-component-thres 50 --out-folder ``` - ### Note: Computing the connected components, and particularly also removing the small components can be quite compute intensive and take a while. +## Custom thresholding +In some cases, the standard threshold ($0.0$) may not be the ideal value for segmenting your tomograms. In order to explore what threshold may be best, you can use the above segmentation command with the flag `--store-probabilities`. This will store a membrane scoremap that you can threshold using different values using the command: + +``` +membrain thresholds --scoremap-path + --thresholds -1.5 --thresholds -0.5 --thresholds 0.0 --thresholds 0.5 +``` +In this way, you can pass as many thresholds as you would like and the function will output one segmentation for each. ## Post-Processing diff --git a/src/membrain_seg/annotations/extract_patches.py b/src/membrain_seg/annotations/extract_patches.py index deec0e1..d9628ab 100644 --- a/src/membrain_seg/annotations/extract_patches.py +++ b/src/membrain_seg/annotations/extract_patches.py @@ -193,7 +193,7 @@ def extract_patches( min_coords[0] : min_coords[0] + 160, min_coords[1] : min_coords[1] + 160, min_coords[2] : min_coords[2] + 160, - ] + ].copy() cur_patch_labels = pad_labels( cur_patch_labels, padding, pad_value=pad_value ) diff --git a/src/membrain_seg/annotations/merge_corrections.py b/src/membrain_seg/annotations/merge_corrections.py index 8286633..361c64a 100644 --- a/src/membrain_seg/annotations/merge_corrections.py +++ b/src/membrain_seg/annotations/merge_corrections.py @@ -40,8 +40,11 @@ def get_corrections_from_folder(folder_name, orig_pred_file): for filename in os.listdir(folder_name): if not ( filename.startswith("Add") + or filename.startswith("add") or filename.startswith("Remove") + or filename.startswith("remove") or filename.startswith("Ignore") + or filename.startswith("ignore") ): print("ATTENTION! Not processing", filename) print("Is this intended?") @@ -49,11 +52,13 @@ def get_corrections_from_folder(folder_name, orig_pred_file): readdata = sitk.GetArrayFromImage( sitk.ReadImage(os.path.join(folder_name, filename)) ) - if filename.startswith("Add"): + print("Adding file", filename, "<--") + + if filename.startswith("Add") or filename.startswith("add"): add_patch += readdata - if filename.startswith("Remove"): + if filename.startswith("Remove") or filename.startswith("remove"): remove_patch += readdata - if filename.startswith("Ignore"): + if filename.startswith("Ignore") or filename.startswith("ignore"): ignore_patch += readdata correction_count += 1 @@ -112,6 +117,8 @@ def convert_single_nrrd_files(labels_dir, corrections_dir, out_dir): token = label_file[:-7] elif label_file.endswith(".nrrd"): token = label_file[:-5] + elif label_file.endswith(".mrc"): + token = label_file[:-4] found_flag = 0 for filename in os.listdir(corrections_dir): if not os.path.isdir(os.path.join(corrections_dir, filename)): @@ -121,7 +128,7 @@ def convert_single_nrrd_files(labels_dir, corrections_dir, out_dir): merged_corrections = get_corrections_from_folder( cur_patch_corrections_folder, os.path.join(labels_dir, label_file) ) - out_file = os.path.join(out_dir, label_file) + out_file = os.path.join(out_dir, token + ".nii.gz") print("Storing corrections in", out_file) write_nifti(out_file, merged_corrections) found_flag = 1 diff --git a/src/membrain_seg/segmentation/cli/segment_cli.py b/src/membrain_seg/segmentation/cli/segment_cli.py index 1378d11..5125e8e 100644 --- a/src/membrain_seg/segmentation/cli/segment_cli.py +++ b/src/membrain_seg/segmentation/cli/segment_cli.py @@ -1,4 +1,5 @@ import os +from typing import List from typer import Option @@ -37,6 +38,17 @@ def segment( help="Threshold for connected components. Components smaller than this will \ be removed from the segmentation.", ), + test_time_augmentation: bool = Option( # noqa: B008 + True, + help="Use 8-fold test time augmentation (TTA)? TTA improves segmentation \ + quality slightly, but also increases runtime.", + ), + segmentation_threshold: float = Option( # noqa: B008 + 0.0, + help="Threshold for the membrane segmentation. Only voxels with a \ + membrane score higher than this threshold will be segmented. \ + (default: 0.0)", + ), sliding_window_size: int = Option( # noqa: B008 160, help="Sliding window size used for inference. Smaller values than 160 \ @@ -58,6 +70,8 @@ def segment( store_connected_components=store_connected_components, connected_component_thres=connected_component_thres, sw_roi_size=sliding_window_size, + test_time_augmentation=test_time_augmentation, + segmentation_threshold=segmentation_threshold, ) @@ -95,3 +109,53 @@ def components( os.path.splitext(os.path.basename(segmentation_path))[0] + "_components.mrc", ) store_tomogram(filename=out_file, tomogram=segmentation) + + +@cli.command(name="thresholds", no_args_is_help=True) +def thresholds( + scoremap_path: str = Option( # noqa: B008 + help="Path to the membrane scoremap to be processed.", **PKWARGS + ), + thresholds: List[float] = Option( # noqa: B008 + ..., + help="List of thresholds. Provide multiple by repeating the option.", + ), + out_folder: str = Option( # noqa: B008 + "./predictions", + help="Path to the folder where thresholdedsegmentations \ + should be stored.", + ), +): + """Process the provided scoremap using given thresholds. + + Given a membrane scoremap, this function thresholds the scoremap data + using the provided threshold(s). The thresholded scoremaps are then stored + in the specified output folder. If multiple thresholds are provided, + separate thresholded scoremaps will be generated for each threshold. + + Example + ------- + membrain thresholds --scoremap-path + --thresholds -1.5 --thresholds -0.5 --thresholds 0.0 --thresholds 0.5 + + This will generate thresholded scoremaps for the provided scoremap at + thresholds -1.5, -0.5, 0.0 and 0.5.The results will be saved with filenames + indicating the threshold values in the default 'predictions' folder or + in the folder specified by the user. + """ + scoremap = load_tomogram(scoremap_path) + score_data = scoremap.data + if not isinstance(thresholds, list): + thresholds = [thresholds] + for threshold in thresholds: + print("Thresholding at", threshold) + thresholded_data = score_data > threshold + segmentation = scoremap + segmentation.data = thresholded_data + out_file = os.path.join( + out_folder, + os.path.splitext(os.path.basename(scoremap_path))[0] + + f"_threshold_{threshold}.mrc", + ) + store_tomogram(filename=out_file, tomogram=segmentation) + print("Saved thresholded scoremap to", out_file) diff --git a/src/membrain_seg/segmentation/dataloading/data_utils.py b/src/membrain_seg/segmentation/dataloading/data_utils.py index 28109d9..6662708 100644 --- a/src/membrain_seg/segmentation/dataloading/data_utils.py +++ b/src/membrain_seg/segmentation/dataloading/data_utils.py @@ -131,6 +131,7 @@ def store_segmented_tomograms( connected_component_thres: int = None, mrc_header: np.recarray = None, voxel_size: float = None, + segmentation_threshold: float = 0.0, ) -> None: """ Helper function for storing output segmentations. @@ -163,6 +164,8 @@ def store_segmented_tomograms( voxel_size: float, optional If given, this will be the voxel size stored in the header of the output segmentation. + segmentation_threshold : float, optional + Threshold for the segmentation. Default is 0.0. """ # Create out directory if it doesn't exist yet make_directory_if_not_exists(out_folder) @@ -178,7 +181,9 @@ def store_segmented_tomograms( data=predictions_np, header=mrc_header, voxel_size=voxel_size ) store_tomogram(out_file, out_tomo) - predictions_np_thres = predictions.squeeze(0).squeeze(0).cpu().numpy() > 0.0 + predictions_np_thres = ( + predictions.squeeze(0).squeeze(0).cpu().numpy() > segmentation_threshold + ) out_file_thres = os.path.join( out_folder, os.path.basename(orig_data_path)[:-4] + "_" + ckpt_token + "_segmented.mrc", diff --git a/src/membrain_seg/segmentation/segment.py b/src/membrain_seg/segmentation/segment.py index 5d34b88..e44a327 100644 --- a/src/membrain_seg/segmentation/segment.py +++ b/src/membrain_seg/segmentation/segment.py @@ -21,6 +21,7 @@ def segment( store_connected_components=False, connected_component_thres=None, test_time_augmentation=True, + segmentation_threshold=0.0, ): """ Segment tomograms using a trained model. @@ -55,6 +56,9 @@ def segment( test_time_augmentation: bool, optional If True, test-time augmentation is performed, i.e. data is rotated into eight different orientations and predictions are averaged. + segmentation_threshold: float, optional + Threshold for the membrane segmentation. Only voxels with a membrane + score higher than this threshold will be segmented. (default: 0.0) Returns ------- @@ -133,5 +137,6 @@ def segment( connected_component_thres=connected_component_thres, mrc_header=mrc_header, voxel_size=voxel_size, + segmentation_threshold=segmentation_threshold, ) return segmentation_file