Skip to content

Commit 6c3eeac

Browse files
committed
Refactor TEST_CASES_CAB to use dict_product for cleaner test case generation
Signed-off-by: R. Garcia-Dias <[email protected]>
1 parent 8fad281 commit 6c3eeac

File tree

1 file changed

+20
-21
lines changed

1 file changed

+20
-21
lines changed

tests/networks/blocks/test_CABlock.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,28 +20,29 @@
2020
from monai.networks import eval_mode
2121
from monai.networks.blocks.cablock import CABlock, FeedForward
2222
from monai.utils import optional_import
23-
from tests.test_utils import SkipIfBeforePyTorchVersion, assert_allclose
23+
from tests.test_utils import SkipIfBeforePyTorchVersion, assert_allclose, dict_product
2424

2525
einops, has_einops = optional_import("einops")
2626

27-
28-
TEST_CASES_CAB = []
29-
for spatial_dims in [2, 3]:
30-
for dim in [32, 64, 128]:
31-
for num_heads in [2, 4, 8]:
32-
for bias in [True, False]:
33-
test_case = [
34-
{
35-
"spatial_dims": spatial_dims,
36-
"dim": dim,
37-
"num_heads": num_heads,
38-
"bias": bias,
39-
"flash_attention": False,
40-
},
41-
(2, dim, *([16] * spatial_dims)),
42-
(2, dim, *([16] * spatial_dims)),
43-
]
44-
TEST_CASES_CAB.append(test_case)
27+
TEST_CASES_CAB = [
28+
[
29+
{
30+
"spatial_dims": params["spatial_dims"],
31+
"dim": params["dim"],
32+
"num_heads": params["num_heads"],
33+
"bias": params["bias"],
34+
"flash_attention": False,
35+
},
36+
(2, params["dim"], *([16] * params["spatial_dims"])),
37+
(2, params["dim"], *([16] * params["spatial_dims"])),
38+
]
39+
for params in dict_product(
40+
spatial_dims=[2, 3],
41+
dim=[32, 64, 128],
42+
num_heads=[2, 4, 8],
43+
bias=[True, False],
44+
)
45+
]
4546

4647

4748
TEST_CASES_FEEDFORWARD = [
@@ -53,7 +54,6 @@
5354

5455

5556
class TestFeedForward(unittest.TestCase):
56-
5757
@parameterized.expand(TEST_CASES_FEEDFORWARD)
5858
def test_shape(self, input_param, input_shape):
5959
net = FeedForward(**input_param)
@@ -69,7 +69,6 @@ def test_gating_mechanism(self):
6969

7070

7171
class TestCABlock(unittest.TestCase):
72-
7372
@parameterized.expand(TEST_CASES_CAB)
7473
@skipUnless(has_einops, "Requires einops")
7574
def test_shape(self, input_param, input_shape, expected_shape):

0 commit comments

Comments
 (0)