Skip to content

Commit

Permalink
Misc features (#36)
Browse files Browse the repository at this point in the history
* Fix patch merge bugs and header mode issues

* change faulty float32 conversion to float16 conversion

* change order of header and data insertion

* Change voxel size argument from segmentation storing

* Fix issues with patch extraction and add CLI functionalities for segmentation threshold and TTA

* add command for testing multiple thresholds on existing scoremap

* Fix typos in documentation and update with new functionalities

* change order of mandatory arguments

* fix tta typo

* Fix bug from thresholding2
  • Loading branch information
LorenzLamm authored Oct 5, 2023
1 parent 15bc808 commit c132a79
Show file tree
Hide file tree
Showing 7 changed files with 107 additions and 15 deletions.
16 changes: 8 additions & 8 deletions docs/Usage/Preprocessing.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ tomo_preprocessing <command> --help


- **match_pixel_size**: Tomogram rescaling to specified pixel size. Example:
`tomo_preprocessing match_pixel_size --input_tomogram <path-to-tomo> --output_path <path-to-output> --pixel_size_out 10.0 --pixel_size_in <your-px-size>`
`tomo_preprocessing match_pixel_size --input-tomogram <path-to-tomo> --output-path <path-to-output> --pixel-size-out 10.0 --pixel-size-in <your-px-size>`
- **match_seg_to_tomo**: Segmentation rescaling to fit to target tomogram's shape. Example:
`tomo_preprocessing match_seg_to_tomo --seg_path <path-to-seg> --orig_tomo_path <path-to-tomo> --output_path <path-to-output>`
`tomo_preprocessing match_seg_to_tomo --seg-path <path-to-seg> --orig-tomo-path <path-to-tomo> --output-path <path-to-output>`
- **extract_spectrum**: Extracts the radially averaged amplitude spectrum from the input tomogram. Example:
`tomo_preprocessing extract_spectrum --input_path <path-to-tomo> --output_path <path-to-output>`
`tomo_preprocessing extract_spectrum --input-path <path-to-tomo> --output-path <path-to-output>`
- **match_spectrum**: Match amplitude of Fourier spectrum from input tomogram to target spectrum. Example:
`tomo_preprocessing match_spectrum --input <path-to-tomo> --target <path-to-spectrum> --output <path-to-output>`

Expand All @@ -64,18 +64,18 @@ tomo_preprocessing <command> --help
Pixel size matching is recommended when your tomogram pixel sizes differs strongly from the training pixel size range (roughly 10-14&Aring;). You can perform it using the command

```shell
tomo_preprocessing match_pixel_size --input_tomogram <path-to-tomo> --output_path <path-to-output> --pixel_size_out 10.0 --pixel_size_in <your-px-size>
tomo_preprocessing match_pixel_size --input-tomogram <path-to-tomo> --output-path <path-to-output> --pixel-size-out 10.0 --pixel-size-in <your-px-size>
```

after adjusting the paths to your respective tomograms.
Afterwards, you can perform MemBrain's segmentation on the rescaled tomogram (i.e. the one specified in `--output_path`).
Afterwards, you can perform MemBrain's segmentation on the rescaled tomogram (i.e. the one specified in `--output-path`).
Now, this new segmentation does not have the same shape as the original non-pixel-size-matched tomogram. To rescale the new segmentation to the original tomogram again, you can use

```shell
tomo_preprocessing match_seg_to_tomo --seg_path <path-to-seg> --orig_tomo_path <path-to-tomo> --output_path <path-to-output>
tomo_preprocessing match_seg_to_tomo --seg_path <path-to-seg> --orig-tomo-path <path-to-tomo> --output-path <path-to-output>
```

where the `--seg_path`is the segmentation created by MemBrain and the `--orig_tomo_path`is the original tomogram before rescaling to the new pixel size.
where the `--seg-path`is the segmentation created by MemBrain and the `--orig-tomo-path`is the original tomogram before rescaling to the new pixel size.
The output of this function will be MemBrain's segmentation, but matched to the pixel size of the original tomogram.


Expand All @@ -84,7 +84,7 @@ Fourier amplitude matching is performed in two steps:

1. Extraction of the target Fourier spectrum:
```shell
tomo_preprocessing extract_spectrum --input_path <path-to-tomo> --output_path <path-to-output>
tomo_preprocessing extract_spectrum --input-path <path-to-tomo> --output-path <path-to-output>
```
This extracts the radially averaged Fourier spectrum and stores it into a .tsv file.
2. Matching of the input tomogram to the extracted spectrum:
Expand Down
13 changes: 12 additions & 1 deletion docs/Usage/Segmentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ You can also compute the connected components [after you have segmented your tom

**--connected-component-thres**: Threshold for connected components. Components smaller than this will be removed from the segmentation. [default: None]

**--test-time-augmentation / --no-test-time-augmentation**: Should 8-fold test time augmentation be used? If activated (default), segmentations tendo be slightly better, but runtime is increased.

**--segmentation-threshold**: Set a custom threshold for thresholding your membrane scoremap to increase / decrease segmented membranes (default: 0.0).

**--sliding-window-size** INTEGER Sliding window size used for inference. Smaller values than 160 consume less GPU, but also lead to worse segmentation results! [default: 160]

**--help** Show this message and exit.
Expand All @@ -102,10 +106,17 @@ If you have segmented your tomograms already, but would still like to extract th
```shell
membrain components --segmentation-path <path-to-your-segmentation> --connected-component-thres 50 --out-folder <folder-to-store-components>
```

### Note:
Computing the connected components, and particularly also removing the small components can be quite compute intensive and take a while.

## Custom thresholding
In some cases, the standard threshold ($0.0$) may not be the ideal value for segmenting your tomograms. In order to explore what threshold may be best, you can use the above segmentation command with the flag `--store-probabilities`. This will store a membrane scoremap that you can threshold using different values using the command:

```
membrain thresholds --scoremap-path <path-to-scoremap>
--thresholds -1.5 --thresholds -0.5 --thresholds 0.0 --thresholds 0.5
```
In this way, you can pass as many thresholds as you would like and the function will output one segmentation for each.


## Post-Processing
Expand Down
2 changes: 1 addition & 1 deletion src/membrain_seg/annotations/extract_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def extract_patches(
min_coords[0] : min_coords[0] + 160,
min_coords[1] : min_coords[1] + 160,
min_coords[2] : min_coords[2] + 160,
]
].copy()
cur_patch_labels = pad_labels(
cur_patch_labels, padding, pad_value=pad_value
)
Expand Down
15 changes: 11 additions & 4 deletions src/membrain_seg/annotations/merge_corrections.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,25 @@ def get_corrections_from_folder(folder_name, orig_pred_file):
for filename in os.listdir(folder_name):
if not (
filename.startswith("Add")
or filename.startswith("add")
or filename.startswith("Remove")
or filename.startswith("remove")
or filename.startswith("Ignore")
or filename.startswith("ignore")
):
print("ATTENTION! Not processing", filename)
print("Is this intended?")
continue
readdata = sitk.GetArrayFromImage(
sitk.ReadImage(os.path.join(folder_name, filename))
)
if filename.startswith("Add"):
print("Adding file", filename, "<--")

if filename.startswith("Add") or filename.startswith("add"):
add_patch += readdata
if filename.startswith("Remove"):
if filename.startswith("Remove") or filename.startswith("remove"):
remove_patch += readdata
if filename.startswith("Ignore"):
if filename.startswith("Ignore") or filename.startswith("ignore"):
ignore_patch += readdata
correction_count += 1

Expand Down Expand Up @@ -112,6 +117,8 @@ def convert_single_nrrd_files(labels_dir, corrections_dir, out_dir):
token = label_file[:-7]
elif label_file.endswith(".nrrd"):
token = label_file[:-5]
elif label_file.endswith(".mrc"):
token = label_file[:-4]
found_flag = 0
for filename in os.listdir(corrections_dir):
if not os.path.isdir(os.path.join(corrections_dir, filename)):
Expand All @@ -121,7 +128,7 @@ def convert_single_nrrd_files(labels_dir, corrections_dir, out_dir):
merged_corrections = get_corrections_from_folder(
cur_patch_corrections_folder, os.path.join(labels_dir, label_file)
)
out_file = os.path.join(out_dir, label_file)
out_file = os.path.join(out_dir, token + ".nii.gz")
print("Storing corrections in", out_file)
write_nifti(out_file, merged_corrections)
found_flag = 1
Expand Down
64 changes: 64 additions & 0 deletions src/membrain_seg/segmentation/cli/segment_cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import List

from typer import Option

Expand Down Expand Up @@ -37,6 +38,17 @@ def segment(
help="Threshold for connected components. Components smaller than this will \
be removed from the segmentation.",
),
test_time_augmentation: bool = Option( # noqa: B008
True,
help="Use 8-fold test time augmentation (TTA)? TTA improves segmentation \
quality slightly, but also increases runtime.",
),
segmentation_threshold: float = Option( # noqa: B008
0.0,
help="Threshold for the membrane segmentation. Only voxels with a \
membrane score higher than this threshold will be segmented. \
(default: 0.0)",
),
sliding_window_size: int = Option( # noqa: B008
160,
help="Sliding window size used for inference. Smaller values than 160 \
Expand All @@ -58,6 +70,8 @@ def segment(
store_connected_components=store_connected_components,
connected_component_thres=connected_component_thres,
sw_roi_size=sliding_window_size,
test_time_augmentation=test_time_augmentation,
segmentation_threshold=segmentation_threshold,
)


Expand Down Expand Up @@ -95,3 +109,53 @@ def components(
os.path.splitext(os.path.basename(segmentation_path))[0] + "_components.mrc",
)
store_tomogram(filename=out_file, tomogram=segmentation)


@cli.command(name="thresholds", no_args_is_help=True)
def thresholds(
scoremap_path: str = Option( # noqa: B008
help="Path to the membrane scoremap to be processed.", **PKWARGS
),
thresholds: List[float] = Option( # noqa: B008
...,
help="List of thresholds. Provide multiple by repeating the option.",
),
out_folder: str = Option( # noqa: B008
"./predictions",
help="Path to the folder where thresholdedsegmentations \
should be stored.",
),
):
"""Process the provided scoremap using given thresholds.
Given a membrane scoremap, this function thresholds the scoremap data
using the provided threshold(s). The thresholded scoremaps are then stored
in the specified output folder. If multiple thresholds are provided,
separate thresholded scoremaps will be generated for each threshold.
Example
-------
membrain thresholds --scoremap-path <path-to-scoremap>
--thresholds -1.5 --thresholds -0.5 --thresholds 0.0 --thresholds 0.5
This will generate thresholded scoremaps for the provided scoremap at
thresholds -1.5, -0.5, 0.0 and 0.5.The results will be saved with filenames
indicating the threshold values in the default 'predictions' folder or
in the folder specified by the user.
"""
scoremap = load_tomogram(scoremap_path)
score_data = scoremap.data
if not isinstance(thresholds, list):
thresholds = [thresholds]
for threshold in thresholds:
print("Thresholding at", threshold)
thresholded_data = score_data > threshold
segmentation = scoremap
segmentation.data = thresholded_data
out_file = os.path.join(
out_folder,
os.path.splitext(os.path.basename(scoremap_path))[0]
+ f"_threshold_{threshold}.mrc",
)
store_tomogram(filename=out_file, tomogram=segmentation)
print("Saved thresholded scoremap to", out_file)
7 changes: 6 additions & 1 deletion src/membrain_seg/segmentation/dataloading/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def store_segmented_tomograms(
connected_component_thres: int = None,
mrc_header: np.recarray = None,
voxel_size: float = None,
segmentation_threshold: float = 0.0,
) -> None:
"""
Helper function for storing output segmentations.
Expand Down Expand Up @@ -163,6 +164,8 @@ def store_segmented_tomograms(
voxel_size: float, optional
If given, this will be the voxel size stored in the header of the
output segmentation.
segmentation_threshold : float, optional
Threshold for the segmentation. Default is 0.0.
"""
# Create out directory if it doesn't exist yet
make_directory_if_not_exists(out_folder)
Expand All @@ -178,7 +181,9 @@ def store_segmented_tomograms(
data=predictions_np, header=mrc_header, voxel_size=voxel_size
)
store_tomogram(out_file, out_tomo)
predictions_np_thres = predictions.squeeze(0).squeeze(0).cpu().numpy() > 0.0
predictions_np_thres = (
predictions.squeeze(0).squeeze(0).cpu().numpy() > segmentation_threshold
)
out_file_thres = os.path.join(
out_folder,
os.path.basename(orig_data_path)[:-4] + "_" + ckpt_token + "_segmented.mrc",
Expand Down
5 changes: 5 additions & 0 deletions src/membrain_seg/segmentation/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def segment(
store_connected_components=False,
connected_component_thres=None,
test_time_augmentation=True,
segmentation_threshold=0.0,
):
"""
Segment tomograms using a trained model.
Expand Down Expand Up @@ -55,6 +56,9 @@ def segment(
test_time_augmentation: bool, optional
If True, test-time augmentation is performed, i.e. data is rotated
into eight different orientations and predictions are averaged.
segmentation_threshold: float, optional
Threshold for the membrane segmentation. Only voxels with a membrane
score higher than this threshold will be segmented. (default: 0.0)
Returns
-------
Expand Down Expand Up @@ -133,5 +137,6 @@ def segment(
connected_component_thres=connected_component_thres,
mrc_header=mrc_header,
voxel_size=voxel_size,
segmentation_threshold=segmentation_threshold,
)
return segmentation_file

0 comments on commit c132a79

Please sign in to comment.