|
21 | 21 | from monai.networks import eval_mode |
22 | 22 | from monai.networks.blocks.transformerblock import TransformerBlock |
23 | 23 | from monai.utils import optional_import |
| 24 | +from tests.test_utils import dict_product |
24 | 25 |
|
25 | 26 | einops, has_einops = optional_import("einops") |
26 | | -TEST_CASE_TRANSFORMERBLOCK = [] |
27 | | -for dropout_rate in np.linspace(0, 1, 4): |
28 | | - for hidden_size in [360, 480, 600, 768]: |
29 | | - for num_heads in [4, 8, 12]: |
30 | | - for mlp_dim in [1024, 3072]: |
31 | | - for cross_attention in [False, True]: |
32 | | - test_case = [ |
33 | | - { |
34 | | - "hidden_size": hidden_size, |
35 | | - "num_heads": num_heads, |
36 | | - "mlp_dim": mlp_dim, |
37 | | - "dropout_rate": dropout_rate, |
38 | | - "with_cross_attention": cross_attention, |
39 | | - }, |
40 | | - (2, 512, hidden_size), |
41 | | - (2, 512, hidden_size), |
42 | | - ] |
43 | | - TEST_CASE_TRANSFORMERBLOCK.append(test_case) |
| 27 | +TEST_CASE_TRANSFORMERBLOCK = [ |
| 28 | + [ |
| 29 | + { |
| 30 | + "hidden_size": params["hidden_size"], |
| 31 | + "num_heads": params["num_heads"], |
| 32 | + "mlp_dim": params["mlp_dim"], |
| 33 | + "dropout_rate": params["dropout_rate"], |
| 34 | + "with_cross_attention": params["with_cross_attention"], |
| 35 | + }, |
| 36 | + (2, 512, params["hidden_size"]), |
| 37 | + (2, 512, params["hidden_size"]), |
| 38 | + ] |
| 39 | + for params in dict_product( |
| 40 | + dropout_rate=np.linspace(0, 1, 4), |
| 41 | + hidden_size=[360, 480, 600, 768], |
| 42 | + num_heads=[4, 8, 12], |
| 43 | + mlp_dim=[1024, 3072], |
| 44 | + with_cross_attention=[False, True], |
| 45 | + ) |
| 46 | +] |
44 | 47 |
|
45 | 48 |
|
46 | 49 | class TestTransformerBlock(unittest.TestCase): |
|
0 commit comments