Skip to content

Commit 836cf6e

Browse files
committed
Refactor PatchEmbeddingBlock test cases to use dict_product for parameter combinations
Signed-off-by: R. Garcia-Dias <[email protected]>
1 parent 0139125 commit 836cf6e

File tree

1 file changed

+49
-46
lines changed

1 file changed

+49
-46
lines changed

tests/networks/blocks/test_patchembedding.py

Lines changed: 49 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -21,58 +21,61 @@
2121
from monai.networks import eval_mode
2222
from monai.networks.blocks.patchembedding import PatchEmbed, PatchEmbeddingBlock
2323
from monai.utils import optional_import
24-
from tests.test_utils import SkipIfBeforePyTorchVersion
24+
from tests.test_utils import SkipIfBeforePyTorchVersion, dict_product
2525

2626
einops, has_einops = optional_import("einops")
2727

2828
TEST_CASE_PATCHEMBEDDINGBLOCK = []
29-
for dropout_rate in (0.5,):
30-
for in_channels in [1, 4]:
31-
for hidden_size in [96, 288]:
32-
for img_size in [32, 64]:
33-
for patch_size in [8, 16]:
34-
for num_heads in [8, 12]:
35-
for proj_type in ["conv", "perceptron"]:
36-
for pos_embed_type in ["none", "learnable", "sincos"]:
37-
# for classification in (False, True): # TODO: add classification tests
38-
for nd in (2, 3):
39-
test_case = [
40-
{
41-
"in_channels": in_channels,
42-
"img_size": (img_size,) * nd,
43-
"patch_size": (patch_size,) * nd,
44-
"hidden_size": hidden_size,
45-
"num_heads": num_heads,
46-
"proj_type": proj_type,
47-
"pos_embed_type": pos_embed_type,
48-
"dropout_rate": dropout_rate,
49-
},
50-
(2, in_channels, *([img_size] * nd)),
51-
(2, (img_size // patch_size) ** nd, hidden_size),
52-
]
53-
if nd == 2:
54-
test_case[0]["spatial_dims"] = 2 # type: ignore
55-
TEST_CASE_PATCHEMBEDDINGBLOCK.append(test_case)
29+
for params in dict_product(
30+
dropout_rate=[0.5],
31+
in_channels=[1, 4],
32+
hidden_size=[96, 288],
33+
img_size=[32, 64],
34+
patch_size=[8, 16],
35+
num_heads=[8, 12],
36+
proj_type=["conv", "perceptron"],
37+
pos_embed_type=["none", "learnable", "sincos"],
38+
nd=[2, 3],
39+
):
40+
test_case = [
41+
{
42+
"in_channels": params["in_channels"],
43+
"img_size": (params["img_size"],) * params["nd"],
44+
"patch_size": (params["patch_size"],) * params["nd"],
45+
"hidden_size": params["hidden_size"],
46+
"num_heads": params["num_heads"],
47+
"proj_type": params["proj_type"],
48+
"pos_embed_type": params["pos_embed_type"],
49+
"dropout_rate": params["dropout_rate"],
50+
},
51+
(2, params["in_channels"], *[params["img_size"]] * params["nd"]),
52+
(2, (params["img_size"] // params["patch_size"]) ** params["nd"], params["hidden_size"]),
53+
]
54+
if params["nd"] == 2:
55+
test_case[0]["spatial_dims"] = 2
56+
TEST_CASE_PATCHEMBEDDINGBLOCK.append(test_case)
5657

5758
TEST_CASE_PATCHEMBED = []
58-
for patch_size in [2]:
59-
for in_chans in [1, 4]:
60-
for img_size in [96]:
61-
for embed_dim in [6, 12]:
62-
for norm_layer in [nn.LayerNorm]:
63-
for nd in [2, 3]:
64-
test_case = [
65-
{
66-
"patch_size": (patch_size,) * nd,
67-
"in_chans": in_chans,
68-
"embed_dim": embed_dim,
69-
"norm_layer": norm_layer,
70-
"spatial_dims": nd,
71-
},
72-
(2, in_chans, *([img_size] * nd)),
73-
(2, embed_dim, *([img_size // patch_size] * nd)),
74-
]
75-
TEST_CASE_PATCHEMBED.append(test_case)
59+
for params in dict_product(
60+
patch_size=[2],
61+
in_chans=[1, 4],
62+
img_size=[96],
63+
embed_dim=[6, 12],
64+
norm_layer=[nn.LayerNorm],
65+
nd=[2, 3],
66+
):
67+
test_case = [
68+
{
69+
"patch_size": (params["patch_size"],) * params["nd"],
70+
"in_chans": params["in_chans"],
71+
"embed_dim": params["embed_dim"],
72+
"norm_layer": params["norm_layer"],
73+
"spatial_dims": params["nd"],
74+
},
75+
(2, params["in_chans"], *[params["img_size"]] * params["nd"]),
76+
(2, params["embed_dim"], *[params["img_size"] // params["patch_size"]] * params["nd"]),
77+
]
78+
TEST_CASE_PATCHEMBED.append(test_case)
7679

7780

7881
@SkipIfBeforePyTorchVersion((1, 11, 1))

0 commit comments

Comments
 (0)