Skip to content

Commit 0139125

Browse files
committed
Refactor self-attention test cases to use dict_product for parameter combinations
Signed-off-by: R. Garcia-Dias <[email protected]>
1 parent 5342cba commit 0139125

File tree

1 file changed

+25
-23
lines changed

1 file changed

+25
-23
lines changed

tests/networks/blocks/test_selfattention.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,33 +22,35 @@
2222
from monai.networks.blocks.selfattention import SABlock
2323
from monai.networks.layers.factories import RelPosEmbedding
2424
from monai.utils import optional_import
25-
from tests.test_utils import SkipIfBeforePyTorchVersion, assert_allclose, test_script_save
25+
from tests.test_utils import SkipIfBeforePyTorchVersion, assert_allclose, dict_product, test_script_save
2626

2727
einops, has_einops = optional_import("einops")
2828

2929
TEST_CASE_SABLOCK = []
30-
for dropout_rate in np.linspace(0, 1, 4):
31-
for hidden_size in [360, 480, 600, 768]:
32-
for num_heads in [4, 6, 8, 12]:
33-
for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]:
34-
for input_size in [(16, 32), (8, 8, 8)]:
35-
for include_fc in [True, False]:
36-
for use_combined_linear in [True, False]:
37-
test_case = [
38-
{
39-
"hidden_size": hidden_size,
40-
"num_heads": num_heads,
41-
"dropout_rate": dropout_rate,
42-
"rel_pos_embedding": rel_pos_embedding,
43-
"input_size": input_size,
44-
"include_fc": include_fc,
45-
"use_combined_linear": use_combined_linear,
46-
"use_flash_attention": True if rel_pos_embedding is None else False,
47-
},
48-
(2, 512, hidden_size),
49-
(2, 512, hidden_size),
50-
]
51-
TEST_CASE_SABLOCK.append(test_case)
30+
for params in dict_product(
31+
dropout_rate=np.linspace(0, 1, 4),
32+
hidden_size=[360, 480, 600, 768],
33+
num_heads=[4, 6, 8, 12],
34+
rel_pos_embedding=[None, RelPosEmbedding.DECOMPOSED],
35+
input_size=[(16, 32), (8, 8, 8)],
36+
include_fc=[True, False],
37+
use_combined_linear=[True, False],
38+
):
39+
test_case = [
40+
{
41+
"hidden_size": params["hidden_size"],
42+
"num_heads": params["num_heads"],
43+
"dropout_rate": params["dropout_rate"],
44+
"rel_pos_embedding": params["rel_pos_embedding"],
45+
"input_size": params["input_size"],
46+
"include_fc": params["include_fc"],
47+
"use_combined_linear": params["use_combined_linear"],
48+
"use_flash_attention": True if params["rel_pos_embedding"] is None else False,
49+
},
50+
(2, 512, params["hidden_size"]),
51+
(2, 512, params["hidden_size"]),
52+
]
53+
TEST_CASE_SABLOCK.append(test_case)
5254

5355

5456
class TestResBlock(unittest.TestCase):

0 commit comments

Comments
 (0)