Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix nonlinear transformations #12

Merged
merged 5 commits into from
May 9, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 31 additions & 14 deletions pathflow_mixmatch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def displace_image(img, displacement, gpu_device, dtype=th.float32):
def affine_register(im1, im2, iterations=1000, lr=0.01, transform_type='similarity', gpu_device=0, opt_cm=True, sigma=[[11,11],[11,11],[3,3]], order=2, pyramid=[[4,4],[2,2]], loss_fn='mse', use_mask=False, interpolation='bicubic'):
assert use_mask==False, "Masking not implemented"
assert transform_type in ['similarity', 'affine', 'rigid', 'non_parametric','bspline','wendland']
start = time.time()
start = time.perf_counter()

# set the used data type
dtype = th.float32
Expand Down Expand Up @@ -142,16 +142,23 @@ def affine_register(im1, im2, iterations=1000, lr=0.01, transform_type='similari
for level, (mov_im_level, fix_im_level) in enumerate(zip(moving_image_pyramid, fixed_image_pyramid)):

# choose the affine transformation model
if transform_type in ['non_parametric','bspline','wendland']:
transform_args[0]=mov_im_level.size
if transform_type == 'non_parametric':
transform_args[0]=mov_im_level[level].size
elif transform_type in ['bspline','wendland']:
# for bspline, sigma must be positive tuple of ints
# for bspline, smaller sigma tuple means less loss of
# microarchitectural details

# transform_opts['sigma'] = sigma[level]
transform_opts['sigma'] = (1, 1)

transformation = transforms[transform_type](*transform_args,**transform_opts)

if level > 0 and transform_type=='bspline':
constant_flow = al.transformation.utils.upsample_displacement(constant_flow,
mov_im_level.size,
interpolation=interpolation)
transformation.set_constant_flow(constant_flow)
# if level > 0 and transform_type=='bspline':
# constant_flow = al.transformation.utils.upsample_displacement(constant_flow,
# mov_im_level.size,
# interpolation=interpolation)
# transformation.set_constant_flow(constant_flow)

if transform_type in ['similarity', 'affine', 'rigid']:
# initialize the translation with the center of mass of the fixed image
Expand Down Expand Up @@ -180,8 +187,8 @@ def affine_register(im1, im2, iterations=1000, lr=0.01, transform_type='similari
# start the registration
registration.start()

if transform_type == 'bspline':
constant_flow = transformation.get_flow()
# if transform_type == 'bspline':
# constant_flow = transformation.get_flow()

# set the intensities back to the original for the visualisation
fixed_image.image = 1 - fixed_image.image
Expand All @@ -191,13 +198,14 @@ def affine_register(im1, im2, iterations=1000, lr=0.01, transform_type='similari
displacement = transformation.get_displacement()
warped_image = al.transformation.utils.warp_image(moving_image, displacement)

end = time.time()
end = time.perf_counter()

print("=================================================================")

print("Registration done in:", end - start, "s")
print("Result parameters:")
transformation.print()
if transform_type in ['similarity', 'affine', 'rigid']:
print("Result parameters:")
transformation.print()

# plot the results
plt.subplot(131)
Expand All @@ -211,7 +219,16 @@ def affine_register(im1, im2, iterations=1000, lr=0.01, transform_type='similari
plt.subplot(133)
plt.imshow(warped_image.numpy(), cmap='gray')
plt.title('Warped Moving Image')
return displacement, warped_image, transformation._phi_z, registration.loss.data.item()

if transform_type in ['similarity', 'affine', 'rigid']:
transformation_param = transformation._phi_z
elif transform_type == 'non_parametric':
transformation_param = transformation.trans_parameters
elif transform_type == 'bspline' or transform_type == 'wendland':
transformation_param = transformation._kernel
else:
pass
return displacement, warped_image, transformation_param, registration.loss.data.item()

def get_loss(im1,im2,gpu_device):

Expand Down