Skip to content

Commit da9fd6c

Browse files
committed
Refactor TEST_CASE_TRANSFORMERBLOCK to use dict_product for cleaner test case generation
Signed-off-by: R. Garcia-Dias <[email protected]>
1 parent 83fc0a4 commit da9fd6c

File tree

1 file changed

+21
-18
lines changed

1 file changed

+21
-18
lines changed

tests/networks/blocks/test_transformerblock.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,26 +21,29 @@
2121
from monai.networks import eval_mode
2222
from monai.networks.blocks.transformerblock import TransformerBlock
2323
from monai.utils import optional_import
24+
from tests.test_utils import dict_product
2425

2526
einops, has_einops = optional_import("einops")
26-
TEST_CASE_TRANSFORMERBLOCK = []
27-
for dropout_rate in np.linspace(0, 1, 4):
28-
for hidden_size in [360, 480, 600, 768]:
29-
for num_heads in [4, 8, 12]:
30-
for mlp_dim in [1024, 3072]:
31-
for cross_attention in [False, True]:
32-
test_case = [
33-
{
34-
"hidden_size": hidden_size,
35-
"num_heads": num_heads,
36-
"mlp_dim": mlp_dim,
37-
"dropout_rate": dropout_rate,
38-
"with_cross_attention": cross_attention,
39-
},
40-
(2, 512, hidden_size),
41-
(2, 512, hidden_size),
42-
]
43-
TEST_CASE_TRANSFORMERBLOCK.append(test_case)
27+
TEST_CASE_TRANSFORMERBLOCK = [
28+
[
29+
{
30+
"hidden_size": params["hidden_size"],
31+
"num_heads": params["num_heads"],
32+
"mlp_dim": params["mlp_dim"],
33+
"dropout_rate": params["dropout_rate"],
34+
"with_cross_attention": params["with_cross_attention"],
35+
},
36+
(2, 512, params["hidden_size"]),
37+
(2, 512, params["hidden_size"]),
38+
]
39+
for params in dict_product(
40+
dropout_rate=np.linspace(0, 1, 4),
41+
hidden_size=[360, 480, 600, 768],
42+
num_heads=[4, 8, 12],
43+
mlp_dim=[1024, 3072],
44+
with_cross_attention=[False, True],
45+
)
46+
]
4447

4548

4649
class TestTransformerBlock(unittest.TestCase):

0 commit comments

Comments
 (0)