Skip to content

Commit 336b287

Browse files
committed
Simplify with list comprehension
Signed-off-by: R. Garcia-Dias <[email protected]>
1 parent 84d85ea commit 336b287

File tree

4 files changed

+36
-48
lines changed

4 files changed

+36
-48
lines changed

tests/apps/detection/networks/test_retinanet.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,13 +90,8 @@
9090
CASE_LIST = [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]
9191
MODEL_LIST = [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]
9292

93-
TEST_CASES = []
94-
for params in dict_product(model=MODEL_LIST, case=CASE_LIST):
95-
TEST_CASES.append([params["model"], *params["case"]])
96-
97-
TEST_CASES_TS = []
98-
for params in dict_product(model=MODEL_LIST, case=[TEST_CASE_1]):
99-
TEST_CASES_TS.append([params["model"], *params["case"]])
93+
TEST_CASES = [[params["model"], *params["case"]] for params in dict_product(model=MODEL_LIST, case=CASE_LIST)]
94+
TEST_CASES_TS = [[params["model"], *params["case"]] for params in dict_product(model=MODEL_LIST, case=[TEST_CASE_1])]
10095

10196

10297
@SkipIfBeforePyTorchVersion((1, 12))

tests/data/meta_tensor/test_meta_tensor.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,8 @@
3737
DTYPES = [[torch.float32], [torch.float64], [torch.float16], [torch.int64], [torch.int32], [None]]
3838

3939
# Replace nested loops with dict_product
40-
TESTS = []
41-
for params in dict_product(device=TEST_DEVICES, dtype=DTYPES):
42-
TESTS.append((*params["device"], *params["dtype"])) # type: ignore
40+
41+
TESTS = [(*params["device"], *params["dtype"]) for params in dict_product(device=TEST_DEVICES, dtype=DTYPES)]
4342

4443

4544
def rand_string(min_len=5, max_len=10):

tests/networks/blocks/test_patchembedding.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,36 +26,39 @@
2626
einops, has_einops = optional_import("einops")
2727

2828

29-
TEST_CASE_PATCHEMBEDDINGBLOCK = dict_product(
30-
dropout_rate=[0.5],
31-
in_channels=[1, 4],
32-
hidden_size=[96, 288],
33-
img_size=[32, 64],
34-
patch_size=[8, 16],
35-
num_heads=[8, 12],
36-
proj_type=["conv", "perceptron"],
37-
pos_embed_type=["none", "learnable", "sincos"],
38-
spatial_dims=[2, 3],
39-
)
4029
TEST_CASE_PATCHEMBEDDINGBLOCK = [
4130
[
4231
params,
4332
(2, params["in_channels"], *([params["img_size"]] * params["spatial_dims"])),
4433
(2, (params["img_size"] // params["patch_size"]) ** params["spatial_dims"], params["hidden_size"]),
4534
]
46-
for params in TEST_CASE_PATCHEMBEDDINGBLOCK
35+
for params in dict_product(
36+
dropout_rate=[0.5],
37+
in_channels=[1, 4],
38+
hidden_size=[96, 288],
39+
img_size=[32, 64],
40+
patch_size=[8, 16],
41+
num_heads=[8, 12],
42+
proj_type=["conv", "perceptron"],
43+
pos_embed_type=["none", "learnable", "sincos"],
44+
spatial_dims=[2, 3],
45+
)
4746
]
4847

49-
TEST_CASE_PATCHEMBED = dict_product(
50-
patch_size=[2], in_chans=[1, 4], img_size=[96], embed_dim=[6, 12], norm_layer=[nn.LayerNorm], spatial_dims=[2, 3]
51-
)
5248
TEST_CASE_PATCHEMBED = [
5349
[
5450
params,
5551
(2, params["in_chans"], *([params["img_size"]] * params["spatial_dims"])),
5652
(2, (params["img_size"] // params["patch_size"]) ** params["spatial_dims"], params["embed_dim"]),
5753
]
58-
for params in TEST_CASE_PATCHEMBED
54+
for params in dict_product(
55+
patch_size=[2],
56+
in_chans=[1, 4],
57+
img_size=[96],
58+
embed_dim=[6, 12],
59+
norm_layer=[nn.LayerNorm],
60+
spatial_dims=[2, 3],
61+
)
5962
]
6063

6164

tests/networks/blocks/test_selfattention.py

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,31 +26,22 @@
2626

2727
einops, has_einops = optional_import("einops")
2828

29-
TEST_CASE_SABLOCK = []
30-
for params in dict_product(
31-
dropout_rate=np.linspace(0, 1, 4),
32-
hidden_size=[360, 480, 600, 768],
33-
num_heads=[4, 6, 8, 12],
34-
rel_pos_embedding=[None, RelPosEmbedding.DECOMPOSED],
35-
input_size=[(16, 32), (8, 8, 8)],
36-
include_fc=[True, False],
37-
use_combined_linear=[True, False],
38-
):
39-
test_case = [
40-
{
41-
"hidden_size": params["hidden_size"],
42-
"num_heads": params["num_heads"],
43-
"dropout_rate": params["dropout_rate"],
44-
"rel_pos_embedding": params["rel_pos_embedding"],
45-
"input_size": params["input_size"],
46-
"include_fc": params["include_fc"],
47-
"use_combined_linear": params["use_combined_linear"],
48-
"use_flash_attention": True if params["rel_pos_embedding"] is None else False,
49-
},
29+
TEST_CASE_SABLOCK = [
30+
[
31+
params,
5032
(2, 512, params["hidden_size"]),
5133
(2, 512, params["hidden_size"]),
5234
]
53-
TEST_CASE_SABLOCK.append(test_case)
35+
for params in dict_product(
36+
dropout_rate=np.linspace(0, 1, 4),
37+
hidden_size=[360, 480, 600, 768],
38+
num_heads=[4, 6, 8, 12],
39+
rel_pos_embedding=[None, RelPosEmbedding.DECOMPOSED],
40+
input_size=[(16, 32), (8, 8, 8)],
41+
include_fc=[True, False],
42+
use_combined_linear=[True, False],
43+
)
44+
]
5445

5546

5647
class TestResBlock(unittest.TestCase):

0 commit comments

Comments
 (0)