Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Misc features #36

Merged
merged 11 commits into from
Oct 5, 2023
16 changes: 8 additions & 8 deletions docs/Usage/Preprocessing.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ tomo_preprocessing <command> --help


- **match_pixel_size**: Tomogram rescaling to specified pixel size. Example:
`tomo_preprocessing match_pixel_size --input_tomogram <path-to-tomo> --output_path <path-to-output> --pixel_size_out 10.0 --pixel_size_in <your-px-size>`
`tomo_preprocessing match_pixel_size --input-tomogram <path-to-tomo> --output-path <path-to-output> --pixel-size-out 10.0 --pixel-size-in <your-px-size>`
- **match_seg_to_tomo**: Segmentation rescaling to fit to target tomogram's shape. Example:
`tomo_preprocessing match_seg_to_tomo --seg_path <path-to-seg> --orig_tomo_path <path-to-tomo> --output_path <path-to-output>`
`tomo_preprocessing match_seg_to_tomo --seg-path <path-to-seg> --orig-tomo-path <path-to-tomo> --output-path <path-to-output>`
- **extract_spectrum**: Extracts the radially averaged amplitude spectrum from the input tomogram. Example:
`tomo_preprocessing extract_spectrum --input_path <path-to-tomo> --output_path <path-to-output>`
`tomo_preprocessing extract_spectrum --input-path <path-to-tomo> --output-path <path-to-output>`
- **match_spectrum**: Match amplitude of Fourier spectrum from input tomogram to target spectrum. Example:
`tomo_preprocessing match_spectrum --input <path-to-tomo> --target <path-to-spectrum> --output <path-to-output>`

Expand All @@ -64,18 +64,18 @@ tomo_preprocessing <command> --help
Pixel size matching is recommended when your tomogram pixel sizes differs strongly from the training pixel size range (roughly 10-14&Aring;). You can perform it using the command

```shell
tomo_preprocessing match_pixel_size --input_tomogram <path-to-tomo> --output_path <path-to-output> --pixel_size_out 10.0 --pixel_size_in <your-px-size>
tomo_preprocessing match_pixel_size --input-tomogram <path-to-tomo> --output-path <path-to-output> --pixel-size-out 10.0 --pixel-size-in <your-px-size>
```

after adjusting the paths to your respective tomograms.
Afterwards, you can perform MemBrain's segmentation on the rescaled tomogram (i.e. the one specified in `--output_path`).
Afterwards, you can perform MemBrain's segmentation on the rescaled tomogram (i.e. the one specified in `--output-path`).
Now, this new segmentation does not have the same shape as the original non-pixel-size-matched tomogram. To rescale the new segmentation to the original tomogram again, you can use

```shell
tomo_preprocessing match_seg_to_tomo --seg_path <path-to-seg> --orig_tomo_path <path-to-tomo> --output_path <path-to-output>
tomo_preprocessing match_seg_to_tomo --seg_path <path-to-seg> --orig-tomo-path <path-to-tomo> --output-path <path-to-output>
```

where the `--seg_path`is the segmentation created by MemBrain and the `--orig_tomo_path`is the original tomogram before rescaling to the new pixel size.
where the `--seg-path`is the segmentation created by MemBrain and the `--orig-tomo-path`is the original tomogram before rescaling to the new pixel size.
The output of this function will be MemBrain's segmentation, but matched to the pixel size of the original tomogram.


Expand All @@ -84,7 +84,7 @@ Fourier amplitude matching is performed in two steps:

1. Extraction of the target Fourier spectrum:
```shell
tomo_preprocessing extract_spectrum --input_path <path-to-tomo> --output_path <path-to-output>
tomo_preprocessing extract_spectrum --input-path <path-to-tomo> --output-path <path-to-output>
```
This extracts the radially averaged Fourier spectrum and stores it into a .tsv file.
2. Matching of the input tomogram to the extracted spectrum:
Expand Down
13 changes: 12 additions & 1 deletion docs/Usage/Segmentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ You can also compute the connected components [after you have segmented your tom

**--connected-component-thres**: Threshold for connected components. Components smaller than this will be removed from the segmentation. [default: None]

**--test-time-augmentation / --no-test-time-augmentation**: Should 8-fold test time augmentation be used? If activated (default), segmentations tendo be slightly better, but runtime is increased.

**--segmentation-threshold**: Set a custom threshold for thresholding your membrane scoremap to increase / decrease segmented membranes (default: 0.0).

**--sliding-window-size** INTEGER Sliding window size used for inference. Smaller values than 160 consume less GPU, but also lead to worse segmentation results! [default: 160]

**--help** Show this message and exit.
Expand All @@ -102,10 +106,17 @@ If you have segmented your tomograms already, but would still like to extract th
```shell
membrain components --segmentation-path <path-to-your-segmentation> --connected-component-thres 50 --out-folder <folder-to-store-components>
```

### Note:
Computing the connected components, and particularly also removing the small components can be quite compute intensive and take a while.

## Custom thresholding
In some cases, the standard threshold ($0.0$) may not be the ideal value for segmenting your tomograms. In order to explore what threshold may be best, you can use the above segmentation command with the flag `--store-probabilities`. This will store a membrane scoremap that you can threshold using different values using the command:

```
membrain thresholds --scoremap-path <path-to-scoremap>
--thresholds -1.5 --thresholds -0.5 --thresholds 0.0 --thresholds 0.5
```
In this way, you can pass as many thresholds as you would like and the function will output one segmentation for each.


## Post-Processing
Expand Down
2 changes: 1 addition & 1 deletion src/membrain_seg/annotations/extract_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def extract_patches(
min_coords[0] : min_coords[0] + 160,
min_coords[1] : min_coords[1] + 160,
min_coords[2] : min_coords[2] + 160,
]
].copy()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When not adding the "copy", overlapping extracted patches will have the ignore-label padding in each other, because the original array is padded instead of the small patches.

cur_patch_labels = pad_labels(
cur_patch_labels, padding, pad_value=pad_value
)
Expand Down
15 changes: 11 additions & 4 deletions src/membrain_seg/annotations/merge_corrections.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,25 @@ def get_corrections_from_folder(folder_name, orig_pred_file):
for filename in os.listdir(folder_name):
if not (
filename.startswith("Add")
or filename.startswith("add")
or filename.startswith("Remove")
or filename.startswith("remove")
or filename.startswith("Ignore")
or filename.startswith("ignore")
):
print("ATTENTION! Not processing", filename)
print("Is this intended?")
continue
readdata = sitk.GetArrayFromImage(
sitk.ReadImage(os.path.join(folder_name, filename))
)
if filename.startswith("Add"):
print("Adding file", filename, "<--")

if filename.startswith("Add") or filename.startswith("add"):
add_patch += readdata
if filename.startswith("Remove"):
if filename.startswith("Remove") or filename.startswith("remove"):
remove_patch += readdata
if filename.startswith("Ignore"):
if filename.startswith("Ignore") or filename.startswith("ignore"):
ignore_patch += readdata
correction_count += 1

Expand Down Expand Up @@ -112,6 +117,8 @@ def convert_single_nrrd_files(labels_dir, corrections_dir, out_dir):
token = label_file[:-7]
elif label_file.endswith(".nrrd"):
token = label_file[:-5]
elif label_file.endswith(".mrc"):
token = label_file[:-4]
found_flag = 0
for filename in os.listdir(corrections_dir):
if not os.path.isdir(os.path.join(corrections_dir, filename)):
Expand All @@ -121,7 +128,7 @@ def convert_single_nrrd_files(labels_dir, corrections_dir, out_dir):
merged_corrections = get_corrections_from_folder(
cur_patch_corrections_folder, os.path.join(labels_dir, label_file)
)
out_file = os.path.join(out_dir, label_file)
out_file = os.path.join(out_dir, token + ".nii.gz")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not specifying .nii.gz format, it will take to data format of the inputs (which can also be e.g. .mrc)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I follow. It looks like the output file path is hardcoded to .nii.gz.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, exactly. I hard-coded it to nii.gz s.t. it does not save out e.g. .mrc files. The input to the training is still nii.gz. We may want to change that at some point to something more common in Cryo-ET. This change here is to make everything compatible with the current training setup.

print("Storing corrections in", out_file)
write_nifti(out_file, merged_corrections)
found_flag = 1
Expand Down
64 changes: 64 additions & 0 deletions src/membrain_seg/segmentation/cli/segment_cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from typing import List

from typer import Option

Expand Down Expand Up @@ -37,6 +38,17 @@ def segment(
help="Threshold for connected components. Components smaller than this will \
be removed from the segmentation.",
),
test_time_augmentation: bool = Option( # noqa: B008
True,
help="Use 8-fold test time augmentation (TTA)? TTA improves segmentation \
quality slightly, but also increases runtime.",
),
segmentation_threshold: float = Option( # noqa: B008
0.0,
help="Threshold for the membrane segmentation. Only voxels with a \
membrane score higher than this threshold will be segmented. \
(default: 0.0)",
),
sliding_window_size: int = Option( # noqa: B008
160,
help="Sliding window size used for inference. Smaller values than 160 \
Expand All @@ -58,6 +70,8 @@ def segment(
store_connected_components=store_connected_components,
connected_component_thres=connected_component_thres,
sw_roi_size=sliding_window_size,
test_time_augmentation=test_time_augmentation,
segmentation_threshold=segmentation_threshold,
)


Expand Down Expand Up @@ -95,3 +109,53 @@ def components(
os.path.splitext(os.path.basename(segmentation_path))[0] + "_components.mrc",
)
store_tomogram(filename=out_file, tomogram=segmentation)


@cli.command(name="thresholds", no_args_is_help=True)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a separate command for a function that

  1. takes as input a scoremap produced by MemBrain-seg, together with multiple float values
  2. thresholds the scoremap at each given float value and stores the resulting segmentation mask

def thresholds(
scoremap_path: str = Option( # noqa: B008
help="Path to the membrane scoremap to be processed.", **PKWARGS
),
thresholds: List[float] = Option( # noqa: B008
...,
help="List of thresholds. Provide multiple by repeating the option.",
),
out_folder: str = Option( # noqa: B008
"./predictions",
help="Path to the folder where thresholdedsegmentations \
should be stored.",
),
):
"""Process the provided scoremap using given thresholds.

Given a membrane scoremap, this function thresholds the scoremap data
using the provided threshold(s). The thresholded scoremaps are then stored
in the specified output folder. If multiple thresholds are provided,
separate thresholded scoremaps will be generated for each threshold.

Example
-------
membrain thresholds --scoremap-path <path-to-scoremap>
--thresholds -1.5 --thresholds -0.5 --thresholds 0.0 --thresholds 0.5

This will generate thresholded scoremaps for the provided scoremap at
thresholds -1.5, -0.5, 0.0 and 0.5.The results will be saved with filenames
indicating the threshold values in the default 'predictions' folder or
in the folder specified by the user.
"""
scoremap = load_tomogram(scoremap_path)
score_data = scoremap.data
if not isinstance(thresholds, list):
thresholds = [thresholds]
for threshold in thresholds:
print("Thresholding at", threshold)
thresholded_data = score_data > threshold
segmentation = scoremap
segmentation.data = thresholded_data
out_file = os.path.join(
out_folder,
os.path.splitext(os.path.basename(scoremap_path))[0]
+ f"_threshold_{threshold}.mrc",
)
store_tomogram(filename=out_file, tomogram=segmentation)
print("Saved thresholded scoremap to", out_file)
7 changes: 6 additions & 1 deletion src/membrain_seg/segmentation/dataloading/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def store_segmented_tomograms(
connected_component_thres: int = None,
mrc_header: np.recarray = None,
voxel_size: float = None,
segmentation_threshold: float = 0.0,
) -> None:
"""
Helper function for storing output segmentations.
Expand Down Expand Up @@ -163,6 +164,8 @@ def store_segmented_tomograms(
voxel_size: float, optional
If given, this will be the voxel size stored in the header of the
output segmentation.
segmentation_threshold : float, optional
Threshold for the segmentation. Default is 0.0.
"""
# Create out directory if it doesn't exist yet
make_directory_if_not_exists(out_folder)
Expand All @@ -178,7 +181,9 @@ def store_segmented_tomograms(
data=predictions_np, header=mrc_header, voxel_size=voxel_size
)
store_tomogram(out_file, out_tomo)
predictions_np_thres = predictions.squeeze(0).squeeze(0).cpu().numpy() > 0.0
predictions_np_thres = (
predictions.squeeze(0).squeeze(0).cpu().numpy() > segmentation_threshold
)
out_file_thres = os.path.join(
out_folder,
os.path.basename(orig_data_path)[:-4] + "_" + ckpt_token + "_segmented.mrc",
Expand Down
5 changes: 5 additions & 0 deletions src/membrain_seg/segmentation/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def segment(
store_connected_components=False,
connected_component_thres=None,
test_time_augmentation=True,
segmentation_threshold=0.0,
):
"""
Segment tomograms using a trained model.
Expand Down Expand Up @@ -55,6 +56,9 @@ def segment(
test_time_augmentation: bool, optional
If True, test-time augmentation is performed, i.e. data is rotated
into eight different orientations and predictions are averaged.
segmentation_threshold: float, optional
Threshold for the membrane segmentation. Only voxels with a membrane
score higher than this threshold will be segmented. (default: 0.0)

Returns
-------
Expand Down Expand Up @@ -133,5 +137,6 @@ def segment(
connected_component_thres=connected_component_thres,
mrc_header=mrc_header,
voxel_size=voxel_size,
segmentation_threshold=segmentation_threshold,
)
return segmentation_file