Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2498,7 +2498,7 @@ def distance_transform_edt(
if return_indices:
dtype = torch.int32
if indices is None:
indices = torch.zeros((img.dim(),) + img.shape, dtype=dtype) # type: ignore
indices = torch.zeros((img.shape[0],) + (img.dim() - 1,) + img.shape[1:], dtype=dtype) # type: ignore
else:
if not isinstance(indices, torch.Tensor) and indices.device != img.device:
raise TypeError("indices must be a torch.Tensor on the same device as img")
Expand Down Expand Up @@ -2532,7 +2532,7 @@ def distance_transform_edt(
raise TypeError("distances must be a numpy.ndarray of dtype float64")
if return_indices:
if indices is None:
indices = np.zeros((img_.ndim,) + img_.shape, dtype=np.int32)
indices = np.zeros((img_.shape[0],) + (img_.ndim - 1,) + img_.shape[1:], dtype=np.int32)
else:
if not isinstance(indices, np.ndarray):
raise TypeError("indices must be a numpy.ndarray")
Expand Down
Loading