Skip to content

Commit 1079d8c

Browse files
committed
require einops also for test_restormer
Signed-off-by: tisalon <[email protected]>
1 parent 30fad17 commit 1079d8c

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

tests/test_restormer.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,17 @@
1515

1616
sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "../"))
1717
import unittest
18+
from unittest import skipUnless
19+
1820

1921
import torch
2022
from parameterized import parameterized
2123

2224
from monai.networks import eval_mode
2325
from monai.networks.nets.restormer import MDTATransformerBlock, OverlapPatchEmbed, Restormer
26+
from monai.utils import optional_import
27+
28+
einops, has_einops = optional_import("einops")
2429

2530
TEST_CASES_TRANSFORMER = [
2631
# [spatial_dims, dim, num_heads, ffn_factor, bias, layer_norm_use_bias, flash_attn, input_shape]
@@ -86,7 +91,8 @@
8691

8792

8893
class TestMDTATransformerBlock(unittest.TestCase):
89-
94+
95+
@skipUnless(has_einops, "Requires einops")
9096
@parameterized.expand(TEST_CASES_TRANSFORMER)
9197
def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, layer_norm_use_bias, flash, shape):
9298
block = MDTATransformerBlock(
@@ -116,18 +122,21 @@ def test_shape(self, spatial_dims, in_channels, embed_dim, input_shape, expected
116122

117123
class TestRestormer(unittest.TestCase):
118124

125+
@skipUnless(has_einops, "Requires einops")
119126
@parameterized.expand(TEST_CASES_RESTORMER)
120127
def test_shape(self, input_param, input_shape, expected_shape):
121128
net = Restormer(**input_param)
122129
with eval_mode(net):
123130
result = net(torch.randn(input_shape))
124131
self.assertEqual(result.shape, expected_shape)
125132

133+
@skipUnless(has_einops, "Requires einops")
126134
def test_small_input_error_2d(self):
127135
net = Restormer(spatial_dims=2, in_channels=1, out_channels=1)
128136
with self.assertRaises(AssertionError):
129137
net(torch.randn(1, 1, 8, 8))
130138

139+
@skipUnless(has_einops, "Requires einops")
131140
def test_small_input_error_3d(self):
132141
net = Restormer(spatial_dims=3, in_channels=1, out_channels=1)
133142
with self.assertRaises(AssertionError):

0 commit comments

Comments
 (0)