diff --git a/ants/registration/apply_transforms.py b/ants/registration/apply_transforms.py index d7b01046..0c7780ce 100644 --- a/ants/registration/apply_transforms.py +++ b/ants/registration/apply_transforms.py @@ -11,7 +11,7 @@ def apply_transforms(fixed, moving, transformlist, interpolator='linear', imagetype=0, whichtoinvert=None, compose=None, - defaultvalue=0, verbose=False, **kwargs): + defaultvalue=0, singleprecision=False, verbose=False, **kwargs): """ Apply a transform list to map an image from one domain to another. In image registration, one computes mappings between (usually) pairs @@ -67,6 +67,10 @@ def apply_transforms(fixed, moving, transformlist, defaultvalue : scalar Default voxel value for mappings outside the image domain. + singleprecision : boolean + if True, use float32 for computations. This is useful for reducing memory + usage for large datasets, at the cost of precision. + verbose : boolean print command and run verbose application of transform. @@ -102,6 +106,8 @@ def apply_transforms(fixed, moving, transformlist, args = [fixed, moving, transformlist, interpolator] + output_pixel_type = 'float' if singleprecision else 'double' + if not isinstance(fixed, str): if isinstance(fixed, iio.ANTsImage) and isinstance(moving, iio.ANTsImage): for tl_path in transformlist: @@ -109,9 +115,9 @@ def apply_transforms(fixed, moving, transformlist, raise Exception('Transform %s does not exist' % tl_path) inpixeltype = fixed.pixeltype - fixed = fixed.clone('float') - moving = moving.clone('float') - warpedmovout = moving.clone() + fixed = fixed.clone(output_pixel_type) + moving = moving.clone(output_pixel_type) + warpedmovout = moving.clone(output_pixel_type) f = fixed m = moving if (moving.dimension == 4) and (fixed.dimension == 3) and (imagetype == 0): @@ -165,7 +171,7 @@ def apply_transforms(fixed, moving, transformlist, if verbose: print(myargs) - processed_args = myargs + ['-z', str(1), '-v', str(myverb), '--float', str(1), '-e', str(imagetype), '-f', str(defaultvalue)] + processed_args = myargs + ['-z', str(1), '-v', str(myverb), '--float', str(int(singleprecision)), '-e', str(imagetype), '-f', str(defaultvalue)] libfn = utils.get_lib_fn('antsApplyTransforms') libfn(processed_args) @@ -180,7 +186,7 @@ def apply_transforms(fixed, moving, transformlist, else: return 1 else: - args = args + ['-z', 1, '--float', 1, '-e', imagetype, '-f', defaultvalue] + args = args + ['-z', str(1), '--float', str(int(singleprecision)), '-e', imagetype, '-f', defaultvalue] processed_args = utils._int_antsProcessArguments(args) libfn = utils.get_lib_fn('antsApplyTransforms') libfn(processed_args) diff --git a/tests/test_core_ants_transform.py b/tests/test_core_ants_transform.py index b319c1c9..f6aef044 100644 --- a/tests/test_core_ants_transform.py +++ b/tests/test_core_ants_transform.py @@ -87,7 +87,7 @@ def test_apply(self): img = ants.image_read(ants.get_ants_data("r16")).clone('float') tx = ants.new_ants_transform(dimension=2) tx.set_parameters((0.9,0,0,1.1,10,11)) - img2 = tx.apply(data=img, reference=img, data_type='image') + img2 = tx.apply(data=img, reference=img, data_type='image') def test_apply_to_point(self): tx = ants.new_ants_transform() @@ -99,14 +99,14 @@ def test_apply_to_vector(self): tx = ants.new_ants_transform() params = tx.parameters tx.set_parameters(params*2) - pt2 = tx.apply_to_vector((1,2,3)) # should be (2,4,6) + pt2 = tx.apply_to_vector((1,2,3)) # should be (2,4,6) def test_apply_to_image(self): for ptype in self.pixeltypes: img = ants.image_read(ants.get_ants_data("r16")).clone(ptype) tx = ants.new_ants_transform(dimension=2) tx.set_parameters((0.9,0,0,1.1,10,11)) - img2 = tx.apply_to_image(img, img) + img2 = tx.apply_to_image(img, img) class TestModule_ants_transform(unittest.TestCase): @@ -154,7 +154,8 @@ def test_apply_ants_transform(self): img = ants.image_read(ants.get_ants_data("r16")).clone('float') tx = ants.new_ants_transform(dimension=2) tx.set_parameters((0.9,0,0,1.1,10,11)) - img2 = ants.apply_ants_transform(tx, data=img, reference=img, data_type='image') + img2 = ants.apply_ants_transform(tx, data=img, reference=img, data_type='image') + def test_apply_ants_transform_to_point(self): tx = ants.new_ants_transform() diff --git a/tests/test_registation.py b/tests/test_registation.py index 0c6705f6..97af91d7 100644 --- a/tests/test_registation.py +++ b/tests/test_registation.py @@ -44,11 +44,21 @@ def test_example(self): fixed = ants.image_read(ants.get_ants_data("r16")) moving = ants.image_read(ants.get_ants_data("r64")) fixed = ants.resample_image(fixed, (64, 64), 1, 0) - moving = ants.resample_image(moving, (64, 64), 1, 0) + moving = ants.resample_image(moving, (128, 128), 1, 0) mytx = ants.registration(fixed=fixed, moving=moving, type_of_transform="SyN") mywarpedimage = ants.apply_transforms( fixed=fixed, moving=moving, transformlist=mytx["fwdtransforms"] ) + self.assertEqual(mywarpedimage.pixeltype, moving.pixeltype) + self.assertTrue(ants.ants_image.image_physical_space_consistency(fixed, mywarpedimage, + 0.0001, datatype = False)) + + # Call with float precision for transforms, but should still return input type + mywarpedimage2 = ants.apply_transforms( + fixed=fixed, moving=moving, transformlist=mytx["fwdtransforms"], singleprecision=True + ) + self.assertEqual(mywarpedimage2.pixeltype, moving.pixeltype) + self.assertAlmostEqual(mywarpedimage.sum(), mywarpedimage2.sum(), places=3) # bad interpolator with self.assertRaises(Exception):