Skip to content

Commit 0d7f772

Browse files
KumoLiuericspod
andauthored
Ensure deterministic in MixUp, CutMix, CutOut (#7813)
Fixes #7697 ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: YunLiu <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]>
1 parent a0935d9 commit 0d7f772

File tree

3 files changed

+146
-69
lines changed

3 files changed

+146
-69
lines changed

monai/transforms/regularization/array.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
import torch
1818

19+
from monai.data.meta_obj import get_track_meta
20+
from monai.utils.type_conversion import convert_to_dst_type, convert_to_tensor
21+
1922
from ..transform import RandomizableTransform
2023

2124
__all__ = ["MixUp", "CutMix", "CutOut", "Mixer"]
@@ -53,9 +56,11 @@ def randomize(self, data=None) -> None:
5356
as needed. You need to call this method everytime you apply the transform to a new
5457
batch.
5558
"""
59+
super().randomize(None)
5660
self._params = (
5761
torch.from_numpy(self.R.beta(self.alpha, self.alpha, self.batch_size)).type(torch.float32),
5862
self.R.permutation(self.batch_size),
63+
[torch.from_numpy(self.R.randint(0, d, size=(1,))) for d in data.shape[2:]] if data is not None else [],
5964
)
6065

6166

@@ -69,7 +74,7 @@ class MixUp(Mixer):
6974
"""
7075

7176
def apply(self, data: torch.Tensor):
72-
weight, perm = self._params
77+
weight, perm, _ = self._params
7378
nsamples, *dims = data.shape
7479
if len(weight) != nsamples:
7580
raise ValueError(f"Expected batch of size: {len(weight)}, but got {nsamples}")
@@ -80,11 +85,18 @@ def apply(self, data: torch.Tensor):
8085
mixweight = weight[(Ellipsis,) + (None,) * len(dims)]
8186
return mixweight * data + (1 - mixweight) * data[perm, ...]
8287

83-
def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None):
84-
self.randomize()
88+
def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None, randomize=True):
89+
data_t = convert_to_tensor(data, track_meta=get_track_meta())
90+
if labels is not None:
91+
labels_t = convert_to_tensor(labels, track_meta=get_track_meta())
92+
if randomize:
93+
self.randomize()
8594
if labels is None:
86-
return self.apply(data)
87-
return self.apply(data), self.apply(labels)
95+
return convert_to_dst_type(self.apply(data_t), dst=data)[0]
96+
return (
97+
convert_to_dst_type(self.apply(data_t), dst=data)[0],
98+
convert_to_dst_type(self.apply(labels_t), dst=labels)[0],
99+
)
88100

89101

90102
class CutMix(Mixer):
@@ -113,33 +125,38 @@ class CutMix(Mixer):
113125
"""
114126

115127
def apply(self, data: torch.Tensor):
116-
weights, perm = self._params
128+
weights, perm, coords = self._params
117129
nsamples, _, *dims = data.shape
118130
if len(weights) != nsamples:
119131
raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}")
120132

121133
mask = torch.ones_like(data)
122134
for s, weight in enumerate(weights):
123-
coords = [torch.randint(0, d, size=(1,)) for d in dims]
124135
lengths = [d * sqrt(1 - weight) for d in dims]
125136
idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)]
126137
mask[s][idx] = 0
127138

128139
return mask * data + (1 - mask) * data[perm, ...]
129140

130141
def apply_on_labels(self, labels: torch.Tensor):
131-
weights, perm = self._params
142+
weights, perm, _ = self._params
132143
nsamples, *dims = labels.shape
133144
if len(weights) != nsamples:
134145
raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}")
135146

136147
mixweight = weights[(Ellipsis,) + (None,) * len(dims)]
137148
return mixweight * labels + (1 - mixweight) * labels[perm, ...]
138149

139-
def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None):
140-
self.randomize()
141-
augmented = self.apply(data)
142-
return (augmented, self.apply_on_labels(labels)) if labels is not None else augmented
150+
def __call__(self, data: torch.Tensor, labels: torch.Tensor | None = None, randomize=True):
151+
data_t = convert_to_tensor(data, track_meta=get_track_meta())
152+
if labels is not None:
153+
labels_t = convert_to_tensor(labels, track_meta=get_track_meta())
154+
if randomize:
155+
self.randomize(data)
156+
augmented = convert_to_dst_type(self.apply(data_t), dst=data)[0]
157+
if labels is not None:
158+
augmented_label = convert_to_dst_type(self.apply(labels_t), dst=labels)[0]
159+
return (augmented, augmented_label) if labels is not None else augmented
143160

144161

145162
class CutOut(Mixer):
@@ -155,20 +172,21 @@ class CutOut(Mixer):
155172
"""
156173

157174
def apply(self, data: torch.Tensor):
158-
weights, _ = self._params
175+
weights, _, coords = self._params
159176
nsamples, _, *dims = data.shape
160177
if len(weights) != nsamples:
161178
raise ValueError(f"Expected batch of size: {len(weights)}, but got {nsamples}")
162179

163180
mask = torch.ones_like(data)
164181
for s, weight in enumerate(weights):
165-
coords = [torch.randint(0, d, size=(1,)) for d in dims]
166182
lengths = [d * sqrt(1 - weight) for d in dims]
167183
idx = [slice(None)] + [slice(c, min(ceil(c + ln), d)) for c, ln, d in zip(coords, lengths, dims)]
168184
mask[s][idx] = 0
169185

170186
return mask * data
171187

172-
def __call__(self, data: torch.Tensor):
173-
self.randomize()
174-
return self.apply(data)
188+
def __call__(self, data: torch.Tensor, randomize=True):
189+
data_t = convert_to_tensor(data, track_meta=get_track_meta())
190+
if randomize:
191+
self.randomize(data)
192+
return convert_to_dst_type(self.apply(data_t), dst=data)[0]

monai/transforms/regularization/dictionary.py

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,23 @@
1111

1212
from __future__ import annotations
1313

14+
from collections.abc import Hashable
15+
16+
import numpy as np
17+
1418
from monai.config import KeysCollection
19+
from monai.config.type_definitions import NdarrayOrTensor
20+
from monai.data.meta_obj import get_track_meta
21+
from monai.utils import convert_to_tensor
1522
from monai.utils.misc import ensure_tuple
1623

17-
from ..transform import MapTransform
24+
from ..transform import MapTransform, RandomizableTransform
1825
from .array import CutMix, CutOut, MixUp
1926

2027
__all__ = ["MixUpd", "MixUpD", "MixUpDict", "CutMixd", "CutMixD", "CutMixDict", "CutOutd", "CutOutD", "CutOutDict"]
2128

2229

23-
class MixUpd(MapTransform):
30+
class MixUpd(MapTransform, RandomizableTransform):
2431
"""
2532
Dictionary-based version :py:class:`monai.transforms.MixUp`.
2633
@@ -31,18 +38,24 @@ class MixUpd(MapTransform):
3138
def __init__(
3239
self, keys: KeysCollection, batch_size: int, alpha: float = 1.0, allow_missing_keys: bool = False
3340
) -> None:
34-
super().__init__(keys, allow_missing_keys)
41+
MapTransform.__init__(self, keys, allow_missing_keys)
3542
self.mixup = MixUp(batch_size, alpha)
3643

44+
def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> MixUpd:
45+
super().set_random_state(seed, state)
46+
self.mixup.set_random_state(seed, state)
47+
return self
48+
3749
def __call__(self, data):
38-
self.mixup.randomize()
39-
result = dict(data)
40-
for k in self.keys:
41-
result[k] = self.mixup.apply(data[k])
42-
return result
50+
d = dict(data)
51+
# all the keys share the same random state
52+
self.mixup.randomize(None)
53+
for k in self.key_iterator(d):
54+
d[k] = self.mixup(data[k], randomize=False)
55+
return d
4356

4457

45-
class CutMixd(MapTransform):
58+
class CutMixd(MapTransform, RandomizableTransform):
4659
"""
4760
Dictionary-based version :py:class:`monai.transforms.CutMix`.
4861
@@ -63,17 +76,27 @@ def __init__(
6376
self.mixer = CutMix(batch_size, alpha)
6477
self.label_keys = ensure_tuple(label_keys) if label_keys is not None else []
6578

66-
def __call__(self, data):
67-
self.mixer.randomize()
68-
result = dict(data)
69-
for k in self.keys:
70-
result[k] = self.mixer.apply(data[k])
71-
for k in self.label_keys:
72-
result[k] = self.mixer.apply_on_labels(data[k])
73-
return result
74-
79+
def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> CutMixd:
80+
super().set_random_state(seed, state)
81+
self.mixer.set_random_state(seed, state)
82+
return self
7583

76-
class CutOutd(MapTransform):
84+
def __call__(self, data):
85+
d = dict(data)
86+
first_key: Hashable = self.first_key(d)
87+
if first_key == ():
88+
out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta())
89+
return out
90+
self.mixer.randomize(d[first_key])
91+
for key, label_key in self.key_iterator(d, self.label_keys):
92+
ret = self.mixer(data[key], data.get(label_key, None), randomize=False)
93+
d[key] = ret[0]
94+
if label_key in d:
95+
d[label_key] = ret[1]
96+
return d
97+
98+
99+
class CutOutd(MapTransform, RandomizableTransform):
77100
"""
78101
Dictionary-based version :py:class:`monai.transforms.CutOut`.
79102
@@ -84,12 +107,21 @@ def __init__(self, keys: KeysCollection, batch_size: int, allow_missing_keys: bo
84107
super().__init__(keys, allow_missing_keys)
85108
self.cutout = CutOut(batch_size)
86109

110+
def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> CutOutd:
111+
super().set_random_state(seed, state)
112+
self.cutout.set_random_state(seed, state)
113+
return self
114+
87115
def __call__(self, data):
88-
result = dict(data)
89-
self.cutout.randomize()
90-
for k in self.keys:
91-
result[k] = self.cutout(data[k])
92-
return result
116+
d = dict(data)
117+
first_key: Hashable = self.first_key(d)
118+
if first_key == ():
119+
out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta())
120+
return out
121+
self.cutout.randomize(d[first_key])
122+
for k in self.key_iterator(d):
123+
d[k] = self.cutout(data[k], randomize=False)
124+
return d
93125

94126

95127
MixUpD = MixUpDict = MixUpd

tests/test_regularization.py

Lines changed: 55 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -13,29 +13,31 @@
1313

1414
import unittest
1515

16+
import numpy as np
1617
import torch
1718

18-
from monai.transforms import CutMix, CutMixd, CutOut, MixUp, MixUpd
19-
from monai.utils import set_determinism
19+
from monai.transforms import CutMix, CutMixd, CutOut, CutOutd, MixUp, MixUpd
20+
from tests.utils import assert_allclose
2021

2122

22-
@unittest.skip("Mixup is non-deterministic. Skip it temporarily")
2323
class TestMixup(unittest.TestCase):
2424

25-
def setUp(self) -> None:
26-
set_determinism(seed=0)
27-
28-
def tearDown(self) -> None:
29-
set_determinism(None)
30-
3125
def test_mixup(self):
3226
for dims in [2, 3]:
3327
shape = (6, 3) + (32,) * dims
3428
sample = torch.rand(*shape, dtype=torch.float32)
3529
mixup = MixUp(6, 1.0)
30+
mixup.set_random_state(seed=0)
3631
output = mixup(sample)
32+
np.random.seed(0)
33+
# simulate the randomize() of transform
34+
np.random.random()
35+
weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32)
36+
perm = np.random.permutation(6)
3737
self.assertEqual(output.shape, sample.shape)
38-
self.assertTrue(any(not torch.allclose(sample, mixup(sample)) for _ in range(10)))
38+
mixweight = weight[(Ellipsis,) + (None,) * (dims + 1)]
39+
expected = mixweight * sample + (1 - mixweight) * sample[perm, ...]
40+
assert_allclose(output, expected, type_test=False, atol=1e-7)
3941

4042
with self.assertRaises(ValueError):
4143
MixUp(6, -0.5)
@@ -53,27 +55,32 @@ def test_mixupd(self):
5355
t = torch.rand(*shape, dtype=torch.float32)
5456
sample = {"a": t, "b": t}
5557
mixup = MixUpd(["a", "b"], 6)
58+
mixup.set_random_state(seed=0)
5659
output = mixup(sample)
57-
self.assertTrue(torch.allclose(output["a"], output["b"]))
60+
np.random.seed(0)
61+
# simulate the randomize() of transform
62+
np.random.random()
63+
weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32)
64+
perm = np.random.permutation(6)
65+
self.assertEqual(output["a"].shape, sample["a"].shape)
66+
mixweight = weight[(Ellipsis,) + (None,) * (dims + 1)]
67+
expected = mixweight * sample["a"] + (1 - mixweight) * sample["a"][perm, ...]
68+
assert_allclose(output["a"], expected, type_test=False, atol=1e-7)
69+
assert_allclose(output["a"], output["b"], type_test=False, atol=1e-7)
70+
# self.assertTrue(torch.allclose(output["a"], output["b"]))
5871

5972
with self.assertRaises(ValueError):
6073
MixUpd(["k1", "k2"], 6, -0.5)
6174

6275

63-
@unittest.skip("CutMix is non-deterministic. Skip it temporarily")
6476
class TestCutMix(unittest.TestCase):
6577

66-
def setUp(self) -> None:
67-
set_determinism(seed=0)
68-
69-
def tearDown(self) -> None:
70-
set_determinism(None)
71-
7278
def test_cutmix(self):
7379
for dims in [2, 3]:
7480
shape = (6, 3) + (32,) * dims
7581
sample = torch.rand(*shape, dtype=torch.float32)
7682
cutmix = CutMix(6, 1.0)
83+
cutmix.set_random_state(seed=0)
7784
output = cutmix(sample)
7885
self.assertEqual(output.shape, sample.shape)
7986
self.assertTrue(any(not torch.allclose(sample, cutmix(sample)) for _ in range(10)))
@@ -85,30 +92,50 @@ def test_cutmixd(self):
8592
label = torch.randint(0, 1, shape)
8693
sample = {"a": t, "b": t, "lbl1": label, "lbl2": label}
8794
cutmix = CutMixd(["a", "b"], 6, label_keys=("lbl1", "lbl2"))
95+
cutmix.set_random_state(seed=123)
8896
output = cutmix(sample)
89-
# croppings are different on each application
90-
self.assertTrue(not torch.allclose(output["a"], output["b"]))
9197
# but mixing of labels is not affected by it
9298
self.assertTrue(torch.allclose(output["lbl1"], output["lbl2"]))
9399

94100

95-
@unittest.skip("CutOut is non-deterministic. Skip it temporarily")
96101
class TestCutOut(unittest.TestCase):
97102

98-
def setUp(self) -> None:
99-
set_determinism(seed=0)
100-
101-
def tearDown(self) -> None:
102-
set_determinism(None)
103-
104103
def test_cutout(self):
105104
for dims in [2, 3]:
106105
shape = (6, 3) + (32,) * dims
107106
sample = torch.rand(*shape, dtype=torch.float32)
108107
cutout = CutOut(6, 1.0)
108+
cutout.set_random_state(seed=123)
109109
output = cutout(sample)
110+
np.random.seed(123)
111+
# simulate the randomize() of transform
112+
np.random.random()
113+
weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32)
114+
perm = np.random.permutation(6)
115+
coords = [torch.from_numpy(np.random.randint(0, d, size=(1,))) for d in sample.shape[2:]]
116+
assert_allclose(weight, cutout._params[0])
117+
assert_allclose(perm, cutout._params[1])
118+
self.assertSequenceEqual(coords, cutout._params[2])
110119
self.assertEqual(output.shape, sample.shape)
111-
self.assertTrue(any(not torch.allclose(sample, cutout(sample)) for _ in range(10)))
120+
121+
def test_cutoutd(self):
122+
for dims in [2, 3]:
123+
shape = (6, 3) + (32,) * dims
124+
t = torch.rand(*shape, dtype=torch.float32)
125+
sample = {"a": t, "b": t}
126+
cutout = CutOutd(["a", "b"], 6, 1.0)
127+
cutout.set_random_state(seed=123)
128+
output = cutout(sample)
129+
np.random.seed(123)
130+
# simulate the randomize() of transform
131+
np.random.random()
132+
weight = torch.from_numpy(np.random.beta(1.0, 1.0, 6)).type(torch.float32)
133+
perm = np.random.permutation(6)
134+
coords = [torch.from_numpy(np.random.randint(0, d, size=(1,))) for d in t.shape[2:]]
135+
assert_allclose(weight, cutout.cutout._params[0])
136+
assert_allclose(perm, cutout.cutout._params[1])
137+
self.assertSequenceEqual(coords, cutout.cutout._params[2])
138+
self.assertEqual(output["a"].shape, sample["a"].shape)
112139

113140

114141
if __name__ == "__main__":

0 commit comments

Comments
 (0)