2020from monai .networks import eval_mode
2121from monai .networks .blocks .cablock import CABlock , FeedForward
2222from monai .utils import optional_import
23- from tests .test_utils import SkipIfBeforePyTorchVersion , assert_allclose
23+ from tests .test_utils import SkipIfBeforePyTorchVersion , assert_allclose , dict_product
2424
2525einops , has_einops = optional_import ("einops" )
2626
27-
28- TEST_CASES_CAB = []
29- for spatial_dims in [2 , 3 ]:
30- for dim in [32 , 64 , 128 ]:
31- for num_heads in [2 , 4 , 8 ]:
32- for bias in [True , False ]:
33- test_case = [
34- {
35- "spatial_dims" : spatial_dims ,
36- "dim" : dim ,
37- "num_heads" : num_heads ,
38- "bias" : bias ,
39- "flash_attention" : False ,
40- },
41- (2 , dim , * ([16 ] * spatial_dims )),
42- (2 , dim , * ([16 ] * spatial_dims )),
43- ]
44- TEST_CASES_CAB .append (test_case )
27+ TEST_CASES_CAB = [
28+ [
29+ {
30+ "spatial_dims" : params ["spatial_dims" ],
31+ "dim" : params ["dim" ],
32+ "num_heads" : params ["num_heads" ],
33+ "bias" : params ["bias" ],
34+ "flash_attention" : False ,
35+ },
36+ (2 , params ["dim" ], * ([16 ] * params ["spatial_dims" ])),
37+ (2 , params ["dim" ], * ([16 ] * params ["spatial_dims" ])),
38+ ]
39+ for params in dict_product (
40+ spatial_dims = [2 , 3 ],
41+ dim = [32 , 64 , 128 ],
42+ num_heads = [2 , 4 , 8 ],
43+ bias = [True , False ],
44+ )
45+ ]
4546
4647
4748TEST_CASES_FEEDFORWARD = [
5354
5455
5556class TestFeedForward (unittest .TestCase ):
56-
5757 @parameterized .expand (TEST_CASES_FEEDFORWARD )
5858 def test_shape (self , input_param , input_shape ):
5959 net = FeedForward (** input_param )
@@ -69,7 +69,6 @@ def test_gating_mechanism(self):
6969
7070
7171class TestCABlock (unittest .TestCase ):
72-
7372 @parameterized .expand (TEST_CASES_CAB )
7473 @skipUnless (has_einops , "Requires einops" )
7574 def test_shape (self , input_param , input_shape , expected_shape ):
0 commit comments