8888
8989class TestMDTATransformerBlock (unittest .TestCase ):
9090
91- @skipUnless (has_einops , "Requires einops" )
9291 @parameterized .expand (TEST_CASES_TRANSFORMER )
92+ @skipUnless (has_einops , "Requires einops" )
9393 def test_shape (self , spatial_dims , dim , heads , ffn_factor , bias , layer_norm_use_bias , flash , shape ):
9494 if flash and not torch .cuda .is_available ():
9595 self .skipTest ("Flash attention requires CUDA" )
@@ -111,6 +111,7 @@ def test_shape(self, spatial_dims, dim, heads, ffn_factor, bias, layer_norm_use_
111111class TestOverlapPatchEmbed (unittest .TestCase ):
112112
113113 @parameterized .expand (TEST_CASES_PATCHEMBED )
114+ @skipUnless (has_einops , "Requires einops" )
114115 def test_shape (self , spatial_dims , in_channels , embed_dim , input_shape , expected_shape ):
115116 net = OverlapPatchEmbed (spatial_dims = spatial_dims , in_channels = in_channels , embed_dim = embed_dim )
116117 with eval_mode (net ):
@@ -120,8 +121,8 @@ def test_shape(self, spatial_dims, in_channels, embed_dim, input_shape, expected
120121
121122class TestRestormer (unittest .TestCase ):
122123
123- @skipUnless (has_einops , "Requires einops" )
124124 @parameterized .expand (TEST_CASES_RESTORMER )
125+ @skipUnless (has_einops , "Requires einops" )
125126 def test_shape (self , input_param , input_shape , expected_shape ):
126127 if input_param .get ("flash_attention" , False ) and not torch .cuda .is_available ():
127128 self .skipTest ("Flash attention requires CUDA" )
0 commit comments