Skip to content

Commit 2d58774

Browse files
added test for many multisample transforms; refactored code
1 parent 3aa1288 commit 2d58774

File tree

2 files changed

+24
-5
lines changed

2 files changed

+24
-5
lines changed

monai/transforms/transform.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,8 @@ def apply_transform(
148148
res_item = apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides)
149149
# Only extend if we're at the leaf level (map_items_ == 1) and the transform
150150
# actually returned a list (not preserving nested structure)
151-
if isinstance(res_item, list) and map_items_ == 1:
152-
if not isinstance(item, (list, tuple)):
153-
res.extend(res_item)
154-
else:
155-
res.append(res_item)
151+
if isinstance(res_item, list) and map_items_ == 1 and not isinstance(item, (list, tuple)):
152+
res.extend(res_item)
156153
else:
157154
res.append(res_item)
158155
return res

tests/transforms/compose/test_compose.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,28 @@ def test_flatten_and_len(self):
282282
def test_backwards_compatible_imports(self):
283283
from monai.transforms.transform import MapTransform, RandomizableTransform, Transform # noqa: F401
284284

285+
def test_list_extend_multi_sample_trait(self):
286+
from monai.transforms import CenterSpatialCrop, RandSpatialCropSamples
287+
288+
center_crop = CenterSpatialCrop([128, 128])
289+
multi_sample_transform = RandSpatialCropSamples([64, 64], 1)
290+
291+
img = torch.zeros([1, 512, 512])
292+
293+
assert execute_compose(img, [center_crop]).shape == torch.Size([1, 128, 128])
294+
single_multi_sample_trait_result = execute_compose(img, [multi_sample_transform, center_crop])
295+
assert (
296+
isinstance(single_multi_sample_trait_result, list)
297+
and len(single_multi_sample_trait_result) == 1
298+
and single_multi_sample_trait_result[0].shape == torch.Size([1, 64, 64])
299+
)
300+
double_multi_sample_trait_result = execute_compose(img, [multi_sample_transform, multi_sample_transform, center_crop])
301+
assert (
302+
isinstance(double_multi_sample_trait_result, list)
303+
and len(double_multi_sample_trait_result) == 1
304+
and double_multi_sample_trait_result[0].shape == torch.Size([1, 64, 64])
305+
)
306+
285307

286308
TEST_COMPOSE_EXECUTE_TEST_CASES = [
287309
[None, tuple()],

0 commit comments

Comments
 (0)