diff --git a/src/membrain_seg/annotations/extract_patch_cli.py b/src/membrain_seg/annotations/extract_patch_cli.py index c984642..339e476 100644 --- a/src/membrain_seg/annotations/extract_patch_cli.py +++ b/src/membrain_seg/annotations/extract_patch_cli.py @@ -84,7 +84,7 @@ def extract_patches( pad_value = 2.0 # Currently still hard-coded because other values not # compatible with training routine yet. if coords_file is not None: - coords = get_csv_data(csv_path=coords_file) + coords = np.array(get_csv_data(csv_path=coords_file), dtype=int) else: coords = [np.array((x, y, z))] _extract_patches( diff --git a/src/membrain_seg/annotations/merge_corrections.py b/src/membrain_seg/annotations/merge_corrections.py index 951083e..8286633 100644 --- a/src/membrain_seg/annotations/merge_corrections.py +++ b/src/membrain_seg/annotations/merge_corrections.py @@ -105,8 +105,13 @@ def convert_single_nrrd_files(labels_dir, corrections_dir, out_dir): for label_file in os.listdir(labels_dir): if not os.path.isfile(os.path.join(labels_dir, label_file)): continue + print("") print("Finding correction files for", label_file) - token = os.path.splitext(label_file)[0] + # token = os.path.splitext(label_file)[0] + if label_file.endswith(".nii.gz"): + token = label_file[:-7] + elif label_file.endswith(".nrrd"): + token = label_file[:-5] found_flag = 0 for filename in os.listdir(corrections_dir): if not os.path.isdir(os.path.join(corrections_dir, filename)): @@ -120,5 +125,5 @@ def convert_single_nrrd_files(labels_dir, corrections_dir, out_dir): print("Storing corrections in", out_file) write_nifti(out_file, merged_corrections) found_flag = 1 - if found_flag == 0: - print("No corrections folder found for patch", token) + if found_flag == 0: + print("No corrections folder found for patch", token) diff --git a/src/membrain_seg/segmentation/dataloading/data_utils.py b/src/membrain_seg/segmentation/dataloading/data_utils.py index e307784..28109d9 100644 --- a/src/membrain_seg/segmentation/dataloading/data_utils.py +++ b/src/membrain_seg/segmentation/dataloading/data_utils.py @@ -118,7 +118,7 @@ def load_data_for_inference(data_path: str, transforms: Callable, device: device new_data = transforms(new_data) new_data = new_data.unsqueeze(0) # Add batch dimension new_data = new_data.to(device) - return new_data, tomogram.header + return new_data, tomogram.header, tomogram.voxel_size def store_segmented_tomograms( @@ -130,6 +130,7 @@ def store_segmented_tomograms( store_connected_components: bool = False, connected_component_thres: int = None, mrc_header: np.recarray = None, + voxel_size: float = None, ) -> None: """ Helper function for storing output segmentations. @@ -159,6 +160,9 @@ def store_segmented_tomograms( If given, the mrc header will be used to retain header information from another tomogram. This way, pixel sizes and other header information is not lost. + voxel_size: float, optional + If given, this will be the voxel size stored in the header of the + output segmentation. """ # Create out directory if it doesn't exist yet make_directory_if_not_exists(out_folder) @@ -170,7 +174,9 @@ def store_segmented_tomograms( out_file = os.path.join( out_folder, os.path.basename(orig_data_path)[:-4] + "_scores.mrc" ) - out_tomo = Tomogram(data=predictions_np, header=mrc_header) + out_tomo = Tomogram( + data=predictions_np, header=mrc_header, voxel_size=voxel_size + ) store_tomogram(out_file, out_tomo) predictions_np_thres = predictions.squeeze(0).squeeze(0).cpu().numpy() > 0.0 out_file_thres = os.path.join( @@ -181,7 +187,9 @@ def store_segmented_tomograms( predictions_np_thres = connected_components( predictions_np_thres, size_thres=connected_component_thres ) - out_tomo = Tomogram(data=predictions_np_thres, header=mrc_header) + out_tomo = Tomogram( + data=predictions_np_thres, header=mrc_header, voxel_size=voxel_size + ) store_tomogram(out_file_thres, out_tomo) print("MemBrain has finished segmenting your tomogram.") return out_file_thres @@ -334,7 +342,7 @@ def convert_dtype(tomogram: np.ndarray) -> np.ndarray: if ( tomogram.min() >= np.finfo("float16").min and tomogram.max() <= np.finfo("float16").max - ): + ) and np.allclose(tomogram, tomogram.astype("float16")): return tomogram.astype("float16") elif ( tomogram.min() >= np.finfo("float32").min @@ -368,33 +376,22 @@ def store_tomogram( if isinstance(tomogram, Tomogram): data = tomogram.data header = tomogram.header + if voxel_size is None: + voxel_size = tomogram.voxel_size else: data = tomogram header = None - data = convert_dtype(data) - data = np.transpose(data, (2, 1, 0)) - dtype_mode = _dtype_to_mode[data.dtype] - out_mrc.set_data(data) + if header is not None: attributes = header.dtype.names for attr in attributes: - # skip density and shape attribues - if attr in [ - "mode", - "dmean", - "dmin", - "dmax", - "rms", - "nx", - "ny", - "nz", - "mx", - "my", - "mz", - ]: + if attr not in ["nlabl", "label"]: continue setattr(out_mrc.header, attr, getattr(header, attr)) - out_mrc.header.mode = dtype_mode + + data = convert_dtype(data) + data = np.transpose(data, (2, 1, 0)) + out_mrc.set_data(data) if voxel_size is not None: out_mrc.voxel_size = voxel_size diff --git a/src/membrain_seg/segmentation/segment.py b/src/membrain_seg/segmentation/segment.py index 513a806..5d34b88 100644 --- a/src/membrain_seg/segmentation/segment.py +++ b/src/membrain_seg/segmentation/segment.py @@ -20,7 +20,7 @@ def segment( sw_roi_size=160, store_connected_components=False, connected_component_thres=None, - test_time_augmentation=True + test_time_augmentation=True, ): """ Segment tomograms using a trained model. @@ -81,7 +81,7 @@ def segment( # Preprocess the new data new_data_path = tomogram_path transforms = get_prediction_transforms() - new_data, mrc_header = load_data_for_inference( + new_data, mrc_header, voxel_size = load_data_for_inference( new_data_path, transforms, device=torch.device("cpu") ) new_data = new_data.to(torch.float32) @@ -106,7 +106,7 @@ def segment( # Perform test time augmentation (8-fold mirroring) predictions = torch.zeros_like(new_data) print("Performing 8-fold test-time augmentation.") - for m in range((8 if test_time_augmentation else 1)): + for m in range(8 if test_time_augmentation else 1): with torch.no_grad(): with torch.cuda.amp.autocast(): predictions += ( @@ -132,5 +132,6 @@ def segment( store_connected_components=store_connected_components, connected_component_thres=connected_component_thres, mrc_header=mrc_header, + voxel_size=voxel_size, ) return segmentation_file