|
20 | 20 | from monai.networks import eval_mode |
21 | 21 | from monai.networks.nets import resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200 |
22 | 22 | from monai.utils import ensure_tuple, optional_import |
23 | | -from tests.test_utils import SkipIfBeforePyTorchVersion, skip_if_quick, test_onnx_save, test_script_save |
| 23 | +from tests.test_utils import SkipIfBeforePyTorchVersion, dict_product, skip_if_quick, test_onnx_save, test_script_save |
24 | 24 |
|
25 | 25 | _, has_torchvision = optional_import("torchvision") |
26 | 26 |
|
|
86 | 86 | (2, 1, 32, 64), |
87 | 87 | ] |
88 | 88 |
|
89 | | -TEST_CASES = [] |
90 | | -for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]: |
91 | | - for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]: |
92 | | - TEST_CASES.append([model, *case]) |
| 89 | +# Create all test case combinations using dict_product |
| 90 | +CASE_LIST = [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A] |
| 91 | +MODEL_LIST = [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200] |
93 | 92 |
|
94 | | -TEST_CASES_TS = [] |
95 | | -for case in [TEST_CASE_1]: |
96 | | - for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]: |
97 | | - TEST_CASES_TS.append([model, *case]) |
| 93 | +TEST_CASES = [[params["model"], *params["case"]] for params in dict_product(model=MODEL_LIST, case=CASE_LIST)] |
| 94 | +TEST_CASES_TS = [[params["model"], *params["case"]] for params in dict_product(model=MODEL_LIST, case=[TEST_CASE_1])] |
98 | 95 |
|
99 | 96 |
|
100 | 97 | @SkipIfBeforePyTorchVersion((1, 12)) |
|
0 commit comments