Skip to content

Commit c3b42bb

Browse files
committed
Refactor TEST_CASE_MaskedAutoEncoderViT to use dict_product for cleaner test case generation
Signed-off-by: R. Garcia-Dias <[email protected]>
1 parent 6c3eeac commit c3b42bb

File tree

1 file changed

+48
-46
lines changed

1 file changed

+48
-46
lines changed

tests/test_masked_autoencoder_vit.py

Lines changed: 48 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -18,54 +18,56 @@
1818

1919
from monai.networks import eval_mode
2020
from monai.networks.nets.masked_autoencoder_vit import MaskedAutoEncoderViT
21-
from tests.test_utils import skip_if_quick
21+
from tests.test_utils import skip_if_quick, dict_product
2222

2323
TEST_CASE_MaskedAutoEncoderViT = []
24-
for masking_ratio in [0.5]:
25-
for dropout_rate in [0.6]:
26-
for in_channels in [4]:
27-
for hidden_size in [768]:
28-
for img_size in [96, 128]:
29-
for patch_size in [16]:
30-
for num_heads in [12]:
31-
for mlp_dim in [3072]:
32-
for num_layers in [4]:
33-
for decoder_hidden_size in [384]:
34-
for decoder_mlp_dim in [512]:
35-
for decoder_num_layers in [4]:
36-
for decoder_num_heads in [16]:
37-
for pos_embed_type in ["sincos", "learnable"]:
38-
for proj_type in ["conv", "perceptron"]:
39-
for nd in (2, 3):
40-
test_case = [
41-
{
42-
"in_channels": in_channels,
43-
"img_size": (img_size,) * nd,
44-
"patch_size": (patch_size,) * nd,
45-
"hidden_size": hidden_size,
46-
"mlp_dim": mlp_dim,
47-
"num_layers": num_layers,
48-
"decoder_hidden_size": decoder_hidden_size,
49-
"decoder_mlp_dim": decoder_mlp_dim,
50-
"decoder_num_layers": decoder_num_layers,
51-
"decoder_num_heads": decoder_num_heads,
52-
"pos_embed_type": pos_embed_type,
53-
"masking_ratio": masking_ratio,
54-
"decoder_pos_embed_type": pos_embed_type,
55-
"num_heads": num_heads,
56-
"proj_type": proj_type,
57-
"dropout_rate": dropout_rate,
58-
},
59-
(2, in_channels, *([img_size] * nd)),
60-
(
61-
2,
62-
(img_size // patch_size) ** nd,
63-
in_channels * (patch_size**nd),
64-
),
65-
]
66-
if nd == 2:
67-
test_case[0]["spatial_dims"] = 2 # type: ignore
68-
TEST_CASE_MaskedAutoEncoderViT.append(test_case)
24+
25+
for base_params in dict_product(
26+
masking_ratio=[0.5],
27+
dropout_rate=[0.6],
28+
in_channels=[4],
29+
hidden_size=[768],
30+
img_size_scalar=[96, 128],
31+
patch_size_scalar=[16],
32+
num_heads=[12],
33+
mlp_dim=[3072],
34+
num_layers=[4],
35+
decoder_hidden_size=[384],
36+
decoder_mlp_dim=[512],
37+
decoder_num_layers=[4],
38+
decoder_num_heads=[16],
39+
pos_embed_type=["sincos", "learnable"],
40+
proj_type=["conv", "perceptron"],
41+
):
42+
img_size_scalar = base_params.pop("img_size_scalar")
43+
patch_size_scalar = base_params.pop("patch_size_scalar")
44+
for nd in (2, 3):
45+
# Parameters for the MaskedAutoEncoderViT model
46+
model_params = base_params.copy()
47+
model_params["img_size"] = (img_size_scalar,) * nd
48+
model_params["patch_size"] = (patch_size_scalar,) * nd
49+
model_params["decoder_pos_embed_type"] = model_params["pos_embed_type"]
50+
51+
# Expected input and output shapes
52+
input_shape = (2, model_params["in_channels"], *([img_size_scalar] * nd))
53+
# N, num_patches, patch_dim_product
54+
# num_patches = (img_size // patch_size) ** nd
55+
# patch_dim_product = in_channels * (patch_size**nd)
56+
expected_shape = (
57+
2,
58+
(img_size_scalar // patch_size_scalar) ** nd,
59+
model_params["in_channels"] * (patch_size_scalar**nd),
60+
)
61+
62+
if nd == 2:
63+
model_params["spatial_dims"] = 2
64+
65+
test_case = [
66+
model_params,
67+
input_shape,
68+
expected_shape,
69+
]
70+
TEST_CASE_MaskedAutoEncoderViT.append(test_case)
6971

7072
TEST_CASE_ill_args = [
7173
[{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (16, 16, 16), "dropout_rate": 5.0}],

0 commit comments

Comments
 (0)