|
15 | 15 |
|
16 | 16 | sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "../")) |
17 | 17 | import unittest |
| 18 | +from unittest import skipUnless |
| 19 | + |
18 | 20 |
|
19 | 21 | import torch |
20 | 22 | from parameterized import parameterized |
21 | 23 |
|
22 | 24 | from monai.networks import eval_mode |
23 | 25 | from monai.networks.nets.restormer import MDTATransformerBlock, OverlapPatchEmbed, Restormer |
| 26 | +from monai.utils import optional_import |
| 27 | + |
| 28 | +einops, has_einops = optional_import("einops") |
24 | 29 |
|
25 | 30 | TEST_CASES_TRANSFORMER = [ |
26 | 31 | # [spatial_dims, dim, num_heads, ffn_factor, bias, layer_norm_use_bias, flash_attn, input_shape] |
|
86 | 91 |
|
87 | 92 |
|
88 | 93 | class TestMDTATransformerBlock(unittest.TestCase): |
89 | | - |
| 94 | + |
| 95 | + @skipUnless(has_einops, "Requires einops") |
90 | 96 | @parameterized.expand(TEST_CASES_TRANSFORMER) |
91 | 97 | def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, layer_norm_use_bias, flash, shape): |
92 | 98 | block = MDTATransformerBlock( |
@@ -116,18 +122,21 @@ def test_shape(self, spatial_dims, in_channels, embed_dim, input_shape, expected |
116 | 122 |
|
117 | 123 | class TestRestormer(unittest.TestCase): |
118 | 124 |
|
| 125 | + @skipUnless(has_einops, "Requires einops") |
119 | 126 | @parameterized.expand(TEST_CASES_RESTORMER) |
120 | 127 | def test_shape(self, input_param, input_shape, expected_shape): |
121 | 128 | net = Restormer(**input_param) |
122 | 129 | with eval_mode(net): |
123 | 130 | result = net(torch.randn(input_shape)) |
124 | 131 | self.assertEqual(result.shape, expected_shape) |
125 | 132 |
|
| 133 | + @skipUnless(has_einops, "Requires einops") |
126 | 134 | def test_small_input_error_2d(self): |
127 | 135 | net = Restormer(spatial_dims=2, in_channels=1, out_channels=1) |
128 | 136 | with self.assertRaises(AssertionError): |
129 | 137 | net(torch.randn(1, 1, 8, 8)) |
130 | 138 |
|
| 139 | + @skipUnless(has_einops, "Requires einops") |
131 | 140 | def test_small_input_error_3d(self): |
132 | 141 | net = Restormer(spatial_dims=3, in_channels=1, out_channels=1) |
133 | 142 | with self.assertRaises(AssertionError): |
|
0 commit comments