Skip to content

Commit 9377b63

Browse files
added slight cleanup and additional test
DCO Remediation Commit for Lukas Folle <[email protected]> I, Lukas Folle <[email protected]>, hereby add my Signed-off-by to this commit: 2d58774 Signed-off-by: Lukas Folle <[email protected]>
1 parent 1d04028 commit 9377b63

File tree

1 file changed

+21
-15
lines changed

1 file changed

+21
-15
lines changed

tests/transforms/compose/test_compose.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -283,26 +283,32 @@ def test_backwards_compatible_imports(self):
283283
from monai.transforms.transform import MapTransform, RandomizableTransform, Transform # noqa: F401
284284

285285
def test_list_extend_multi_sample_trait(self):
286-
from monai.transforms import CenterSpatialCrop, RandSpatialCropSamples
287-
288-
center_crop = CenterSpatialCrop([128, 128])
289-
multi_sample_transform = RandSpatialCropSamples([64, 64], 1)
286+
center_crop = mt.CenterSpatialCrop([128, 128])
287+
multi_sample_transform = mt.RandSpatialCropSamples([64, 64], 1)
290288

291289
img = torch.zeros([1, 512, 512])
292290

293-
assert execute_compose(img, [center_crop]).shape == torch.Size([1, 128, 128])
291+
self.assertEqual(execute_compose(img, [center_crop]).shape, torch.Size([1, 128, 128]))
294292
single_multi_sample_trait_result = execute_compose(img, [multi_sample_transform, center_crop])
295-
assert (
296-
isinstance(single_multi_sample_trait_result, list)
297-
and len(single_multi_sample_trait_result) == 1
298-
and single_multi_sample_trait_result[0].shape == torch.Size([1, 64, 64])
299-
)
293+
self.assertIsInstance(single_multi_sample_trait_result, list)
294+
self.assertEqual(len(single_multi_sample_trait_result), 1)
295+
self.assertEqual(single_multi_sample_trait_result[0].shape, torch.Size([1, 64, 64]))
296+
300297
double_multi_sample_trait_result = execute_compose(img, [multi_sample_transform, multi_sample_transform, center_crop])
301-
assert (
302-
isinstance(double_multi_sample_trait_result, list)
303-
and len(double_multi_sample_trait_result) == 1
304-
and double_multi_sample_trait_result[0].shape == torch.Size([1, 64, 64])
305-
)
298+
self.assertIsInstance(double_multi_sample_trait_result, list)
299+
self.assertEqual(len(double_multi_sample_trait_result), 1)
300+
self.assertEqual(double_multi_sample_trait_result[0].shape, torch.Size([1, 64, 64]))
301+
302+
def test_multi_sample_trait_cardinality(self):
303+
img = torch.zeros([1, 128, 128])
304+
t2 = mt.RandSpatialCropSamples([32, 32], num_samples=2)
305+
306+
# chaining should multiply counts: 2 x 2 = 4, flattened
307+
res = execute_compose(img, [t2, t2])
308+
self.assertIsInstance(res, list)
309+
self.assertEqual(len(res), 4)
310+
for r in res:
311+
self.assertEqual(r.shape, torch.Size([1, 32, 32]))
306312

307313

308314
TEST_COMPOSE_EXECUTE_TEST_CASES = [

0 commit comments

Comments
 (0)