Skip to content

Commit e0cda55

Browse files
added list extend to MultiSampleTrait
1 parent b92b2ce commit e0cda55

File tree

1 file changed

+77
-25
lines changed

1 file changed

+77
-25
lines changed

monai/transforms/transform.py

Lines changed: 77 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,22 @@ def _apply_transform(
9090
"""
9191
from monai.transforms.lazy.functional import apply_pending_transforms_in_order
9292

93-
data = apply_pending_transforms_in_order(transform, data, lazy, overrides, logger_name)
93+
data = apply_pending_transforms_in_order(
94+
transform, data, lazy, overrides, logger_name
95+
)
9496

9597
if isinstance(data, tuple) and unpack_parameters:
96-
return transform(*data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(*data)
98+
return (
99+
transform(*data, lazy=lazy)
100+
if isinstance(transform, LazyTrait)
101+
else transform(*data)
102+
)
97103

98-
return transform(data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(data)
104+
return (
105+
transform(data, lazy=lazy)
106+
if isinstance(transform, LazyTrait)
107+
else transform(data)
108+
)
99109

100110

101111
def apply_transform(
@@ -143,31 +153,49 @@ def apply_transform(
143153
try:
144154
map_items_ = int(map_items) if isinstance(map_items, bool) else map_items
145155
if isinstance(data, (list, tuple)) and map_items_ > 0:
146-
return [
147-
apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides)
148-
for item in data
149-
]
150-
return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats)
156+
res = []
157+
for item in data:
158+
res_item = _apply_transform(
159+
transform, item, unpack_items, lazy, overrides, log_stats
160+
)
161+
if isinstance(res_item, list | tuple):
162+
res.extend(res_item)
163+
else:
164+
res.append(res_item)
165+
return res
166+
return _apply_transform(
167+
transform, data, unpack_items, lazy, overrides, log_stats
168+
)
151169
except Exception as e:
152170
# if in debug mode, don't swallow exception so that the breakpoint
153171
# appears where the exception was raised.
154172
if MONAIEnvVars.debug():
155173
raise
156-
if log_stats is not False and not isinstance(transform, transforms.compose.Compose):
174+
if log_stats is not False and not isinstance(
175+
transform, transforms.compose.Compose
176+
):
157177
# log the input data information of exact transform in the transform chain
158178
if isinstance(log_stats, str):
159-
datastats = transforms.utility.array.DataStats(data_shape=False, value_range=False, name=log_stats)
179+
datastats = transforms.utility.array.DataStats(
180+
data_shape=False, value_range=False, name=log_stats
181+
)
160182
else:
161-
datastats = transforms.utility.array.DataStats(data_shape=False, value_range=False)
183+
datastats = transforms.utility.array.DataStats(
184+
data_shape=False, value_range=False
185+
)
162186
logger = logging.getLogger(datastats._logger_name)
163-
logger.error(f"\n=== Transform input info -- {type(transform).__name__} ===")
187+
logger.error(
188+
f"\n=== Transform input info -- {type(transform).__name__} ==="
189+
)
164190
if isinstance(data, (list, tuple)):
165191
data = data[0]
166192

167193
def _log_stats(data, prefix: str | None = "Data"):
168194
if isinstance(data, (np.ndarray, torch.Tensor)):
169195
# log data type, shape, range for array
170-
datastats(img=data, data_shape=True, value_range=True, prefix=prefix)
196+
datastats(
197+
img=data, data_shape=True, value_range=True, prefix=prefix
198+
)
171199
else:
172200
# log data type and value for other metadata
173201
datastats(img=data, data_value=True, prefix=prefix)
@@ -194,7 +222,9 @@ class Randomizable(ThreadUnsafe, RandomizableTrait):
194222

195223
R: np.random.RandomState = np.random.RandomState()
196224

197-
def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Randomizable:
225+
def set_random_state(
226+
self, seed: int | None = None, state: np.random.RandomState | None = None
227+
) -> Randomizable:
198228
"""
199229
Set the random state locally, to control the randomness, the derived
200230
classes should use :py:attr:`self.R` instead of `np.random` to introduce random
@@ -212,14 +242,20 @@ def set_random_state(self, seed: int | None = None, state: np.random.RandomState
212242
213243
"""
214244
if seed is not None:
215-
_seed = np.int64(id(seed) if not isinstance(seed, (int, np.integer)) else seed)
216-
_seed = _seed % MAX_SEED # need to account for Numpy2.0 which doesn't silently convert to int64
245+
_seed = np.int64(
246+
id(seed) if not isinstance(seed, (int, np.integer)) else seed
247+
)
248+
_seed = (
249+
_seed % MAX_SEED
250+
) # need to account for Numpy2.0 which doesn't silently convert to int64
217251
self.R = np.random.RandomState(_seed)
218252
return self
219253

220254
if state is not None:
221255
if not isinstance(state, np.random.RandomState):
222-
raise TypeError(f"state must be None or a np.random.RandomState but is {type(state).__name__}.")
256+
raise TypeError(
257+
f"state must be None or a np.random.RandomState but is {type(state).__name__}."
258+
)
223259
self.R = state
224260
return self
225261

@@ -238,7 +274,9 @@ def randomize(self, data: Any) -> None:
238274
Raises:
239275
NotImplementedError: When the subclass does not override this method.
240276
"""
241-
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
277+
raise NotImplementedError(
278+
f"Subclass {self.__class__.__name__} must implement this method."
279+
)
242280

243281

244282
class Transform(ABC):
@@ -294,7 +332,9 @@ def __call__(self, data: Any):
294332
NotImplementedError: When the subclass does not override this method.
295333
296334
"""
297-
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
335+
raise NotImplementedError(
336+
f"Subclass {self.__class__.__name__} must implement this method."
337+
)
298338

299339

300340
class LazyTransform(Transform, LazyTrait):
@@ -397,11 +437,15 @@ def __call__(self, data):
397437
def __new__(cls, *args, **kwargs):
398438
if config.USE_META_DICT:
399439
# call_update after MapTransform.__call__
400-
cls.__call__ = transforms.attach_hook(cls.__call__, MapTransform.call_update, "post") # type: ignore
440+
cls.__call__ = transforms.attach_hook(
441+
cls.__call__, MapTransform.call_update, "post"
442+
) # type: ignore
401443

402444
if hasattr(cls, "inverse"):
403445
# inverse_update before InvertibleTransform.inverse
404-
cls.inverse: Any = transforms.attach_hook(cls.inverse, transforms.InvertibleTransform.inverse_update)
446+
cls.inverse: Any = transforms.attach_hook(
447+
cls.inverse, transforms.InvertibleTransform.inverse_update
448+
)
405449
return Transform.__new__(cls)
406450

407451
def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:
@@ -412,7 +456,9 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No
412456
raise ValueError("keys must be non empty.")
413457
for key in self.keys:
414458
if not isinstance(key, Hashable):
415-
raise TypeError(f"keys must be one of (Hashable, Iterable[Hashable]) but is {type(keys).__name__}.")
459+
raise TypeError(
460+
f"keys must be one of (Hashable, Iterable[Hashable]) but is {type(keys).__name__}."
461+
)
416462

417463
def call_update(self, data):
418464
"""
@@ -432,7 +478,9 @@ def call_update(self, data):
432478
for k in dict_i:
433479
if not isinstance(dict_i[k], MetaTensor):
434480
continue
435-
list_d[idx] = transforms.sync_meta_info(k, dict_i, t=not isinstance(self, transforms.InvertD))
481+
list_d[idx] = transforms.sync_meta_info(
482+
k, dict_i, t=not isinstance(self, transforms.InvertD)
483+
)
436484
return list_d[0] if is_dict else list_d
437485

438486
@abstractmethod
@@ -460,9 +508,13 @@ def __call__(self, data):
460508
An updated dictionary version of ``data`` by applying the transform.
461509
462510
"""
463-
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
511+
raise NotImplementedError(
512+
f"Subclass {self.__class__.__name__} must implement this method."
513+
)
464514

465-
def key_iterator(self, data: Mapping[Hashable, Any], *extra_iterables: Iterable | None) -> Generator:
515+
def key_iterator(
516+
self, data: Mapping[Hashable, Any], *extra_iterables: Iterable | None
517+
) -> Generator:
466518
"""
467519
Iterate across keys and optionally extra iterables. If key is missing, exception is raised if
468520
`allow_missing_keys==False` (default). If `allow_missing_keys==True`, key is skipped.

0 commit comments

Comments
 (0)