From 2b37b94f4d11a8f99aa2150fea57107aa044e4f3 Mon Sep 17 00:00:00 2001 From: Fabian Klopfer Date: Mon, 25 Mar 2024 16:10:45 +0100 Subject: [PATCH 01/21] Implement TorchIO transforms wrapper analogous to TorchVision transforms wrapper and test case Signed-off-by: Fabian Klopfer --- monai/transforms/__init__.py | 1 + monai/transforms/utility/array.py | 37 +++++++++++++++++++++ tests/test_torchio.py | 53 +++++++++++++++++++++++++++++++ 3 files changed, 91 insertions(+) create mode 100644 tests/test_torchio.py diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 2aa8fbf8a1..3dc65f742e 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -505,6 +505,7 @@ ToDevice, ToNumpy, ToPIL, + TorchIO, TorchVision, ToTensor, Transpose, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 5dfbcb0e91..0e3b01f512 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -98,6 +98,7 @@ "ConvertToMultiChannelBasedOnBratsClasses", "AddExtremePointsChannel", "TorchVision", + "TorchIO", "MapLabelValue", "IntensityStats", "ToDevice", @@ -1163,6 +1164,42 @@ def __call__(self, img: NdarrayOrTensor): return out +class TorchIO: + """ + This is a wrapper transform for TorchIO transforms based on the specified transform name and args. + As most of the TorchIO transforms only work for PyTorch Tensor, this transform expects input + data to be PyTorch Tensor, users can easily call `ToTensor` transform to convert a Numpy array to Tensor. + + """ + + backend = [TransformBackends.TORCH] + + def __init__(self, name: str, *args, **kwargs) -> None: + """ + Args: + name: The transform name in TorchIO package. + args: parameters for the TorchIO transform. + kwargs: parameters for the TorchIO transform. + + """ + super().__init__() + self.name = name + transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name) + self.trans = transform(*args, **kwargs) + + def __call__(self, img: NdarrayOrTensor): + """ + Args: + img: PyTorch Tensor data for the TorchIO transform. + + """ + img_t, *_ = convert_data_type(img, torch.Tensor) + + out = self.trans(img_t) + out, *_ = convert_to_dst_type(src=out, dst=img) + return out + + class MapLabelValue: """ Utility to map label values to another set of values. diff --git a/tests/test_torchio.py b/tests/test_torchio.py new file mode 100644 index 0000000000..f235d3250f --- /dev/null +++ b/tests/test_torchio.py @@ -0,0 +1,53 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +from parameterized import parameterized +import numpy as np +import torch + +from monai.transforms import TorchIO +from monai.utils import set_determinism + +TEST_DIMS = [3, 128, 160, 160] +TESTS = [ + [{"name": "RescaleIntensity"}, torch.rand(TEST_DIMS)], + [{"name": "ZNormalization"}, torch.rand(TEST_DIMS)], + [{"name": "RandomAffine"}, torch.rand(TEST_DIMS)], + [{"name": "RandomElasticDeformation"}, torch.rand(TEST_DIMS)], + [{"name": "RandomAnisotropy"}, torch.rand(TEST_DIMS)], + [{"name": "RandomMotion"}, torch.rand(TEST_DIMS)], + [{"name": "RandomGhosting"}, torch.rand(TEST_DIMS)], + [{"name": "RandomSpike"}, torch.rand(TEST_DIMS)], + [{"name": "RandomBiasField"}, torch.rand(TEST_DIMS)], + [{"name": "RandomBlur"}, torch.rand(TEST_DIMS)], + [{"name": "RandomNoise"}, torch.rand(TEST_DIMS)], + [{"name": "RandomSwap"}, torch.rand(TEST_DIMS)], + [{"name": "RandomGamma"}, torch.rand(TEST_DIMS)], + ] + + +class TestTorchIO(unittest.TestCase): + + @parameterized.expand(TESTS) + def test_value(self, input_param, input_data): + set_determinism(seed=0) + result = TorchIO(**input_param)(input_data) + self.assertIsNotNone(result) + self.assertFalse(np.array_equal(result.numpy(), input_data.numpy()), + f'{input_param} failed') + + +if __name__ == "__main__": + unittest.main() From 96955c6edae80f11a522a0713540044c0a4e569c Mon Sep 17 00:00:00 2001 From: Fabian Klopfer Date: Mon, 25 Mar 2024 18:53:04 +0100 Subject: [PATCH 02/21] Add torchio to dependencies Signed-off-by: Fabian Klopfer --- environment-dev.yml | 1 + requirements-dev.txt | 1 + setup.cfg | 3 +++ 3 files changed, 5 insertions(+) diff --git a/environment-dev.yml b/environment-dev.yml index d23958baba..20427d5d5c 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -7,6 +7,7 @@ channels: dependencies: - numpy>=1.20 - pytorch>=1.9 + - torchio - torchvision - pytorch-cuda=11.6 - pip diff --git a/requirements-dev.txt b/requirements-dev.txt index af1b8b89d5..f7f9a6db45 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -24,6 +24,7 @@ pytype>=2020.6.1; platform_system != "Windows" types-pkg_resources mypy>=1.5.0 ninja +torchio torchvision psutil cucim-cu12; platform_system == "Linux" and python_version >= "3.9" and python_version <= "3.10" diff --git a/setup.cfg b/setup.cfg index d7cb703d25..9e7a8fdada 100644 --- a/setup.cfg +++ b/setup.cfg @@ -54,6 +54,7 @@ all = tensorboard gdown>=4.7.3 pytorch-ignite==0.4.11 + torchio torchvision itk>=5.2 tqdm>=4.47.0 @@ -100,6 +101,8 @@ gdown = gdown>=4.7.3 ignite = pytorch-ignite==0.4.11 +torchio = + torchio torchvision = torchvision itk = From e478ef23bf0e0435fe0c6983683043058c3b6450 Mon Sep 17 00:00:00 2001 From: Fabian Klopfer Date: Mon, 25 Mar 2024 20:01:35 +0100 Subject: [PATCH 03/21] Fixup import order in test Signed-off-by: Fabian Klopfer --- tests/test_torchio.py | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/tests/test_torchio.py b/tests/test_torchio.py index f235d3250f..fee7e36b85 100644 --- a/tests/test_torchio.py +++ b/tests/test_torchio.py @@ -13,29 +13,29 @@ import unittest -from parameterized import parameterized import numpy as np import torch +from parameterized import parameterized from monai.transforms import TorchIO from monai.utils import set_determinism TEST_DIMS = [3, 128, 160, 160] TESTS = [ - [{"name": "RescaleIntensity"}, torch.rand(TEST_DIMS)], - [{"name": "ZNormalization"}, torch.rand(TEST_DIMS)], - [{"name": "RandomAffine"}, torch.rand(TEST_DIMS)], - [{"name": "RandomElasticDeformation"}, torch.rand(TEST_DIMS)], - [{"name": "RandomAnisotropy"}, torch.rand(TEST_DIMS)], - [{"name": "RandomMotion"}, torch.rand(TEST_DIMS)], - [{"name": "RandomGhosting"}, torch.rand(TEST_DIMS)], - [{"name": "RandomSpike"}, torch.rand(TEST_DIMS)], - [{"name": "RandomBiasField"}, torch.rand(TEST_DIMS)], - [{"name": "RandomBlur"}, torch.rand(TEST_DIMS)], - [{"name": "RandomNoise"}, torch.rand(TEST_DIMS)], - [{"name": "RandomSwap"}, torch.rand(TEST_DIMS)], - [{"name": "RandomGamma"}, torch.rand(TEST_DIMS)], - ] + [{"name": "RescaleIntensity"}, torch.rand(TEST_DIMS)], + [{"name": "ZNormalization"}, torch.rand(TEST_DIMS)], + [{"name": "RandomAffine"}, torch.rand(TEST_DIMS)], + [{"name": "RandomElasticDeformation"}, torch.rand(TEST_DIMS)], + [{"name": "RandomAnisotropy"}, torch.rand(TEST_DIMS)], + [{"name": "RandomMotion"}, torch.rand(TEST_DIMS)], + [{"name": "RandomGhosting"}, torch.rand(TEST_DIMS)], + [{"name": "RandomSpike"}, torch.rand(TEST_DIMS)], + [{"name": "RandomBiasField"}, torch.rand(TEST_DIMS)], + [{"name": "RandomBlur"}, torch.rand(TEST_DIMS)], + [{"name": "RandomNoise"}, torch.rand(TEST_DIMS)], + [{"name": "RandomSwap"}, torch.rand(TEST_DIMS)], + [{"name": "RandomGamma"}, torch.rand(TEST_DIMS)], +] class TestTorchIO(unittest.TestCase): @@ -45,8 +45,7 @@ def test_value(self, input_param, input_data): set_determinism(seed=0) result = TorchIO(**input_param)(input_data) self.assertIsNotNone(result) - self.assertFalse(np.array_equal(result.numpy(), input_data.numpy()), - f'{input_param} failed') + self.assertFalse(np.array_equal(result.numpy(), input_data.numpy()), f"{input_param} failed") if __name__ == "__main__": From a3cfde19ce5d799cbd293a259b315977eb676a04 Mon Sep 17 00:00:00 2001 From: Fabian Klopfer Date: Tue, 26 Mar 2024 14:25:10 +0100 Subject: [PATCH 04/21] Add skipUnless annotation to torchio transform wrapper test Signed-off-by: Fabian Klopfer --- tests/test_torchio.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/test_torchio.py b/tests/test_torchio.py index fee7e36b85..00fff616ee 100644 --- a/tests/test_torchio.py +++ b/tests/test_torchio.py @@ -12,13 +12,16 @@ from __future__ import annotations import unittest +from unittest import skipUnless import numpy as np import torch from parameterized import parameterized from monai.transforms import TorchIO -from monai.utils import set_determinism +from monai.utils import set_determinism, optional_import + +_, has_torchio = optional_import("torchio") TEST_DIMS = [3, 128, 160, 160] TESTS = [ @@ -38,6 +41,7 @@ ] +@skipUnless(has_torchio, "Requires torchio") class TestTorchIO(unittest.TestCase): @parameterized.expand(TESTS) From e19dc3c40cb68895e9e9ba9111f2a27669d987bb Mon Sep 17 00:00:00 2001 From: Fabian Klopfer Date: Tue, 26 Mar 2024 14:35:02 +0100 Subject: [PATCH 05/21] fixup imports Signed-off-by: Fabian Klopfer --- tests/test_torchio.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_torchio.py b/tests/test_torchio.py index 00fff616ee..1d9281b8db 100644 --- a/tests/test_torchio.py +++ b/tests/test_torchio.py @@ -19,7 +19,7 @@ from parameterized import parameterized from monai.transforms import TorchIO -from monai.utils import set_determinism, optional_import +from monai.utils import optional_import, set_determinism _, has_torchio = optional_import("torchio") From de491af6d4f15b552041769f00b8f45a65768476 Mon Sep 17 00:00:00 2001 From: Fabian Klopfer Date: Thu, 28 Mar 2024 13:16:33 +0100 Subject: [PATCH 06/21] add TorchIOd wrapper, add Transform and RandomizableTrait as base classes to TorchIO and Transform to TorchVision. document conversion for torchvision and remove conversion for torchio. Add dtypes to docstring for tio Signed-off-by: Fabian Klopfer --- monai/transforms/__init__.py | 3 ++ monai/transforms/utility/array.py | 20 ++++------- monai/transforms/utility/dictionary.py | 35 +++++++++++++++++++ tests/test_torchiod.py | 48 ++++++++++++++++++++++++++ 4 files changed, 92 insertions(+), 14 deletions(-) create mode 100644 tests/test_torchiod.py diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 43b3bd5028..ef76862617 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -628,6 +628,9 @@ ToPILd, ToPILD, ToPILDict, + TorchIOd, + TorchIOD, + TorchIODict, TorchVisiond, TorchVisionD, TorchVisionDict, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 0e3b01f512..d0411ed1b9 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1128,12 +1128,10 @@ def __call__( return concatenate((img, points_image), axis=0) -class TorchVision: +class TorchVision(Transform): """ This is a wrapper transform for PyTorch TorchVision transform based on the specified transform name and args. - As most of the TorchVision transforms only work for PIL image and PyTorch Tensor, this transform expects input - data to be PyTorch Tensor, users can easily call `ToTensor` transform to convert a Numpy array to Tensor. - + Data is converted to a torch.tensor before applying the transform and then converted back to the original data type. """ backend = [TransformBackends.TORCH] @@ -1164,12 +1162,9 @@ def __call__(self, img: NdarrayOrTensor): return out -class TorchIO: +class TorchIO(Transform, RandomizableTrait): """ This is a wrapper transform for TorchIO transforms based on the specified transform name and args. - As most of the TorchIO transforms only work for PyTorch Tensor, this transform expects input - data to be PyTorch Tensor, users can easily call `ToTensor` transform to convert a Numpy array to Tensor. - """ backend = [TransformBackends.TORCH] @@ -1190,14 +1185,11 @@ def __init__(self, name: str, *args, **kwargs) -> None: def __call__(self, img: NdarrayOrTensor): """ Args: - img: PyTorch Tensor data for the TorchIO transform. + img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image, + or dict containing 4D tensors as values """ - img_t, *_ = convert_data_type(img, torch.Tensor) - - out = self.trans(img_t) - out, *_ = convert_to_dst_type(src=out, dst=img) - return out + return self.trans(img) class MapLabelValue: diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 7e3a7b0454..10b29e1ae4 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -59,6 +59,7 @@ ToDevice, ToNumpy, ToPIL, + TorchIO, TorchVision, ToTensor, Transpose, @@ -171,6 +172,9 @@ "ToTensorD", "ToTensorDict", "ToTensord", + "TorchIOD", + "TorchIODict", + "TorchIOd", "TorchVisionD", "TorchVisionDict", "TorchVisiond", @@ -1419,6 +1423,36 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N return d +class TorchIOd(MapTransform, RandomizableTrait): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for transforms. + All transforms in TorchIO can be applied randomly with probability p by specifying the `p=` argument. + """ + + backend = TorchIO.backend + + def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + name: The transform name in TorchIO package. + allow_missing_keys: don't raise exception if key is missing. + args: parameters for the TorchIO transform. + kwargs: parameters for the TorchIO transform. + + """ + super().__init__(keys, allow_missing_keys) + self.name = name + self.trans = TorchIO(name, *args, **kwargs) + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.trans(d[key]) + return d + + class MapLabelValued(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.MapLabelValue`. @@ -1771,6 +1805,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N ) AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld TorchVisionD = TorchVisionDict = TorchVisiond +TorchIOD = TorchIODict = TorchIOd RandTorchVisionD = RandTorchVisionDict = RandTorchVisiond RandLambdaD = RandLambdaDict = RandLambdad MapLabelValueD = MapLabelValueDict = MapLabelValued diff --git a/tests/test_torchiod.py b/tests/test_torchiod.py new file mode 100644 index 0000000000..13341da9ac --- /dev/null +++ b/tests/test_torchiod.py @@ -0,0 +1,48 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import torch +from parameterized import parameterized + +from monai.transforms import TorchIOd +from monai.utils import optional_import, set_determinism +from tests.utils import assert_allclose + +_, has_torchio = optional_import("torchio") + +TEST_DIMS = [3, 128, 160, 160] +TEST_TENSOR = torch.rand(TEST_DIMS) +TESTS = [ + [ + {"keys": "img", "name": "RescaleIntensity", "out_min_max": (0, 42)}, + {"img": TEST_TENSOR}, + ((TEST_TENSOR - TEST_TENSOR.min()) / (TEST_TENSOR.max() - TEST_TENSOR.min())) * 42, + ] +] + + +@skipUnless(has_torchio, "Requires torchio") +class TestTorchVisiond(unittest.TestCase): + + @parameterized.expand(TESTS) + def test_value(self, input_param, input_data, expected_value): + set_determinism(seed=0) + result = TorchIOd(**input_param)(input_data) + assert_allclose(result["img"], expected_value, atol=1e-4, rtol=1e-4, type_test=False) + + +if __name__ == "__main__": + unittest.main() From 50cd7ecc4249fcabd18f2ed6180058622021f9d9 Mon Sep 17 00:00:00 2001 From: Fabian Klopfer Date: Thu, 28 Mar 2024 15:10:23 +0100 Subject: [PATCH 07/21] add TorchIO and TorchIOd to docs Signed-off-by: Fabian Klopfer --- docs/source/transforms.rst | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index bd3feb3497..6d2c1d8b21 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -1150,6 +1150,13 @@ Utility :members: :special-members: __call__ +`TorchIO` +""""""""""""" +.. autoclass:: TorchIO + :members: + :special-members: __call__ + + `MapLabelValue` """"""""""""""" .. autoclass:: MapLabelValue @@ -2193,6 +2200,12 @@ Utility (Dict) :members: :special-members: __call__ +`TorchIOd` +"""""""""""""" +.. autoclass:: TorchIOd + :members: + :special-members: __call__ + `MapLabelValued` """""""""""""""" .. autoclass:: MapLabelValued From 701c83ea47ad108ebbbebd55e6bd9ae6cea93148 Mon Sep 17 00:00:00 2001 From: Fabian Klopfer Date: Thu, 28 Mar 2024 17:34:02 +0100 Subject: [PATCH 08/21] add flag to apply the same random transform to all elements in the dict. add tags file generated by ctags to gitignore, make `prob` and `p` kwargs mutually exclusive for TorchIO transforms and initialize `p` with `prob` if the latter was provided. Add test cases for applying the same transform Signed-off-by: Fabian Klopfer --- .gitignore | 3 +++ monai/transforms/utility/array.py | 11 +++++++++++ monai/transforms/utility/dictionary.py | 23 +++++++++++++++++++--- tests/test_torchiod.py | 27 +++++++++++++++++++++++--- 4 files changed, 58 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index 437677d2bb..76c6ab0d12 100644 --- a/.gitignore +++ b/.gitignore @@ -149,6 +149,9 @@ tests/testing_data/nrrd_example.nrrd # clang format tool .clang-format-bin/ +# ctags +tags + # VSCode .vscode/ *.zip diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index d0411ed1b9..85895f6daf 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1165,6 +1165,7 @@ def __call__(self, img: NdarrayOrTensor): class TorchIO(Transform, RandomizableTrait): """ This is a wrapper transform for TorchIO transforms based on the specified transform name and args. + See https://torchio.readthedocs.io/transforms/transforms.html for more details. """ backend = [TransformBackends.TORCH] @@ -1176,10 +1177,20 @@ def __init__(self, name: str, *args, **kwargs) -> None: args: parameters for the TorchIO transform. kwargs: parameters for the TorchIO transform. + Note: + The `p=` kwarg of TorchIO transforms control set the probability with which the transform is applied. + You can specify the probability of applying the transform by passing either `prob` ot `p` in kwargs but' + ' not both. """ super().__init__() self.name = name transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name) + + if "prob" in kwargs: + if "p" in kwargs: + raise ValueError("Cannot specify both 'prob' and 'p' in kwargs.") + kwargs["p"] = kwargs.pop("prob") + self.trans = transform(*args, **kwargs) def __call__(self, img: NdarrayOrTensor): diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 10b29e1ae4..92cfba3cdb 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1431,12 +1431,21 @@ class TorchIOd(MapTransform, RandomizableTrait): backend = TorchIO.backend - def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None: + def __init__( + self, + keys: KeysCollection, + name: str, + apply_same_transform: bool = False, + allow_missing_keys: bool = False, + *args, + **kwargs, + ) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` name: The transform name in TorchIO package. + apply_same_transform: whether to apply the same transform for all the items specified by `keys`. allow_missing_keys: don't raise exception if key is missing. args: parameters for the TorchIO transform. kwargs: parameters for the TorchIO transform. @@ -1444,12 +1453,20 @@ def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = F """ super().__init__(keys, allow_missing_keys) self.name = name + self.apply_same_transform = apply_same_transform + + if self.apply_same_transform: + kwargs["include"] = self.keys + self.trans = TorchIO(name, *args, **kwargs) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: d = dict(data) - for key in self.key_iterator(d): - d[key] = self.trans(d[key]) + if self.apply_same_transform: + d = self.trans(d) + else: + for key in self.key_iterator(d): + d[key] = self.trans(d[key]) return d diff --git a/tests/test_torchiod.py b/tests/test_torchiod.py index 13341da9ac..bf4ea9ec5d 100644 --- a/tests/test_torchiod.py +++ b/tests/test_torchiod.py @@ -14,6 +14,7 @@ import unittest from unittest import skipUnless +import numpy as np import torch from parameterized import parameterized @@ -25,24 +26,44 @@ TEST_DIMS = [3, 128, 160, 160] TEST_TENSOR = torch.rand(TEST_DIMS) -TESTS = [ +TEST1 = [ [ {"keys": "img", "name": "RescaleIntensity", "out_min_max": (0, 42)}, {"img": TEST_TENSOR}, ((TEST_TENSOR - TEST_TENSOR.min()) / (TEST_TENSOR.max() - TEST_TENSOR.min())) * 42, ] ] +TEST2 = [ + [ + {"keys": ["img1", "img2"], "name": "RandomAffine", "apply_same_transform": True}, + {"img1": TEST_TENSOR, "img2": TEST_TENSOR}, + ] +] +TEST3 = [[{"keys": ["img1", "img2"], "name": "RandomAffine"}, {"img1": TEST_TENSOR, "img2": TEST_TENSOR}]] @skipUnless(has_torchio, "Requires torchio") -class TestTorchVisiond(unittest.TestCase): +class TestTorchIOd(unittest.TestCase): - @parameterized.expand(TESTS) + @parameterized.expand(TEST1) def test_value(self, input_param, input_data, expected_value): set_determinism(seed=0) result = TorchIOd(**input_param)(input_data) assert_allclose(result["img"], expected_value, atol=1e-4, rtol=1e-4, type_test=False) + @parameterized.expand(TEST2) + def test_common_random_transform(self, input_param, input_data): + set_determinism(seed=0) + result = TorchIOd(**input_param)(input_data) + assert_allclose(result["img1"], result["img2"], atol=1e-4, rtol=1e-4, type_test=False) + + @parameterized.expand(TEST3) + def test_different_random_transform(self, input_param, input_data): + set_determinism(seed=0) + result = TorchIOd(**input_param)(input_data) + equal = np.allclose(result["img1"], result["img2"], atol=1e-4, rtol=1e-4) + self.assertFalse(equal) + if __name__ == "__main__": unittest.main() From 74b6b41e387f98a18dc80ebf83f5092aec7fdee2 Mon Sep 17 00:00:00 2001 From: Fabian Klopfer Date: Mon, 10 Jun 2024 19:20:10 +0200 Subject: [PATCH 09/21] Remove trailing quotes docs/source/transforms.rst Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Fabian Klopfer --- docs/source/transforms.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 6d2c1d8b21..2f0e9895a3 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -1151,7 +1151,7 @@ Utility :special-members: __call__ `TorchIO` -""""""""""""" +""""""""" .. autoclass:: TorchIO :members: :special-members: __call__ From b8966043cfc21639554ddfb2570df50f926b6ac1 Mon Sep 17 00:00:00 2001 From: Fabian Klopfer Date: Mon, 10 Jun 2024 19:20:21 +0200 Subject: [PATCH 10/21] Remove trailing quotes docs/source/transforms.rst Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Signed-off-by: Fabian Klopfer --- docs/source/transforms.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 2f0e9895a3..21a7e5e44e 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -2201,7 +2201,7 @@ Utility (Dict) :special-members: __call__ `TorchIOd` -"""""""""""""" +"""""""""" .. autoclass:: TorchIOd :members: :special-members: __call__ From 09d1099dd0d68baa57aa58d7415f06434f1b2ff5 Mon Sep 17 00:00:00 2001 From: Fabian Klopfer Date: Mon, 18 Nov 2024 20:08:50 +0100 Subject: [PATCH 11/21] TorchIO, RandTorchIO, TorchIOd and RandTorchIOd; add RandTorchVision as well Signed-off-by: Fabian Klopfer --- monai/transforms/utility/array.py | 80 ++++++++++++++++++++++---- monai/transforms/utility/dictionary.py | 60 ++++++++++++++----- 2 files changed, 113 insertions(+), 27 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 9d67e69033..28b5f11142 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -105,6 +105,8 @@ "ToDevice", "CuCIM", "RandCuCIM", + "RandTorchIO", + "RandTorchVision", "ToCupy", "ImageFilter", "RandImageFilter", @@ -1139,7 +1141,7 @@ def __call__( class TorchVision(Transform): """ - This is a wrapper transform for PyTorch TorchVision transform based on the specified transform name and args. + This is a wrapper transform for PyTorch TorchVision non-randomized transform based on the specified transform name and args. Data is converted to a torch.tensor before applying the transform and then converted back to the original data type. """ @@ -1171,9 +1173,43 @@ def __call__(self, img: NdarrayOrTensor): return out -class TorchIO(Transform, RandomizableTrait): +class RandTorchVision(Transform, RandomizableTrait): """ - This is a wrapper transform for TorchIO transforms based on the specified transform name and args. + This is a wrapper transform for PyTorch TorchVision randomized transform based on the specified transform name and args. + Data is converted to a torch.tensor before applying the transform and then converted back to the original data type. + """ + + backend = [TransformBackends.TORCH] + + def __init__(self, name: str, *args, **kwargs) -> None: + """ + Args: + name: The transform name in TorchVision package. + args: parameters for the TorchVision transform. + kwargs: parameters for the TorchVision transform. + + """ + super().__init__() + self.name = name + transform, _ = optional_import("torchvision.transforms", "0.8.0", min_version, name=name) + self.trans = transform(*args, **kwargs) + + def __call__(self, img: NdarrayOrTensor): + """ + Args: + img: PyTorch Tensor data for the TorchVision transform. + + """ + img_t, *_ = convert_data_type(img, torch.Tensor) + + out = self.trans(img_t) + out, *_ = convert_to_dst_type(src=out, dst=img) + return out + + +class TorchIO(Transform): + """ + This is a wrapper for TorchIO non-randomized transforms based on the specified transform name and args. See https://torchio.readthedocs.io/transforms/transforms.html for more details. """ @@ -1185,21 +1221,41 @@ def __init__(self, name: str, *args, **kwargs) -> None: name: The transform name in TorchIO package. args: parameters for the TorchIO transform. kwargs: parameters for the TorchIO transform. - - Note: - The `p=` kwarg of TorchIO transforms control set the probability with which the transform is applied. - You can specify the probability of applying the transform by passing either `prob` ot `p` in kwargs but' - ' not both. """ super().__init__() self.name = name transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name) + self.trans = transform(*args, **kwargs) - if "prob" in kwargs: - if "p" in kwargs: - raise ValueError("Cannot specify both 'prob' and 'p' in kwargs.") - kwargs["p"] = kwargs.pop("prob") + def __call__(self, img: NdarrayOrTensor): + """ + Args: + img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image, + or dict containing 4D tensors as values + """ + return self.trans(img) + +class RandTorchIO(Transform, RandomizableTrait): + """ + This is a wrapper for TorchIO randomized transforms based on the specified transform name and args. + See https://torchio.readthedocs.io/transforms/transforms.html for more details. + Use this wrapper for all TorchIO transform inheriting from RandomTransform: + https://torchio.readthedocs.io/transforms/augmentation.html#randomtransform + """ + + backend = [TransformBackends.TORCH] + + def __init__(self, name: str, *args, **kwargs) -> None: + """ + Args: + name: The transform name in TorchIO package. + args: parameters for the TorchIO transform. + kwargs: parameters for the TorchIO transform. + """ + super().__init__() + self.name = name + transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name) self.trans = transform(*args, **kwargs) def __call__(self, img: NdarrayOrTensor): diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index f29119d348..14409fc0e6 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -134,6 +134,9 @@ "RandCuCIMD", "RandCuCIMDict", "RandImageFilterd", + "RandTorchIOd", + "RandTorchIOD", + "RandTorchIODict", "RandLambdaD", "RandLambdaDict", "RandLambdad", @@ -1449,10 +1452,10 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N return d -class TorchIOd(MapTransform, RandomizableTrait): +class TorchIOd(MapTransform): """ - Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for transforms. - All transforms in TorchIO can be applied randomly with probability p by specifying the `p=` argument. + Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for non-randomized transforms. + For randomized transforms of TorchIO use :py:class:`monai.transforms.RandTorchIOd`. """ backend = TorchIO.backend @@ -1461,7 +1464,6 @@ def __init__( self, keys: KeysCollection, name: str, - apply_same_transform: bool = False, allow_missing_keys: bool = False, *args, **kwargs, @@ -1479,21 +1481,49 @@ def __init__( """ super().__init__(keys, allow_missing_keys) self.name = name - self.apply_same_transform = apply_same_transform + kwargs["include"] = self.keys + + self.trans = TorchIO(name, *args, **kwargs) + + def __call__(self, data: Mapping[NdarrayOrTensor]) -> dict[NdarrayOrTensor]: + return self.trans(dict(data)) - if self.apply_same_transform: - kwargs["include"] = self.keys + +class RandTorchIOd(MapTransform, RandomizableTrait): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for randomized transforms. + For non-randomized transforms of TorchIO use :py:class:`monai.transforms.TorchIOd`. + """ + + backend = TorchIO.backend + + def __init__( + self, + keys: KeysCollection, + name: str, + allow_missing_keys: bool = False, + *args, + **kwargs, + ) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + name: The transform name in TorchIO package. + apply_same_transform: whether to apply the same transform for all the items specified by `keys`. + allow_missing_keys: don't raise exception if key is missing. + args: parameters for the TorchIO transform. + kwargs: parameters for the TorchIO transform. + + """ + super().__init__(keys, allow_missing_keys) + self.name = name + kwargs["include"] = self.keys self.trans = TorchIO(name, *args, **kwargs) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: - d = dict(data) - if self.apply_same_transform: - d = self.trans(d) - else: - for key in self.key_iterator(d): - d[key] = self.trans(d[key]) - return d + def __call__(self, data: Mapping[NdarrayOrTensor]) -> dict[NdarrayOrTensor]: + return self.trans(dict(data)) class MapLabelValued(MapTransform): From 1c9334cc890e813f9240d9f249cb21eb2a86c393 Mon Sep 17 00:00:00 2001 From: Fabian Klopfer Date: Mon, 18 Nov 2024 20:17:33 +0100 Subject: [PATCH 12/21] Fixup alias for RandTorchIOd Signed-off-by: Fabian Klopfer --- monai/transforms/utility/dictionary.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 14409fc0e6..b291206d1b 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1952,9 +1952,10 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch ConvertToMultiChannelBasedOnBratsClassesd ) AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld -TorchVisionD = TorchVisionDict = TorchVisiond TorchIOD = TorchIODict = TorchIOd +TorchVisionD = TorchVisionDict = TorchVisiond RandTorchVisionD = RandTorchVisionDict = RandTorchVisiond +RandTorchIOd = RandTorchIODict = RandTorchIOD RandLambdaD = RandLambdaDict = RandLambdad MapLabelValueD = MapLabelValueDict = MapLabelValued IntensityStatsD = IntensityStatsDict = IntensityStatsd From 63d7579e84d6365ca49e63bf501f9c6c6d0e761d Mon Sep 17 00:00:00 2001 From: Fabian Klopfer Date: Mon, 18 Nov 2024 20:08:50 +0100 Subject: [PATCH 13/21] TorchIO, RandTorchIO, TorchIOd and RandTorchIOd; add RandTorchVision as well Signed-off-by: Fabian Klopfer --- monai/transforms/utility/array.py | 80 ++++++++++++++++++++++---- monai/transforms/utility/dictionary.py | 63 ++++++++++++++------ 2 files changed, 115 insertions(+), 28 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 9d67e69033..28b5f11142 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -105,6 +105,8 @@ "ToDevice", "CuCIM", "RandCuCIM", + "RandTorchIO", + "RandTorchVision", "ToCupy", "ImageFilter", "RandImageFilter", @@ -1139,7 +1141,7 @@ def __call__( class TorchVision(Transform): """ - This is a wrapper transform for PyTorch TorchVision transform based on the specified transform name and args. + This is a wrapper transform for PyTorch TorchVision non-randomized transform based on the specified transform name and args. Data is converted to a torch.tensor before applying the transform and then converted back to the original data type. """ @@ -1171,9 +1173,43 @@ def __call__(self, img: NdarrayOrTensor): return out -class TorchIO(Transform, RandomizableTrait): +class RandTorchVision(Transform, RandomizableTrait): """ - This is a wrapper transform for TorchIO transforms based on the specified transform name and args. + This is a wrapper transform for PyTorch TorchVision randomized transform based on the specified transform name and args. + Data is converted to a torch.tensor before applying the transform and then converted back to the original data type. + """ + + backend = [TransformBackends.TORCH] + + def __init__(self, name: str, *args, **kwargs) -> None: + """ + Args: + name: The transform name in TorchVision package. + args: parameters for the TorchVision transform. + kwargs: parameters for the TorchVision transform. + + """ + super().__init__() + self.name = name + transform, _ = optional_import("torchvision.transforms", "0.8.0", min_version, name=name) + self.trans = transform(*args, **kwargs) + + def __call__(self, img: NdarrayOrTensor): + """ + Args: + img: PyTorch Tensor data for the TorchVision transform. + + """ + img_t, *_ = convert_data_type(img, torch.Tensor) + + out = self.trans(img_t) + out, *_ = convert_to_dst_type(src=out, dst=img) + return out + + +class TorchIO(Transform): + """ + This is a wrapper for TorchIO non-randomized transforms based on the specified transform name and args. See https://torchio.readthedocs.io/transforms/transforms.html for more details. """ @@ -1185,21 +1221,41 @@ def __init__(self, name: str, *args, **kwargs) -> None: name: The transform name in TorchIO package. args: parameters for the TorchIO transform. kwargs: parameters for the TorchIO transform. - - Note: - The `p=` kwarg of TorchIO transforms control set the probability with which the transform is applied. - You can specify the probability of applying the transform by passing either `prob` ot `p` in kwargs but' - ' not both. """ super().__init__() self.name = name transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name) + self.trans = transform(*args, **kwargs) - if "prob" in kwargs: - if "p" in kwargs: - raise ValueError("Cannot specify both 'prob' and 'p' in kwargs.") - kwargs["p"] = kwargs.pop("prob") + def __call__(self, img: NdarrayOrTensor): + """ + Args: + img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image, + or dict containing 4D tensors as values + """ + return self.trans(img) + +class RandTorchIO(Transform, RandomizableTrait): + """ + This is a wrapper for TorchIO randomized transforms based on the specified transform name and args. + See https://torchio.readthedocs.io/transforms/transforms.html for more details. + Use this wrapper for all TorchIO transform inheriting from RandomTransform: + https://torchio.readthedocs.io/transforms/augmentation.html#randomtransform + """ + + backend = [TransformBackends.TORCH] + + def __init__(self, name: str, *args, **kwargs) -> None: + """ + Args: + name: The transform name in TorchIO package. + args: parameters for the TorchIO transform. + kwargs: parameters for the TorchIO transform. + """ + super().__init__() + self.name = name + transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name) self.trans = transform(*args, **kwargs) def __call__(self, img: NdarrayOrTensor): diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index f29119d348..b6b81700b1 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -137,6 +137,9 @@ "RandLambdaD", "RandLambdaDict", "RandLambdad", + "RandTorchIOd", + "RandTorchIOD", + "RandTorchIODict", "RandTorchVisionD", "RandTorchVisionDict", "RandTorchVisiond", @@ -1449,10 +1452,10 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N return d -class TorchIOd(MapTransform, RandomizableTrait): +class TorchIOd(MapTransform): """ - Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for transforms. - All transforms in TorchIO can be applied randomly with probability p by specifying the `p=` argument. + Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for non-randomized transforms. + For randomized transforms of TorchIO use :py:class:`monai.transforms.RandTorchIOd`. """ backend = TorchIO.backend @@ -1461,7 +1464,6 @@ def __init__( self, keys: KeysCollection, name: str, - apply_same_transform: bool = False, allow_missing_keys: bool = False, *args, **kwargs, @@ -1479,21 +1481,49 @@ def __init__( """ super().__init__(keys, allow_missing_keys) self.name = name - self.apply_same_transform = apply_same_transform + kwargs["include"] = self.keys - if self.apply_same_transform: - kwargs["include"] = self.keys + self.trans = TorchIO(name, *args, **kwargs) + + def __call__(self, data: Mapping[NdarrayOrTensor]) -> dict[NdarrayOrTensor]: + return self.trans(dict(data)) + + +class RandTorchIOd(MapTransform, RandomizableTrait): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.TorchIO` for randomized transforms. + For non-randomized transforms of TorchIO use :py:class:`monai.transforms.TorchIOd`. + """ + + backend = TorchIO.backend + + def __init__( + self, + keys: KeysCollection, + name: str, + allow_missing_keys: bool = False, + *args, + **kwargs, + ) -> None: + """ + Args: + keys: keys of the corresponding items to be transformed. + See also: :py:class:`monai.transforms.compose.MapTransform` + name: The transform name in TorchIO package. + apply_same_transform: whether to apply the same transform for all the items specified by `keys`. + allow_missing_keys: don't raise exception if key is missing. + args: parameters for the TorchIO transform. + kwargs: parameters for the TorchIO transform. + + """ + super().__init__(keys, allow_missing_keys) + self.name = name + kwargs["include"] = self.keys self.trans = TorchIO(name, *args, **kwargs) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: - d = dict(data) - if self.apply_same_transform: - d = self.trans(d) - else: - for key in self.key_iterator(d): - d[key] = self.trans(d[key]) - return d + def __call__(self, data: Mapping[NdarrayOrTensor]) -> dict[NdarrayOrTensor]: + return self.trans(dict(data)) class MapLabelValued(MapTransform): @@ -1922,9 +1952,10 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch ConvertToMultiChannelBasedOnBratsClassesd ) AddExtremePointsChannelD = AddExtremePointsChannelDict = AddExtremePointsChanneld -TorchVisionD = TorchVisionDict = TorchVisiond TorchIOD = TorchIODict = TorchIOd +TorchVisionD = TorchVisionDict = TorchVisiond RandTorchVisionD = RandTorchVisionDict = RandTorchVisiond +RandTorchIOD = RandTorchIODict = RandTorchIOd RandLambdaD = RandLambdaDict = RandLambdad MapLabelValueD = MapLabelValueDict = MapLabelValued IntensityStatsD = IntensityStatsDict = IntensityStatsd From d27bb78ccffdff26d70078509931224cc8fa445f Mon Sep 17 00:00:00 2001 From: Fabian Klopfer Date: Mon, 18 Nov 2024 20:32:05 +0100 Subject: [PATCH 14/21] remove duplicate export Signed-off-by: Fabian Klopfer --- monai/transforms/utility/dictionary.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index c7188ed582..b6b81700b1 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -134,9 +134,6 @@ "RandCuCIMD", "RandCuCIMDict", "RandImageFilterd", - "RandTorchIOd", - "RandTorchIOD", - "RandTorchIODict", "RandLambdaD", "RandLambdaDict", "RandLambdad", From 99bc99383f3da51805262db610db748555c0639b Mon Sep 17 00:00:00 2001 From: Fabian Klopfer Date: Mon, 18 Nov 2024 21:03:49 +0100 Subject: [PATCH 15/21] fix formatting Signed-off-by: Fabian Klopfer --- monai/transforms/utility/array.py | 1 + monai/transforms/utility/dictionary.py | 18 ++---------------- 2 files changed, 3 insertions(+), 16 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 28b5f11142..ea33869fd9 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -1236,6 +1236,7 @@ def __call__(self, img: NdarrayOrTensor): """ return self.trans(img) + class RandTorchIO(Transform, RandomizableTrait): """ This is a wrapper for TorchIO randomized transforms based on the specified transform name and args. diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index b6b81700b1..173a832a49 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1460,14 +1460,7 @@ class TorchIOd(MapTransform): backend = TorchIO.backend - def __init__( - self, - keys: KeysCollection, - name: str, - allow_missing_keys: bool = False, - *args, - **kwargs, - ) -> None: + def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None: """ Args: keys: keys of the corresponding items to be transformed. @@ -1497,14 +1490,7 @@ class RandTorchIOd(MapTransform, RandomizableTrait): backend = TorchIO.backend - def __init__( - self, - keys: KeysCollection, - name: str, - allow_missing_keys: bool = False, - *args, - **kwargs, - ) -> None: + def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None: """ Args: keys: keys of the corresponding items to be transformed. From 98e8275d30daaaa9af229c67780168bc5d5cb3db Mon Sep 17 00:00:00 2001 From: Fabian Klopfer Date: Mon, 18 Nov 2024 22:21:36 +0100 Subject: [PATCH 16/21] remove apply same flag from test and remove redundant test, fix type annotations Signed-off-by: Fabian Klopfer --- monai/transforms/utility/dictionary.py | 4 ++-- tests/test_torchiod.py | 18 ++---------------- 2 files changed, 4 insertions(+), 18 deletions(-) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 173a832a49..57105fbe42 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1478,7 +1478,7 @@ def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = F self.trans = TorchIO(name, *args, **kwargs) - def __call__(self, data: Mapping[NdarrayOrTensor]) -> dict[NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: return self.trans(dict(data)) @@ -1508,7 +1508,7 @@ def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = F self.trans = TorchIO(name, *args, **kwargs) - def __call__(self, data: Mapping[NdarrayOrTensor]) -> dict[NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: return self.trans(dict(data)) diff --git a/tests/test_torchiod.py b/tests/test_torchiod.py index bf4ea9ec5d..61a16c0c49 100644 --- a/tests/test_torchiod.py +++ b/tests/test_torchiod.py @@ -14,7 +14,6 @@ import unittest from unittest import skipUnless -import numpy as np import torch from parameterized import parameterized @@ -33,13 +32,7 @@ ((TEST_TENSOR - TEST_TENSOR.min()) / (TEST_TENSOR.max() - TEST_TENSOR.min())) * 42, ] ] -TEST2 = [ - [ - {"keys": ["img1", "img2"], "name": "RandomAffine", "apply_same_transform": True}, - {"img1": TEST_TENSOR, "img2": TEST_TENSOR}, - ] -] -TEST3 = [[{"keys": ["img1", "img2"], "name": "RandomAffine"}, {"img1": TEST_TENSOR, "img2": TEST_TENSOR}]] +TEST2 = [[{"keys": ["img1", "img2"], "name": "RandomAffine"}, {"img1": TEST_TENSOR, "img2": TEST_TENSOR}]] @skipUnless(has_torchio, "Requires torchio") @@ -52,18 +45,11 @@ def test_value(self, input_param, input_data, expected_value): assert_allclose(result["img"], expected_value, atol=1e-4, rtol=1e-4, type_test=False) @parameterized.expand(TEST2) - def test_common_random_transform(self, input_param, input_data): + def test_random_transform(self, input_param, input_data): set_determinism(seed=0) result = TorchIOd(**input_param)(input_data) assert_allclose(result["img1"], result["img2"], atol=1e-4, rtol=1e-4, type_test=False) - @parameterized.expand(TEST3) - def test_different_random_transform(self, input_param, input_data): - set_determinism(seed=0) - result = TorchIOd(**input_param)(input_data) - equal = np.allclose(result["img1"], result["img2"], atol=1e-4, rtol=1e-4) - self.assertFalse(equal) - if __name__ == "__main__": unittest.main() From f05aab5b6e7af592e4e27428a90ca77a4908216d Mon Sep 17 00:00:00 2001 From: Fabian Klopfer Date: Mon, 18 Nov 2024 23:08:06 +0100 Subject: [PATCH 17/21] fixup Signed-off-by: Fabian Klopfer --- monai/transforms/utility/dictionary.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 57105fbe42..7572d6b973 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1479,7 +1479,10 @@ def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = F self.trans = TorchIO(name, *args, **kwargs) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: - return self.trans(dict(data)) + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.trans(d[key]) + return d class RandTorchIOd(MapTransform, RandomizableTrait): @@ -1509,7 +1512,11 @@ def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = F self.trans = TorchIO(name, *args, **kwargs) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: - return self.trans(dict(data)) + d = dict(data) + for key in self.key_iterator(d): + d[key] = self.trans(d[key]) + return d + class MapLabelValued(MapTransform): From 27bd7feb405bdc630c24db2567d283715760474a Mon Sep 17 00:00:00 2001 From: Fabian Klopfer Date: Tue, 19 Nov 2024 19:15:52 +0100 Subject: [PATCH 18/21] Finally... Signed-off-by: Fabian Klopfer --- monai/transforms/__init__.py | 5 +++ monai/transforms/utility/array.py | 8 ++-- monai/transforms/utility/dictionary.py | 17 ++------ tests/test_rand_torchio.py | 54 ++++++++++++++++++++++++++ tests/test_rand_torchiod.py | 44 +++++++++++++++++++++ tests/test_torchio.py | 19 +-------- tests/test_torchiod.py | 16 ++------ 7 files changed, 117 insertions(+), 46 deletions(-) create mode 100644 tests/test_rand_torchio.py create mode 100644 tests/test_rand_torchiod.py diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 5065366ecf..d15042181b 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -531,6 +531,8 @@ RandIdentity, RandImageFilter, RandLambda, + RandTorchIO, + RandTorchVision, RemoveRepeatedChannel, RepeatChannel, SimulateDelay, @@ -621,6 +623,9 @@ RandLambdad, RandLambdaD, RandLambdaDict, + RandTorchIOd, + RandTorchIOD, + RandTorchIODict, RandTorchVisiond, RandTorchVisionD, RandTorchVisionDict, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index ea33869fd9..bcfdaa4fd2 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -18,10 +18,10 @@ import sys import time import warnings -from collections.abc import Mapping, Sequence +from collections.abc import Hashable, Mapping, Sequence from copy import deepcopy from functools import partial -from typing import Any, Callable +from typing import Any, Callable, Union import numpy as np import torch @@ -1227,7 +1227,7 @@ def __init__(self, name: str, *args, **kwargs) -> None: transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name) self.trans = transform(*args, **kwargs) - def __call__(self, img: NdarrayOrTensor): + def __call__(self, img: Union[NdarrayOrTensor, Mapping[Hashable, NdarrayOrTensor]]): """ Args: img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image, @@ -1259,7 +1259,7 @@ def __init__(self, name: str, *args, **kwargs) -> None: transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name) self.trans = transform(*args, **kwargs) - def __call__(self, img: NdarrayOrTensor): + def __call__(self, img: Union[NdarrayOrTensor, Mapping[Hashable, NdarrayOrTensor]]): """ Args: img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image, diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 7572d6b973..7dd2397a74 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -1466,7 +1466,6 @@ def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = F keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` name: The transform name in TorchIO package. - apply_same_transform: whether to apply the same transform for all the items specified by `keys`. allow_missing_keys: don't raise exception if key is missing. args: parameters for the TorchIO transform. kwargs: parameters for the TorchIO transform. @@ -1478,11 +1477,8 @@ def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = F self.trans = TorchIO(name, *args, **kwargs) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: - d = dict(data) - for key in self.key_iterator(d): - d[key] = self.trans(d[key]) - return d + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: + return dict(self.trans(data)) class RandTorchIOd(MapTransform, RandomizableTrait): @@ -1499,7 +1495,6 @@ def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = F keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` name: The transform name in TorchIO package. - apply_same_transform: whether to apply the same transform for all the items specified by `keys`. allow_missing_keys: don't raise exception if key is missing. args: parameters for the TorchIO transform. kwargs: parameters for the TorchIO transform. @@ -1511,12 +1506,8 @@ def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = F self.trans = TorchIO(name, *args, **kwargs) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: - d = dict(data) - for key in self.key_iterator(d): - d[key] = self.trans(d[key]) - return d - + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: + return dict(self.trans(data)) class MapLabelValued(MapTransform): diff --git a/tests/test_rand_torchio.py b/tests/test_rand_torchio.py new file mode 100644 index 0000000000..ab212d4a11 --- /dev/null +++ b/tests/test_rand_torchio.py @@ -0,0 +1,54 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms import RandTorchIO +from monai.utils import optional_import, set_determinism + +_, has_torchio = optional_import("torchio") + +TEST_DIMS = [3, 128, 160, 160] +TESTS = [ + [{"name": "RandomAffine"}, torch.rand(TEST_DIMS)], + [{"name": "RandomElasticDeformation"}, torch.rand(TEST_DIMS)], + [{"name": "RandomAnisotropy"}, torch.rand(TEST_DIMS)], + [{"name": "RandomMotion"}, torch.rand(TEST_DIMS)], + [{"name": "RandomGhosting"}, torch.rand(TEST_DIMS)], + [{"name": "RandomSpike"}, torch.rand(TEST_DIMS)], + [{"name": "RandomBiasField"}, torch.rand(TEST_DIMS)], + [{"name": "RandomBlur"}, torch.rand(TEST_DIMS)], + [{"name": "RandomNoise"}, torch.rand(TEST_DIMS)], + [{"name": "RandomSwap"}, torch.rand(TEST_DIMS)], + [{"name": "RandomGamma"}, torch.rand(TEST_DIMS)], +] + + +@skipUnless(has_torchio, "Requires torchio") +class TestRandTorchIO(unittest.TestCase): + + @parameterized.expand(TESTS) + def test_value(self, input_param, input_data): + set_determinism(seed=0) + result = RandTorchIO(**input_param)(input_data) + self.assertIsNotNone(result) + self.assertFalse(np.array_equal(result.numpy(), input_data.numpy()), f"{input_param} failed") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rand_torchiod.py b/tests/test_rand_torchiod.py new file mode 100644 index 0000000000..52bcf7c576 --- /dev/null +++ b/tests/test_rand_torchiod.py @@ -0,0 +1,44 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from unittest import skipUnless + +import numpy as np +import torch +from parameterized import parameterized + +from monai.transforms import RandTorchIOd +from monai.utils import optional_import, set_determinism +from tests.utils import assert_allclose + +_, has_torchio = optional_import("torchio") + +TEST_DIMS = [3, 128, 160, 160] +TEST_TENSOR = torch.rand(TEST_DIMS) +TEST_PARAMS = [[{"keys": ["img1", "img2"], "name": "RandomAffine"}, {"img1": TEST_TENSOR, "img2": TEST_TENSOR}]] + + +@skipUnless(has_torchio, "Requires torchio") +class TestRandTorchIOd(unittest.TestCase): + + @parameterized.expand(TEST_PARAMS) + def test_random_transform(self, input_param, input_data): + set_determinism(seed=0) + result = RandTorchIOd(**input_param)(input_data) + self.assertFalse(np.allclose(input_data["img1"], result["img1"], atol=1e-6, rtol=1e-6)) + assert_allclose(result["img1"], result["img2"], atol=1e-4, rtol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_torchio.py b/tests/test_torchio.py index 1d9281b8db..d2d598ca4c 100644 --- a/tests/test_torchio.py +++ b/tests/test_torchio.py @@ -19,26 +19,12 @@ from parameterized import parameterized from monai.transforms import TorchIO -from monai.utils import optional_import, set_determinism +from monai.utils import optional_import _, has_torchio = optional_import("torchio") TEST_DIMS = [3, 128, 160, 160] -TESTS = [ - [{"name": "RescaleIntensity"}, torch.rand(TEST_DIMS)], - [{"name": "ZNormalization"}, torch.rand(TEST_DIMS)], - [{"name": "RandomAffine"}, torch.rand(TEST_DIMS)], - [{"name": "RandomElasticDeformation"}, torch.rand(TEST_DIMS)], - [{"name": "RandomAnisotropy"}, torch.rand(TEST_DIMS)], - [{"name": "RandomMotion"}, torch.rand(TEST_DIMS)], - [{"name": "RandomGhosting"}, torch.rand(TEST_DIMS)], - [{"name": "RandomSpike"}, torch.rand(TEST_DIMS)], - [{"name": "RandomBiasField"}, torch.rand(TEST_DIMS)], - [{"name": "RandomBlur"}, torch.rand(TEST_DIMS)], - [{"name": "RandomNoise"}, torch.rand(TEST_DIMS)], - [{"name": "RandomSwap"}, torch.rand(TEST_DIMS)], - [{"name": "RandomGamma"}, torch.rand(TEST_DIMS)], -] +TESTS = [[{"name": "RescaleIntensity"}, torch.rand(TEST_DIMS)], [{"name": "ZNormalization"}, torch.rand(TEST_DIMS)]] @skipUnless(has_torchio, "Requires torchio") @@ -46,7 +32,6 @@ class TestTorchIO(unittest.TestCase): @parameterized.expand(TESTS) def test_value(self, input_param, input_data): - set_determinism(seed=0) result = TorchIO(**input_param)(input_data) self.assertIsNotNone(result) self.assertFalse(np.array_equal(result.numpy(), input_data.numpy()), f"{input_param} failed") diff --git a/tests/test_torchiod.py b/tests/test_torchiod.py index 61a16c0c49..892287461c 100644 --- a/tests/test_torchiod.py +++ b/tests/test_torchiod.py @@ -18,37 +18,29 @@ from parameterized import parameterized from monai.transforms import TorchIOd -from monai.utils import optional_import, set_determinism +from monai.utils import optional_import from tests.utils import assert_allclose _, has_torchio = optional_import("torchio") TEST_DIMS = [3, 128, 160, 160] TEST_TENSOR = torch.rand(TEST_DIMS) -TEST1 = [ +TEST_PARAMS = [ [ {"keys": "img", "name": "RescaleIntensity", "out_min_max": (0, 42)}, {"img": TEST_TENSOR}, ((TEST_TENSOR - TEST_TENSOR.min()) / (TEST_TENSOR.max() - TEST_TENSOR.min())) * 42, ] ] -TEST2 = [[{"keys": ["img1", "img2"], "name": "RandomAffine"}, {"img1": TEST_TENSOR, "img2": TEST_TENSOR}]] @skipUnless(has_torchio, "Requires torchio") class TestTorchIOd(unittest.TestCase): - @parameterized.expand(TEST1) + @parameterized.expand(TEST_PARAMS) def test_value(self, input_param, input_data, expected_value): - set_determinism(seed=0) result = TorchIOd(**input_param)(input_data) - assert_allclose(result["img"], expected_value, atol=1e-4, rtol=1e-4, type_test=False) - - @parameterized.expand(TEST2) - def test_random_transform(self, input_param, input_data): - set_determinism(seed=0) - result = TorchIOd(**input_param)(input_data) - assert_allclose(result["img1"], result["img2"], atol=1e-4, rtol=1e-4, type_test=False) + assert_allclose(result["img"], expected_value, atol=1e-4, rtol=1e-4) if __name__ == "__main__": From 133d391179e488d484f5a5cd39893a072e9af18e Mon Sep 17 00:00:00 2001 From: Fabian Klopfer Date: Fri, 22 Nov 2024 16:14:42 +0100 Subject: [PATCH 19/21] add docs Signed-off-by: Fabian Klopfer --- docs/source/transforms.rst | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 59a5ed9a26..d9825bd8b9 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -1186,6 +1186,11 @@ Utility :members: :special-members: __call__ +`RandTorchIO` +""""""""" +.. autoclass:: RandTorchIO + :members: + :special-members: __call__ `MapLabelValue` """"""""""""""" @@ -2266,6 +2271,12 @@ Utility (Dict) :members: :special-members: __call__ +`RandTorchIOd` +"""""""""" +.. autoclass:: RandTorchIOd + :members: + :special-members: __call__ + `MapLabelValued` """""""""""""""" .. autoclass:: MapLabelValued From a1046d9f5e6df950a13369c7057beec51ce8a189 Mon Sep 17 00:00:00 2001 From: Fabian Klopfer Date: Mon, 25 Nov 2024 21:46:16 +0100 Subject: [PATCH 20/21] correct indentation of docs Signed-off-by: Fabian Klopfer --- docs/source/transforms.rst | 4 ++-- monai/transforms/utility/array.py | 26 +++++++------------------- 2 files changed, 9 insertions(+), 21 deletions(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index d9825bd8b9..d2585daf63 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -1187,7 +1187,7 @@ Utility :special-members: __call__ `RandTorchIO` -""""""""" +""""""""""""" .. autoclass:: RandTorchIO :members: :special-members: __call__ @@ -2272,7 +2272,7 @@ Utility (Dict) :special-members: __call__ `RandTorchIOd` -"""""""""" +"""""""""""""" .. autoclass:: RandTorchIOd :members: :special-members: __call__ diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index bcfdaa4fd2..ce94c1d071 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -695,9 +695,7 @@ def __init__( _logger.setLevel(logging.INFO) if logging.root.getEffectiveLevel() > logging.INFO: # Avoid duplicate stream handlers to be added when multiple DataStats are used in a chain. - has_console_handler = any( - hasattr(h, "is_data_stats_handler") and h.is_data_stats_handler for h in _logger.handlers - ) + has_console_handler = any(hasattr(h, "is_data_stats_handler") and h.is_data_stats_handler for h in _logger.handlers) if not has_console_handler: # if the root log level is higher than INFO, set a separate stream handler to record console = logging.StreamHandler(sys.stdout) @@ -807,9 +805,7 @@ class Lambda(InvertibleTransform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__( - self, func: Callable | None = None, inv_func: Callable = no_collation, track_meta: bool = True - ) -> None: + def __init__(self, func: Callable | None = None, inv_func: Callable = no_collation, track_meta: bool = True) -> None: if func is not None and not callable(func): raise TypeError(f"func must be None or callable but is {type(func).__name__}.") self.func = func @@ -1043,9 +1039,7 @@ def __call__( if output_shape is None: output_shape = self.output_shape indices: list[NdarrayOrTensor] - indices = map_classes_to_indices( - label, self.num_classes, image, self.image_threshold, self.max_samples_per_class - ) + indices = map_classes_to_indices(label, self.num_classes, image, self.image_threshold, self.max_samples_per_class) if output_shape is not None: indices = [unravel_indices(cls_indices, output_shape) for cls_indices in indices] @@ -1242,7 +1236,7 @@ class RandTorchIO(Transform, RandomizableTrait): This is a wrapper for TorchIO randomized transforms based on the specified transform name and args. See https://torchio.readthedocs.io/transforms/transforms.html for more details. Use this wrapper for all TorchIO transform inheriting from RandomTransform: - https://torchio.readthedocs.io/transforms/augmentation.html#randomtransform + https://torchio.readthedocs.io/transforms/augmentation.html#randomtransform """ backend = [TransformBackends.TORCH] @@ -1657,9 +1651,7 @@ class ImageFilter(Transform): """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - supported_filters = sorted( - ["mean", "laplace", "elliptical", "sobel", "sharpen", "median", "gauss", "savitzky_golay"] - ) + supported_filters = sorted(["mean", "laplace", "elliptical", "sobel", "sharpen", "median", "gauss", "savitzky_golay"]) def __init__(self, filter: str | NdarrayOrTensor | nn.Module, filter_size: int | None = None, **kwargs) -> None: self._check_filter_format(filter, filter_size) @@ -1800,9 +1792,7 @@ class RandImageFilter(RandomizableTransform): backend = ImageFilter.backend - def __init__( - self, filter: str | NdarrayOrTensor, filter_size: int | None = None, prob: float = 0.1, **kwargs - ) -> None: + def __init__(self, filter: str | NdarrayOrTensor, filter_size: int | None = None, prob: float = 0.1, **kwargs) -> None: super().__init__(prob) self.filter = ImageFilter(filter, filter_size, **kwargs) @@ -1892,9 +1882,7 @@ def _compute_final_affine(self, affine: torch.Tensor, applied_affine: torch.Tens return affine - def transform_coordinates( - self, data: torch.Tensor, affine: torch.Tensor | None = None - ) -> tuple[torch.Tensor, dict]: + def transform_coordinates(self, data: torch.Tensor, affine: torch.Tensor | None = None) -> tuple[torch.Tensor, dict]: """ Transform coordinates using an affine transformation matrix. From 2a7842d9e7c9b500641b64bfc3f9f37cd09b6a1f Mon Sep 17 00:00:00 2001 From: Fabian Klopfer Date: Tue, 26 Nov 2024 18:22:03 +0100 Subject: [PATCH 21/21] apply autofix and validate that docs still build Signed-off-by: Fabian Klopfer --- monai/transforms/utility/array.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index ce94c1d071..84422a9ee5 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -695,7 +695,9 @@ def __init__( _logger.setLevel(logging.INFO) if logging.root.getEffectiveLevel() > logging.INFO: # Avoid duplicate stream handlers to be added when multiple DataStats are used in a chain. - has_console_handler = any(hasattr(h, "is_data_stats_handler") and h.is_data_stats_handler for h in _logger.handlers) + has_console_handler = any( + hasattr(h, "is_data_stats_handler") and h.is_data_stats_handler for h in _logger.handlers + ) if not has_console_handler: # if the root log level is higher than INFO, set a separate stream handler to record console = logging.StreamHandler(sys.stdout) @@ -805,7 +807,9 @@ class Lambda(InvertibleTransform): backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, func: Callable | None = None, inv_func: Callable = no_collation, track_meta: bool = True) -> None: + def __init__( + self, func: Callable | None = None, inv_func: Callable = no_collation, track_meta: bool = True + ) -> None: if func is not None and not callable(func): raise TypeError(f"func must be None or callable but is {type(func).__name__}.") self.func = func @@ -1039,7 +1043,9 @@ def __call__( if output_shape is None: output_shape = self.output_shape indices: list[NdarrayOrTensor] - indices = map_classes_to_indices(label, self.num_classes, image, self.image_threshold, self.max_samples_per_class) + indices = map_classes_to_indices( + label, self.num_classes, image, self.image_threshold, self.max_samples_per_class + ) if output_shape is not None: indices = [unravel_indices(cls_indices, output_shape) for cls_indices in indices] @@ -1651,7 +1657,9 @@ class ImageFilter(Transform): """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - supported_filters = sorted(["mean", "laplace", "elliptical", "sobel", "sharpen", "median", "gauss", "savitzky_golay"]) + supported_filters = sorted( + ["mean", "laplace", "elliptical", "sobel", "sharpen", "median", "gauss", "savitzky_golay"] + ) def __init__(self, filter: str | NdarrayOrTensor | nn.Module, filter_size: int | None = None, **kwargs) -> None: self._check_filter_format(filter, filter_size) @@ -1792,7 +1800,9 @@ class RandImageFilter(RandomizableTransform): backend = ImageFilter.backend - def __init__(self, filter: str | NdarrayOrTensor, filter_size: int | None = None, prob: float = 0.1, **kwargs) -> None: + def __init__( + self, filter: str | NdarrayOrTensor, filter_size: int | None = None, prob: float = 0.1, **kwargs + ) -> None: super().__init__(prob) self.filter = ImageFilter(filter, filter_size, **kwargs) @@ -1882,7 +1892,9 @@ def _compute_final_affine(self, affine: torch.Tensor, applied_affine: torch.Tens return affine - def transform_coordinates(self, data: torch.Tensor, affine: torch.Tensor | None = None) -> tuple[torch.Tensor, dict]: + def transform_coordinates( + self, data: torch.Tensor, affine: torch.Tensor | None = None + ) -> tuple[torch.Tensor, dict]: """ Transform coordinates using an affine transformation matrix.