Skip to content

Commit 74741e7

Browse files
committed
Refactor test cases in test_pad_mode.py to use dict_product for cleaner test generation
Signed-off-by: R. Garcia-Dias <[email protected]>
1 parent 51a94c2 commit 74741e7

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

tests/utils/test_pad_mode.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,27 @@
1818

1919
from monai.transforms import CastToType, Pad
2020
from monai.utils import NumpyPadMode, PytorchPadMode
21-
from tests.test_utils import SkipIfBeforePyTorchVersion
21+
from tests.test_utils import SkipIfBeforePyTorchVersion, dict_product
2222

2323

2424
@SkipIfBeforePyTorchVersion((1, 10, 1))
2525
class TestPadMode(unittest.TestCase):
2626
def test_pad(self):
2727
expected_shapes = {3: (1, 15, 10), 4: (1, 10, 6, 7)}
28-
for t in (float, int, np.uint8, np.int16, np.float32, bool):
29-
for d in ("cuda:0", "cpu") if torch.cuda.is_available() else ("cpu",):
30-
for s in ((1, 10, 10), (1, 5, 6, 7)):
31-
for m in list(PytorchPadMode) + list(NumpyPadMode):
32-
a = torch.rand(s)
33-
to_pad = [(0, 0), (2, 3)] if len(s) == 3 else [(0, 0), (2, 3), (0, 0), (0, 0)]
34-
out = Pad(to_pad=to_pad, mode=m)(CastToType(dtype=t)(a).to(d))
35-
self.assertEqual(out.shape, expected_shapes[len(s)])
28+
devices = ("cuda:0", "cpu") if torch.cuda.is_available() else ("cpu",)
29+
shapes = ((1, 10, 10), (1, 5, 6, 7))
30+
types = (float, int, np.uint8, np.int16, np.float32, bool)
31+
modes = list(PytorchPadMode) + list(NumpyPadMode)
32+
33+
for params in dict_product(t=types, d=devices, s=shapes, m=modes):
34+
t = params["t"]
35+
d = params["d"]
36+
s = params["s"]
37+
m = params["m"]
38+
a = torch.rand(s)
39+
to_pad = [(0, 0), (2, 3)] if len(s) == 3 else [(0, 0), (2, 3), (0, 0), (0, 0)]
40+
out = Pad(to_pad=to_pad, mode=m)(CastToType(dtype=t)(a).to(d))
41+
self.assertEqual(out.shape, expected_shapes[len(s)])
3642

3743

3844
if __name__ == "__main__":

0 commit comments

Comments
 (0)