diff --git a/src/membrain_seg/segmentation/cli/ske_cli.py b/src/membrain_seg/segmentation/cli/ske_cli.py index dd4a458..212773a 100644 --- a/src/membrain_seg/segmentation/cli/ske_cli.py +++ b/src/membrain_seg/segmentation/cli/ske_cli.py @@ -7,6 +7,7 @@ store_tomogram, ) + from ..skeletonize import skeletonization as _skeletonization from .cli import cli @@ -53,12 +54,14 @@ def skeletonize( --batch-size """ # Assuming _skeletonization function is already defined and can handle batch_size + segmentation = load_tomogram(label_path) ske = _skeletonization(segmentation=segmentation.data, batch_size=batch_size) # Update the segmentation data with the skeletonized output while preserving the original header and voxel_size segmentation.data = ske + if not os.path.exists(out_folder): os.makedirs(out_folder) @@ -66,6 +69,6 @@ def skeletonize( out_folder, os.path.splitext(os.path.basename(label_path))[0] + "_skel.mrc", ) - + store_tomogram(filename=out_file, tomogram=segmentation) print("Skeleton saved to ", out_file) diff --git a/src/membrain_seg/segmentation/skeletonize.py b/src/membrain_seg/segmentation/skeletonize.py index 9d845dd..f4048d5 100644 --- a/src/membrain_seg/segmentation/skeletonize.py +++ b/src/membrain_seg/segmentation/skeletonize.py @@ -12,6 +12,7 @@ import scipy.ndimage as ndimage import torch + from membrain_seg.segmentation.skeletonization.diff3d import ( compute_gradients, compute_hessian, @@ -23,6 +24,7 @@ from membrain_seg.segmentation.training.surface_dice import apply_gaussian_filter + def skeletonization(segmentation: np.ndarray, batch_size: int) -> np.ndarray: """ Perform skeletonization on a tomogram segmentation. @@ -36,6 +38,7 @@ def skeletonization(segmentation: np.ndarray, batch_size: int) -> np.ndarray: segmentation : ndarray Tomogram segmentation as a numpy array, where non-zero values represent the structures of interest. + batch_size : int The number of elements to process in one batch during eigen decomposition. Useful for managing memory usage. @@ -58,6 +61,7 @@ def skeletonization(segmentation: np.ndarray, batch_size: int) -> np.ndarray: --batch-size 1000000 This command runs the skeletonization process from the command line. """ + # Convert non-zero segmentation values to 1.0 labels = (segmentation > 0) * 1.0