@@ -90,22 +90,12 @@ def _apply_transform(
9090 """
9191 from monai .transforms .lazy .functional import apply_pending_transforms_in_order
9292
93- data = apply_pending_transforms_in_order (
94- transform , data , lazy , overrides , logger_name
95- )
93+ data = apply_pending_transforms_in_order (transform , data , lazy , overrides , logger_name )
9694
9795 if isinstance (data , tuple ) and unpack_parameters :
98- return (
99- transform (* data , lazy = lazy )
100- if isinstance (transform , LazyTrait )
101- else transform (* data )
102- )
96+ return transform (* data , lazy = lazy ) if isinstance (transform , LazyTrait ) else transform (* data )
10397
104- return (
105- transform (data , lazy = lazy )
106- if isinstance (transform , LazyTrait )
107- else transform (data )
108- )
98+ return transform (data , lazy = lazy ) if isinstance (transform , LazyTrait ) else transform (data )
10999
110100
111101def apply_transform (
@@ -155,45 +145,38 @@ def apply_transform(
155145 if isinstance (data , (list , tuple )) and map_items_ > 0 :
156146 res : list [ReturnType ] = []
157147 for item in data :
158- res_item = _apply_transform (transform , item , unpack_items , lazy , overrides , log_stats )
159- if isinstance (res_item , (list , tuple )):
160- res .extend (res_item )
148+ res_item = apply_transform (transform , item , map_items_ - 1 , unpack_items , log_stats , lazy , overrides )
149+ # Only extend if we're at the leaf level (map_items_ == 1) and the transform
150+ # actually returned a list (not preserving nested structure)
151+ if isinstance (res_item , list ) and map_items_ == 1 :
152+ if not isinstance (item , (list , tuple )):
153+ res .extend (res_item )
154+ else :
155+ res .append (res_item )
161156 else :
162157 res .append (res_item )
163158 return res
164- return _apply_transform (
165- transform , data , unpack_items , lazy , overrides , log_stats
166- )
159+ return _apply_transform (transform , data , unpack_items , lazy , overrides , log_stats )
167160 except Exception as e :
168161 # if in debug mode, don't swallow exception so that the breakpoint
169162 # appears where the exception was raised.
170163 if MONAIEnvVars .debug ():
171164 raise
172- if log_stats is not False and not isinstance (
173- transform , transforms .compose .Compose
174- ):
165+ if log_stats is not False and not isinstance (transform , transforms .compose .Compose ):
175166 # log the input data information of exact transform in the transform chain
176167 if isinstance (log_stats , str ):
177- datastats = transforms .utility .array .DataStats (
178- data_shape = False , value_range = False , name = log_stats
179- )
168+ datastats = transforms .utility .array .DataStats (data_shape = False , value_range = False , name = log_stats )
180169 else :
181- datastats = transforms .utility .array .DataStats (
182- data_shape = False , value_range = False
183- )
170+ datastats = transforms .utility .array .DataStats (data_shape = False , value_range = False )
184171 logger = logging .getLogger (datastats ._logger_name )
185- logger .error (
186- f"\n === Transform input info -- { type (transform ).__name__ } ==="
187- )
172+ logger .error (f"\n === Transform input info -- { type (transform ).__name__ } ===" )
188173 if isinstance (data , (list , tuple )):
189174 data = data [0 ]
190175
191176 def _log_stats (data , prefix : str | None = "Data" ):
192177 if isinstance (data , (np .ndarray , torch .Tensor )):
193178 # log data type, shape, range for array
194- datastats (
195- img = data , data_shape = True , value_range = True , prefix = prefix
196- )
179+ datastats (img = data , data_shape = True , value_range = True , prefix = prefix )
197180 else :
198181 # log data type and value for other metadata
199182 datastats (img = data , data_value = True , prefix = prefix )
@@ -220,9 +203,7 @@ class Randomizable(ThreadUnsafe, RandomizableTrait):
220203
221204 R : np .random .RandomState = np .random .RandomState ()
222205
223- def set_random_state (
224- self , seed : int | None = None , state : np .random .RandomState | None = None
225- ) -> Randomizable :
206+ def set_random_state (self , seed : int | None = None , state : np .random .RandomState | None = None ) -> Randomizable :
226207 """
227208 Set the random state locally, to control the randomness, the derived
228209 classes should use :py:attr:`self.R` instead of `np.random` to introduce random
@@ -240,20 +221,14 @@ def set_random_state(
240221
241222 """
242223 if seed is not None :
243- _seed = np .int64 (
244- id (seed ) if not isinstance (seed , (int , np .integer )) else seed
245- )
246- _seed = (
247- _seed % MAX_SEED
248- ) # need to account for Numpy2.0 which doesn't silently convert to int64
224+ _seed = np .int64 (id (seed ) if not isinstance (seed , (int , np .integer )) else seed )
225+ _seed = _seed % MAX_SEED # need to account for Numpy2.0 which doesn't silently convert to int64
249226 self .R = np .random .RandomState (_seed )
250227 return self
251228
252229 if state is not None :
253230 if not isinstance (state , np .random .RandomState ):
254- raise TypeError (
255- f"state must be None or a np.random.RandomState but is { type (state ).__name__ } ."
256- )
231+ raise TypeError (f"state must be None or a np.random.RandomState but is { type (state ).__name__ } ." )
257232 self .R = state
258233 return self
259234
@@ -272,9 +247,7 @@ def randomize(self, data: Any) -> None:
272247 Raises:
273248 NotImplementedError: When the subclass does not override this method.
274249 """
275- raise NotImplementedError (
276- f"Subclass { self .__class__ .__name__ } must implement this method."
277- )
250+ raise NotImplementedError (f"Subclass { self .__class__ .__name__ } must implement this method." )
278251
279252
280253class Transform (ABC ):
@@ -330,9 +303,7 @@ def __call__(self, data: Any):
330303 NotImplementedError: When the subclass does not override this method.
331304
332305 """
333- raise NotImplementedError (
334- f"Subclass { self .__class__ .__name__ } must implement this method."
335- )
306+ raise NotImplementedError (f"Subclass { self .__class__ .__name__ } must implement this method." )
336307
337308
338309class LazyTransform (Transform , LazyTrait ):
@@ -435,15 +406,11 @@ def __call__(self, data):
435406 def __new__ (cls , * args , ** kwargs ):
436407 if config .USE_META_DICT :
437408 # call_update after MapTransform.__call__
438- cls .__call__ = transforms .attach_hook (
439- cls .__call__ , MapTransform .call_update , "post"
440- ) # type: ignore
409+ cls .__call__ = transforms .attach_hook (cls .__call__ , MapTransform .call_update , "post" ) # type: ignore
441410
442411 if hasattr (cls , "inverse" ):
443412 # inverse_update before InvertibleTransform.inverse
444- cls .inverse : Any = transforms .attach_hook (
445- cls .inverse , transforms .InvertibleTransform .inverse_update
446- )
413+ cls .inverse : Any = transforms .attach_hook (cls .inverse , transforms .InvertibleTransform .inverse_update )
447414 return Transform .__new__ (cls )
448415
449416 def __init__ (self , keys : KeysCollection , allow_missing_keys : bool = False ) -> None :
@@ -454,9 +421,7 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No
454421 raise ValueError ("keys must be non empty." )
455422 for key in self .keys :
456423 if not isinstance (key , Hashable ):
457- raise TypeError (
458- f"keys must be one of (Hashable, Iterable[Hashable]) but is { type (keys ).__name__ } ."
459- )
424+ raise TypeError (f"keys must be one of (Hashable, Iterable[Hashable]) but is { type (keys ).__name__ } ." )
460425
461426 def call_update (self , data ):
462427 """
@@ -476,9 +441,7 @@ def call_update(self, data):
476441 for k in dict_i :
477442 if not isinstance (dict_i [k ], MetaTensor ):
478443 continue
479- list_d [idx ] = transforms .sync_meta_info (
480- k , dict_i , t = not isinstance (self , transforms .InvertD )
481- )
444+ list_d [idx ] = transforms .sync_meta_info (k , dict_i , t = not isinstance (self , transforms .InvertD ))
482445 return list_d [0 ] if is_dict else list_d
483446
484447 @abstractmethod
@@ -506,13 +469,9 @@ def __call__(self, data):
506469 An updated dictionary version of ``data`` by applying the transform.
507470
508471 """
509- raise NotImplementedError (
510- f"Subclass { self .__class__ .__name__ } must implement this method."
511- )
472+ raise NotImplementedError (f"Subclass { self .__class__ .__name__ } must implement this method." )
512473
513- def key_iterator (
514- self , data : Mapping [Hashable , Any ], * extra_iterables : Iterable | None
515- ) -> Generator :
474+ def key_iterator (self , data : Mapping [Hashable , Any ], * extra_iterables : Iterable | None ) -> Generator :
516475 """
517476 Iterate across keys and optionally extra iterables. If key is missing, exception is raised if
518477 `allow_missing_keys==False` (default). If `allow_missing_keys==True`, key is skipped.
@@ -532,8 +491,7 @@ def key_iterator(
532491 yield (key ,) + tuple (_ex_iters ) if extra_iterables else key
533492 elif not self .allow_missing_keys :
534493 raise KeyError (
535- f"Key `{ key } ` of transform `{ self .__class__ .__name__ } ` was missing in the data"
536- " and allow_missing_keys==False."
494+ f"Key `{ key } ` of transform `{ self .__class__ .__name__ } ` was missing in the data and allow_missing_keys==False."
537495 )
538496
539497 def first_key (self , data : dict [Hashable , Any ]):
0 commit comments