|
20 | 20 | from monai.networks import eval_mode |
21 | 21 | from monai.networks.nets import resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200 |
22 | 22 | from monai.utils import ensure_tuple, optional_import |
23 | | -from tests.test_utils import SkipIfBeforePyTorchVersion, skip_if_quick, test_onnx_save, test_script_save |
| 23 | +from tests.test_utils import SkipIfBeforePyTorchVersion, dict_product, skip_if_quick, test_onnx_save, test_script_save |
24 | 24 |
|
25 | 25 | _, has_torchvision = optional_import("torchvision") |
26 | 26 |
|
|
86 | 86 | (2, 1, 32, 64), |
87 | 87 | ] |
88 | 88 |
|
| 89 | +# Create all test case combinations using dict_product |
| 90 | +CASE_LIST = [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A] |
| 91 | +MODEL_LIST = [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200] |
| 92 | + |
89 | 93 | TEST_CASES = [] |
90 | | -for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]: |
91 | | - for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]: |
92 | | - TEST_CASES.append([model, *case]) |
| 94 | +for params in dict_product( |
| 95 | + model=MODEL_LIST, |
| 96 | + case=CASE_LIST, |
| 97 | +): |
| 98 | + TEST_CASES.append([params["model"], *params["case"]]) |
93 | 99 |
|
94 | 100 | TEST_CASES_TS = [] |
95 | | -for case in [TEST_CASE_1]: |
96 | | - for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]: |
97 | | - TEST_CASES_TS.append([model, *case]) |
| 101 | +for params in dict_product( |
| 102 | + model=MODEL_LIST, |
| 103 | + case=[TEST_CASE_1], |
| 104 | +): |
| 105 | + TEST_CASES_TS.append([params["model"], *params["case"]]) |
98 | 106 |
|
99 | 107 |
|
100 | 108 | @SkipIfBeforePyTorchVersion((1, 12)) |
|
0 commit comments