|
18 | 18 |
|
19 | 19 | from monai.transforms import CastToType, Pad |
20 | 20 | from monai.utils import NumpyPadMode, PytorchPadMode |
21 | | -from tests.test_utils import SkipIfBeforePyTorchVersion |
| 21 | +from tests.test_utils import SkipIfBeforePyTorchVersion, dict_product |
22 | 22 |
|
23 | 23 |
|
24 | 24 | @SkipIfBeforePyTorchVersion((1, 10, 1)) |
25 | 25 | class TestPadMode(unittest.TestCase): |
26 | 26 | def test_pad(self): |
27 | 27 | expected_shapes = {3: (1, 15, 10), 4: (1, 10, 6, 7)} |
28 | | - for t in (float, int, np.uint8, np.int16, np.float32, bool): |
29 | | - for d in ("cuda:0", "cpu") if torch.cuda.is_available() else ("cpu",): |
30 | | - for s in ((1, 10, 10), (1, 5, 6, 7)): |
31 | | - for m in list(PytorchPadMode) + list(NumpyPadMode): |
32 | | - a = torch.rand(s) |
33 | | - to_pad = [(0, 0), (2, 3)] if len(s) == 3 else [(0, 0), (2, 3), (0, 0), (0, 0)] |
34 | | - out = Pad(to_pad=to_pad, mode=m)(CastToType(dtype=t)(a).to(d)) |
35 | | - self.assertEqual(out.shape, expected_shapes[len(s)]) |
| 28 | + devices = ("cuda:0", "cpu") if torch.cuda.is_available() else ("cpu",) |
| 29 | + shapes = ((1, 10, 10), (1, 5, 6, 7)) |
| 30 | + types = (float, int, np.uint8, np.int16, np.float32, bool) |
| 31 | + modes = list(PytorchPadMode) + list(NumpyPadMode) |
| 32 | + |
| 33 | + for params in dict_product(t=types, d=devices, s=shapes, m=modes): |
| 34 | + t = params["t"] |
| 35 | + d = params["d"] |
| 36 | + s = params["s"] |
| 37 | + m = params["m"] |
| 38 | + a = torch.rand(s) |
| 39 | + to_pad = [(0, 0), (2, 3)] if len(s) == 3 else [(0, 0), (2, 3), (0, 0), (0, 0)] |
| 40 | + out = Pad(to_pad=to_pad, mode=m)(CastToType(dtype=t)(a).to(d)) |
| 41 | + self.assertEqual(out.shape, expected_shapes[len(s)]) |
36 | 42 |
|
37 | 43 |
|
38 | 44 | if __name__ == "__main__": |
|
0 commit comments