Skip to content

Commit f5ac534

Browse files
committed
Add dict_product utility for test case generation
- Introduced a new `dict_product` function to create combinations of input parameters. - Updated test cases to use the new utility for generating shape and input type combinations. - Cleaned up code by removing old nested loops in favor of the new function.
1 parent 679ac3c commit f5ac534

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

tests/test_utils.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,10 @@
3030
import warnings
3131
from contextlib import contextmanager
3232
from functools import partial, reduce
33+
from itertools import product
3334
from pathlib import Path
3435
from subprocess import PIPE, Popen
35-
from typing import Callable
36+
from typing import Callable, Literal
3637
from urllib.error import ContentTooShortError, HTTPError
3738

3839
import numpy as np
@@ -862,6 +863,21 @@ def equal_state_dict(st_1, st_2):
862863
if torch.cuda.is_available():
863864
TEST_DEVICES.append([torch.device("cuda")])
864865

866+
867+
def dict_product(trailing=False, format: Literal["list", "dict"] = "dict", **items):
868+
keys = items.keys()
869+
values = items.values()
870+
for pvalues in product(*values):
871+
dict_comb = dict(zip(keys, pvalues))
872+
if format == "dict":
873+
if trailing:
874+
yield [dict_comb] + list(pvalues)
875+
else:
876+
yield dict_comb
877+
else:
878+
yield pvalues
879+
880+
865881
if __name__ == "__main__":
866882
parser = argparse.ArgumentParser(prog="util")
867883
parser.add_argument("-c", "--count", default=2, help="max number of gpus")

tests/transforms/test_gibbs_noise.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,12 @@
2121
from monai.transforms import GibbsNoise
2222
from monai.utils.misc import set_determinism
2323
from monai.utils.module import optional_import
24-
from tests.test_utils import TEST_NDARRAYS, assert_allclose
24+
from tests.test_utils import TEST_NDARRAYS, assert_allclose, dict_product
2525

2626
_, has_torch_fft = optional_import("torch.fft", name="fftshift")
2727

28-
TEST_CASES = []
29-
for shape in ((128, 64), (64, 48, 80)):
30-
for input_type in TEST_NDARRAYS if has_torch_fft else [np.array]:
31-
TEST_CASES.append((shape, input_type))
28+
params = {"shape": ((128, 64), (64, 48, 80)), "input_type": TEST_NDARRAYS if has_torch_fft else [np.array]}
29+
TEST_CASES = list(dict_product(format="list", **params))
3230

3331

3432
class TestGibbsNoise(unittest.TestCase):

0 commit comments

Comments
 (0)