-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Misc features #36
Misc features #36
Changes from all commits
342270e
b1cd6c7
4e6679f
784d965
a625b75
d0f2b72
b7f462e
4ace2be
d1ea30e
89b544f
490f730
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,20 +40,25 @@ def get_corrections_from_folder(folder_name, orig_pred_file): | |
for filename in os.listdir(folder_name): | ||
if not ( | ||
filename.startswith("Add") | ||
or filename.startswith("add") | ||
or filename.startswith("Remove") | ||
or filename.startswith("remove") | ||
or filename.startswith("Ignore") | ||
or filename.startswith("ignore") | ||
): | ||
print("ATTENTION! Not processing", filename) | ||
print("Is this intended?") | ||
continue | ||
readdata = sitk.GetArrayFromImage( | ||
sitk.ReadImage(os.path.join(folder_name, filename)) | ||
) | ||
if filename.startswith("Add"): | ||
print("Adding file", filename, "<--") | ||
|
||
if filename.startswith("Add") or filename.startswith("add"): | ||
add_patch += readdata | ||
if filename.startswith("Remove"): | ||
if filename.startswith("Remove") or filename.startswith("remove"): | ||
remove_patch += readdata | ||
if filename.startswith("Ignore"): | ||
if filename.startswith("Ignore") or filename.startswith("ignore"): | ||
ignore_patch += readdata | ||
correction_count += 1 | ||
|
||
|
@@ -112,6 +117,8 @@ def convert_single_nrrd_files(labels_dir, corrections_dir, out_dir): | |
token = label_file[:-7] | ||
elif label_file.endswith(".nrrd"): | ||
token = label_file[:-5] | ||
elif label_file.endswith(".mrc"): | ||
token = label_file[:-4] | ||
found_flag = 0 | ||
for filename in os.listdir(corrections_dir): | ||
if not os.path.isdir(os.path.join(corrections_dir, filename)): | ||
|
@@ -121,7 +128,7 @@ def convert_single_nrrd_files(labels_dir, corrections_dir, out_dir): | |
merged_corrections = get_corrections_from_folder( | ||
cur_patch_corrections_folder, os.path.join(labels_dir, label_file) | ||
) | ||
out_file = os.path.join(out_dir, label_file) | ||
out_file = os.path.join(out_dir, token + ".nii.gz") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If not specifying .nii.gz format, it will take to data format of the inputs (which can also be e.g. .mrc) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure I follow. It looks like the output file path is hardcoded to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, exactly. I hard-coded it to |
||
print("Storing corrections in", out_file) | ||
write_nifti(out_file, merged_corrections) | ||
found_flag = 1 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import os | ||
from typing import List | ||
|
||
from typer import Option | ||
|
||
|
@@ -37,6 +38,17 @@ def segment( | |
help="Threshold for connected components. Components smaller than this will \ | ||
be removed from the segmentation.", | ||
), | ||
test_time_augmentation: bool = Option( # noqa: B008 | ||
True, | ||
help="Use 8-fold test time augmentation (TTA)? TTA improves segmentation \ | ||
quality slightly, but also increases runtime.", | ||
), | ||
segmentation_threshold: float = Option( # noqa: B008 | ||
0.0, | ||
help="Threshold for the membrane segmentation. Only voxels with a \ | ||
membrane score higher than this threshold will be segmented. \ | ||
(default: 0.0)", | ||
), | ||
sliding_window_size: int = Option( # noqa: B008 | ||
160, | ||
help="Sliding window size used for inference. Smaller values than 160 \ | ||
|
@@ -58,6 +70,8 @@ def segment( | |
store_connected_components=store_connected_components, | ||
connected_component_thres=connected_component_thres, | ||
sw_roi_size=sliding_window_size, | ||
test_time_augmentation=test_time_augmentation, | ||
segmentation_threshold=segmentation_threshold, | ||
) | ||
|
||
|
||
|
@@ -95,3 +109,53 @@ def components( | |
os.path.splitext(os.path.basename(segmentation_path))[0] + "_components.mrc", | ||
) | ||
store_tomogram(filename=out_file, tomogram=segmentation) | ||
|
||
|
||
@cli.command(name="thresholds", no_args_is_help=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a separate command for a function that
|
||
def thresholds( | ||
scoremap_path: str = Option( # noqa: B008 | ||
help="Path to the membrane scoremap to be processed.", **PKWARGS | ||
), | ||
thresholds: List[float] = Option( # noqa: B008 | ||
..., | ||
help="List of thresholds. Provide multiple by repeating the option.", | ||
), | ||
out_folder: str = Option( # noqa: B008 | ||
"./predictions", | ||
help="Path to the folder where thresholdedsegmentations \ | ||
should be stored.", | ||
), | ||
): | ||
"""Process the provided scoremap using given thresholds. | ||
|
||
Given a membrane scoremap, this function thresholds the scoremap data | ||
using the provided threshold(s). The thresholded scoremaps are then stored | ||
in the specified output folder. If multiple thresholds are provided, | ||
separate thresholded scoremaps will be generated for each threshold. | ||
|
||
Example | ||
------- | ||
membrain thresholds --scoremap-path <path-to-scoremap> | ||
--thresholds -1.5 --thresholds -0.5 --thresholds 0.0 --thresholds 0.5 | ||
|
||
This will generate thresholded scoremaps for the provided scoremap at | ||
thresholds -1.5, -0.5, 0.0 and 0.5.The results will be saved with filenames | ||
indicating the threshold values in the default 'predictions' folder or | ||
in the folder specified by the user. | ||
""" | ||
scoremap = load_tomogram(scoremap_path) | ||
score_data = scoremap.data | ||
if not isinstance(thresholds, list): | ||
thresholds = [thresholds] | ||
for threshold in thresholds: | ||
print("Thresholding at", threshold) | ||
thresholded_data = score_data > threshold | ||
segmentation = scoremap | ||
segmentation.data = thresholded_data | ||
out_file = os.path.join( | ||
out_folder, | ||
os.path.splitext(os.path.basename(scoremap_path))[0] | ||
+ f"_threshold_{threshold}.mrc", | ||
) | ||
store_tomogram(filename=out_file, tomogram=segmentation) | ||
print("Saved thresholded scoremap to", out_file) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When not adding the "copy", overlapping extracted patches will have the ignore-label padding in each other, because the original array is padded instead of the small patches.