|
18 | 18 |
|
19 | 19 | from monai.networks import eval_mode |
20 | 20 | from monai.networks.nets.masked_autoencoder_vit import MaskedAutoEncoderViT |
21 | | -from tests.test_utils import skip_if_quick |
| 21 | +from tests.test_utils import skip_if_quick, dict_product |
22 | 22 |
|
23 | 23 | TEST_CASE_MaskedAutoEncoderViT = [] |
24 | | -for masking_ratio in [0.5]: |
25 | | - for dropout_rate in [0.6]: |
26 | | - for in_channels in [4]: |
27 | | - for hidden_size in [768]: |
28 | | - for img_size in [96, 128]: |
29 | | - for patch_size in [16]: |
30 | | - for num_heads in [12]: |
31 | | - for mlp_dim in [3072]: |
32 | | - for num_layers in [4]: |
33 | | - for decoder_hidden_size in [384]: |
34 | | - for decoder_mlp_dim in [512]: |
35 | | - for decoder_num_layers in [4]: |
36 | | - for decoder_num_heads in [16]: |
37 | | - for pos_embed_type in ["sincos", "learnable"]: |
38 | | - for proj_type in ["conv", "perceptron"]: |
39 | | - for nd in (2, 3): |
40 | | - test_case = [ |
41 | | - { |
42 | | - "in_channels": in_channels, |
43 | | - "img_size": (img_size,) * nd, |
44 | | - "patch_size": (patch_size,) * nd, |
45 | | - "hidden_size": hidden_size, |
46 | | - "mlp_dim": mlp_dim, |
47 | | - "num_layers": num_layers, |
48 | | - "decoder_hidden_size": decoder_hidden_size, |
49 | | - "decoder_mlp_dim": decoder_mlp_dim, |
50 | | - "decoder_num_layers": decoder_num_layers, |
51 | | - "decoder_num_heads": decoder_num_heads, |
52 | | - "pos_embed_type": pos_embed_type, |
53 | | - "masking_ratio": masking_ratio, |
54 | | - "decoder_pos_embed_type": pos_embed_type, |
55 | | - "num_heads": num_heads, |
56 | | - "proj_type": proj_type, |
57 | | - "dropout_rate": dropout_rate, |
58 | | - }, |
59 | | - (2, in_channels, *([img_size] * nd)), |
60 | | - ( |
61 | | - 2, |
62 | | - (img_size // patch_size) ** nd, |
63 | | - in_channels * (patch_size**nd), |
64 | | - ), |
65 | | - ] |
66 | | - if nd == 2: |
67 | | - test_case[0]["spatial_dims"] = 2 # type: ignore |
68 | | - TEST_CASE_MaskedAutoEncoderViT.append(test_case) |
| 24 | + |
| 25 | +for base_params in dict_product( |
| 26 | + masking_ratio=[0.5], |
| 27 | + dropout_rate=[0.6], |
| 28 | + in_channels=[4], |
| 29 | + hidden_size=[768], |
| 30 | + img_size_scalar=[96, 128], |
| 31 | + patch_size_scalar=[16], |
| 32 | + num_heads=[12], |
| 33 | + mlp_dim=[3072], |
| 34 | + num_layers=[4], |
| 35 | + decoder_hidden_size=[384], |
| 36 | + decoder_mlp_dim=[512], |
| 37 | + decoder_num_layers=[4], |
| 38 | + decoder_num_heads=[16], |
| 39 | + pos_embed_type=["sincos", "learnable"], |
| 40 | + proj_type=["conv", "perceptron"], |
| 41 | +): |
| 42 | + img_size_scalar = base_params.pop("img_size_scalar") |
| 43 | + patch_size_scalar = base_params.pop("patch_size_scalar") |
| 44 | + for nd in (2, 3): |
| 45 | + # Parameters for the MaskedAutoEncoderViT model |
| 46 | + model_params = base_params.copy() |
| 47 | + model_params["img_size"] = (img_size_scalar,) * nd |
| 48 | + model_params["patch_size"] = (patch_size_scalar,) * nd |
| 49 | + model_params["decoder_pos_embed_type"] = model_params["pos_embed_type"] |
| 50 | + |
| 51 | + # Expected input and output shapes |
| 52 | + input_shape = (2, model_params["in_channels"], *([img_size_scalar] * nd)) |
| 53 | + # N, num_patches, patch_dim_product |
| 54 | + # num_patches = (img_size // patch_size) ** nd |
| 55 | + # patch_dim_product = in_channels * (patch_size**nd) |
| 56 | + expected_shape = ( |
| 57 | + 2, |
| 58 | + (img_size_scalar // patch_size_scalar) ** nd, |
| 59 | + model_params["in_channels"] * (patch_size_scalar**nd), |
| 60 | + ) |
| 61 | + |
| 62 | + if nd == 2: |
| 63 | + model_params["spatial_dims"] = 2 |
| 64 | + |
| 65 | + test_case = [ |
| 66 | + model_params, |
| 67 | + input_shape, |
| 68 | + expected_shape, |
| 69 | + ] |
| 70 | + TEST_CASE_MaskedAutoEncoderViT.append(test_case) |
69 | 71 |
|
70 | 72 | TEST_CASE_ill_args = [ |
71 | 73 | [{"in_channels": 1, "img_size": (128, 128, 128), "patch_size": (16, 16, 16), "dropout_rate": 5.0}], |
|
0 commit comments