Skip to content

Commit 341538d

Browse files
committed
Add new test
Signed-off-by: Mason Cleveland <[email protected]>
1 parent d16616c commit 341538d

File tree

1 file changed

+43
-1
lines changed

1 file changed

+43
-1
lines changed

tests/data/test_persistentdataset.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import torch
2121
from parameterized import parameterized
2222

23-
from monai.data import PersistentDataset, json_hashing
23+
from monai.data import MetaTensor, PersistentDataset, json_hashing
2424
from monai.transforms import Compose, Flip, Identity, LoadImaged, SimulateDelayd, Transform
2525

2626
TEST_CASE_1 = [
@@ -43,6 +43,14 @@
4343

4444
TEST_CASE_3 = [None, (128, 128, 128)]
4545

46+
TEST_CASE_4 = [True, False, False, MetaTensor]
47+
48+
TEST_CASE_5 = [True, True, True, None]
49+
50+
TEST_CASE_6 = [False, False, False, torch.Tensor]
51+
52+
TEST_CASE_7 = [False, True, False, torch.Tensor]
53+
4654

4755
class _InplaceXform(Transform):
4856

@@ -168,6 +176,40 @@ def test_different_transforms(self):
168176
l2 = ((im1 - im2) ** 2).sum() ** 0.5
169177
self.assertGreater(l2, 1)
170178

179+
@parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7])
180+
def test_track_meta_and_weights_only(self, track_meta, weights_only, expected_error, expected_type):
181+
"""
182+
Ensure expected behavior for all combinations of `track_meta` and `weights_only`.
183+
"""
184+
test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4))
185+
with tempfile.TemporaryDirectory() as tempdir:
186+
nib.save(test_image, os.path.join(tempdir, "test_image.nii.gz"))
187+
test_data = [{"image": os.path.join(tempdir, "test_image.nii.gz")}]
188+
transform = Compose([LoadImaged(keys=["image"])])
189+
cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data")
190+
191+
if expected_error:
192+
with self.assertRaises(ValueError):
193+
PersistentDataset(
194+
data=test_data,
195+
transform=transform,
196+
cache_dir=cache_dir,
197+
track_meta=track_meta,
198+
weights_only=weights_only,
199+
)
200+
201+
else:
202+
test_dataset = PersistentDataset(
203+
data=test_data,
204+
transform=transform,
205+
cache_dir=cache_dir,
206+
track_meta=track_meta,
207+
weights_only=weights_only,
208+
)
209+
210+
im = test_dataset[0]["image"]
211+
self.assertIsInstance(im, expected_type)
212+
171213

172214
if __name__ == "__main__":
173215
unittest.main()

0 commit comments

Comments
 (0)