Skip to content

Commit

Permalink
Merge pull request #12 from sumanthratna/fix-nonlinear-transformations
Browse files Browse the repository at this point in the history
Fix nonlinear transformations
  • Loading branch information
jlevy44 authored May 9, 2020
2 parents fe064ed + 312a973 commit 3334417
Showing 1 changed file with 31 additions and 14 deletions.
45 changes: 31 additions & 14 deletions pathflow_mixmatch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def displace_image(img, displacement, gpu_device, dtype=th.float32):
def affine_register(im1, im2, iterations=1000, lr=0.01, transform_type='similarity', gpu_device=0, opt_cm=True, sigma=[[11,11],[11,11],[3,3]], order=2, pyramid=[[4,4],[2,2]], loss_fn='mse', use_mask=False, interpolation='bicubic'):
assert use_mask==False, "Masking not implemented"
assert transform_type in ['similarity', 'affine', 'rigid', 'non_parametric','bspline','wendland']
start = time.time()
start = time.perf_counter()

# set the used data type
dtype = th.float32
Expand Down Expand Up @@ -142,16 +142,23 @@ def affine_register(im1, im2, iterations=1000, lr=0.01, transform_type='similari
for level, (mov_im_level, fix_im_level) in enumerate(zip(moving_image_pyramid, fixed_image_pyramid)):

# choose the affine transformation model
if transform_type in ['non_parametric','bspline','wendland']:
transform_args[0]=mov_im_level.size
if transform_type == 'non_parametric':
transform_args[0]=mov_im_level[level].size
elif transform_type in ['bspline','wendland']:
# for bspline, sigma must be positive tuple of ints
# for bspline, smaller sigma tuple means less loss of
# microarchitectural details

# transform_opts['sigma'] = sigma[level]
transform_opts['sigma'] = (1, 1)

transformation = transforms[transform_type](*transform_args,**transform_opts)

if level > 0 and transform_type=='bspline':
constant_flow = al.transformation.utils.upsample_displacement(constant_flow,
mov_im_level.size,
interpolation=interpolation)
transformation.set_constant_flow(constant_flow)
# if level > 0 and transform_type=='bspline':
# constant_flow = al.transformation.utils.upsample_displacement(constant_flow,
# mov_im_level.size,
# interpolation=interpolation)
# transformation.set_constant_flow(constant_flow)

if transform_type in ['similarity', 'affine', 'rigid']:
# initialize the translation with the center of mass of the fixed image
Expand Down Expand Up @@ -180,8 +187,8 @@ def affine_register(im1, im2, iterations=1000, lr=0.01, transform_type='similari
# start the registration
registration.start()

if transform_type == 'bspline':
constant_flow = transformation.get_flow()
# if transform_type == 'bspline':
# constant_flow = transformation.get_flow()

# set the intensities back to the original for the visualisation
fixed_image.image = 1 - fixed_image.image
Expand All @@ -191,13 +198,14 @@ def affine_register(im1, im2, iterations=1000, lr=0.01, transform_type='similari
displacement = transformation.get_displacement()
warped_image = al.transformation.utils.warp_image(moving_image, displacement)

end = time.time()
end = time.perf_counter()

print("=================================================================")

print("Registration done in:", end - start, "s")
print("Result parameters:")
transformation.print()
if transform_type in ['similarity', 'affine', 'rigid']:
print("Result parameters:")
transformation.print()

# plot the results
plt.subplot(131)
Expand All @@ -211,7 +219,16 @@ def affine_register(im1, im2, iterations=1000, lr=0.01, transform_type='similari
plt.subplot(133)
plt.imshow(warped_image.numpy(), cmap='gray')
plt.title('Warped Moving Image')
return displacement, warped_image, transformation._phi_z, registration.loss.data.item()

if transform_type in ['similarity', 'affine', 'rigid']:
transformation_param = transformation._phi_z
elif transform_type == 'non_parametric':
transformation_param = transformation.trans_parameters
elif transform_type == 'bspline' or transform_type == 'wendland':
transformation_param = transformation._kernel
else:
pass
return displacement, warped_image, transformation_param, registration.loss.data.item()

def get_loss(im1,im2,gpu_device):

Expand Down

0 comments on commit 3334417

Please sign in to comment.