Skip to content

Commit 7df8cb9

Browse files
avoided breaking map_item functionality
DCO Remediation Commit for Lukas Folle <[email protected]> I, Lukas Folle <[email protected]>, hereby add my Signed-off-by to this commit: e0cda55 Signed-off-by: Lukas Folle <[email protected]>
1 parent 77c138d commit 7df8cb9

File tree

1 file changed

+30
-72
lines changed

1 file changed

+30
-72
lines changed

monai/transforms/transform.py

Lines changed: 30 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -90,22 +90,12 @@ def _apply_transform(
9090
"""
9191
from monai.transforms.lazy.functional import apply_pending_transforms_in_order
9292

93-
data = apply_pending_transforms_in_order(
94-
transform, data, lazy, overrides, logger_name
95-
)
93+
data = apply_pending_transforms_in_order(transform, data, lazy, overrides, logger_name)
9694

9795
if isinstance(data, tuple) and unpack_parameters:
98-
return (
99-
transform(*data, lazy=lazy)
100-
if isinstance(transform, LazyTrait)
101-
else transform(*data)
102-
)
96+
return transform(*data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(*data)
10397

104-
return (
105-
transform(data, lazy=lazy)
106-
if isinstance(transform, LazyTrait)
107-
else transform(data)
108-
)
98+
return transform(data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(data)
10999

110100

111101
def apply_transform(
@@ -155,45 +145,38 @@ def apply_transform(
155145
if isinstance(data, (list, tuple)) and map_items_ > 0:
156146
res: list[ReturnType] = []
157147
for item in data:
158-
res_item = _apply_transform(transform, item, unpack_items, lazy, overrides, log_stats)
159-
if isinstance(res_item, (list, tuple)):
160-
res.extend(res_item)
148+
res_item = apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides)
149+
# Only extend if we're at the leaf level (map_items_ == 1) and the transform
150+
# actually returned a list (not preserving nested structure)
151+
if isinstance(res_item, list) and map_items_ == 1:
152+
if not isinstance(item, (list, tuple)):
153+
res.extend(res_item)
154+
else:
155+
res.append(res_item)
161156
else:
162157
res.append(res_item)
163158
return res
164-
return _apply_transform(
165-
transform, data, unpack_items, lazy, overrides, log_stats
166-
)
159+
return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats)
167160
except Exception as e:
168161
# if in debug mode, don't swallow exception so that the breakpoint
169162
# appears where the exception was raised.
170163
if MONAIEnvVars.debug():
171164
raise
172-
if log_stats is not False and not isinstance(
173-
transform, transforms.compose.Compose
174-
):
165+
if log_stats is not False and not isinstance(transform, transforms.compose.Compose):
175166
# log the input data information of exact transform in the transform chain
176167
if isinstance(log_stats, str):
177-
datastats = transforms.utility.array.DataStats(
178-
data_shape=False, value_range=False, name=log_stats
179-
)
168+
datastats = transforms.utility.array.DataStats(data_shape=False, value_range=False, name=log_stats)
180169
else:
181-
datastats = transforms.utility.array.DataStats(
182-
data_shape=False, value_range=False
183-
)
170+
datastats = transforms.utility.array.DataStats(data_shape=False, value_range=False)
184171
logger = logging.getLogger(datastats._logger_name)
185-
logger.error(
186-
f"\n=== Transform input info -- {type(transform).__name__} ==="
187-
)
172+
logger.error(f"\n=== Transform input info -- {type(transform).__name__} ===")
188173
if isinstance(data, (list, tuple)):
189174
data = data[0]
190175

191176
def _log_stats(data, prefix: str | None = "Data"):
192177
if isinstance(data, (np.ndarray, torch.Tensor)):
193178
# log data type, shape, range for array
194-
datastats(
195-
img=data, data_shape=True, value_range=True, prefix=prefix
196-
)
179+
datastats(img=data, data_shape=True, value_range=True, prefix=prefix)
197180
else:
198181
# log data type and value for other metadata
199182
datastats(img=data, data_value=True, prefix=prefix)
@@ -220,9 +203,7 @@ class Randomizable(ThreadUnsafe, RandomizableTrait):
220203

221204
R: np.random.RandomState = np.random.RandomState()
222205

223-
def set_random_state(
224-
self, seed: int | None = None, state: np.random.RandomState | None = None
225-
) -> Randomizable:
206+
def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Randomizable:
226207
"""
227208
Set the random state locally, to control the randomness, the derived
228209
classes should use :py:attr:`self.R` instead of `np.random` to introduce random
@@ -240,20 +221,14 @@ def set_random_state(
240221
241222
"""
242223
if seed is not None:
243-
_seed = np.int64(
244-
id(seed) if not isinstance(seed, (int, np.integer)) else seed
245-
)
246-
_seed = (
247-
_seed % MAX_SEED
248-
) # need to account for Numpy2.0 which doesn't silently convert to int64
224+
_seed = np.int64(id(seed) if not isinstance(seed, (int, np.integer)) else seed)
225+
_seed = _seed % MAX_SEED # need to account for Numpy2.0 which doesn't silently convert to int64
249226
self.R = np.random.RandomState(_seed)
250227
return self
251228

252229
if state is not None:
253230
if not isinstance(state, np.random.RandomState):
254-
raise TypeError(
255-
f"state must be None or a np.random.RandomState but is {type(state).__name__}."
256-
)
231+
raise TypeError(f"state must be None or a np.random.RandomState but is {type(state).__name__}.")
257232
self.R = state
258233
return self
259234

@@ -272,9 +247,7 @@ def randomize(self, data: Any) -> None:
272247
Raises:
273248
NotImplementedError: When the subclass does not override this method.
274249
"""
275-
raise NotImplementedError(
276-
f"Subclass {self.__class__.__name__} must implement this method."
277-
)
250+
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
278251

279252

280253
class Transform(ABC):
@@ -330,9 +303,7 @@ def __call__(self, data: Any):
330303
NotImplementedError: When the subclass does not override this method.
331304
332305
"""
333-
raise NotImplementedError(
334-
f"Subclass {self.__class__.__name__} must implement this method."
335-
)
306+
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
336307

337308

338309
class LazyTransform(Transform, LazyTrait):
@@ -435,15 +406,11 @@ def __call__(self, data):
435406
def __new__(cls, *args, **kwargs):
436407
if config.USE_META_DICT:
437408
# call_update after MapTransform.__call__
438-
cls.__call__ = transforms.attach_hook(
439-
cls.__call__, MapTransform.call_update, "post"
440-
) # type: ignore
409+
cls.__call__ = transforms.attach_hook(cls.__call__, MapTransform.call_update, "post") # type: ignore
441410

442411
if hasattr(cls, "inverse"):
443412
# inverse_update before InvertibleTransform.inverse
444-
cls.inverse: Any = transforms.attach_hook(
445-
cls.inverse, transforms.InvertibleTransform.inverse_update
446-
)
413+
cls.inverse: Any = transforms.attach_hook(cls.inverse, transforms.InvertibleTransform.inverse_update)
447414
return Transform.__new__(cls)
448415

449416
def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:
@@ -454,9 +421,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No
454421
raise ValueError("keys must be non empty.")
455422
for key in self.keys:
456423
if not isinstance(key, Hashable):
457-
raise TypeError(
458-
f"keys must be one of (Hashable, Iterable[Hashable]) but is {type(keys).__name__}."
459-
)
424+
raise TypeError(f"keys must be one of (Hashable, Iterable[Hashable]) but is {type(keys).__name__}.")
460425

461426
def call_update(self, data):
462427
"""
@@ -476,9 +441,7 @@ def call_update(self, data):
476441
for k in dict_i:
477442
if not isinstance(dict_i[k], MetaTensor):
478443
continue
479-
list_d[idx] = transforms.sync_meta_info(
480-
k, dict_i, t=not isinstance(self, transforms.InvertD)
481-
)
444+
list_d[idx] = transforms.sync_meta_info(k, dict_i, t=not isinstance(self, transforms.InvertD))
482445
return list_d[0] if is_dict else list_d
483446

484447
@abstractmethod
@@ -506,13 +469,9 @@ def __call__(self, data):
506469
An updated dictionary version of ``data`` by applying the transform.
507470
508471
"""
509-
raise NotImplementedError(
510-
f"Subclass {self.__class__.__name__} must implement this method."
511-
)
472+
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
512473

513-
def key_iterator(
514-
self, data: Mapping[Hashable, Any], *extra_iterables: Iterable | None
515-
) -> Generator:
474+
def key_iterator(self, data: Mapping[Hashable, Any], *extra_iterables: Iterable | None) -> Generator:
516475
"""
517476
Iterate across keys and optionally extra iterables. If key is missing, exception is raised if
518477
`allow_missing_keys==False` (default). If `allow_missing_keys==True`, key is skipped.
@@ -532,8 +491,7 @@ def key_iterator(
532491
yield (key,) + tuple(_ex_iters) if extra_iterables else key
533492
elif not self.allow_missing_keys:
534493
raise KeyError(
535-
f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the data"
536-
" and allow_missing_keys==False."
494+
f"Key `{key}` of transform `{self.__class__.__name__}` was missing in the data and allow_missing_keys==False."
537495
)
538496

539497
def first_key(self, data: dict[Hashable, Any]):

0 commit comments

Comments
 (0)