Skip to content

Commit 351ec00

Browse files
committed
fix errors
Signed-off-by: sewon.jeon <[email protected]>
1 parent 8961b77 commit 351ec00

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

monai/transforms/compose.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -293,13 +293,10 @@ def set_group_recursive(obj, gid):
293293
# Skip magic methods and common non-transform attributes
294294
if attr_name.startswith("__") or attr_name in ("transforms", "transform"):
295295
continue
296-
try:
297-
attr = getattr(obj, attr_name, None)
298-
if attr is not None and isinstance(attr, TraceableTransform) and not isinstance(attr, Compose):
299-
# Recursively set group on nested transforms
300-
set_group_recursive(attr, gid)
301-
except Exception:
302-
pass
296+
attr = getattr(obj, attr_name, None)
297+
if attr is not None and isinstance(attr, TraceableTransform) and not isinstance(attr, Compose):
298+
# Recursively set group on nested transforms
299+
set_group_recursive(attr, gid)
303300

304301
for transform in self.transforms:
305302
set_group_recursive(transform, group_id)

monai/transforms/inverse.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ def _init_trace_threadlocal(self):
8282
if not hasattr(self._tracing, "value"):
8383
self._tracing.value = MONAIEnvVars.trace_transform() != "0"
8484

85+
# Initialize group identifier (set by Compose for automatic group tracking)
86+
if not hasattr(self, "_group"):
87+
self._group: str | None = None
88+
8589
def __getstate__(self):
8690
"""When pickling, remove the `_tracing` member from the output, if present, since it's not picklable."""
8791
_dict = dict(getattr(self, "__dict__", {})) # this makes __dict__ always present in the unpickled object
@@ -119,6 +123,9 @@ def get_transform_info(self) -> dict:
119123
"""
120124
Return a dictionary with the relevant information pertaining to an applied transform.
121125
"""
126+
# Ensure _group is initialized
127+
self._init_trace_threadlocal()
128+
122129
vals = (
123130
self.__class__.__name__,
124131
id(self),
@@ -128,7 +135,7 @@ def get_transform_info(self) -> dict:
128135
info = dict(zip(self.transform_info_keys(), vals))
129136

130137
# Add group if set (automatically set by Compose)
131-
if hasattr(self, "_group") and self._group is not None:
138+
if self._group is not None:
132139
info[TraceKeys.GROUP] = self._group
133140

134141
return info

0 commit comments

Comments
 (0)