@@ -90,12 +90,22 @@ def _apply_transform(
9090 """
9191 from monai .transforms .lazy .functional import apply_pending_transforms_in_order
9292
93- data = apply_pending_transforms_in_order (transform , data , lazy , overrides , logger_name )
93+ data = apply_pending_transforms_in_order (
94+ transform , data , lazy , overrides , logger_name
95+ )
9496
9597 if isinstance (data , tuple ) and unpack_parameters :
96- return transform (* data , lazy = lazy ) if isinstance (transform , LazyTrait ) else transform (* data )
98+ return (
99+ transform (* data , lazy = lazy )
100+ if isinstance (transform , LazyTrait )
101+ else transform (* data )
102+ )
97103
98- return transform (data , lazy = lazy ) if isinstance (transform , LazyTrait ) else transform (data )
104+ return (
105+ transform (data , lazy = lazy )
106+ if isinstance (transform , LazyTrait )
107+ else transform (data )
108+ )
99109
100110
101111def apply_transform (
@@ -143,31 +153,49 @@ def apply_transform(
143153 try :
144154 map_items_ = int (map_items ) if isinstance (map_items , bool ) else map_items
145155 if isinstance (data , (list , tuple )) and map_items_ > 0 :
146- return [
147- apply_transform (transform , item , map_items_ - 1 , unpack_items , log_stats , lazy , overrides )
148- for item in data
149- ]
150- return _apply_transform (transform , data , unpack_items , lazy , overrides , log_stats )
156+ res = []
157+ for item in data :
158+ res_item = _apply_transform (
159+ transform , item , unpack_items , lazy , overrides , log_stats
160+ )
161+ if isinstance (res_item , list | tuple ):
162+ res .extend (res_item )
163+ else :
164+ res .append (res_item )
165+ return res
166+ return _apply_transform (
167+ transform , data , unpack_items , lazy , overrides , log_stats
168+ )
151169 except Exception as e :
152170 # if in debug mode, don't swallow exception so that the breakpoint
153171 # appears where the exception was raised.
154172 if MONAIEnvVars .debug ():
155173 raise
156- if log_stats is not False and not isinstance (transform , transforms .compose .Compose ):
174+ if log_stats is not False and not isinstance (
175+ transform , transforms .compose .Compose
176+ ):
157177 # log the input data information of exact transform in the transform chain
158178 if isinstance (log_stats , str ):
159- datastats = transforms .utility .array .DataStats (data_shape = False , value_range = False , name = log_stats )
179+ datastats = transforms .utility .array .DataStats (
180+ data_shape = False , value_range = False , name = log_stats
181+ )
160182 else :
161- datastats = transforms .utility .array .DataStats (data_shape = False , value_range = False )
183+ datastats = transforms .utility .array .DataStats (
184+ data_shape = False , value_range = False
185+ )
162186 logger = logging .getLogger (datastats ._logger_name )
163- logger .error (f"\n === Transform input info -- { type (transform ).__name__ } ===" )
187+ logger .error (
188+ f"\n === Transform input info -- { type (transform ).__name__ } ==="
189+ )
164190 if isinstance (data , (list , tuple )):
165191 data = data [0 ]
166192
167193 def _log_stats (data , prefix : str | None = "Data" ):
168194 if isinstance (data , (np .ndarray , torch .Tensor )):
169195 # log data type, shape, range for array
170- datastats (img = data , data_shape = True , value_range = True , prefix = prefix )
196+ datastats (
197+ img = data , data_shape = True , value_range = True , prefix = prefix
198+ )
171199 else :
172200 # log data type and value for other metadata
173201 datastats (img = data , data_value = True , prefix = prefix )
@@ -194,7 +222,9 @@ class Randomizable(ThreadUnsafe, RandomizableTrait):
194222
195223 R : np .random .RandomState = np .random .RandomState ()
196224
197- def set_random_state (self , seed : int | None = None , state : np .random .RandomState | None = None ) -> Randomizable :
225+ def set_random_state (
226+ self , seed : int | None = None , state : np .random .RandomState | None = None
227+ ) -> Randomizable :
198228 """
199229 Set the random state locally, to control the randomness, the derived
200230 classes should use :py:attr:`self.R` instead of `np.random` to introduce random
@@ -212,14 +242,20 @@ def set_random_state(self, seed: int | None = None, state: np.random.RandomState
212242
213243 """
214244 if seed is not None :
215- _seed = np .int64 (id (seed ) if not isinstance (seed , (int , np .integer )) else seed )
216- _seed = _seed % MAX_SEED # need to account for Numpy2.0 which doesn't silently convert to int64
245+ _seed = np .int64 (
246+ id (seed ) if not isinstance (seed , (int , np .integer )) else seed
247+ )
248+ _seed = (
249+ _seed % MAX_SEED
250+ ) # need to account for Numpy2.0 which doesn't silently convert to int64
217251 self .R = np .random .RandomState (_seed )
218252 return self
219253
220254 if state is not None :
221255 if not isinstance (state , np .random .RandomState ):
222- raise TypeError (f"state must be None or a np.random.RandomState but is { type (state ).__name__ } ." )
256+ raise TypeError (
257+ f"state must be None or a np.random.RandomState but is { type (state ).__name__ } ."
258+ )
223259 self .R = state
224260 return self
225261
@@ -238,7 +274,9 @@ def randomize(self, data: Any) -> None:
238274 Raises:
239275 NotImplementedError: When the subclass does not override this method.
240276 """
241- raise NotImplementedError (f"Subclass { self .__class__ .__name__ } must implement this method." )
277+ raise NotImplementedError (
278+ f"Subclass { self .__class__ .__name__ } must implement this method."
279+ )
242280
243281
244282class Transform (ABC ):
@@ -294,7 +332,9 @@ def __call__(self, data: Any):
294332 NotImplementedError: When the subclass does not override this method.
295333
296334 """
297- raise NotImplementedError (f"Subclass { self .__class__ .__name__ } must implement this method." )
335+ raise NotImplementedError (
336+ f"Subclass { self .__class__ .__name__ } must implement this method."
337+ )
298338
299339
300340class LazyTransform (Transform , LazyTrait ):
@@ -397,11 +437,15 @@ def __call__(self, data):
397437 def __new__ (cls , * args , ** kwargs ):
398438 if config .USE_META_DICT :
399439 # call_update after MapTransform.__call__
400- cls .__call__ = transforms .attach_hook (cls .__call__ , MapTransform .call_update , "post" ) # type: ignore
440+ cls .__call__ = transforms .attach_hook (
441+ cls .__call__ , MapTransform .call_update , "post"
442+ ) # type: ignore
401443
402444 if hasattr (cls , "inverse" ):
403445 # inverse_update before InvertibleTransform.inverse
404- cls .inverse : Any = transforms .attach_hook (cls .inverse , transforms .InvertibleTransform .inverse_update )
446+ cls .inverse : Any = transforms .attach_hook (
447+ cls .inverse , transforms .InvertibleTransform .inverse_update
448+ )
405449 return Transform .__new__ (cls )
406450
407451 def __init__ (self , keys : KeysCollection , allow_missing_keys : bool = False ) -> None :
@@ -412,7 +456,9 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No
412456 raise ValueError ("keys must be non empty." )
413457 for key in self .keys :
414458 if not isinstance (key , Hashable ):
415- raise TypeError (f"keys must be one of (Hashable, Iterable[Hashable]) but is { type (keys ).__name__ } ." )
459+ raise TypeError (
460+ f"keys must be one of (Hashable, Iterable[Hashable]) but is { type (keys ).__name__ } ."
461+ )
416462
417463 def call_update (self , data ):
418464 """
@@ -432,7 +478,9 @@ def call_update(self, data):
432478 for k in dict_i :
433479 if not isinstance (dict_i [k ], MetaTensor ):
434480 continue
435- list_d [idx ] = transforms .sync_meta_info (k , dict_i , t = not isinstance (self , transforms .InvertD ))
481+ list_d [idx ] = transforms .sync_meta_info (
482+ k , dict_i , t = not isinstance (self , transforms .InvertD )
483+ )
436484 return list_d [0 ] if is_dict else list_d
437485
438486 @abstractmethod
@@ -460,9 +508,13 @@ def __call__(self, data):
460508 An updated dictionary version of ``data`` by applying the transform.
461509
462510 """
463- raise NotImplementedError (f"Subclass { self .__class__ .__name__ } must implement this method." )
511+ raise NotImplementedError (
512+ f"Subclass { self .__class__ .__name__ } must implement this method."
513+ )
464514
465- def key_iterator (self , data : Mapping [Hashable , Any ], * extra_iterables : Iterable | None ) -> Generator :
515+ def key_iterator (
516+ self , data : Mapping [Hashable , Any ], * extra_iterables : Iterable | None
517+ ) -> Generator :
466518 """
467519 Iterate across keys and optionally extra iterables. If key is missing, exception is raised if
468520 `allow_missing_keys==False` (default). If `allow_missing_keys==True`, key is skipped.
0 commit comments