Skip to content

Commit b46ccc0

Browse files
committed
redesign dict_product to make more readable
Signed-off-by: R. Garcia-Dias <[email protected]>
1 parent a5596d7 commit b46ccc0

File tree

3 files changed

+127
-46
lines changed

3 files changed

+127
-46
lines changed

tests/networks/blocks/test_patchembedding.py

Lines changed: 22 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525

2626
einops, has_einops = optional_import("einops")
2727

28-
TEST_CASE_PATCHEMBEDDINGBLOCK = []
29-
for params in dict_product(
28+
29+
TEST_CASE_PATCHEMBEDDINGBLOCK = dict_product(
3030
dropout_rate=[0.5],
3131
in_channels=[1, 4],
3232
hidden_size=[96, 288],
@@ -36,41 +36,27 @@
3636
proj_type=["conv", "perceptron"],
3737
pos_embed_type=["none", "learnable", "sincos"],
3838
nd=[2, 3],
39-
):
40-
test_case = [
41-
{
42-
"in_channels": params["in_channels"],
43-
"img_size": (params["img_size"],) * params["nd"],
44-
"patch_size": (params["patch_size"],) * params["nd"],
45-
"hidden_size": params["hidden_size"],
46-
"num_heads": params["num_heads"],
47-
"proj_type": params["proj_type"],
48-
"pos_embed_type": params["pos_embed_type"],
49-
"dropout_rate": params["dropout_rate"],
50-
"spatial_dims": params["nd"],
51-
},
52-
(2, params["in_channels"], *[params["img_size"]] * params["nd"]),
53-
(2, (params["img_size"] // params["patch_size"]) ** params["nd"], params["hidden_size"]),
54-
]
55-
TEST_CASE_PATCHEMBEDDINGBLOCK.append(test_case)
56-
57-
TEST_CASE_PATCHEMBED = []
58-
for params in dict_product(
39+
)
40+
TEST_CASE_PATCHEMBEDDINGBLOCK = [
41+
[
42+
params,
43+
(2, params["in_channels"], *([params["img_size"]] * params["nd"])),
44+
(2, (params["img_size"] // params["patch_size"]) ** params["nd"], params["hidden_size"]),
45+
]
46+
for params in TEST_CASE_PATCHEMBEDDINGBLOCK
47+
]
48+
49+
TEST_CASE_PATCHEMBED = dict_product(
5950
patch_size=[2], in_chans=[1, 4], img_size=[96], embed_dim=[6, 12], norm_layer=[nn.LayerNorm], nd=[2, 3]
60-
):
61-
test_case = [
62-
{
63-
"patch_size": (params["patch_size"],) * params["nd"],
64-
"in_chans": params["in_chans"],
65-
"embed_dim": params["embed_dim"],
66-
"norm_layer": params["norm_layer"],
67-
"spatial_dims": params["nd"],
68-
},
69-
(2, params["in_chans"], *[params["img_size"]] * params["nd"]),
70-
(2, params["embed_dim"], *[params["img_size"] // params["patch_size"]] * params["nd"]),
71-
]
72-
TEST_CASE_PATCHEMBED.append(test_case)
73-
51+
)
52+
TEST_CASE_PATCHEMBED = [
53+
[
54+
params,
55+
(2, params["in_chans"], *([params["img_size"]] * params["nd"])),
56+
(2, (params["img_size"] // params["patch_size"]) ** params["nd"], params["embed_dim"]),
57+
]
58+
for params in TEST_CASE_PATCHEMBED
59+
]
7460

7561
@SkipIfBeforePyTorchVersion((1, 11, 1))
7662
class TestPatchEmbeddingBlock(unittest.TestCase):

tests/test_test_utils.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
from __future__ import annotations
2+
3+
import unittest
4+
5+
from tests.test_utils import dict_product
6+
7+
8+
class TestTestUtils(unittest.TestCase):
9+
def setUp(self):
10+
TEST_CASE_PATCHEMBEDDINGBLOCK = []
11+
for dropout_rate in (0.5,):
12+
for in_channels in [1, 4]:
13+
for hidden_size in [96, 288]:
14+
for img_size in [32, 64]:
15+
for patch_size in [8, 16]:
16+
for num_heads in [8, 12]:
17+
for proj_type in ["conv", "perceptron"]:
18+
for pos_embed_type in ["none", "learnable", "sincos"]:
19+
# for classification in (False, True): # TODO: add classification tests
20+
for nd in (2, 3):
21+
test_case = [
22+
{
23+
"in_channels": in_channels,
24+
"img_size": (img_size,) * nd,
25+
"patch_size": (patch_size,) * nd,
26+
"hidden_size": hidden_size,
27+
"num_heads": num_heads,
28+
"proj_type": proj_type,
29+
"pos_embed_type": pos_embed_type,
30+
"dropout_rate": dropout_rate,
31+
"spatial_dims": nd
32+
},
33+
(2, in_channels, *([img_size] * nd)),
34+
(2, (img_size // patch_size) ** nd, hidden_size),
35+
]
36+
TEST_CASE_PATCHEMBEDDINGBLOCK.append(test_case)
37+
38+
self.test_case_patchembeddingblock = TEST_CASE_PATCHEMBEDDINGBLOCK
39+
40+
def test_case_patchembeddingblock(self):
41+
test_case_patchembeddingblock = dict_product(
42+
dropout_rate=[0.5],
43+
in_channels=[1, 4],
44+
hidden_size=[96, 288],
45+
img_size=[32, 64],
46+
patch_size=[8, 16],
47+
num_heads=[8, 12],
48+
proj_type=["conv", "perceptron"],
49+
pos_embed_type=["none", "learnable", "sincos"],
50+
nd=[2, 3],
51+
)
52+
test_case_patchembeddingblock = [
53+
[
54+
params,
55+
(2, params["in_channels"], *([params["img_size"]] * params["nd"])),
56+
(2, (params["img_size"] // params["patch_size"]) ** params["nd"], params["hidden_size"]),
57+
]
58+
for params in test_case_patchembeddingblock
59+
]
60+
61+
self.assertIsInstance(test_case_patchembeddingblock, list)
62+
self.assertGreater(len(test_case_patchembeddingblock), 0)
63+
self.assertEqual(
64+
len(test_case_patchembeddingblock),
65+
len(self.test_case_patchembeddingblock),
66+
)
67+
self.assertEqual(
68+
len(test_case_patchembeddingblock[0]),
69+
len(self.test_case_patchembeddingblock[0]),
70+
)
71+
self.assertEqual(
72+
len(test_case_patchembeddingblock[0][0]),
73+
len(self.test_case_patchembeddingblock[0][0]),
74+
)
75+
self.assertEqual(
76+
test_case_patchembeddingblock[0][0]["in_channels"],
77+
self.test_case_patchembeddingblock[0][0]["in_channels"],
78+
)
79+
self.assertEqual(
80+
test_case_patchembeddingblock[0][1],
81+
self.test_case_patchembeddingblock[0][1],
82+
)
83+
self.assertEqual(
84+
test_case_patchembeddingblock[0][2],
85+
self.test_case_patchembeddingblock[0][2],
86+
)
87+
88+
if __name__ == "__main__":
89+
unittest.main()

tests/test_utils.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -864,18 +864,24 @@ def equal_state_dict(st_1, st_2):
864864
TEST_DEVICES.append([torch.device("cuda")])
865865

866866

867-
def dict_product(trailing=False, format: Literal["list", "dict"] = "dict", **items):
867+
def dict_product(**items: dict[str, list]) -> list[dict]:
868+
"""Create cartesian product, equivalent to a nested for-loop, combinations of the items dict.
869+
870+
Args:
871+
items: dict of items to be combined.
872+
873+
Returns:
874+
list: list of dictionaries with the combinations of the input items.
875+
876+
Example:
877+
>>> dict_product(x=[1, 2], y=[3, 4])
878+
[{'x': 1, 'y': 3}, {'x': 1, 'y': 4}, {'x': 2, 'y': 3}, {'x': 2, 'y': 4}]
879+
"""
868880
keys = items.keys()
869881
values = items.values()
870-
for pvalues in product(*values):
871-
dict_comb = dict(zip(keys, pvalues))
872-
if format == "dict":
873-
if trailing:
874-
yield [dict_comb] + list(pvalues)
875-
else:
876-
yield dict_comb
877-
else:
878-
yield pvalues
882+
prod_values = product(*values)
883+
prod_dict = [dict(zip(keys, v)) for v in prod_values]
884+
return prod_dict
879885

880886

881887
if __name__ == "__main__":

0 commit comments

Comments
 (0)