Skip to content

Commit 27bd7fe

Browse files
committed
Finally...
Signed-off-by: Fabian Klopfer <[email protected]>
1 parent f05aab5 commit 27bd7fe

File tree

7 files changed

+117
-46
lines changed

7 files changed

+117
-46
lines changed

monai/transforms/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,8 @@
531531
RandIdentity,
532532
RandImageFilter,
533533
RandLambda,
534+
RandTorchIO,
535+
RandTorchVision,
534536
RemoveRepeatedChannel,
535537
RepeatChannel,
536538
SimulateDelay,
@@ -621,6 +623,9 @@
621623
RandLambdad,
622624
RandLambdaD,
623625
RandLambdaDict,
626+
RandTorchIOd,
627+
RandTorchIOD,
628+
RandTorchIODict,
624629
RandTorchVisiond,
625630
RandTorchVisionD,
626631
RandTorchVisionDict,

monai/transforms/utility/array.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
import sys
1919
import time
2020
import warnings
21-
from collections.abc import Mapping, Sequence
21+
from collections.abc import Hashable, Mapping, Sequence
2222
from copy import deepcopy
2323
from functools import partial
24-
from typing import Any, Callable
24+
from typing import Any, Callable, Union
2525

2626
import numpy as np
2727
import torch
@@ -1227,7 +1227,7 @@ def __init__(self, name: str, *args, **kwargs) -> None:
12271227
transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name)
12281228
self.trans = transform(*args, **kwargs)
12291229

1230-
def __call__(self, img: NdarrayOrTensor):
1230+
def __call__(self, img: Union[NdarrayOrTensor, Mapping[Hashable, NdarrayOrTensor]]):
12311231
"""
12321232
Args:
12331233
img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image,
@@ -1259,7 +1259,7 @@ def __init__(self, name: str, *args, **kwargs) -> None:
12591259
transform, _ = optional_import("torchio.transforms", "0.18.0", min_version, name=name)
12601260
self.trans = transform(*args, **kwargs)
12611261

1262-
def __call__(self, img: NdarrayOrTensor):
1262+
def __call__(self, img: Union[NdarrayOrTensor, Mapping[Hashable, NdarrayOrTensor]]):
12631263
"""
12641264
Args:
12651265
img: an instance of torchio.Subject, torchio.Image, numpy.ndarray, torch.Tensor, SimpleITK.Image,

monai/transforms/utility/dictionary.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1466,7 +1466,6 @@ def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = F
14661466
keys: keys of the corresponding items to be transformed.
14671467
See also: :py:class:`monai.transforms.compose.MapTransform`
14681468
name: The transform name in TorchIO package.
1469-
apply_same_transform: whether to apply the same transform for all the items specified by `keys`.
14701469
allow_missing_keys: don't raise exception if key is missing.
14711470
args: parameters for the TorchIO transform.
14721471
kwargs: parameters for the TorchIO transform.
@@ -1478,11 +1477,8 @@ def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = F
14781477

14791478
self.trans = TorchIO(name, *args, **kwargs)
14801479

1481-
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
1482-
d = dict(data)
1483-
for key in self.key_iterator(d):
1484-
d[key] = self.trans(d[key])
1485-
return d
1480+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:
1481+
return dict(self.trans(data))
14861482

14871483

14881484
class RandTorchIOd(MapTransform, RandomizableTrait):
@@ -1499,7 +1495,6 @@ def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = F
14991495
keys: keys of the corresponding items to be transformed.
15001496
See also: :py:class:`monai.transforms.compose.MapTransform`
15011497
name: The transform name in TorchIO package.
1502-
apply_same_transform: whether to apply the same transform for all the items specified by `keys`.
15031498
allow_missing_keys: don't raise exception if key is missing.
15041499
args: parameters for the TorchIO transform.
15051500
kwargs: parameters for the TorchIO transform.
@@ -1511,12 +1506,8 @@ def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = F
15111506

15121507
self.trans = TorchIO(name, *args, **kwargs)
15131508

1514-
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
1515-
d = dict(data)
1516-
for key in self.key_iterator(d):
1517-
d[key] = self.trans(d[key])
1518-
return d
1519-
1509+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:
1510+
return dict(self.trans(data))
15201511

15211512

15221513
class MapLabelValued(MapTransform):

tests/test_rand_torchio.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
from unittest import skipUnless
16+
17+
import numpy as np
18+
import torch
19+
from parameterized import parameterized
20+
21+
from monai.transforms import RandTorchIO
22+
from monai.utils import optional_import, set_determinism
23+
24+
_, has_torchio = optional_import("torchio")
25+
26+
TEST_DIMS = [3, 128, 160, 160]
27+
TESTS = [
28+
[{"name": "RandomAffine"}, torch.rand(TEST_DIMS)],
29+
[{"name": "RandomElasticDeformation"}, torch.rand(TEST_DIMS)],
30+
[{"name": "RandomAnisotropy"}, torch.rand(TEST_DIMS)],
31+
[{"name": "RandomMotion"}, torch.rand(TEST_DIMS)],
32+
[{"name": "RandomGhosting"}, torch.rand(TEST_DIMS)],
33+
[{"name": "RandomSpike"}, torch.rand(TEST_DIMS)],
34+
[{"name": "RandomBiasField"}, torch.rand(TEST_DIMS)],
35+
[{"name": "RandomBlur"}, torch.rand(TEST_DIMS)],
36+
[{"name": "RandomNoise"}, torch.rand(TEST_DIMS)],
37+
[{"name": "RandomSwap"}, torch.rand(TEST_DIMS)],
38+
[{"name": "RandomGamma"}, torch.rand(TEST_DIMS)],
39+
]
40+
41+
42+
@skipUnless(has_torchio, "Requires torchio")
43+
class TestRandTorchIO(unittest.TestCase):
44+
45+
@parameterized.expand(TESTS)
46+
def test_value(self, input_param, input_data):
47+
set_determinism(seed=0)
48+
result = RandTorchIO(**input_param)(input_data)
49+
self.assertIsNotNone(result)
50+
self.assertFalse(np.array_equal(result.numpy(), input_data.numpy()), f"{input_param} failed")
51+
52+
53+
if __name__ == "__main__":
54+
unittest.main()

tests/test_rand_torchiod.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
from unittest import skipUnless
16+
17+
import numpy as np
18+
import torch
19+
from parameterized import parameterized
20+
21+
from monai.transforms import RandTorchIOd
22+
from monai.utils import optional_import, set_determinism
23+
from tests.utils import assert_allclose
24+
25+
_, has_torchio = optional_import("torchio")
26+
27+
TEST_DIMS = [3, 128, 160, 160]
28+
TEST_TENSOR = torch.rand(TEST_DIMS)
29+
TEST_PARAMS = [[{"keys": ["img1", "img2"], "name": "RandomAffine"}, {"img1": TEST_TENSOR, "img2": TEST_TENSOR}]]
30+
31+
32+
@skipUnless(has_torchio, "Requires torchio")
33+
class TestRandTorchIOd(unittest.TestCase):
34+
35+
@parameterized.expand(TEST_PARAMS)
36+
def test_random_transform(self, input_param, input_data):
37+
set_determinism(seed=0)
38+
result = RandTorchIOd(**input_param)(input_data)
39+
self.assertFalse(np.allclose(input_data["img1"], result["img1"], atol=1e-6, rtol=1e-6))
40+
assert_allclose(result["img1"], result["img2"], atol=1e-4, rtol=1e-4)
41+
42+
43+
if __name__ == "__main__":
44+
unittest.main()

tests/test_torchio.py

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,34 +19,19 @@
1919
from parameterized import parameterized
2020

2121
from monai.transforms import TorchIO
22-
from monai.utils import optional_import, set_determinism
22+
from monai.utils import optional_import
2323

2424
_, has_torchio = optional_import("torchio")
2525

2626
TEST_DIMS = [3, 128, 160, 160]
27-
TESTS = [
28-
[{"name": "RescaleIntensity"}, torch.rand(TEST_DIMS)],
29-
[{"name": "ZNormalization"}, torch.rand(TEST_DIMS)],
30-
[{"name": "RandomAffine"}, torch.rand(TEST_DIMS)],
31-
[{"name": "RandomElasticDeformation"}, torch.rand(TEST_DIMS)],
32-
[{"name": "RandomAnisotropy"}, torch.rand(TEST_DIMS)],
33-
[{"name": "RandomMotion"}, torch.rand(TEST_DIMS)],
34-
[{"name": "RandomGhosting"}, torch.rand(TEST_DIMS)],
35-
[{"name": "RandomSpike"}, torch.rand(TEST_DIMS)],
36-
[{"name": "RandomBiasField"}, torch.rand(TEST_DIMS)],
37-
[{"name": "RandomBlur"}, torch.rand(TEST_DIMS)],
38-
[{"name": "RandomNoise"}, torch.rand(TEST_DIMS)],
39-
[{"name": "RandomSwap"}, torch.rand(TEST_DIMS)],
40-
[{"name": "RandomGamma"}, torch.rand(TEST_DIMS)],
41-
]
27+
TESTS = [[{"name": "RescaleIntensity"}, torch.rand(TEST_DIMS)], [{"name": "ZNormalization"}, torch.rand(TEST_DIMS)]]
4228

4329

4430
@skipUnless(has_torchio, "Requires torchio")
4531
class TestTorchIO(unittest.TestCase):
4632

4733
@parameterized.expand(TESTS)
4834
def test_value(self, input_param, input_data):
49-
set_determinism(seed=0)
5035
result = TorchIO(**input_param)(input_data)
5136
self.assertIsNotNone(result)
5237
self.assertFalse(np.array_equal(result.numpy(), input_data.numpy()), f"{input_param} failed")

tests/test_torchiod.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,37 +18,29 @@
1818
from parameterized import parameterized
1919

2020
from monai.transforms import TorchIOd
21-
from monai.utils import optional_import, set_determinism
21+
from monai.utils import optional_import
2222
from tests.utils import assert_allclose
2323

2424
_, has_torchio = optional_import("torchio")
2525

2626
TEST_DIMS = [3, 128, 160, 160]
2727
TEST_TENSOR = torch.rand(TEST_DIMS)
28-
TEST1 = [
28+
TEST_PARAMS = [
2929
[
3030
{"keys": "img", "name": "RescaleIntensity", "out_min_max": (0, 42)},
3131
{"img": TEST_TENSOR},
3232
((TEST_TENSOR - TEST_TENSOR.min()) / (TEST_TENSOR.max() - TEST_TENSOR.min())) * 42,
3333
]
3434
]
35-
TEST2 = [[{"keys": ["img1", "img2"], "name": "RandomAffine"}, {"img1": TEST_TENSOR, "img2": TEST_TENSOR}]]
3635

3736

3837
@skipUnless(has_torchio, "Requires torchio")
3938
class TestTorchIOd(unittest.TestCase):
4039

41-
@parameterized.expand(TEST1)
40+
@parameterized.expand(TEST_PARAMS)
4241
def test_value(self, input_param, input_data, expected_value):
43-
set_determinism(seed=0)
4442
result = TorchIOd(**input_param)(input_data)
45-
assert_allclose(result["img"], expected_value, atol=1e-4, rtol=1e-4, type_test=False)
46-
47-
@parameterized.expand(TEST2)
48-
def test_random_transform(self, input_param, input_data):
49-
set_determinism(seed=0)
50-
result = TorchIOd(**input_param)(input_data)
51-
assert_allclose(result["img1"], result["img2"], atol=1e-4, rtol=1e-4, type_test=False)
43+
assert_allclose(result["img"], expected_value, atol=1e-4, rtol=1e-4)
5244

5345

5446
if __name__ == "__main__":

0 commit comments

Comments
 (0)