|
22 | 22 | from monai.networks.blocks.selfattention import SABlock |
23 | 23 | from monai.networks.layers.factories import RelPosEmbedding |
24 | 24 | from monai.utils import optional_import |
25 | | -from tests.test_utils import SkipIfBeforePyTorchVersion, assert_allclose, test_script_save |
| 25 | +from tests.test_utils import SkipIfBeforePyTorchVersion, assert_allclose, dict_product, test_script_save |
26 | 26 |
|
27 | 27 | einops, has_einops = optional_import("einops") |
28 | 28 |
|
29 | 29 | TEST_CASE_SABLOCK = [] |
30 | | -for dropout_rate in np.linspace(0, 1, 4): |
31 | | - for hidden_size in [360, 480, 600, 768]: |
32 | | - for num_heads in [4, 6, 8, 12]: |
33 | | - for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]: |
34 | | - for input_size in [(16, 32), (8, 8, 8)]: |
35 | | - for include_fc in [True, False]: |
36 | | - for use_combined_linear in [True, False]: |
37 | | - test_case = [ |
38 | | - { |
39 | | - "hidden_size": hidden_size, |
40 | | - "num_heads": num_heads, |
41 | | - "dropout_rate": dropout_rate, |
42 | | - "rel_pos_embedding": rel_pos_embedding, |
43 | | - "input_size": input_size, |
44 | | - "include_fc": include_fc, |
45 | | - "use_combined_linear": use_combined_linear, |
46 | | - "use_flash_attention": True if rel_pos_embedding is None else False, |
47 | | - }, |
48 | | - (2, 512, hidden_size), |
49 | | - (2, 512, hidden_size), |
50 | | - ] |
51 | | - TEST_CASE_SABLOCK.append(test_case) |
| 30 | +for params in dict_product( |
| 31 | + dropout_rate=np.linspace(0, 1, 4), |
| 32 | + hidden_size=[360, 480, 600, 768], |
| 33 | + num_heads=[4, 6, 8, 12], |
| 34 | + rel_pos_embedding=[None, RelPosEmbedding.DECOMPOSED], |
| 35 | + input_size=[(16, 32), (8, 8, 8)], |
| 36 | + include_fc=[True, False], |
| 37 | + use_combined_linear=[True, False], |
| 38 | +): |
| 39 | + test_case = [ |
| 40 | + { |
| 41 | + "hidden_size": params["hidden_size"], |
| 42 | + "num_heads": params["num_heads"], |
| 43 | + "dropout_rate": params["dropout_rate"], |
| 44 | + "rel_pos_embedding": params["rel_pos_embedding"], |
| 45 | + "input_size": params["input_size"], |
| 46 | + "include_fc": params["include_fc"], |
| 47 | + "use_combined_linear": params["use_combined_linear"], |
| 48 | + "use_flash_attention": True if params["rel_pos_embedding"] is None else False, |
| 49 | + }, |
| 50 | + (2, 512, params["hidden_size"]), |
| 51 | + (2, 512, params["hidden_size"]), |
| 52 | + ] |
| 53 | + TEST_CASE_SABLOCK.append(test_case) |
52 | 54 |
|
53 | 55 |
|
54 | 56 | class TestResBlock(unittest.TestCase): |
|
0 commit comments