Skip to content

Commit

Permalink
Use double precision by default in apply_transforms computations (#585)
Browse files Browse the repository at this point in the history
  • Loading branch information
cookpa authored Mar 21, 2024
1 parent 1e36373 commit 8612b9e
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 11 deletions.
18 changes: 12 additions & 6 deletions ants/registration/apply_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
def apply_transforms(fixed, moving, transformlist,
interpolator='linear', imagetype=0,
whichtoinvert=None, compose=None,
defaultvalue=0, verbose=False, **kwargs):
defaultvalue=0, singleprecision=False, verbose=False, **kwargs):
"""
Apply a transform list to map an image from one domain to another.
In image registration, one computes mappings between (usually) pairs
Expand Down Expand Up @@ -67,6 +67,10 @@ def apply_transforms(fixed, moving, transformlist,
defaultvalue : scalar
Default voxel value for mappings outside the image domain.
singleprecision : boolean
if True, use float32 for computations. This is useful for reducing memory
usage for large datasets, at the cost of precision.
verbose : boolean
print command and run verbose application of transform.
Expand Down Expand Up @@ -102,16 +106,18 @@ def apply_transforms(fixed, moving, transformlist,

args = [fixed, moving, transformlist, interpolator]

output_pixel_type = 'float' if singleprecision else 'double'

if not isinstance(fixed, str):
if isinstance(fixed, iio.ANTsImage) and isinstance(moving, iio.ANTsImage):
for tl_path in transformlist:
if not os.path.exists(tl_path):
raise Exception('Transform %s does not exist' % tl_path)

inpixeltype = fixed.pixeltype
fixed = fixed.clone('float')
moving = moving.clone('float')
warpedmovout = moving.clone()
fixed = fixed.clone(output_pixel_type)
moving = moving.clone(output_pixel_type)
warpedmovout = moving.clone(output_pixel_type)
f = fixed
m = moving
if (moving.dimension == 4) and (fixed.dimension == 3) and (imagetype == 0):
Expand Down Expand Up @@ -165,7 +171,7 @@ def apply_transforms(fixed, moving, transformlist,
if verbose:
print(myargs)

processed_args = myargs + ['-z', str(1), '-v', str(myverb), '--float', str(1), '-e', str(imagetype), '-f', str(defaultvalue)]
processed_args = myargs + ['-z', str(1), '-v', str(myverb), '--float', str(int(singleprecision)), '-e', str(imagetype), '-f', str(defaultvalue)]
libfn = utils.get_lib_fn('antsApplyTransforms')
libfn(processed_args)

Expand All @@ -180,7 +186,7 @@ def apply_transforms(fixed, moving, transformlist,
else:
return 1
else:
args = args + ['-z', 1, '--float', 1, '-e', imagetype, '-f', defaultvalue]
args = args + ['-z', str(1), '--float', str(int(singleprecision)), '-e', imagetype, '-f', defaultvalue]
processed_args = utils._int_antsProcessArguments(args)
libfn = utils.get_lib_fn('antsApplyTransforms')
libfn(processed_args)
Expand Down
9 changes: 5 additions & 4 deletions tests/test_core_ants_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def test_apply(self):
img = ants.image_read(ants.get_ants_data("r16")).clone('float')
tx = ants.new_ants_transform(dimension=2)
tx.set_parameters((0.9,0,0,1.1,10,11))
img2 = tx.apply(data=img, reference=img, data_type='image')
img2 = tx.apply(data=img, reference=img, data_type='image')

def test_apply_to_point(self):
tx = ants.new_ants_transform()
Expand All @@ -99,14 +99,14 @@ def test_apply_to_vector(self):
tx = ants.new_ants_transform()
params = tx.parameters
tx.set_parameters(params*2)
pt2 = tx.apply_to_vector((1,2,3)) # should be (2,4,6)
pt2 = tx.apply_to_vector((1,2,3)) # should be (2,4,6)

def test_apply_to_image(self):
for ptype in self.pixeltypes:
img = ants.image_read(ants.get_ants_data("r16")).clone(ptype)
tx = ants.new_ants_transform(dimension=2)
tx.set_parameters((0.9,0,0,1.1,10,11))
img2 = tx.apply_to_image(img, img)
img2 = tx.apply_to_image(img, img)


class TestModule_ants_transform(unittest.TestCase):
Expand Down Expand Up @@ -154,7 +154,8 @@ def test_apply_ants_transform(self):
img = ants.image_read(ants.get_ants_data("r16")).clone('float')
tx = ants.new_ants_transform(dimension=2)
tx.set_parameters((0.9,0,0,1.1,10,11))
img2 = ants.apply_ants_transform(tx, data=img, reference=img, data_type='image')
img2 = ants.apply_ants_transform(tx, data=img, reference=img, data_type='image')


def test_apply_ants_transform_to_point(self):
tx = ants.new_ants_transform()
Expand Down
12 changes: 11 additions & 1 deletion tests/test_registation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,21 @@ def test_example(self):
fixed = ants.image_read(ants.get_ants_data("r16"))
moving = ants.image_read(ants.get_ants_data("r64"))
fixed = ants.resample_image(fixed, (64, 64), 1, 0)
moving = ants.resample_image(moving, (64, 64), 1, 0)
moving = ants.resample_image(moving, (128, 128), 1, 0)
mytx = ants.registration(fixed=fixed, moving=moving, type_of_transform="SyN")
mywarpedimage = ants.apply_transforms(
fixed=fixed, moving=moving, transformlist=mytx["fwdtransforms"]
)
self.assertEqual(mywarpedimage.pixeltype, moving.pixeltype)
self.assertTrue(ants.ants_image.image_physical_space_consistency(fixed, mywarpedimage,
0.0001, datatype = False))

# Call with float precision for transforms, but should still return input type
mywarpedimage2 = ants.apply_transforms(
fixed=fixed, moving=moving, transformlist=mytx["fwdtransforms"], singleprecision=True
)
self.assertEqual(mywarpedimage2.pixeltype, moving.pixeltype)
self.assertAlmostEqual(mywarpedimage.sum(), mywarpedimage2.sum(), places=3)

# bad interpolator
with self.assertRaises(Exception):
Expand Down

0 comments on commit 8612b9e

Please sign in to comment.