Skip to content

Commit

Permalink
Header issues (#31)
Browse files Browse the repository at this point in the history
* Fix patch merge bugs and header mode issues

* change faulty float32 conversion to float16 conversion

* change order of header and data insertion

* Change voxel size argument from segmentation storing
  • Loading branch information
LorenzLamm authored Aug 10, 2023
1 parent 63d1c1b commit 48e647d
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 30 deletions.
2 changes: 1 addition & 1 deletion src/membrain_seg/annotations/extract_patch_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def extract_patches(
pad_value = 2.0 # Currently still hard-coded because other values not
# compatible with training routine yet.
if coords_file is not None:
coords = get_csv_data(csv_path=coords_file)
coords = np.array(get_csv_data(csv_path=coords_file), dtype=int)
else:
coords = [np.array((x, y, z))]
_extract_patches(
Expand Down
11 changes: 8 additions & 3 deletions src/membrain_seg/annotations/merge_corrections.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,13 @@ def convert_single_nrrd_files(labels_dir, corrections_dir, out_dir):
for label_file in os.listdir(labels_dir):
if not os.path.isfile(os.path.join(labels_dir, label_file)):
continue
print("")
print("Finding correction files for", label_file)
token = os.path.splitext(label_file)[0]
# token = os.path.splitext(label_file)[0]
if label_file.endswith(".nii.gz"):
token = label_file[:-7]
elif label_file.endswith(".nrrd"):
token = label_file[:-5]
found_flag = 0
for filename in os.listdir(corrections_dir):
if not os.path.isdir(os.path.join(corrections_dir, filename)):
Expand All @@ -120,5 +125,5 @@ def convert_single_nrrd_files(labels_dir, corrections_dir, out_dir):
print("Storing corrections in", out_file)
write_nifti(out_file, merged_corrections)
found_flag = 1
if found_flag == 0:
print("No corrections folder found for patch", token)
if found_flag == 0:
print("No corrections folder found for patch", token)
43 changes: 20 additions & 23 deletions src/membrain_seg/segmentation/dataloading/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def load_data_for_inference(data_path: str, transforms: Callable, device: device
new_data = transforms(new_data)
new_data = new_data.unsqueeze(0) # Add batch dimension
new_data = new_data.to(device)
return new_data, tomogram.header
return new_data, tomogram.header, tomogram.voxel_size


def store_segmented_tomograms(
Expand All @@ -130,6 +130,7 @@ def store_segmented_tomograms(
store_connected_components: bool = False,
connected_component_thres: int = None,
mrc_header: np.recarray = None,
voxel_size: float = None,
) -> None:
"""
Helper function for storing output segmentations.
Expand Down Expand Up @@ -159,6 +160,9 @@ def store_segmented_tomograms(
If given, the mrc header will be used to retain header information
from another tomogram. This way, pixel sizes and other header
information is not lost.
voxel_size: float, optional
If given, this will be the voxel size stored in the header of the
output segmentation.
"""
# Create out directory if it doesn't exist yet
make_directory_if_not_exists(out_folder)
Expand All @@ -170,7 +174,9 @@ def store_segmented_tomograms(
out_file = os.path.join(
out_folder, os.path.basename(orig_data_path)[:-4] + "_scores.mrc"
)
out_tomo = Tomogram(data=predictions_np, header=mrc_header)
out_tomo = Tomogram(
data=predictions_np, header=mrc_header, voxel_size=voxel_size
)
store_tomogram(out_file, out_tomo)
predictions_np_thres = predictions.squeeze(0).squeeze(0).cpu().numpy() > 0.0
out_file_thres = os.path.join(
Expand All @@ -181,7 +187,9 @@ def store_segmented_tomograms(
predictions_np_thres = connected_components(
predictions_np_thres, size_thres=connected_component_thres
)
out_tomo = Tomogram(data=predictions_np_thres, header=mrc_header)
out_tomo = Tomogram(
data=predictions_np_thres, header=mrc_header, voxel_size=voxel_size
)
store_tomogram(out_file_thres, out_tomo)
print("MemBrain has finished segmenting your tomogram.")
return out_file_thres
Expand Down Expand Up @@ -334,7 +342,7 @@ def convert_dtype(tomogram: np.ndarray) -> np.ndarray:
if (
tomogram.min() >= np.finfo("float16").min
and tomogram.max() <= np.finfo("float16").max
):
) and np.allclose(tomogram, tomogram.astype("float16")):
return tomogram.astype("float16")
elif (
tomogram.min() >= np.finfo("float32").min
Expand Down Expand Up @@ -368,33 +376,22 @@ def store_tomogram(
if isinstance(tomogram, Tomogram):
data = tomogram.data
header = tomogram.header
if voxel_size is None:
voxel_size = tomogram.voxel_size
else:
data = tomogram
header = None
data = convert_dtype(data)
data = np.transpose(data, (2, 1, 0))
dtype_mode = _dtype_to_mode[data.dtype]
out_mrc.set_data(data)

if header is not None:
attributes = header.dtype.names
for attr in attributes:
# skip density and shape attribues
if attr in [
"mode",
"dmean",
"dmin",
"dmax",
"rms",
"nx",
"ny",
"nz",
"mx",
"my",
"mz",
]:
if attr not in ["nlabl", "label"]:
continue
setattr(out_mrc.header, attr, getattr(header, attr))
out_mrc.header.mode = dtype_mode

data = convert_dtype(data)
data = np.transpose(data, (2, 1, 0))
out_mrc.set_data(data)
if voxel_size is not None:
out_mrc.voxel_size = voxel_size

Expand Down
7 changes: 4 additions & 3 deletions src/membrain_seg/segmentation/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def segment(
sw_roi_size=160,
store_connected_components=False,
connected_component_thres=None,
test_time_augmentation=True
test_time_augmentation=True,
):
"""
Segment tomograms using a trained model.
Expand Down Expand Up @@ -81,7 +81,7 @@ def segment(
# Preprocess the new data
new_data_path = tomogram_path
transforms = get_prediction_transforms()
new_data, mrc_header = load_data_for_inference(
new_data, mrc_header, voxel_size = load_data_for_inference(
new_data_path, transforms, device=torch.device("cpu")
)
new_data = new_data.to(torch.float32)
Expand All @@ -106,7 +106,7 @@ def segment(
# Perform test time augmentation (8-fold mirroring)
predictions = torch.zeros_like(new_data)
print("Performing 8-fold test-time augmentation.")
for m in range((8 if test_time_augmentation else 1)):
for m in range(8 if test_time_augmentation else 1):
with torch.no_grad():
with torch.cuda.amp.autocast():
predictions += (
Expand All @@ -132,5 +132,6 @@ def segment(
store_connected_components=store_connected_components,
connected_component_thres=connected_component_thres,
mrc_header=mrc_header,
voxel_size=voxel_size,
)
return segmentation_file

0 comments on commit 48e647d

Please sign in to comment.