|
66 | 66 | optional_import, |
67 | 67 | ) |
68 | 68 | from monai.utils.enums import TransformBackends |
69 | | -from monai.utils.misc import is_module_ver_at_least |
70 | 69 | from monai.utils.type_conversion import convert_to_dst_type, get_dtype_string, get_equivalent_dtype |
71 | 70 |
|
72 | 71 | PILImageImage, has_pil = optional_import("PIL.Image", name="Image") |
@@ -939,19 +938,10 @@ def __call__( |
939 | 938 | data = img[[*select_labels]] |
940 | 939 | else: |
941 | 940 | where: Callable = np.where if isinstance(img, np.ndarray) else torch.where # type: ignore |
942 | | - if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)): |
943 | | - data = where(in1d(img, select_labels), True, False).reshape(img.shape) |
944 | | - # pre pytorch 1.8.0, need to use 1/0 instead of True/False |
945 | | - else: |
946 | | - data = where( |
947 | | - in1d(img, select_labels), torch.tensor(1, device=img.device), torch.tensor(0, device=img.device) |
948 | | - ).reshape(img.shape) |
| 941 | + data = where(in1d(img, select_labels), True, False).reshape(img.shape) |
949 | 942 |
|
950 | 943 | if merge_channels or self.merge_channels: |
951 | | - if isinstance(img, np.ndarray) or is_module_ver_at_least(torch, (1, 8, 0)): |
952 | | - return data.any(0)[None] |
953 | | - # pre pytorch 1.8.0 compatibility |
954 | | - return data.to(torch.uint8).any(0)[None].to(bool) # type: ignore |
| 944 | + return data.any(0)[None] |
955 | 945 |
|
956 | 946 | return data |
957 | 947 |
|
|
0 commit comments