diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 2963c8a2f8..8491e4739c 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -66,7 +66,6 @@ optional_import, ) from monai.utils.enums import TransformBackends -from monai.utils.misc import is_module_ver_at_least from monai.utils.type_conversion import convert_to_dst_type, get_dtype_string, get_equivalent_dtype PILImageImage, has_pil = optional_import("PIL.Image", name="Image") @@ -939,19 +938,10 @@ def __call__( data = img[[*select_labels]] else: where: Callable = np.where if isinstance(img, np.ndarray) else torch.where # type: ignore - if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)): - data = where(in1d(img, select_labels), True, False).reshape(img.shape) - # pre pytorch 1.8.0, need to use 1/0 instead of True/False - else: - data = where( - in1d(img, select_labels), torch.tensor(1, device=img.device), torch.tensor(0, device=img.device) - ).reshape(img.shape) + data = where(in1d(img, select_labels), True, False).reshape(img.shape) if merge_channels or self.merge_channels: - if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)): - return data.any(0)[None] - # pre pytorch 1.8.0 compatibility - return data.to(torch.uint8).any(0)[None].to(bool) # type: ignore + return data.any(0)[None] return data diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 365bd1eab5..8f22d00674 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -18,7 +18,6 @@ import torch from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor -from monai.utils.misc import is_module_ver_at_least from monai.utils.type_conversion import convert_data_type, convert_to_dst_type __all__ = [ @@ -215,10 +214,9 @@ def floor_divide(a: NdarrayOrTensor, b) -> NdarrayOrTensor: Element-wise floor division between two arrays/tensors. """ if isinstance(a, torch.Tensor): - if is_module_ver_at_least(torch, (1, 8, 0)): - return torch.div(a, b, rounding_mode="floor") return torch.floor_divide(a, b) - return np.floor_divide(a, b) + else: + return np.floor_divide(a, b) def unravel_index(idx, shape) -> NdarrayOrTensor: