Skip to content

Commit 01711cf

Browse files
Fix channel-first indices buffer for distance_transform_edt (return_indices=True) #8656 (#8657)
Fixes #8656 Distance_transform_edt indices preallocation to use channel-first (C, spatial_dims, ...) layout for both torch/cuCIM and NumPy/SciPy paths, resolving “indices array has wrong shape” errors when return_indices=True. ### Description ``` import torch from monai.transforms.utils import distance_transform_edt img = torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.float32) # shape (1, 3, 3) # Previously raised: RuntimeError: indices array has wrong shape indices = distance_transform_edt(img, return_distances=False, return_indices=True) print(indices.shape) # now: (1, 2, 3, 3) ``` ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: alexanderjaus <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]>
1 parent 4014c84 commit 01711cf

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

monai/transforms/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2498,7 +2498,7 @@ def distance_transform_edt(
24982498
if return_indices:
24992499
dtype = torch.int32
25002500
if indices is None:
2501-
indices = torch.zeros((img.dim(),) + img.shape, dtype=dtype) # type: ignore
2501+
indices = torch.zeros((img.shape[0],) + (img.dim() - 1,) + img.shape[1:], dtype=dtype) # type: ignore
25022502
else:
25032503
if not isinstance(indices, torch.Tensor) and indices.device != img.device:
25042504
raise TypeError("indices must be a torch.Tensor on the same device as img")
@@ -2532,7 +2532,7 @@ def distance_transform_edt(
25322532
raise TypeError("distances must be a numpy.ndarray of dtype float64")
25332533
if return_indices:
25342534
if indices is None:
2535-
indices = np.zeros((img_.ndim,) + img_.shape, dtype=np.int32)
2535+
indices = np.zeros((img_.shape[0],) + (img_.ndim - 1,) + img_.shape[1:], dtype=np.int32)
25362536
else:
25372537
if not isinstance(indices, np.ndarray):
25382538
raise TypeError("indices must be a numpy.ndarray")

0 commit comments

Comments
 (0)