-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmesoSfM.py
1338 lines (1181 loc) · 77.3 KB
/
mesoSfM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import numpy as np
import os
import matplotlib.pyplot as plt
import cv2
import tensorflow as tf
from tqdm.notebook import tqdm
import scipy.signal
from tensorflow.python.training.tracking.data_structures import ListWrapper
class mesoSfM:
def __init__(self, stack, ul_coords, recon_shape, ul_offset=(0, 0), batch_size=None, scale=1, momentum=None,
batch_across_images=False, report_error_map=False, restrict_function='clip'):
# stack: stack of images to be stitched; 3D tensor of shape num_images, num_rows, num_cols, num_channels;
# ul_coords: array of upper-left coordinates of the images (in pixels); 2D tensor of shape num_images, 2 (row,
# col; y, x);
# recon_shape: shape of the final reconstruction/stitched image (row, col, channels);
# ul_offset: length-2 vector (row, col) specifying a constant offset for all the ul_coords (in pixels);
# batch_size: if None, then don't use batching;
# scale: factor between 0 and 1 specifying how much to downsample the reconstruction;
# momentum: only relevant for batching; specifies how much averaging of previous iterations to use for recon;
# either both batch_size and momentum must be non-None, or both must be None;
# batch_across_images: if you're batching (i.e., batch_size is not None), then if true, then batch across the
# image dimension; otherwise, batch across pixels; (if batching while using unet, then you must batch across
# image dimension);
# report_error_map:
# restrict_function choice decides what to do if a point goes beyond the boundaries; 'clip' or 'mod' or
# 'mod_with_random_shifts';
self.tf_dtype = tf.float32
self.np_dtype = np.float32
self.stack = np.uint8(stack) # cast is 8-bit to save memory for the batch generator; cast batch to float32;
self.num_channels = self.stack.shape[3] # number of channels; stack must at least have a singleton dim 3;
self.num_images = self.stack.shape[0] # number of images in dataset;
self.ul_coords = np.uint16(ul_coords) # uint8 is too narrow, because these pixel coordinates are large;
self.recon_shape_base = recon_shape # this will be the base recon_shape; the effective recon_shape will depend
# on the scale factor that the user specifies;
self.ul_offset = np.array(ul_offset)
self.scale = scale
self.momentum = momentum
self.batch_size = batch_size
self.batch_across_images = batch_across_images
self.report_error_map = report_error_map
self.restrict_function = restrict_function
self.sig_proj = .42465 # for the intepolation kernel width;
self.subtract_min_from_height_map = True
self.optimizer = tf.keras.optimizers.Adam
# unet parameters if relevant; define these manually if needed;
self.filters_list = None
self.skip_list = None
self.unet_scale = .01 # to scale the output of the unet
self.output_nonlinearity = 'linear' # 'linear' or 'leaky_relu'
self.upsample_method = 'bilinear' # 'bilinear' or 'nearest'
self.ckpt = None # for checkpointing models when using unet;
self.save_iter = 15 # save model every this many iterations;
self.recompute_CNN = False # save memory using tf.recompute_grad;
self.height_scale_factor = 4000 # to scale the ego_height parameter to avoid dealing with large values;
# camera parameters for getting absolute scale;
self.use_absolute_scale_calibration = False
self.effective_focal_length_mm = 4.3 # effective focal length in mm;
self.magnification_j = None # user needs to figure out the magnification of the jth camera view;
self.j = 0 # and specify which camera view index; by default, assume the first camera;
def create_variables(self, deformation_model, learning_rates=None, variable_initial_values=None, recon=None,
normalize=None, remove_global_transform=False, antialiasing_filter=False,
stack_downsample_factor=None, force_ground_surface_up=False):
# define tf.Variables and optimizers;
# deformation_model: affects what tf.Variables will be defined and optimized;
# learning_rates is a dictionary of variable names (strings) and the corresponding learning rate; None means
# use the default, defined below; can specify a learning rate a negative number, indicating that variable is not
# optimized; can supply initial variable values as a dictionary; if None, use default;
# if recon and normalize are supplied, then initialize recon_previous and normalize_previous with these after
# upsampling; both must be supplied; (this is for momentum/running average calculations); most likely, recon and
# normalize will be from earlier reconstruction attempts using mesoSfM, but at lower resolution;
# remove_global_transform: e.g., no global shift or scale;
# stack_downsample_factor is an integer that downsamples the stack and coordinates; if None, then it will be
# computed from self.scale;
# force_ground_surface_up; when using camera model, force the mean surface normal to be [0,0,-1];
# these must be both be None, or neither None:
assert (self.batch_size is not None) == (self.momentum is not None)
assert (recon is None) == (normalize is None)
# define downsample factor:
self.recon_shape = np.int32(self.recon_shape_base * self.scale)
if stack_downsample_factor is None:
self.downsample = np.int32(1 / self.scale) # also downsample the images to save computation;
self.downsample = np.maximum(self.downsample, 1) # obviously can't downsample with 0;
else:
self.downsample = stack_downsample_factor
self.im_downsampled_shape = np.array([(self.stack.shape[1] - 1) // self.downsample + 1,
(self.stack.shape[2] - 1) // self.downsample + 1])
# coordinates of the images:
c = np.arange(self.stack.shape[2], dtype=np.uint16)
r = np.arange(self.stack.shape[1], dtype=np.uint16)
r, c = np.meshgrid(r, c, indexing='ij')
rc_base = np.stack([r, c]).T
self.rc_base = np.tile(rc_base[None],
[self.num_images, 1, 1, 1]) # base coordinates (could be downsampled after applying a
# scale factor);
self.deformation_model = deformation_model
self.remove_global_transform = remove_global_transform
self.force_ground_surface_up = force_ground_surface_up # only relevant if using a camera model;
# list of tf.Variables and optimizers, to be populated by one or more _create_variables_*deformation_model*;
self.train_var_list = list()
self.optimizer_list = list()
self.non_train_list = list() # list of variables that aren't trained (probably .assign()'d; for checkpoints);
self.tensors_to_track = dict() # intermediate tensors to track; have a tf.function return the contents;
def use_default_for_missing(input_dict, default_dict):
# to be used directly below; allows for dictionaries in which not all keys are specified; if not specified,
# then use default_dict's value;
if input_dict is None: # if nothing given, then use the default;
return default_dict
else:
for key in default_dict:
if key in input_dict:
if input_dict[key] is None: # if the key is present, but None is specified;
input_dict[key] = default_dict[key]
else: # i.e., use the value given;
pass
else: # if key is not even present;
input_dict[key] = default_dict[key]
return input_dict
if 'camera_parameters' in deformation_model:
if 'unet' in deformation_model:
# user has to manually define these:
assert self.filters_list is not None
assert self.skip_list is not None
default_learning_rates = {'camera_focal_length': 1e-3, 'camera_height': 1e-3, 'ground_surface_normal': 1e-3,
'camera_in_plane_angle': 1e-3, 'rc': 10, 'gain': 1e-3, 'ego_height': 1e-3,
'bias': 1e-3}
default_variable_initial_values = {'camera_focal_length': np.float32(1), # initialize same value as ...
'camera_height': 1 * np.ones([self.num_images]), # ...for height;
'ground_surface_normal': np.concatenate(
[np.zeros((self.num_images, 2)) + 1e-7, # to avoid /0;
-np.ones((self.num_images, 1))], axis=1),
'camera_in_plane_angle': np.zeros([self.num_images]),
'rc': self.ul_coords,
'ego_height': 1e-7 + # add small value to allow gradients to prop;
np.zeros([self.num_images,
self.im_downsampled_shape[0],
self.im_downsampled_shape[1]]),
'gain': np.ones(self.num_images),
'bias': np.zeros(self.num_images)}
learning_rates = use_default_for_missing(learning_rates, default_learning_rates)
variable_initial_values = use_default_for_missing(variable_initial_values, default_variable_initial_values)
self._create_variables_camera_parameters(learning_rates, variable_initial_values)
# these are for perspective to perspective and perspective to orthographic:
if 'unet' not in deformation_model:
# create radial deformation field:
self._create_variables_perspective_to_orthographic(learning_rates, variable_initial_values)
elif 'unet' in deformation_model:
# create neural network layers:
self._create_variables_height_map_unet(learning_rates, variable_initial_values)
if self.remove_global_transform:
# removing the global scaling transform entails adapting the focal length to the mean height;
assert learning_rates['camera_focal_length'] < 0
if 'perspective_to_perspective' in deformation_model:
# these won't be optimized; also these initializing values are very temporary, and should be immediately
# modified by code later; the reason why they are tf.Variables is so that the user can manually modify
# them in eager mode with .assign();
self.reference_camera_height = tf.Variable(1, dtype=self.tf_dtype, name='reference_camera_height')
self.reference_camera_rc = tf.Variable(np.zeros(2), dtype=self.tf_dtype, name='reference_camera_rc')
# note that because there's no tilt, reference_camera_height is the same as the camera to vanishing
# point distance, and the reference_camera_rc is the same as the vanishing point position;
self.non_train_list.append(self.reference_camera_rc)
self.non_train_list.append(self.reference_camera_height)
else:
raise Exception('invalid deformation model: ' + deformation_model)
# intensity adjustment (uniform within a given image for now):
self.gain = tf.Variable(variable_initial_values['gain'], dtype=self.tf_dtype, name='gain')
self.gain_optim = self.optimizer(learning_rate=learning_rates['gain'])
self.train_var_list.append(self.gain)
self.optimizer_list.append(self.gain_optim)
# intensity bias (also uniform within each image):
self.bias = tf.Variable(variable_initial_values['bias'], dtype=self.tf_dtype, name='bias')
self.bias_optim = self.optimizer(learning_rate=learning_rates['bias'])
self.train_var_list.append(self.bias)
self.optimizer_list.append(self.bias_optim)
# illumination flattening with a 2nd order polynomial (this does everything gain+vignetting does):
if 'illum_flat' in learning_rates:
if 'illum_flat' in variable_initial_values:
illum_flat_init = variable_initial_values['illum_flat']
else:
illum_flat_init = np.zeros((self.num_images, 6)) # a0 + a1x + a2y + a3x^2 + a4y^2 + a5xy;
self.correct_illum_flat = True
self.illum_flat_params = tf.Variable(illum_flat_init, dtype=self.tf_dtype, name='illum_flat')
self.illum_flat_params_optim = self.optimizer(learning_rate=learning_rates['illum_flat'])
self.train_var_list.append(self.illum_flat_params)
self.optimizer_list.append(self.illum_flat_params_optim)
else:
self.correct_illum_flat = False
# barrel or pincushion distortion correction:
if 'radial_camera_distortion' in learning_rates:
assert 'camera' in deformation_model # this is explicitly a camera modeling option;
if 'radial_camera_distortion' in variable_initial_values:
camera_distortion_init = variable_initial_values['radial_camera_distortion']
if np.ndim(camera_distortion_init) == 0: # user is specifying there to be only one common parameter;
camera_distortion_init = np.reshape(camera_distortion_init, (1, 1)) # make sure at least 2 dims;
elif np.ndim(camera_distortion_init) == 1:
# assume this one dimension refers to the camera view, not the polynomial order;
camera_distortion_init = np.reshape(camera_distortion_init, (-1, 1))
else:
camera_distortion_init = np.zeros((self.num_images, 1)) # second dim: polynomial order;
num_poly_terms = camera_distortion_init.shape[1] # number of terms in the polynomial (even only);
self.radial_powers = (np.arange(num_poly_terms) + 1) # half of the even powers to raise to;
self.correct_radial_camera_distortion = camera_distortion_init.shape[0] # this info used if batching;
self.radial_camera_distortion = tf.Variable(camera_distortion_init, dtype=self.tf_dtype,
name='radial_camera_distortion')
self.radial_camera_distortion_optim = self.optimizer(
learning_rate=learning_rates['radial_camera_distortion'])
self.train_var_list.append(self.radial_camera_distortion)
self.optimizer_list.append(self.radial_camera_distortion_optim)
else:
self.correct_radial_camera_distortion = False
if 'radial_camera_distortion_piecewise_linear' in learning_rates:
assert 'camera' in deformation_model
if 'radial_camera_distortion_piecewise_linear' in variable_initial_values:
camera_distortion_init = variable_initial_values['radial_camera_distortion_piecewise_linear']
assert np.ndim(camera_distortion_init) == 1 # for now, only allow a common distortion among all cams;
# length of this vector determines how many discretization levels;
else:
camera_distortion_init = np.zeros(50)
self.num_radial_pixels = len(camera_distortion_init) # how many discretization levels (nodes);
self.correct_radial_camera_distortion_piecewise_linear = -1 # used if batching images; this never equals
# num_images;
self.radial_camera_distortion_piecewise_linear = tf.Variable(camera_distortion_init, dtype=self.tf_dtype,
name='radial_camera_distortion_piecewise_linear')
self.radial_camera_distortion_piecewise_linear_optim = self.optimizer(
learning_rate=learning_rates['radial_camera_distortion_piecewise_linear'])
self.train_var_list.append(self.radial_camera_distortion_piecewise_linear)
self.optimizer_list.append(self.radial_camera_distortion_piecewise_linear_optim)
else:
self.correct_radial_camera_distortion_piecewise_linear = False
# if camera is not centered; this defines the center of the above two distortions, so center the camera first,
# apply the distortions, then decenter back;
if 'camera_distortion_center' in learning_rates:
assert 'camera' in deformation_model
if 'camera_distortion_center' in variable_initial_values:
camera_distortion_center_init = variable_initial_values['camera_distortion_center']
if np.ndim(camera_distortion_center_init) == 1: # one common pair of parameters;
assert len(camera_distortion_center_init) == 2 # x and y centers;
camera_distortion_center_init = tf.reshape(camera_distortion_center_init, (1, -1))
if np.ndim(camera_distortion_center_init) == 0:
raise Exception('must supply two values for x/y centration parameters')
else:
camera_distortion_center_init = np.zeros((self.num_images, 2)) # by default, diff pair for each camera;
self.correct_camera_distortion_center = camera_distortion_center_init.shape[0] # this info used if
# batching;
self.camera_distortion_center = tf.Variable(camera_distortion_center_init, dtype=self.tf_dtype,
name='camera_distortion_center')
self.camera_distortion_center_optim = self.optimizer(
learning_rate=learning_rates['camera_distortion_center'])
self.train_var_list.append(self.camera_distortion_center)
self.optimizer_list.append(self.camera_distortion_center_optim)
else:
self.correct_camera_distortion_center = False
# create a list of booleans to accompany self.train_var_list and self.optimizer_list to specify whether to train
# those variables (as specified by the whether the user-specified learning rates are negative); doing this so
# that autograph doesn't traverse all branches of the conditionals; if the user ever wants to turn off
# optimization of a variable mid-optimization, then just do .assign(0) to the learning rate, such that the
# update is still happening, but the change is 0;
self.trainable_or_not = list()
for var in self.train_var_list:
if type(var) is list:
# if the variable is a list of variables, then this should be for the unet; modify here if there are
# other scenarios;
assert 'unet' in self.deformation_model
name = 'ego_height'
else:
name = var.name[:-2]
flag = learning_rates[name] > 0
self.trainable_or_not.append(flag)
# downsample rc coordinates and stack:
rc = np.transpose(self.rc_base[:, ::self.downsample, ::self.downsample, :], (0, 2, 1, 3))
if antialiasing_filter:
downsample = int(self.downsample) # cv2 doesn't like numpy values?
if downsample == 1:
print('warning: antialiasing filter is applied even though there is no downsampling')
ksize = int(downsample * 2.5)
if ksize % 2 == 0: # must be odd;
ksize += 1
stack_downsamp = np.stack(
[cv2.GaussianBlur(im, (ksize, ksize), downsample, downsample)
[::downsample, ::downsample] for im in self.stack])
else:
stack_downsamp = self.stack[:, ::self.downsample, ::self.downsample, :]
rc_downsamp = np.reshape(rc, (self.rc_base.shape[0], -1, self.rc_base.shape[-1])) # flatten spatial dims;
stack_downsamp = np.reshape(stack_downsamp,
[self.num_images, -1, self.num_channels]) # flatten spatial dims;
self.rc_downsamp = rc_downsamp
self.stack_downsamp = stack_downsamp
# create variables relevant for batching (or set them to None if not batching):
if self.momentum is not None:
if 'camera_parameters_perspective_' in deformation_model:
# we're going to give the coregistered height map a ride;
num_channels = self.num_channels + 1
else:
num_channels = self.num_channels
if recon is None: # if none supplied, initialize with 0s;
recon_previous = np.zeros([self.recon_shape[0], self.recon_shape[1], num_channels])
normalize_previous = np.zeros(self.recon_shape)
else: # otherwise, upsample to the current shape;
recon_previous = cv2.resize(np.nan_to_num(recon), tuple(self.recon_shape[::-1]))
normalize_previous = cv2.resize(np.nan_to_num(normalize), tuple(self.recon_shape[::-1]))
if num_channels == 1:
# cv2 seems to squeeze singleton channels dimensions, so if it's singleton, add it back:
recon_previous = recon_previous[:, :, None]
if recon_previous.shape[-1] != num_channels:
# if you would like to use the previous RGB image as initialization, but your previous run didn't
# estimate a height map;
assert 'camera_parameters_perspective_' in deformation_model
# add empty height channel:
recon_previous = np.concatenate([recon_previous, np.zeros_like(recon_previous[:, :, 0:1])], axis=-1)
# initialize first recon and normalize tensors for momentum; use the scaled recon shape, not the base shape;
self.recon_previous = tf.Variable(recon_previous, dtype=self.tf_dtype, trainable=False)
self.non_train_list.append(self.recon_previous)
else:
self.recon_previous = None
self.normalize_previous = None
def _create_variables_perspective_to_orthographic(self, learning_rates, variable_initial_values):
# radially inwardly pointing vector magnitudes, where the larger the magnitude, the taller the object;
ego_height = variable_initial_values['ego_height']
# make sure the first dimensions match (number of images in stack):
assert ego_height.shape[0] == self.num_images
if ego_height.shape[1:] != tuple(self.im_downsampled_shape):
# presumably you've initialized this with the results from another optimization at a different scale;
# thus, resize to match the current scale:
if type(ego_height) != np.ndarray:
# convert from tf to np if you need to:
ego_height = ego_height.numpy()
ego_height = np.stack([cv2.resize(im, tuple(self.im_downsampled_shape[::-1])) for im in ego_height])
self.ego_height = tf.Variable(ego_height, dtype=self.tf_dtype, name='ego_height')
self.ego_height_optim = self.optimizer(learning_rate=learning_rates['ego_height'])
self.train_var_list.append(self.ego_height)
self.optimizer_list.append(self.ego_height_optim)
def _create_variables_height_map_unet(self, learning_rates, variable_initial_values):
self.network = unet(self.filters_list, self.skip_list, output_nonlinearity=self.output_nonlinearity,
upsample_method=self.upsample_method)
# run the network once so that we can access network.trainable_variables
self.network(tf.zeros([1, 2 ** len(self.filters_list), 2 ** len(self.filters_list), self.num_channels],
dtype=self.tf_dtype))
self.train_var_list.append(self.network.trainable_variables)
self.optimizer_list.append(self.optimizer(learning_rate=learning_rates['ego_height']))
if self.recompute_CNN:
self.network = tf.recompute_grad(self.network)
# get padded shape that the network likes;
self.padded_shape = [get_compatible_size(dim, len(self.filters_list)) for dim in self.im_downsampled_shape]
pad_r = self.padded_shape[0] - self.im_downsampled_shape[0]
pad_c = self.padded_shape[1] - self.im_downsampled_shape[1]
pad_top = pad_r // 2
pad_bottom = int(tf.math.ceil(pad_r / 2))
pad_left = pad_c // 2
pad_right = int(tf.math.ceil(pad_c / 2))
pad_specs = ((pad_top, pad_bottom), (pad_left, pad_right))
self.pad_layer = tf.keras.layers.ZeroPadding2D(pad_specs)
self.depad_layer = tf.keras.layers.Cropping2D(pad_specs)
def _create_variables_camera_parameters(self, learning_rates, variable_initial_values):
# coordinates are in camera space;
self.camera_focal_length = tf.Variable(variable_initial_values['camera_focal_length'], dtype=self.tf_dtype,
name='camera_focal_length')
self.camera_focal_length_optim = self.optimizer(learning_rate=learning_rates['camera_focal_length'])
self.camera_height = tf.Variable(variable_initial_values['camera_height'], dtype=self.tf_dtype,
name='camera_height') # height from the camera perspective;
self.camera_height_optim = self.optimizer(learning_rate=learning_rates['camera_height'])
self.ground_surface_normal = tf.Variable(variable_initial_values['ground_surface_normal'], dtype=self.tf_dtype,
name='ground_surface_normal')
self.ground_surface_normal_optim = self.optimizer(
learning_rate=learning_rates['ground_surface_normal'])
self.camera_in_plane_angle = tf.Variable(variable_initial_values['camera_in_plane_angle'],
dtype=self.tf_dtype, name='camera_in_plane_angle')
self.camera_in_plane_angle_optim = self.optimizer(
learning_rate=learning_rates['camera_in_plane_angle'])
self.rc_ul_per_im = tf.Variable(variable_initial_values['rc'], dtype=self.tf_dtype, name='rc')
self.rc_ul_per_im_optim = self.optimizer(learning_rate=learning_rates['rc'])
self.train_var_list.append(self.camera_focal_length)
self.optimizer_list.append(self.camera_focal_length_optim)
self.train_var_list.append(self.camera_height)
self.optimizer_list.append(self.camera_height_optim)
self.train_var_list.append(self.ground_surface_normal)
self.optimizer_list.append(self.ground_surface_normal_optim)
self.train_var_list.append(self.camera_in_plane_angle)
self.optimizer_list.append(self.camera_in_plane_angle_optim)
self.train_var_list.append(self.rc_ul_per_im)
self.optimizer_list.append(self.rc_ul_per_im_optim)
def generate_dataset(self):
# user calls this function to get a dataset to iterate over; if not using batching, then just return a tuple or
# list of length 1 (i.e., the whole dataset is one batch);
if self.batch_size is not None:
if self.batch_across_images:
# sample a subset of the images, and keep track of the indices downsampled so that you can gather the
# corresponding variables;
tensor_slices = (self.stack_downsamp,
(self.rc_downsamp, np.arange(self.num_images, dtype=np.int32)))
dataset = (tf.data.Dataset.from_tensor_slices(tensor_slices).shuffle(self.num_images)
.batch(self.batch_size, drop_remainder=True).repeat(None).prefetch(1))
return dataset
else:
# transpose to batch along space, not image number;
rc_downsamp_T = np.transpose(self.rc_downsamp, (1, 0, 2))
stack_downsamp_T = np.transpose(self.stack_downsamp, (1, 0, 2))
if 'camera_parameters_perspective_' in self.deformation_model:
# need to also get coordinates of the spatial positions to index into pixel-wise deformation fields:
tensor_slices = (stack_downsamp_T,
(rc_downsamp_T, np.arange(np.prod(self.im_downsampled_shape), dtype=np.int32)))
else:
tensor_slices = (stack_downsamp_T, rc_downsamp_T)
dataset = (tf.data.Dataset.from_tensor_slices(tensor_slices)
.shuffle(len(rc_downsamp_T)).batch(self.batch_size).repeat(None).prefetch(1))
return dataset
else: # basically a 1-batch dataset;
return self.stack_downsamp, self.rc_downsamp
def _warp_camera_parameters(self, rc_downsamp, use_radial_deformation, p2p_warp_mode=None,
inds_downsamp=None, stack_downsamp=None):
# shape of rc_downsamp: num_images, _, 2;
# use_radial_deformation is a boolean flag specifying whether to do the per-pixel radial deformation fields to
# warp perspective to othographic OR perspective to perspective; if the latter, then p2p_warp_mode specifies
# how to specify the perspective reference to warp to; the options are 'mean', 'random', 'fixed', and None,
# where None means you're using perspective-to-orthographpic warping and not perspective to perspective;
# inds_downsamp is passed if batching and using radial deformations;
# stack_downsamp is only needed if using a unet;
if p2p_warp_mode is None:
assert 'perspective_to_perspective' not in self.deformation_model
else:
assert 'perspective_to_perspective' in self.deformation_model
if self.batch_across_images and self.batch_size is not None:
# in generate_recon, we defined the batch and non-batch versions;
rc_ul_per_im = self.rc_ul_per_im_batch
gain = self.gain_batch
bias = self.bias_batch
camera_height = self.camera_height_batch
ground_surface_normal = self.ground_surface_normal_batch
camera_in_plane_angle = self.camera_in_plane_angle_batch
if 'unet' not in self.deformation_model:
ego_height = self.ego_height_batch
self.ego_height_to_regularize = self.ego_height_batch
else:
pass
if self.correct_radial_camera_distortion:
if self.correct_radial_camera_distortion == self.num_images:
radial_camera_distortion = self.radial_camera_distortion_batch
else:
radial_camera_distortion = self.radial_camera_distortion
if self.correct_radial_camera_distortion_piecewise_linear:
if self.correct_radial_camera_distortion_piecewise_linear == self.num_images:
radial_camera_distortion_piecewise_linear = self.radial_camera_distortion_piecewise_linear_batch
else:
radial_camera_distortion_piecewise_linear = self.radial_camera_distortion_piecewise_linear
if self.correct_camera_distortion_center:
if self.correct_camera_distortion_center == self.num_images:
camera_distortion_center = self.camera_distortion_center_batch
else:
camera_distortion_center = self.camera_distortion_center
num_images = self.batch_size # for reshaping below;
camera_focal_length = self.camera_focal_length # not batched;
else:
rc_ul_per_im = self.rc_ul_per_im
gain = self.gain
bias = self.bias
camera_height = self.camera_height
ground_surface_normal = self.ground_surface_normal
camera_in_plane_angle = self.camera_in_plane_angle
if 'unet' not in self.deformation_model:
ego_height = self.ego_height
self.ego_height_to_regularize = self.ego_height
if self.correct_radial_camera_distortion:
radial_camera_distortion = self.radial_camera_distortion
if self.correct_radial_camera_distortion_piecewise_linear:
radial_camera_distortion_piecewise_linear = self.radial_camera_distortion_piecewise_linear
if self.correct_camera_distortion_center:
camera_distortion_center = self.camera_distortion_center
num_images = self.num_images # for reshaping below;
camera_focal_length = self.camera_focal_length
if self.remove_global_transform:
# don't use self.camera_focal_length; set to geometric mean of the camera heights;
# also, use the tf.Variable version always (not the batch version);
camera_focal_length = tf.reduce_prod(self.camera_height, axis=0, keepdims=False) ** (1 / self.num_images)
self.camera_focal_length.assign(camera_focal_length)
camera_in_plane_angle = camera_in_plane_angle - tf.reduce_mean(self.camera_in_plane_angle)
rc_ul_per_im = rc_ul_per_im - tf.reduce_mean(self.rc_ul_per_im, axis=0, keepdims=True)
im_dims = np.array(self.stack.shape)[1:3] # for normalization of image coordinates;
max_dim = np.max(im_dims) # to keep isotropic;
camera_yx = (rc_downsamp - .5 * im_dims[None, None, :]) / max_dim # normalize to -.5 to .5;
if self.correct_camera_distortion_center:
camera_yx -= camera_distortion_center[:, None, :]
if self.correct_radial_camera_distortion:
camera_r2 = camera_yx[:, :, 0] ** 2 + camera_yx[:, :, 1] ** 2 # radial distance squared;
# dims^: camera, pixels
camera_r2 *= 2 # make it go from -1 to 1 rather than -.5 to .5;
if self.correct_radial_camera_distortion:
# even polynomial to account for distortion:
camera_even_poly = tf.math.pow(camera_r2[:, :, None], self.radial_powers[None, None, :])
# dims^: camera, pixels, power
camera_even_poly = tf.reduce_sum(camera_even_poly * radial_camera_distortion[:, None, :], 2)
# dims^: camera, pixels
radial_correction_factor = 1 + camera_even_poly[:, :, None]
self.tensors_to_track['camera_distortion_radial'] = radial_correction_factor
else:
radial_correction_factor = 1
camera_yx = camera_yx * radial_correction_factor
if self.correct_radial_camera_distortion_piecewise_linear:
camera_r = tf.sqrt(camera_yx[:, :, 0] ** 2 + camera_yx[:, :, 1] ** 2) # radial distance;
# dims^: camera, pixels; these radial coordinates should be between 0 and .5*sqrt(2), but could go higher
# if th center moves); thus to be safe, just multiply by num_radial_pixels;
r_scale = camera_r * self.num_radial_pixels
# find nearest pixels and distances thereto:
r_floor = tf.floor(r_scale)
r_ceil = tf.minimum(r_floor + 1, self.num_radial_pixels - 1) # to prevent out of range indexing;
r_middle = r_scale - r_floor
r_floor = tf.cast(r_floor, dtype=tf.int32)
r_ceil = tf.cast(r_ceil, dtype=tf.int32)
distortion = 1 + radial_camera_distortion_piecewise_linear
distortion /= tf.reduce_max(distortion) # to prevent global expansion;
correction_factor_floor = tf.gather(distortion, r_floor)
correction_factor_ceil = tf.gather(distortion, r_ceil)
# bilinear interpolation:
correction_factor = correction_factor_ceil * r_middle + correction_factor_floor * (1 - r_middle)
camera_yx *= correction_factor[:, :, None]
self.tensors_to_track['camera_distortion_piecewise_linear'] = correction_factor
if self.correct_camera_distortion_center:
camera_yx += camera_distortion_center[:, None, :]
# in-plane rotation:
cos = tf.cos(camera_in_plane_angle)
sin = tf.sin(camera_in_plane_angle)
rotmat_xy = tf.stack([[cos, sin], [-sin, cos]])
camera_yx = tf.einsum('cri,ijc->crj', camera_yx, rotmat_xy)
n_ground, _ = tf.linalg.normalize(ground_surface_normal, axis=1) # normalize to unit mag;
# shape^: num_images, 3
# projecting to object space (computed analytically and taylor-expanded);
nx = n_ground[:, 0][:, None, None]
ny = n_ground[:, 1][:, None, None]
nx2 = nx ** 2
ny2 = ny ** 2
x = camera_yx[:, :, 1][:, :, None]
y = camera_yx[:, :, 0][:, :, None]
h = camera_height[:, None, None]
f = camera_focal_length
# using a taylor expansion:
flat_x = (h * x / f +
h * x * (nx * x + ny * y) / f ** 2 +
h * (f ** 2 * nx2 * x + 2 * nx2 * x ** 3 + f ** 2 * nx * ny * y + 4 * nx * ny * x ** 2 * y +
2 * ny2 * x * y ** 2) / 2 / f ** 3)
flat_y = (h * y / f +
h * y * (nx * x + ny * y) / f ** 2 +
h * (f ** 2 * ny2 * y + 2 * ny2 * y ** 3 + f ** 2 * nx * ny * x + 4 * nx * ny * y ** 2 * x +
2 * nx2 * y * x ** 2) / 2 / f ** 3)
flat_xy = tf.concat([flat_x, flat_y], axis=2)
n_dot_r = n_ground[:, 2] * camera_height # shape: num_images; dot product between a point on the
# ground, r (0, 0, camera_height); (this is needed below);
self.flat_xy = flat_xy
if use_radial_deformation:
# compute the vanishing point, which will be used if you use the projective to orthographic mode;
vanish_xyz = n_dot_r[:, None] * n_ground # shape: num_images, 3;
camera_to_vanish_point_xyz = tf.norm(vanish_xyz, axis=1) # distance from camera to the ground; the actual
# height of the camera;
# projection to object space simplifies to this:
vanish_xy = -camera_height[:, None] * n_ground[:, :2]
# vanishing point in camera plane:
vanish_camera_xyz = n_ground * camera_focal_length / n_ground[:, 2:] # follow surface normal to camera
# plane;
vanish_camera_xy = vanish_camera_xyz[:, 0:2] # don't need z, it's just the focal length;
# account for in-plane camera rotation:
vanish_camera_xy = tf.einsum('ci,ijc->cj', vanish_camera_xy, rotmat_xy)
self.vanish_xy = vanish_xy
# convert back to row-column:
flat_rc = (flat_xy[:, :, ::-1] * max_dim + .5 * im_dims[None, None, :])
if use_radial_deformation:
vanish_rc = (vanish_xy[:, ::-1] * max_dim + .5 * im_dims[None, :])
vanish_camera_rc = (vanish_camera_xy[:, ::-1] * max_dim + .5 * im_dims[None, :])
self.camera_to_vanish_point_rc = camera_to_vanish_point_xyz * max_dim # convert from xy units to rc units;
self.tensors_to_track['camera_to_vanish_point_xyz'] = self.camera_to_vanish_point_rc
# add translations (same as for the homography and affine implementations):
# these translations don't affect camera_to_vanish_point;
rc_warp = rc_ul_per_im[:, None, :] + flat_rc
if use_radial_deformation:
vanish_warp = rc_ul_per_im + vanish_rc
if self.restrict_function == 'mod_with_random_shifts':
# to discourage registration with overlapped regions;
random_shift = tf.random.uniform(shape=(1, 1, 2), minval=0, maxval=self.recon_shape.max() / self.scale)
rc_warp += random_shift
if use_radial_deformation:
vanish_warp += random_shift[0]
else:
rc_warp += self.ul_offset[None, None, :]
if use_radial_deformation:
vanish_warp += self.ul_offset[None, :]
if 'unet' in self.deformation_model:
# generate self.ego_height; doesn't matter if batching or not because it's generated from the image batch;
unet_input = tf.reshape(stack_downsamp, [num_images, self.im_downsampled_shape[0],
self.im_downsampled_shape[1], self.num_channels])
unet_input = self.pad_layer(unet_input)
unet_output = self.network(unet_input)
ego_height = tf.reduce_mean(self.depad_layer(unet_output), [-1]) # remove last dimension;
ego_height *= self.unet_scale
self.ego_height_to_regularize = ego_height # need this for regularization;
self.ego_height = ego_height
# multiplicative version:
if use_radial_deformation:
self.tensors_to_track['vanish_warp'] = vanish_warp
self.tensors_to_track['vanish_camera'] = vanish_camera_rc
if p2p_warp_mode is None: # using perspective-to-orthographic
ego_height_flat = tf.reshape(ego_height, [num_images, -1]) # flatten spatial dimensions;
if self.subtract_min_from_height_map:
ego_height_flat -= tf.reduce_min(ego_height_flat)
if inds_downsamp is not None:
assert self.batch_size is not None # this should never be raised, but just in case;
ego_height = tf.gather(ego_height_flat, inds_downsamp, axis=1) # batch along pixels;
else:
ego_height = ego_height_flat
if self.use_absolute_scale_calibration:
H = self.camera_to_vanish_point_rc[:, None]
if self.batch_size is not None:
# need to compute the self.j'th camera-to-vanish-point height in case you're batching, which means
# that the self.j'th entry may not be computed;
n_ground, _ = tf.linalg.normalize(self.ground_surface_normal[self.j], axis=0)
n_dot_r = n_ground[2] * self.camera_height[self.j]
vanish_xyz = n_dot_r * n_ground # shape: 3;
camera_to_vanish_point_xyz = tf.norm(vanish_xyz, axis=0)
H_j = camera_to_vanish_point_xyz * max_dim
else:
H_j = H[self.j]
r = rc_warp - vanish_warp[:, None, :] # lateral distance to vanishing point;
M_j = self.magnification_j
f_eff = self.effective_focal_length_mm
self.another_height_scale_factor = f_eff * (1 + 1 / M_j) # scale ego_height again to make this
# case similar to the other case (self.use_absolute_scale_calibration);
h = ego_height * self.another_height_scale_factor
delta_radial = h / f_eff / (1 + 1 / M_j * H / H_j)
rc_warp = r * (1 - delta_radial[:, :, None]) + vanish_warp[:, None, :]
ego_height *= self.height_scale_factor # to keep consistent with regularization coefficients;
# note that you have to divide by height_scale_factor because the height map is scaled by this, but
# you have to divide by another_scale_factor because the multiplication above allows ego_height to
# shrink;
else:
ego_height *= self.height_scale_factor # denominator is large in next line, so multiply by a large
# value to allow self.ego_height to take on smaller values;
delta_radial = ego_height / self.camera_to_vanish_point_rc[:, None]
rc_warp = ((rc_warp - vanish_warp[:, None, :]) * (1 - delta_radial[:, :, None]) +
vanish_warp[:, None, :])
self.ego_height_for_concat = ego_height # for concatenating with self.im below;
else: # perspective-to-perspective warping
# first, need to define the reference camera view:
if p2p_warp_mode == 'mean':
self.reference_camera_height.assign(tf.reduce_mean(self.camera_to_vanish_point_rc))
self.reference_camera_rc.assign(tf.reduce_mean(vanish_warp, axis=0))
elif p2p_warp_mode == 'random':
height_min = tf.reduce_min(self.camera_to_vanish_point_rc)
height_max = tf.reduce_max(self.camera_to_vanish_point_rc)
self.reference_camera_height.assign(tf.random.uniform((), height_min, height_max))
rc_min = tf.reduce_min(vanish_warp, axis=0)
rc_max = tf.reduce_max(vanish_warp, axis=0)
self.reference_camera_rc.assign(tf.random.uniform((2,), rc_min, rc_max))
elif p2p_warp_mode == 'fixed':
pass # do nothing, accept current values;
elif p2p_warp_mode == 'random_choice':
# pick one of the camera view among the existing;
random_choice = tf.random.uniform((), 0, self.num_images, dtype=tf.int32)
self.reference_camera_height.assign(tf.gather(self.camera_to_vanish_point_rc, random_choice))
self.reference_camera_rc.assign(tf.gather(vanish_warp, random_choice))
else:
raise Exception('invalid perspective-to-perspective warp mode passed to gradient_update: '
+ p2p_warp_mode)
# vector deformation field to warp to the reference perspective:
h = tf.reshape(ego_height, [num_images, -1, 1]) * self.height_scale_factor
H_r = self.reference_camera_height
H = self.camera_to_vanish_point_rc[:, None, None]
R_r = self.reference_camera_rc[None, None, :]
R = vanish_warp[:, None, :]
r = rc_warp - R # position vectors relative to each camera's vanishing point;
p2p_warp = h / (H_r - h) * (R - R_r) + r * h * (H - H_r) / H / (H_r - h) # the magic equation;
rc_warp += p2p_warp # shape: num_images, flattened spatial, 2
self.ego_height_for_concat = h # for concatenating with self.im below;
if self.use_absolute_scale_calibration:
raise Exception('not yet implemented for perspective-to-perspective')
rc_warp = tf.reshape(rc_warp, [-1, 2]) * self.scale # flatten
return rc_warp
def _generate_recon(self, stack_downsamp, rc_downsamp, dither_coords, p2p_warp_mode=None, assign_update_recon=True):
# backprojects all the images into the reconstruction, with the specified scale;
# if batching, the gradient_update function will update the reconstruction with a running average;
# if batching, this function should not be called by the user, as it will continually update the recon with the
# same batch; if not batching, then this generates the full reconstruction;
# p2p_warp_mode: if using perspective-to-perspective warping; can be None, 'mean', 'random', or 'fixed';
# assign_update_recon: only relevant if using batching; controls whether to use the .assign() mechanism to
# update the reconstruction (specified via update_gradient option in the gradient_update function);
if self.batch_size is not None:
if self.batch_across_images:
# distinguish inds_downsamp and inds_image_downsamp, where the former is for pixel-level batching while
# the latter is for image-level batching;
rc_downsamp, inds_image_downsamp = rc_downsamp # unpack
# now, for all variables whose first dimension corresponds to the image dimension, gather:
self.rc_ul_per_im_batch = tf.gather(self.rc_ul_per_im, inds_image_downsamp, axis=0)
self.gain_batch = tf.gather(self.gain, inds_image_downsamp, axis=0)
self.bias_batch = tf.gather(self.bias, inds_image_downsamp, axis=0)
# these are used below:
gain = self.gain_batch
bias = self.bias_batch
if 'camera_parameters' in self.deformation_model:
self.camera_height_batch = tf.gather(self.camera_height, inds_image_downsamp, axis=0)
self.ground_surface_normal_batch = tf.gather(self.ground_surface_normal,
inds_image_downsamp, axis=0)
self.camera_in_plane_angle_batch = tf.gather(self.camera_in_plane_angle,
inds_image_downsamp, axis=0)
if 'unet' not in self.deformation_model:
# if using unet, then ego_height will already be gathered as it is generated by the unet;
self.ego_height_batch = tf.gather(self.ego_height, inds_image_downsamp, axis=0)
# the following self.correct__ variables serve two purposes: 1) to signify whether they are being
# used, and 2) specify length of first dimension of the corresponding distortion variable to decide
# whether we need to use tf.gather;
if self.correct_radial_camera_distortion == self.num_images:
self.radial_camera_distortion_batch = tf.gather(self.radial_camera_distortion,
inds_image_downsamp, axis=0)
if self.correct_radial_camera_distortion_piecewise_linear == self.num_images:
self.radial_camera_distortion_piecewise_linear_batch = tf.gather(
self.radial_camera_distortion_piecewise_linear, inds_image_downsamp, axis=0)
if self.correct_camera_distortion_center == self.num_images:
self.camera_distortion_center_batch = tf.gather(self.camera_distortion_center,
inds_image_downsamp, axis=0)
else:
raise Exception('image-level batching not yet implemented for a non-camera model')
inds_downsamp = None
else:
# if batching, then stack_downsamp and rc_downsamp are transposed and need to be untransposed;
if 'camera_parameters' in self.deformation_model:
# also need the indices of the pixels chosen, because radial deformations are pixel-wise;
# make sure you package these together into a tuple in the script;
rc_downsamp, inds_downsamp = rc_downsamp
else:
inds_downsamp = None
stack_downsamp = tf.transpose(stack_downsamp, (1, 0, 2))
rc_downsamp = tf.transpose(rc_downsamp, (1, 0, 2))
# these are used below:
gain = self.gain
bias = self.bias
else:
inds_downsamp = None
# these are used below:
gain = self.gain
bias = self.bias
# to save CPU memory, the dataset and coordinates are stored as uint8 and uint16, respectively; thus, cast to
# float here;
stack_downsamp = tf.cast(stack_downsamp, self.tf_dtype)
rc_downsamp = tf.cast(rc_downsamp, self.tf_dtype)
# function that restricts coordinates to the grid (store as self.variable so that error_map can use it):
if self.restrict_function == 'clip':
self.restrict = lambda x: tf.clip_by_value(x, tf.zeros_like(x), self.recon_shape[None] - 1)
elif 'mod' in self.restrict_function: # 'mod' or 'mod_with_random_shifts';
self.restrict = lambda x: tf.math.floormod(x, self.recon_shape[None])
else:
raise Exception('invalid restrict_function')
# apply gain:
gain_norm = gain / tf.reduce_mean(gain) # normalize so that there's no global gain;
im = stack_downsamp * gain_norm[:, None, None] + bias[:, None, None]
# apply illumination flattening:
if self.correct_illum_flat:
if self.batch_size is not None:
if self.batch_across_images:
self.illum_flat_batch = tf.gather(self.bias, inds_image_downsamp, axis=0)
illum_flat = self.illum_flat_batch
else: # parts of every image present, so no gathering:
illum_flat = self.illum_flat_params
else:
illum_flat = self.illum_flat_params
# same as used for distortion, but no decentering:
im_dims = np.array(self.stack.shape)[1:3] # for normalization of image coordinates;
max_dim = np.max(im_dims) # to keep isotropic;
camera_yx = (rc_downsamp - .5 * im_dims[None, None, :]) / max_dim
y = camera_yx[:, :, 0]
x = camera_yx[:, :, 1]
DC = illum_flat[:, 0:1]
DC = DC - tf.reduce_mean(DC) # to avoid global gain;
correction = (1 + DC + illum_flat[:, 1:2] * x + illum_flat[:, 2:3] * y
+ illum_flat[:, 3:4] * x ** 2 + illum_flat[:, 4:5] * y ** 2 + illum_flat[:, 5:] * x * y)
# shape: num_images, _;
im = im * correction[:, :, None]
self.im = tf.reshape(im, (-1, self.num_channels)) # flatten all but channels;
# warped coordinates:
if self.deformation_model == 'camera_parameters':
self.rc_warp = self._warp_camera_parameters(rc_downsamp, use_radial_deformation=False)
elif self.deformation_model == 'camera_parameters_perspective_to_orthographic':
self.rc_warp = self._warp_camera_parameters(rc_downsamp, inds_downsamp=inds_downsamp,
use_radial_deformation=True)
elif self.deformation_model == 'camera_parameters_perspective_to_orthographic_unet':
self.rc_warp = self._warp_camera_parameters(rc_downsamp, inds_downsamp=inds_downsamp,
use_radial_deformation=True, stack_downsamp=stack_downsamp)
elif self.deformation_model == 'camera_parameters_perspective_to_perspective':
self.rc_warp = self._warp_camera_parameters(rc_downsamp, p2p_warp_mode=p2p_warp_mode,
inds_downsamp=inds_downsamp,
use_radial_deformation=True)
elif self.deformation_model == 'camera_parameters_perspective_to_perspective_unet':
self.rc_warp = self._warp_camera_parameters(rc_downsamp, p2p_warp_mode=p2p_warp_mode,
inds_downsamp=inds_downsamp,
use_radial_deformation=True, stack_downsamp=stack_downsamp)
else:
raise Exception('invalid deformation model: ' + self.deformation_model)
if 'camera_parameters_perspective_to_' in self.deformation_model:
# adding the height map as a channel to the reconstruction, so first augment self.im with self.ego_height:
self.im = tf.concat([self.im, tf.reshape(self.ego_height_for_concat, [-1])[:, None]], axis=1)
self.num_channels_recon = self.num_channels + 1 # for the recon, need one more channel;
else:
self.num_channels_recon = self.num_channels
#
if dither_coords:
self.rc_warp += tf.random.uniform([1, 2], -1, 1, dtype=self.tf_dtype)
if self.batch_size is not None:
print('Minor warning: using a running average for the recon while dithering coordinates')
# neighboring pixels:
rc_floor = tf.floor(self.rc_warp)
rc_ceil = rc_floor + 1
# distance to neighboring pixels:
frc = self.rc_warp - rc_floor
crc = rc_ceil - self.rc_warp
# cast
rc_floor = tf.cast(rc_floor, tf.int32)
rc_ceil = tf.cast(rc_ceil, tf.int32)
self.rc_ff = self.restrict(rc_floor)
self.rc_cc = self.restrict(rc_ceil)
self.rc_cf = self.restrict(tf.stack([rc_ceil[:, 0], rc_floor[:, 1]], 1))
self.rc_fc = self.restrict(tf.stack([rc_floor[:, 0], rc_ceil[:, 1]], 1))
# sig_proj = .42465 # chosen so that if point is exactly in between
# ...two pixels, .5 weight is assigned to each pixel
self.frc = tf.exp(-frc ** 2 / 2. / self.sig_proj ** 2)
self.crc = tf.exp(-crc ** 2 / 2. / self.sig_proj ** 2)
# augmented coordinates:
rc_4 = tf.concat([self.rc_ff, self.rc_cc, self.rc_cf, self.rc_fc], 0)
# interpolated:
im_4 = tf.concat([self.im * self.frc[:, 0, None] * self.frc[:, 1, None],
self.im * self.crc[:, 0, None] * self.crc[:, 1, None],
self.im * self.crc[:, 0, None] * self.frc[:, 1, None],
self.im * self.frc[:, 0, None] * self.crc[:, 1, None]], 0)
w_4 = tf.concat([self.frc[:, 0] * self.frc[:, 1],
self.crc[:, 0] * self.crc[:, 1],
self.crc[:, 0] * self.frc[:, 1],
self.frc[:, 0] * self.crc[:, 1]], 0)
if self.momentum is not None:
# update with moving average:
self.im_4_previous = tf.gather_nd(self.recon_previous,
rc_4) * w_4[:, None] # with appropriate weighting by w_4;
self.im_4_updated = (im_4 * self.momentum + self.im_4_previous * (1 - self.momentum))
normalize = tf.scatter_nd(rc_4, w_4, self.recon_shape)
self.norm_updated_regathered = tf.gather_nd(normalize, rc_4)
self.im_4_updated_norm = self.im_4_updated / self.norm_updated_regathered[:, None] # pre-normalize;
# since tensor_scatter_nd_update doesn't accumulate values, but tensor_scatter_nd_add does, first zero
# out the regions to be updated and then just add them:
recon_zeroed = tf.tensor_scatter_nd_update(self.recon_previous, rc_4,
tf.zeros_like(self.im_4_updated_norm))
self.recon = tf.tensor_scatter_nd_add(recon_zeroed, rc_4, self.im_4_updated_norm)
if assign_update_recon:
with tf.device('/CPU:0'):
self.recon_previous.assign(self.recon)
self.normalize = None # normalize not needed; in fact, normalize_previous also not needed;
else:
self.normalize = tf.scatter_nd(rc_4, w_4, self.recon_shape)
self.recon = tf.scatter_nd(rc_4, im_4, [self.recon_shape[0], self.recon_shape[1], self.num_channels_recon])
self.recon = tf.math.divide_no_nan(self.recon, self.normalize[:, :, None]) # creates recon H by W by C;
if 'camera_parameters_perspective_to' in self.deformation_model:
self.height_map = self.recon[:, :, -1]
if self.use_absolute_scale_calibration:
# divide out the scale factors to get the true height in mm:
self.tensors_to_track['height_map'] = self.height_map / (self.height_scale_factor /
self.another_height_scale_factor)
else:
self.tensors_to_track['height_map'] = self.height_map
def _forward_prediction(self):
# given the reconstruction, generate forward prediction;
# forward model:
ff = tf.gather_nd(self.recon, self.rc_ff)
cc = tf.gather_nd(self.recon, self.rc_cc)
cf = tf.gather_nd(self.recon, self.rc_cf)
fc = tf.gather_nd(self.recon, self.rc_fc)
self.forward = (ff * self.frc[:, 0, None] * self.frc[:, 1, None] +
cc * self.crc[:, 0, None] * self.crc[:, 1, None] +
cf * self.crc[:, 0, None] * self.frc[:, 1, None] +
fc * self.frc[:, 0, None] * self.crc[:, 1, None])
self.forward /= ((self.frc[:, 0, None] * self.frc[:, 1, None]) +
(self.crc[:, 0, None] * self.crc[:, 1, None]) +
(self.crc[:, 0, None] * self.frc[:, 1, None]) +
(self.frc[:, 0, None] * self.crc[:, 1, None]))
if 'camera_parameters_perspective' in self.deformation_model:
# split off the last dimension, the height dimension, to compute the height map MSE:
self.forward_height = self.forward[:, -1]
self.error_height = self.forward_height - self.im[:, -1] # save this for computing error map;
self.MSE_height = tf.reduce_mean(self.error_height ** 2)
self.error = self.forward[:, :-1] - self.im[:, :-1] # remaining channels are the actual recon;
self.MSE = tf.reduce_mean(self.error ** 2)
self.recon = self.recon[:, :, :-1] # discard the height map channel, as it's recorded elsewhere;
else: