Skip to content

Commit 2c28c37

Browse files
committed
allow map_items to be int
Signed-off-by: YunLiu <[email protected]>
1 parent c0daf6f commit 2c28c37

File tree

3 files changed

+31
-9
lines changed

3 files changed

+31
-9
lines changed

monai/transforms/compose.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
def execute_compose(
4848
data: NdarrayOrTensor | Sequence[NdarrayOrTensor] | Mapping[Any, NdarrayOrTensor],
4949
transforms: Sequence[Any],
50-
map_items: bool = True,
50+
map_items: bool | int = True,
5151
unpack_items: bool = False,
5252
start: int = 0,
5353
end: int | None = None,
@@ -66,7 +66,7 @@ def execute_compose(
6666
data: a tensor-like object to be transformed
6767
transforms: a sequence of transforms to be carried out
6868
map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple.
69-
defaults to `True`.
69+
defaults to `True`. If set to an integer, the transform will be applied to that index of the input `data`.
7070
unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform.
7171
defaults to `False`.
7272
start: the index of the first transform to be executed. If not set, this defaults to 0
@@ -206,7 +206,7 @@ class Compose(Randomizable, InvertibleTransform, LazyTransform):
206206
Args:
207207
transforms: sequence of callables.
208208
map_items: whether to apply transform to each item in the input `data` if `data` is a list or tuple.
209-
defaults to `True`.
209+
defaults to `True`. If set to an integer, the transform will be applied to that index of the input `data`.
210210
unpack_items: whether to unpack input `data` with `*` as parameters for the callable function of transform.
211211
defaults to `False`.
212212
log_stats: this optional parameter allows you to specify a logger by name for logging of pipeline execution.
@@ -227,7 +227,7 @@ class Compose(Randomizable, InvertibleTransform, LazyTransform):
227227
def __init__(
228228
self,
229229
transforms: Sequence[Callable] | Callable | None = None,
230-
map_items: bool = True,
230+
map_items: bool | int = True,
231231
unpack_items: bool = False,
232232
log_stats: bool | str = False,
233233
lazy: bool | None = False,
@@ -238,9 +238,9 @@ def __init__(
238238
if transforms is None:
239239
transforms = []
240240

241-
if not isinstance(map_items, bool):
241+
if not isinstance(map_items, (bool, int)):
242242
raise ValueError(
243-
f"Argument 'map_items' should be boolean. Got {type(map_items)}."
243+
f"Argument 'map_items' should be boolean or int. Got {type(map_items)}."
244244
"Check brackets when passing a sequence of callables."
245245
)
246246

monai/transforms/transform.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def _apply_transform(
101101
def apply_transform(
102102
transform: Callable[..., ReturnType],
103103
data: Any,
104-
map_items: bool = True,
104+
map_items: bool | int = True,
105105
unpack_items: bool = False,
106106
log_stats: bool | str = False,
107107
lazy: bool | None = None,
@@ -119,6 +119,7 @@ def apply_transform(
119119
data: an object to be transformed.
120120
map_items: whether to apply transform to each item in `data`,
121121
if `data` is a list or tuple. Defaults to True.
122+
it can also be an int, if so, apply the transform to each item in the list `map_items` times.
122123
unpack_items: whether to unpack parameters using `*`. Defaults to False.
123124
log_stats: log errors when they occur in the processing pipeline. By default, this is set to False, which
124125
disables the logger for processing pipeline errors. Setting it to None or True will enable logging to the
@@ -136,8 +137,15 @@ def apply_transform(
136137
Union[List[ReturnType], ReturnType]: The return type of `transform` or a list thereof.
137138
"""
138139
try:
139-
if isinstance(data, (list, tuple)) and map_items:
140-
return [apply_transform(transform, item, map_items, unpack_items, log_stats, lazy, overrides) for item in data]
140+
if isinstance(data, (list, tuple)) and isinstance(map_items, (int, bool)):
141+
# if map_items is an int, apply the transform to each item in the list `map_items` times
142+
if isinstance(map_items, int) and type(map_items) is not bool and map_items > 0:
143+
return [
144+
apply_transform(transform, item, map_items - 2, unpack_items, log_stats, lazy, overrides)
145+
for item in data
146+
]
147+
else:
148+
return [_apply_transform(transform, item, unpack_items, lazy, overrides, log_stats) for item in data]
141149
return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats)
142150
except Exception as e:
143151
# if in debug mode, don't swallow exception so that the breakpoint

tests/test_compose.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,20 @@ def b(i, i2):
141141
self.assertEqual(mt.Compose(transforms, unpack_items=True)(data), expected)
142142
self.assertEqual(execute_compose(data, transforms, unpack_items=True), expected)
143143

144+
def test_list_non_dict_compose_with_unpack_map_2(self):
145+
146+
def a(i, i2):
147+
return i + "a", i2 + "a2"
148+
149+
def b(i, i2):
150+
return i + "b", i2 + "b2"
151+
152+
transforms = [a, b, a, b]
153+
data = [[("", ""), ("", "")], [("t", "t"), ("t", "t")]]
154+
expected = [[("abab", "a2b2a2b2"), ("abab", "a2b2a2b2")], [("tabab", "ta2b2a2b2"), ("tabab", "ta2b2a2b2")]]
155+
self.assertEqual(mt.Compose(transforms, map_items=2, unpack_items=True)(data), expected)
156+
self.assertEqual(execute_compose(data, transforms, map_items=2, unpack_items=True), expected)
157+
144158
def test_list_dict_compose_no_map(self):
145159

146160
def a(d): # transform to handle dict data

0 commit comments

Comments
 (0)