diff --git a/pathflow_mixmatch/cli.py b/pathflow_mixmatch/cli.py index 5fb30c2..7cddeb9 100644 --- a/pathflow_mixmatch/cli.py +++ b/pathflow_mixmatch/cli.py @@ -87,7 +87,7 @@ def displace_image(img, displacement, gpu_device, dtype=th.float32): def affine_register(im1, im2, iterations=1000, lr=0.01, transform_type='similarity', gpu_device=0, opt_cm=True, sigma=[[11,11],[11,11],[3,3]], order=2, pyramid=[[4,4],[2,2]], loss_fn='mse', use_mask=False, interpolation='bicubic'): assert use_mask==False, "Masking not implemented" assert transform_type in ['similarity', 'affine', 'rigid', 'non_parametric','bspline','wendland'] - start = time.time() + start = time.perf_counter() # set the used data type dtype = th.float32 @@ -142,16 +142,23 @@ def affine_register(im1, im2, iterations=1000, lr=0.01, transform_type='similari for level, (mov_im_level, fix_im_level) in enumerate(zip(moving_image_pyramid, fixed_image_pyramid)): # choose the affine transformation model - if transform_type in ['non_parametric','bspline','wendland']: - transform_args[0]=mov_im_level.size + if transform_type == 'non_parametric': + transform_args[0]=mov_im_level[level].size + elif transform_type in ['bspline','wendland']: + # for bspline, sigma must be positive tuple of ints + # for bspline, smaller sigma tuple means less loss of + # microarchitectural details + + # transform_opts['sigma'] = sigma[level] + transform_opts['sigma'] = (1, 1) transformation = transforms[transform_type](*transform_args,**transform_opts) - if level > 0 and transform_type=='bspline': - constant_flow = al.transformation.utils.upsample_displacement(constant_flow, - mov_im_level.size, - interpolation=interpolation) - transformation.set_constant_flow(constant_flow) + # if level > 0 and transform_type=='bspline': + # constant_flow = al.transformation.utils.upsample_displacement(constant_flow, + # mov_im_level.size, + # interpolation=interpolation) + # transformation.set_constant_flow(constant_flow) if transform_type in ['similarity', 'affine', 'rigid']: # initialize the translation with the center of mass of the fixed image @@ -180,8 +187,8 @@ def affine_register(im1, im2, iterations=1000, lr=0.01, transform_type='similari # start the registration registration.start() - if transform_type == 'bspline': - constant_flow = transformation.get_flow() + # if transform_type == 'bspline': + # constant_flow = transformation.get_flow() # set the intensities back to the original for the visualisation fixed_image.image = 1 - fixed_image.image @@ -191,13 +198,14 @@ def affine_register(im1, im2, iterations=1000, lr=0.01, transform_type='similari displacement = transformation.get_displacement() warped_image = al.transformation.utils.warp_image(moving_image, displacement) - end = time.time() + end = time.perf_counter() print("=================================================================") print("Registration done in:", end - start, "s") - print("Result parameters:") - transformation.print() + if transform_type in ['similarity', 'affine', 'rigid']: + print("Result parameters:") + transformation.print() # plot the results plt.subplot(131) @@ -211,7 +219,16 @@ def affine_register(im1, im2, iterations=1000, lr=0.01, transform_type='similari plt.subplot(133) plt.imshow(warped_image.numpy(), cmap='gray') plt.title('Warped Moving Image') - return displacement, warped_image, transformation._phi_z, registration.loss.data.item() + + if transform_type in ['similarity', 'affine', 'rigid']: + transformation_param = transformation._phi_z + elif transform_type == 'non_parametric': + transformation_param = transformation.trans_parameters + elif transform_type == 'bspline' or transform_type == 'wendland': + transformation_param = transformation._kernel + else: + pass + return displacement, warped_image, transformation_param, registration.loss.data.item() def get_loss(im1,im2,gpu_device):