|
20 | 20 | import torch |
21 | 21 | from parameterized import parameterized |
22 | 22 |
|
23 | | -from monai.data import PersistentDataset, json_hashing |
| 23 | +from monai.data import MetaTensor, PersistentDataset, json_hashing |
24 | 24 | from monai.transforms import Compose, Flip, Identity, LoadImaged, SimulateDelayd, Transform |
25 | 25 |
|
26 | 26 | TEST_CASE_1 = [ |
|
43 | 43 |
|
44 | 44 | TEST_CASE_3 = [None, (128, 128, 128)] |
45 | 45 |
|
| 46 | +TEST_CASE_4 = [True, False, False, MetaTensor] |
| 47 | + |
| 48 | +TEST_CASE_5 = [True, True, True, None] |
| 49 | + |
| 50 | +TEST_CASE_6 = [False, False, False, torch.Tensor] |
| 51 | + |
| 52 | +TEST_CASE_7 = [False, True, False, torch.Tensor] |
| 53 | + |
46 | 54 |
|
47 | 55 | class _InplaceXform(Transform): |
48 | 56 |
|
@@ -168,6 +176,40 @@ def test_different_transforms(self): |
168 | 176 | l2 = ((im1 - im2) ** 2).sum() ** 0.5 |
169 | 177 | self.assertGreater(l2, 1) |
170 | 178 |
|
| 179 | + @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) |
| 180 | + def test_track_meta_and_weights_only(self, track_meta, weights_only, expected_error, expected_type): |
| 181 | + """ |
| 182 | + Ensure expected behavior for all combinations of `track_meta` and `weights_only`. |
| 183 | + """ |
| 184 | + test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4)) |
| 185 | + with tempfile.TemporaryDirectory() as tempdir: |
| 186 | + nib.save(test_image, os.path.join(tempdir, "test_image.nii.gz")) |
| 187 | + test_data = [{"image": os.path.join(tempdir, "test_image.nii.gz")}] |
| 188 | + transform = Compose([LoadImaged(keys=["image"])]) |
| 189 | + cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data") |
| 190 | + |
| 191 | + if expected_error: |
| 192 | + with self.assertRaises(ValueError): |
| 193 | + PersistentDataset( |
| 194 | + data=test_data, |
| 195 | + transform=transform, |
| 196 | + cache_dir=cache_dir, |
| 197 | + track_meta=track_meta, |
| 198 | + weights_only=weights_only, |
| 199 | + ) |
| 200 | + |
| 201 | + else: |
| 202 | + test_dataset = PersistentDataset( |
| 203 | + data=test_data, |
| 204 | + transform=transform, |
| 205 | + cache_dir=cache_dir, |
| 206 | + track_meta=track_meta, |
| 207 | + weights_only=weights_only, |
| 208 | + ) |
| 209 | + |
| 210 | + im = test_dataset[0]["image"] |
| 211 | + self.assertIsInstance(im, expected_type) |
| 212 | + |
171 | 213 |
|
172 | 214 | if __name__ == "__main__": |
173 | 215 | unittest.main() |
0 commit comments