Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ def __init__(
pickle_protocol: int = DEFAULT_PROTOCOL,
hash_transform: Callable[..., bytes] | None = None,
reset_ops_id: bool = True,
track_meta: bool = False,
weights_only: bool = True,
) -> None:
"""
Args:
Expand Down Expand Up @@ -264,7 +266,16 @@ def __init__(
When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors.
This is useful for skipping the transform instance checks when inverting applied operations
using the cached content and with re-created transform instances.
track_meta: whether to track the meta information, if `True`, will convert to `MetaTensor`.
default to `False`. Cannot be used with `weights_only=True`.
weights_only: keyword argument passed to `torch.load` when reading cached files.
default to `True`. When set to `True`, `torch.load` restricts loading to tensors and
other safe objects. Setting this to `False` is required for loading `MetaTensor`
objects saved with `track_meta=True`.

Raises:
ValueError: When both `track_meta=True` and `weights_only=True`, since this combination
prevents cached MetaTensors from being reloaded and causes perpetual cache regeneration.
"""
super().__init__(data=data, transform=transform)
self.cache_dir = Path(cache_dir) if cache_dir is not None else None
Expand All @@ -280,6 +291,13 @@ def __init__(
if hash_transform is not None:
self.set_transform_hash(hash_transform)
self.reset_ops_id = reset_ops_id
if track_meta and weights_only:
raise ValueError(
"Invalid argument combination: `track_meta=True` cannot be used with `weights_only=True`. "
"To cache and reload MetaTensors, set `track_meta=True` and `weights_only=False`."
)
self.track_meta = track_meta
self.weights_only = weights_only

def set_transform_hash(self, hash_xform_func: Callable[..., bytes]):
"""Get hashable transforms, and then hash them. Hashable transforms
Expand Down Expand Up @@ -377,7 +395,7 @@ def _cachecheck(self, item_transformed):

if hashfile is not None and hashfile.is_file(): # cache hit
try:
return torch.load(hashfile, weights_only=True)
return torch.load(hashfile, weights_only=self.weights_only)
except PermissionError as e:
if sys.platform != "win32":
raise e
Expand All @@ -398,7 +416,7 @@ def _cachecheck(self, item_transformed):
with tempfile.TemporaryDirectory() as tmpdirname:
temp_hash_file = Path(tmpdirname) / hashfile.name
torch.save(
obj=convert_to_tensor(_item_transformed, convert_numeric=False),
obj=convert_to_tensor(_item_transformed, convert_numeric=False, track_meta=self.track_meta),
f=temp_hash_file,
pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD),
pickle_protocol=self.pickle_protocol,
Expand Down
44 changes: 43 additions & 1 deletion tests/data/test_persistentdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import torch
from parameterized import parameterized

from monai.data import PersistentDataset, json_hashing
from monai.data import MetaTensor, PersistentDataset, json_hashing
from monai.transforms import Compose, Flip, Identity, LoadImaged, SimulateDelayd, Transform

TEST_CASE_1 = [
Expand All @@ -43,6 +43,14 @@

TEST_CASE_3 = [None, (128, 128, 128)]

TEST_CASE_4 = [True, False, False, MetaTensor]

TEST_CASE_5 = [True, True, True, None]

TEST_CASE_6 = [False, False, False, torch.Tensor]

TEST_CASE_7 = [False, True, False, torch.Tensor]


class _InplaceXform(Transform):

Expand Down Expand Up @@ -168,6 +176,40 @@ def test_different_transforms(self):
l2 = ((im1 - im2) ** 2).sum() ** 0.5
self.assertGreater(l2, 1)

@parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7])
def test_track_meta_and_weights_only(self, track_meta, weights_only, expected_error, expected_type):
"""
Ensure expected behavior for all combinations of `track_meta` and `weights_only`.
"""
test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4))
with tempfile.TemporaryDirectory() as tempdir:
nib.save(test_image, os.path.join(tempdir, "test_image.nii.gz"))
test_data = [{"image": os.path.join(tempdir, "test_image.nii.gz")}]
transform = Compose([LoadImaged(keys=["image"])])
cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data")

if expected_error:
with self.assertRaises(ValueError):
PersistentDataset(
data=test_data,
transform=transform,
cache_dir=cache_dir,
track_meta=track_meta,
weights_only=weights_only,
)

else:
test_dataset = PersistentDataset(
data=test_data,
transform=transform,
cache_dir=cache_dir,
track_meta=track_meta,
weights_only=weights_only,
)

im = test_dataset[0]["image"]
self.assertIsInstance(im, expected_type)


if __name__ == "__main__":
unittest.main()
Loading