diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index f94f11eca9..2f57f4614a 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -11,6 +11,7 @@ from __future__ import annotations +import threading import warnings from collections.abc import Hashable, Mapping from contextlib import contextmanager @@ -66,15 +67,41 @@ class TraceableTransform(Transform): The information in the stack of applied transforms must be compatible with the default collate, by only storing strings, numbers and arrays. - `tracing` could be enabled by `self.set_tracing` or setting + `tracing` could be enabled by assigning to `self.tracing` or setting `MONAI_TRACE_TRANSFORM` when initializing the class. """ - tracing = MONAIEnvVars.trace_transform() != "0" + def _init_trace_threadlocal(self): + """Create a `_tracing` instance member to store the thread-local tracing state value.""" + # needed since this class is meant to be a trait with no constructor + if not hasattr(self, "_tracing"): + self._tracing = threading.local() + + # This is True while the above initialising _tracing is False when this is + # called from a different thread than the one initialising _tracing. + if not hasattr(self._tracing, "value"): + self._tracing.value = MONAIEnvVars.trace_transform() != "0" + + def __getstate__(self): + """When pickling, remove the `_tracing` member from the output, if present, since it's not picklable.""" + _dict = dict(getattr(self, "__dict__", {})) # this makes __dict__ always present in the unpickled object + _slots = {k: getattr(self, k) for k in getattr(self, "__slots__", [])} + _dict.pop("_tracing", None) # remove tracing + return _dict if len(_slots) == 0 else (_dict, _slots) + + @property + def tracing(self) -> bool: + """ + Returns the tracing state, which is thread-local and initialised to `MONAIEnvVars.trace_transform() != "0"`. + """ + self._init_trace_threadlocal() + return bool(self._tracing.value) - def set_tracing(self, tracing: bool) -> None: - """Set whether to trace transforms.""" - self.tracing = tracing + @tracing.setter + def tracing(self, val: bool): + """Sets the thread-local tracing state to `val`.""" + self._init_trace_threadlocal() + self._tracing.value = val @staticmethod def trace_key(key: Hashable = None): @@ -291,7 +318,7 @@ def check_transforms_match(self, transform: Mapping) -> None: def get_most_recent_transform(self, data, key: Hashable = None, check: bool = True, pop: bool = False): """ - Get most recent transform for the stack. + Get most recent matching transform for the current class from the sequence of applied operations. Args: data: dictionary of data or `MetaTensor`. @@ -316,9 +343,14 @@ def get_most_recent_transform(self, data, key: Hashable = None, check: bool = Tr all_transforms = data.get(self.trace_key(key), MetaTensor.get_default_applied_operations()) else: raise ValueError(f"`data` should be either `MetaTensor` or dictionary, got {type(data)}.") + + if not all_transforms: + raise ValueError(f"Item of type {type(data)} (key: {key}, pop: {pop}) has empty 'applied_operations'") + if check: self.check_transforms_match(all_transforms[-1]) - return all_transforms.pop() if pop else all_transforms[-1] + + return all_transforms.pop(-1) if pop else all_transforms[-1] def pop_transform(self, data, key: Hashable = None, check: bool = True): """ diff --git a/tests/transforms/test_inverse.py b/tests/transforms/inverse/test_inverse.py similarity index 100% rename from tests/transforms/test_inverse.py rename to tests/transforms/inverse/test_inverse.py diff --git a/tests/transforms/inverse/test_inverse_dict.py b/tests/transforms/inverse/test_inverse_dict.py new file mode 100644 index 0000000000..466be7411c --- /dev/null +++ b/tests/transforms/inverse/test_inverse_dict.py @@ -0,0 +1,105 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest +from itertools import product + +import torch +from parameterized import parameterized + +from monai.data import DataLoader, Dataset, MetaTensor, ThreadDataLoader, create_test_image_2d +from monai.engines.evaluator import SupervisedEvaluator +from monai.transforms import Compose, EnsureChannelFirstd, Invertd, Spacingd +from monai.utils.enums import CommonKeys +from tests.test_utils import TEST_DEVICES, SkipIfNoModule + + +class TestInvertDict(unittest.TestCase): + + def setUp(self): + self.orig_size = (60, 60) + img, _ = create_test_image_2d(*self.orig_size, 2, 10, num_seg_classes=2) + self.img = MetaTensor(img, meta={"original_channel_dim": float("nan"), "pixdim": [1.0, 1.0]}) + self.key = CommonKeys.IMAGE + self.pred = CommonKeys.PRED + self.new_pixdim = 2.0 + + self.preprocessing = Compose([EnsureChannelFirstd(self.key), Spacingd(self.key, pixdim=[self.new_pixdim] * 2)]) + + self.postprocessing = Compose([Invertd(self.pred, transform=self.preprocessing, orig_keys=self.key)]) + + @parameterized.expand(TEST_DEVICES) + def test_simple_processing(self, device): + """ + Tests postprocessing operations perform correctly, in particular that `Invertd` does inversion correctly. + + This will apply the preprocessing sequence which resizes the result, then the postprocess sequence which + returns it to the original shape using Invertd. This tests that the shape of the output is the same as the + original image. This will also test that Invertd doesn't get confused if transforms in the postprocessing + sequence are tracing and so adding information to `applied_operations`, this is what `Lambdad` is doing in + `self.postprocessing`. + """ + + item = {self.key: self.img.to(device)} + pre = self.preprocessing(item) + + nw = int(self.orig_size[0] / self.new_pixdim) + nh = int(self.orig_size[1] / self.new_pixdim) + + self.assertTupleEqual(pre[self.key].shape, (1, nh, nw), "Pre-processing did not reshape input correctly") + self.assertTrue(len(pre[self.key].applied_operations) > 0, "Pre-processing transforms did not trace correctly") + + pre[self.pred] = pre[self.key] # the inputs are the prediction for this test + + post = self.postprocessing(pre) + + self.assertTupleEqual( + post[self.pred].shape, (1, *self.orig_size), "Result does not have same shape as original input" + ) + + @parameterized.expand(product(sum(TEST_DEVICES, []), [True, False])) + @SkipIfNoModule("ignite") + def test_workflow(self, device, use_threads): + """ + This tests the interaction between pre and postprocesing transform sequences being executed in parallel. + + When the `ThreadDataLoader` is used to load batches, this is done in parallel at times with the execution of + the post-process transform sequence. Previously this encountered a race condition at times because the + `TraceableTransform.tracing` variables of transforms was being toggled in different threads, so at times a + pre-process transform wouldn't trace correctly and so confuse `Invertd`. Using a `SupervisedEvaluator` is + the best way to induce this race condition, other methods didn't get the timing right.. + """ + batch_size = 2 + ds_size = 4 + test_data = [{self.key: self.img.clone().to(device)} for _ in range(ds_size)] + ds = Dataset(test_data, transform=self.preprocessing) + dl_type = ThreadDataLoader if use_threads else DataLoader + dl = dl_type(ds, num_workers=0, batch_size=batch_size) + + class AssertAppliedOps(torch.nn.Module): + def forward(self, x): + assert len(x.applied_operations) == x.shape[0] + assert all(len(a) > 0 for a in x.applied_operations) + return x + + evaluator = SupervisedEvaluator( + device=device, network=AssertAppliedOps(), postprocessing=self.postprocessing, val_data_loader=dl + ) + + evaluator.run() + + self.assertTupleEqual(evaluator.state.output[0][self.pred].shape, (1, *self.orig_size)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/transforms/test_invert.py b/tests/transforms/inverse/test_invert.py similarity index 100% rename from tests/transforms/test_invert.py rename to tests/transforms/inverse/test_invert.py diff --git a/tests/transforms/test_invertd.py b/tests/transforms/inverse/test_invertd.py similarity index 100% rename from tests/transforms/test_invertd.py rename to tests/transforms/inverse/test_invertd.py diff --git a/tests/transforms/inverse/test_traceable_transform.py b/tests/transforms/inverse/test_traceable_transform.py index 6a499b2dd9..8ee7c9e62f 100644 --- a/tests/transforms/inverse/test_traceable_transform.py +++ b/tests/transforms/inverse/test_traceable_transform.py @@ -45,13 +45,13 @@ def test_default(self): self.assertEqual(len(data[expected_key]), 2) self.assertEqual(data[expected_key][-1]["class"], "_TraceTest") - with self.assertRaises(IndexError): + with self.assertRaises(ValueError): a.pop({"test": "test"}) # no stack in the data data = a.pop(data) data = a.pop(data) self.assertEqual(data[expected_key], []) - with self.assertRaises(IndexError): # no more items + with self.assertRaises(ValueError): # no more items a.pop(data)