@@ -283,26 +283,32 @@ def test_backwards_compatible_imports(self):
283283 from monai .transforms .transform import MapTransform , RandomizableTransform , Transform # noqa: F401
284284
285285 def test_list_extend_multi_sample_trait (self ):
286- from monai .transforms import CenterSpatialCrop , RandSpatialCropSamples
287-
288- center_crop = CenterSpatialCrop ([128 , 128 ])
289- multi_sample_transform = RandSpatialCropSamples ([64 , 64 ], 1 )
286+ center_crop = mt .CenterSpatialCrop ([128 , 128 ])
287+ multi_sample_transform = mt .RandSpatialCropSamples ([64 , 64 ], 1 )
290288
291289 img = torch .zeros ([1 , 512 , 512 ])
292290
293- assert execute_compose (img , [center_crop ]).shape == torch .Size ([1 , 128 , 128 ])
291+ self . assertEqual ( execute_compose (img , [center_crop ]).shape , torch .Size ([1 , 128 , 128 ]) )
294292 single_multi_sample_trait_result = execute_compose (img , [multi_sample_transform , center_crop ])
295- assert (
296- isinstance (single_multi_sample_trait_result , list )
297- and len (single_multi_sample_trait_result ) == 1
298- and single_multi_sample_trait_result [0 ].shape == torch .Size ([1 , 64 , 64 ])
299- )
293+ self .assertIsInstance (single_multi_sample_trait_result , list )
294+ self .assertEqual (len (single_multi_sample_trait_result ), 1 )
295+ self .assertEqual (single_multi_sample_trait_result [0 ].shape , torch .Size ([1 , 64 , 64 ]))
296+
300297 double_multi_sample_trait_result = execute_compose (img , [multi_sample_transform , multi_sample_transform , center_crop ])
301- assert (
302- isinstance (double_multi_sample_trait_result , list )
303- and len (double_multi_sample_trait_result ) == 1
304- and double_multi_sample_trait_result [0 ].shape == torch .Size ([1 , 64 , 64 ])
305- )
298+ self .assertIsInstance (double_multi_sample_trait_result , list )
299+ self .assertEqual (len (double_multi_sample_trait_result ), 1 )
300+ self .assertEqual (double_multi_sample_trait_result [0 ].shape , torch .Size ([1 , 64 , 64 ]))
301+
302+ def test_multi_sample_trait_cardinality (self ):
303+ img = torch .zeros ([1 , 128 , 128 ])
304+ t2 = mt .RandSpatialCropSamples ([32 , 32 ], num_samples = 2 )
305+
306+ # chaining should multiply counts: 2 x 2 = 4, flattened
307+ res = execute_compose (img , [t2 , t2 ])
308+ self .assertIsInstance (res , list )
309+ self .assertEqual (len (res ), 4 )
310+ for r in res :
311+ self .assertEqual (r .shape , torch .Size ([1 , 32 , 32 ]))
306312
307313
308314TEST_COMPOSE_EXECUTE_TEST_CASES = [
0 commit comments