1111
1212from __future__ import annotations
1313
14+ import threading
1415import warnings
1516from collections .abc import Hashable , Mapping
1617from contextlib import contextmanager
@@ -66,15 +67,41 @@ class TraceableTransform(Transform):
6667 The information in the stack of applied transforms must be compatible with the
6768 default collate, by only storing strings, numbers and arrays.
6869
69- `tracing` could be enabled by `self.set_tracing ` or setting
70+ `tracing` could be enabled by assigning to `self.tracing ` or setting
7071 `MONAI_TRACE_TRANSFORM` when initializing the class.
7172 """
7273
73- tracing = MONAIEnvVars .trace_transform () != "0"
74+ def _init_trace_threadlocal (self ):
75+ """Create a `_tracing` instance member to store the thread-local tracing state value."""
76+ # needed since this class is meant to be a trait with no constructor
77+ if not hasattr (self , "_tracing" ):
78+ self ._tracing = threading .local ()
79+
80+ # This is True while the above initialising _tracing is False when this is
81+ # called from a different thread than the one initialising _tracing.
82+ if not hasattr (self ._tracing , "value" ):
83+ self ._tracing .value = MONAIEnvVars .trace_transform () != "0"
84+
85+ def __getstate__ (self ):
86+ """When pickling, remove the `_tracing` member from the output, if present, since it's not picklable."""
87+ _dict = dict (getattr (self , "__dict__" , {})) # this makes __dict__ always present in the unpickled object
88+ _slots = {k : getattr (self , k ) for k in getattr (self , "__slots__" , [])}
89+ _dict .pop ("_tracing" , None ) # remove tracing
90+ return _dict if len (_slots ) == 0 else (_dict , _slots )
91+
92+ @property
93+ def tracing (self ) -> bool :
94+ """
95+ Returns the tracing state, which is thread-local and initialised to `MONAIEnvVars.trace_transform() != "0"`.
96+ """
97+ self ._init_trace_threadlocal ()
98+ return bool (self ._tracing .value )
7499
75- def set_tracing (self , tracing : bool ) -> None :
76- """Set whether to trace transforms."""
77- self .tracing = tracing
100+ @tracing .setter
101+ def tracing (self , val : bool ):
102+ """Sets the thread-local tracing state to `val`."""
103+ self ._init_trace_threadlocal ()
104+ self ._tracing .value = val
78105
79106 @staticmethod
80107 def trace_key (key : Hashable = None ):
@@ -291,7 +318,7 @@ def check_transforms_match(self, transform: Mapping) -> None:
291318
292319 def get_most_recent_transform (self , data , key : Hashable = None , check : bool = True , pop : bool = False ):
293320 """
294- Get most recent transform for the stack .
321+ Get most recent matching transform for the current class from the sequence of applied operations .
295322
296323 Args:
297324 data: dictionary of data or `MetaTensor`.
@@ -316,9 +343,14 @@ def get_most_recent_transform(self, data, key: Hashable = None, check: bool = Tr
316343 all_transforms = data .get (self .trace_key (key ), MetaTensor .get_default_applied_operations ())
317344 else :
318345 raise ValueError (f"`data` should be either `MetaTensor` or dictionary, got { type (data )} ." )
346+
347+ if not all_transforms :
348+ raise ValueError (f"Item of type { type (data )} (key: { key } , pop: { pop } ) has empty 'applied_operations'" )
349+
319350 if check :
320351 self .check_transforms_match (all_transforms [- 1 ])
321- return all_transforms .pop () if pop else all_transforms [- 1 ]
352+
353+ return all_transforms .pop (- 1 ) if pop else all_transforms [- 1 ]
322354
323355 def pop_transform (self , data , key : Hashable = None , check : bool = True ):
324356 """
0 commit comments