Skip to content

Commit 23b7fb6

Browse files
committed
Merge remote-tracking branch 'origin/dev' into docker_slim
2 parents 463bb91 + c968907 commit 23b7fb6

File tree

3 files changed

+58
-7
lines changed

3 files changed

+58
-7
lines changed

monai/data/dataset.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,8 @@ def __init__(
230230
pickle_protocol: int = DEFAULT_PROTOCOL,
231231
hash_transform: Callable[..., bytes] | None = None,
232232
reset_ops_id: bool = True,
233+
track_meta: bool = False,
234+
weights_only: bool = True,
233235
) -> None:
234236
"""
235237
Args:
@@ -264,7 +266,17 @@ def __init__(
264266
When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors.
265267
This is useful for skipping the transform instance checks when inverting applied operations
266268
using the cached content and with re-created transform instances.
267-
269+
track_meta: whether to track the meta information, if `True`, will convert to `MetaTensor`.
270+
default to `False`. Cannot be used with `weights_only=True`.
271+
weights_only: keyword argument passed to `torch.load` when reading cached files.
272+
default to `True`. When set to `True`, `torch.load` restricts loading to tensors and
273+
other safe objects. Setting this to `False` is required for loading `MetaTensor`
274+
objects saved with `track_meta=True`, however this creates the possibility of remote
275+
code execution through `torch.load` so be aware of the security implications of doing so.
276+
277+
Raises:
278+
ValueError: When both `track_meta=True` and `weights_only=True`, since this combination
279+
prevents cached MetaTensors from being reloaded and causes perpetual cache regeneration.
268280
"""
269281
super().__init__(data=data, transform=transform)
270282
self.cache_dir = Path(cache_dir) if cache_dir is not None else None
@@ -280,6 +292,13 @@ def __init__(
280292
if hash_transform is not None:
281293
self.set_transform_hash(hash_transform)
282294
self.reset_ops_id = reset_ops_id
295+
if track_meta and weights_only:
296+
raise ValueError(
297+
"Invalid argument combination: `track_meta=True` cannot be used with `weights_only=True`. "
298+
"To cache and reload MetaTensors, set `track_meta=True` and `weights_only=False`."
299+
)
300+
self.track_meta = track_meta
301+
self.weights_only = weights_only
283302

284303
def set_transform_hash(self, hash_xform_func: Callable[..., bytes]):
285304
"""Get hashable transforms, and then hash them. Hashable transforms
@@ -377,7 +396,7 @@ def _cachecheck(self, item_transformed):
377396

378397
if hashfile is not None and hashfile.is_file(): # cache hit
379398
try:
380-
return torch.load(hashfile, weights_only=True)
399+
return torch.load(hashfile, weights_only=self.weights_only)
381400
except PermissionError as e:
382401
if sys.platform != "win32":
383402
raise e
@@ -398,7 +417,7 @@ def _cachecheck(self, item_transformed):
398417
with tempfile.TemporaryDirectory() as tmpdirname:
399418
temp_hash_file = Path(tmpdirname) / hashfile.name
400419
torch.save(
401-
obj=convert_to_tensor(_item_transformed, convert_numeric=False),
420+
obj=convert_to_tensor(_item_transformed, convert_numeric=False, track_meta=self.track_meta),
402421
f=temp_hash_file,
403422
pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
404423
pickle_protocol=self.pickle_protocol,

monai/transforms/croppad/functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def crop_or_pad_nd(img: torch.Tensor, translation_mat, spatial_size: tuple[int,
144144
_mode = _convert_pt_pad_mode(mode)
145145
img = pad_nd(img, to_pad, mode=_mode, **kwargs)
146146
if do_crop:
147-
img = img[to_crop]
147+
img = img[tuple(to_crop)]
148148
return img
149149

150150

tests/data/test_persistentdataset.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from __future__ import annotations
1313

14+
import contextlib
1415
import os
1516
import tempfile
1617
import unittest
@@ -20,7 +21,7 @@
2021
import torch
2122
from parameterized import parameterized
2223

23-
from monai.data import PersistentDataset, json_hashing
24+
from monai.data import MetaTensor, PersistentDataset, json_hashing
2425
from monai.transforms import Compose, Flip, Identity, LoadImaged, SimulateDelayd, Transform
2526

2627
TEST_CASE_1 = [
@@ -43,9 +44,16 @@
4344

4445
TEST_CASE_3 = [None, (128, 128, 128)]
4546

47+
TEST_CASE_4 = [True, False, False, MetaTensor]
48+
49+
TEST_CASE_5 = [True, True, True, None]
50+
51+
TEST_CASE_6 = [False, False, False, torch.Tensor]
52+
53+
TEST_CASE_7 = [False, True, False, torch.Tensor]
4654

47-
class _InplaceXform(Transform):
4855

56+
class _InplaceXform(Transform):
4957
def __call__(self, data):
5058
if data:
5159
data[0] = data[0] + np.pi
@@ -55,7 +63,6 @@ def __call__(self, data):
5563

5664

5765
class TestDataset(unittest.TestCase):
58-
5966
def test_cache(self):
6067
"""testing no inplace change to the hashed item"""
6168
items = [[list(range(i))] for i in range(5)]
@@ -168,6 +175,31 @@ def test_different_transforms(self):
168175
l2 = ((im1 - im2) ** 2).sum() ** 0.5
169176
self.assertGreater(l2, 1)
170177

178+
@parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7])
179+
def test_track_meta_and_weights_only(self, track_meta, weights_only, expected_error, expected_type):
180+
"""
181+
Ensure expected behavior for all combinations of `track_meta` and `weights_only`.
182+
"""
183+
test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4))
184+
with tempfile.TemporaryDirectory() as tempdir:
185+
nib.save(test_image, os.path.join(tempdir, "test_image.nii.gz"))
186+
test_data = [{"image": os.path.join(tempdir, "test_image.nii.gz")}]
187+
transform = Compose([LoadImaged(keys=["image"])])
188+
cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data")
189+
190+
cm = self.assertRaises(ValueError) if expected_error else contextlib.nullcontext()
191+
with cm:
192+
test_dataset = PersistentDataset(
193+
data=test_data,
194+
transform=transform,
195+
cache_dir=cache_dir,
196+
track_meta=track_meta,
197+
weights_only=weights_only,
198+
)
199+
200+
im = test_dataset[0]["image"]
201+
self.assertIsInstance(im, expected_type)
202+
171203

172204
if __name__ == "__main__":
173205
unittest.main()

0 commit comments

Comments
 (0)