-
Notifications
You must be signed in to change notification settings - Fork 0
/
LBFGS.py
executable file
·1106 lines (900 loc) · 41.5 KB
/
LBFGS.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import torch
import numpy as np
import matplotlib.pyplot as plt
from functools import reduce
from copy import deepcopy
from torch.optim import Optimizer
def is_legal(v):
"""
Checks that tensor is not NaN or Inf.
Inputs:
v (tensor): tensor to be checked
"""
legal = not torch.isnan(v).any() and not torch.isinf(v)
return legal
def polyinterp(points, x_min_bound=None, x_max_bound=None, plot=False):
"""
Gives the minimizer and minimum of the interpolating polynomial over given points
based on function and derivative information. Defaults to bisection if no critical
points are valid.
Based on polyinterp.m Matlab function in minFunc by Mark Schmidt with some slight
modifications.
Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
Last edited 12/6/18.
Inputs:
points (nparray): two-dimensional array with each point of form [x f g]
x_min_bound (float): minimum value that brackets minimum (default: minimum of points)
x_max_bound (float): maximum value that brackets minimum (default: maximum of points)
plot (bool): plot interpolating polynomial
Outputs:
x_sol (float): minimizer of interpolating polynomial
F_min (float): minimum of interpolating polynomial
Note:
. Set f or g to np.nan if they are unknown
"""
no_points = points.shape[0]
order = np.sum(1 - np.isnan(points[:, 1:3]).astype('int')) - 1
x_min = np.min(points[:, 0])
x_max = np.max(points[:, 0])
# compute bounds of interpolation area
if x_min_bound is None:
x_min_bound = x_min
if x_max_bound is None:
x_max_bound = x_max
# explicit formula for quadratic interpolation
if no_points == 2 and order == 2 and plot is False:
# Solution to quadratic interpolation is given by:
# a = -(f1 - f2 - g1(x1 - x2))/(x1 - x2)^2
# x_min = x1 - g1/(2a)
# if x1 = 0, then is given by:
# x_min = - (g1*x2^2)/(2(f2 - f1 - g1*x2))
if points[0, 0] == 0:
x_sol = -points[0, 2] * points[1, 0] ** 2 / (2 * (points[1, 1] - points[0, 1] - points[0, 2] * points[1, 0]))
else:
a = -(points[0, 1] - points[1, 1] - points[0, 2] * (points[0, 0] - points[1, 0])) / (points[0, 0] - points[1, 0]) ** 2
x_sol = points[0, 0] - points[0, 2]/(2*a)
x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
# explicit formula for cubic interpolation
elif no_points == 2 and order == 3 and plot is False:
# Solution to cubic interpolation is given by:
# d1 = g1 + g2 - 3((f1 - f2)/(x1 - x2))
# d2 = sqrt(d1^2 - g1*g2)
# x_min = x2 - (x2 - x1)*((g2 + d2 - d1)/(g2 - g1 + 2*d2))
d1 = points[0, 2] + points[1, 2] - 3 * ((points[0, 1] - points[1, 1]) / (points[0, 0] - points[1, 0]))
d2 = np.sqrt(d1 ** 2 - points[0, 2] * points[1, 2])
if np.isreal(d2):
x_sol = points[1, 0] - (points[1, 0] - points[0, 0]) * ((points[1, 2] + d2 - d1) / (points[1, 2] - points[0, 2] + 2 * d2))
x_sol = np.minimum(np.maximum(x_min_bound, x_sol), x_max_bound)
else:
x_sol = (x_max_bound + x_min_bound)/2
# solve linear system
else:
# define linear constraints
A = np.zeros((0, order + 1))
b = np.zeros((0, 1))
# add linear constraints on function values
for i in range(no_points):
if not np.isnan(points[i, 1]):
constraint = np.zeros((1, order + 1))
for j in range(order, -1, -1):
constraint[0, order - j] = points[i, 0] ** j
A = np.append(A, constraint, 0)
b = np.append(b, points[i, 1])
# add linear constraints on gradient values
for i in range(no_points):
if not np.isnan(points[i, 2]):
constraint = np.zeros((1, order + 1))
for j in range(order):
constraint[0, j] = (order - j) * points[i, 0] ** (order - j - 1)
A = np.append(A, constraint, 0)
b = np.append(b, points[i, 2])
# check if system is solvable
if A.shape[0] != A.shape[1] or np.linalg.matrix_rank(A) != A.shape[0]:
x_sol = (x_min_bound + x_max_bound)/2
f_min = np.Inf
else:
# solve linear system for interpolating polynomial
coeff = np.linalg.solve(A, b)
# compute critical points
dcoeff = np.zeros(order)
for i in range(len(coeff) - 1):
dcoeff[i] = coeff[i] * (order - i)
crit_pts = np.array([x_min_bound, x_max_bound])
crit_pts = np.append(crit_pts, points[:, 0])
if not np.isinf(dcoeff).any():
roots = np.roots(dcoeff)
crit_pts = np.append(crit_pts, roots)
# test critical points
f_min = np.Inf
x_sol = (x_min_bound + x_max_bound) / 2 # defaults to bisection
for crit_pt in crit_pts:
if np.isreal(crit_pt) and crit_pt >= x_min_bound and crit_pt <= x_max_bound:
F_cp = np.polyval(coeff, crit_pt)
if np.isreal(F_cp) and F_cp < f_min:
x_sol = np.real(crit_pt)
f_min = np.real(F_cp)
if(plot):
plt.figure()
x = np.arange(x_min_bound, x_max_bound, (x_max_bound - x_min_bound)/10000)
f = np.polyval(coeff, x)
plt.plot(x, f)
plt.plot(x_sol, f_min, 'x')
return x_sol
class LBFGS(Optimizer):
"""
Implements the L-BFGS algorithm. Compatible with multi-batch and full-overlap
L-BFGS implementations and (stochastic) Powell damping. Partly based on the
original L-BFGS implementation in PyTorch, Mark Schmidt's minFunc MATLAB code,
and Michael Overton's weak Wolfe line search MATLAB code.
Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
Last edited 10/20/20.
Warnings:
. Does not support per-parameter options and parameter groups.
. All parameters have to be on a single device.
Inputs:
lr (float): steplength or learning rate (default: 1)
history_size (int): update history size (default: 10)
line_search (str): designates line search to use (default: 'Wolfe')
Options:
'None': uses steplength designated in algorithm
'Armijo': uses Armijo backtracking line search
'Wolfe': uses Armijo-Wolfe bracketing line search
dtype: data type (default: torch.float)
debug (bool): debugging mode
References:
[1] Berahas, Albert S., Jorge Nocedal, and Martin Takác. "A Multi-Batch L-BFGS
Method for Machine Learning." Advances in Neural Information Processing
Systems. 2016.
[2] Bollapragada, Raghu, et al. "A Progressive Batching L-BFGS Method for Machine
Learning." International Conference on Machine Learning. 2018.
[3] Lewis, Adrian S., and Michael L. Overton. "Nonsmooth Optimization via Quasi-Newton
Methods." Mathematical Programming 141.1-2 (2013): 135-163.
[4] Liu, Dong C., and Jorge Nocedal. "On the Limited Memory BFGS Method for
Large Scale Optimization." Mathematical Programming 45.1-3 (1989): 503-528.
[5] Nocedal, Jorge. "Updating Quasi-Newton Matrices With Limited Storage."
Mathematics of Computation 35.151 (1980): 773-782.
[6] Nocedal, Jorge, and Stephen J. Wright. "Numerical Optimization." Springer New York,
2006.
[7] Schmidt, Mark. "minFunc: Unconstrained Differentiable Multivariate Optimization
in Matlab." Software available at http://www.cs.ubc.ca/~schmidtm/Software/minFunc.html
(2005).
[8] Schraudolph, Nicol N., Jin Yu, and Simon Günter. "A Stochastic Quasi-Newton
Method for Online Convex Optimization." Artificial Intelligence and Statistics.
2007.
[9] Wang, Xiao, et al. "Stochastic Quasi-Newton Methods for Nonconvex Stochastic
Optimization." SIAM Journal on Optimization 27.2 (2017): 927-956.
"""
def __init__(self, params, lr=1., history_size=10, line_search='Wolfe',
dtype=torch.float, debug=False):
# ensure inputs are valid
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0 <= history_size:
raise ValueError("Invalid history size: {}".format(history_size))
if line_search not in ['Armijo', 'Wolfe', 'None']:
raise ValueError("Invalid line search: {}".format(line_search))
defaults = dict(lr=lr, history_size=history_size, line_search=line_search, dtype=dtype, debug=debug)
super(LBFGS, self).__init__(params, defaults)
if len(self.param_groups) != 1:
raise ValueError("L-BFGS doesn't support per-parameter options "
"(parameter groups)")
self._params = self.param_groups[0]['params']
self._numel_cache = None
state = self.state['global_state']
state.setdefault('n_iter', 0)
state.setdefault('curv_skips', 0)
state.setdefault('fail_skips', 0)
state.setdefault('H_diag',1)
state.setdefault('fail', True)
state['old_dirs'] = []
state['old_stps'] = []
def _numel(self):
if self._numel_cache is None:
self._numel_cache = reduce(lambda total, p: total + p.numel(), self._params, 0)
return self._numel_cache
def _gather_flat_grad(self):
views = []
for p in self._params:
if p.grad is None:
view = p.data.new(p.data.numel()).zero_()
elif p.grad.data.is_sparse:
view = p.grad.data.to_dense().view(-1)
else:
view = p.grad.data.view(-1)
views.append(view)
return torch.cat(views, 0)
def _add_update(self, step_size, update):
offset = 0
for p in self._params:
numel = p.numel()
# view as to avoid deprecated pointwise semantics
p.data.add_(step_size, update[offset:offset + numel].view_as(p.data))
offset += numel
assert offset == self._numel()
def _copy_params(self):
current_params = []
for param in self._params:
current_params.append(deepcopy(param.data))
return current_params
def _load_params(self, current_params):
i = 0
for param in self._params:
param.data[:] = current_params[i]
i += 1
def line_search(self, line_search):
"""
Switches line search option.
Inputs:
line_search (str): designates line search to use
Options:
'None': uses steplength designated in algorithm
'Armijo': uses Armijo backtracking line search
'Wolfe': uses Armijo-Wolfe bracketing line search
"""
group = self.param_groups[0]
group['line_search'] = line_search
return
def two_loop_recursion(self, vec):
"""
Performs two-loop recursion on given vector to obtain Hv.
Inputs:
vec (tensor): 1-D tensor to apply two-loop recursion to
Output:
r (tensor): matrix-vector product Hv
"""
group = self.param_groups[0]
history_size = group['history_size']
state = self.state['global_state']
old_dirs = state.get('old_dirs') # change in gradients
old_stps = state.get('old_stps') # change in iterates
H_diag = state.get('H_diag')
# compute the product of the inverse Hessian approximation and the gradient
num_old = len(old_dirs)
if 'rho' not in state:
state['rho'] = [None] * history_size
state['alpha'] = [None] * history_size
rho = state['rho']
alpha = state['alpha']
for i in range(num_old):
rho[i] = 1. / old_stps[i].dot(old_dirs[i])
q = vec
for i in range(num_old - 1, -1, -1):
alpha[i] = old_dirs[i].dot(q) * rho[i]
q.add_(-alpha[i], old_stps[i])
# multiply by initial Hessian
# r/d is the final direction
r = torch.mul(q, H_diag)
for i in range(num_old):
beta = old_stps[i].dot(r) * rho[i]
r.add_(alpha[i] - beta, old_dirs[i])
return r
def curvature_update(self, flat_grad, eps=1e-2, damping=False):
"""
Performs curvature update.
Inputs:
flat_grad (tensor): 1-D tensor of flattened gradient for computing
gradient difference with previously stored gradient
eps (float): constant for curvature pair rejection or damping (default: 1e-2)
damping (bool): flag for using Powell damping (default: False)
"""
assert len(self.param_groups) == 1
# load parameters
if(eps <= 0):
raise(ValueError('Invalid eps; must be positive.'))
group = self.param_groups[0]
history_size = group['history_size']
debug = group['debug']
# variables cached in state (for tracing)
state = self.state['global_state']
fail = state.get('fail')
# check if line search failed
if not fail:
d = state.get('d')
t = state.get('t')
old_dirs = state.get('old_dirs')
old_stps = state.get('old_stps')
H_diag = state.get('H_diag')
prev_flat_grad = state.get('prev_flat_grad')
Bs = state.get('Bs')
# compute y's
y = flat_grad.sub(prev_flat_grad)
s = d.mul(t)
sBs = s.dot(Bs)
ys = y.dot(s) # y*s
# update L-BFGS matrix
if ys > eps * sBs or damping == True:
# perform Powell damping
if damping == True and ys < eps*sBs:
if debug:
print('Applying Powell damping...')
theta = ((1 - eps) * sBs)/(sBs - ys)
y = theta * y + (1 - theta) * Bs
# updating memory
if len(old_dirs) == history_size:
# shift history by one (limited-memory)
old_dirs.pop(0)
old_stps.pop(0)
# store new direction/step
old_dirs.append(s)
old_stps.append(y)
# update scale of initial Hessian approximation
H_diag = ys / y.dot(y) # (y*y)
state['old_dirs'] = old_dirs
state['old_stps'] = old_stps
state['H_diag'] = H_diag
else:
# save skip
state['curv_skips'] += 1
if debug:
print('Curvature pair skipped due to failed criterion')
else:
# save skip
state['fail_skips'] += 1
if debug:
print('Line search failed; curvature pair update skipped')
return
def _step(self, p_k, g_Ok, g_Sk=None, options=None):
"""
Performs a single optimization step.
Inputs:
p_k (tensor): 1-D tensor specifying search direction
g_Ok (tensor): 1-D tensor of flattened gradient over overlap O_k used
for gradient differencing in curvature pair update
g_Sk (tensor): 1-D tensor of flattened gradient over full sample S_k
used for curvature pair damping or rejection criterion,
if None, will use g_Ok (default: None)
options (dict): contains options for performing line search (default: None)
Options for Armijo backtracking line search:
'closure' (callable): reevaluates model and returns function value
'current_loss' (tensor): objective value at current iterate (default: F(x_k))
'gtd' (tensor): inner product g_Ok'd in line search (default: g_Ok'd)
'eta' (tensor): factor for decreasing steplength > 0 (default: 2)
'c1' (tensor): sufficient decrease constant in (0, 1) (default: 1e-4)
'max_ls' (int): maximum number of line search steps permitted (default: 10)
'interpolate' (bool): flag for using interpolation (default: True)
'inplace' (bool): flag for inplace operations (default: True)
'ls_debug' (bool): debugging mode for line search
Options for Wolfe line search:
'closure' (callable): reevaluates model and returns function value
'current_loss' (tensor): objective value at current iterate (default: F(x_k))
'gtd' (tensor): inner product g_Ok'd in line search (default: g_Ok'd)
'eta' (float): factor for extrapolation (default: 2)
'c1' (float): sufficient decrease constant in (0, 1) (default: 1e-4)
'c2' (float): curvature condition constant in (0, 1) (default: 0.9)
'max_ls' (int): maximum number of line search steps permitted (default: 10)
'interpolate' (bool): flag for using interpolation (default: True)
'inplace' (bool): flag for inplace operations (default: True)
'ls_debug' (bool): debugging mode for line search
Outputs (depends on line search):
. No line search:
t (float): steplength
. Armijo backtracking line search:
F_new (tensor): loss function at new iterate
t (tensor): final steplength
ls_step (int): number of backtracks
closure_eval (int): number of closure evaluations
desc_dir (bool): descent direction flag
True: p_k is descent direction with respect to the line search
function
False: p_k is not a descent direction with respect to the line
search function
fail (bool): failure flag
True: line search reached maximum number of iterations, failed
False: line search succeeded
. Wolfe line search:
F_new (tensor): loss function at new iterate
g_new (tensor): gradient at new iterate
t (float): final steplength
ls_step (int): number of backtracks
closure_eval (int): number of closure evaluations
grad_eval (int): number of gradient evaluations
desc_dir (bool): descent direction flag
True: p_k is descent direction with respect to the line search
function
False: p_k is not a descent direction with respect to the line
search function
fail (bool): failure flag
True: line search reached maximum number of iterations, failed
False: line search succeeded
Notes:
. If encountering line search failure in the deterministic setting, one
should try increasing the maximum number of line search steps max_ls.
"""
if options is None:
options = {}
assert len(self.param_groups) == 1
# load parameter options
group = self.param_groups[0]
lr = group['lr']
line_search = group['line_search']
dtype = group['dtype']
debug = group['debug']
# variables cached in state (for tracing)
state = self.state['global_state']
d = state.get('d')
t = state.get('t')
prev_flat_grad = state.get('prev_flat_grad')
Bs = state.get('Bs')
# keep track of nb of iterations
state['n_iter'] += 1
# set search direction
d = p_k
# modify previous gradient
if prev_flat_grad is None:
prev_flat_grad = g_Ok.clone()
else:
prev_flat_grad.copy_(g_Ok)
# set initial step size
t = lr
# closure evaluation counter
closure_eval = 0
if g_Sk is None:
g_Sk = g_Ok.clone()
# perform Armijo backtracking line search
if line_search == 'Armijo':
# load options
if options:
if 'closure' not in options.keys():
raise(ValueError('closure option not specified.'))
else:
closure = options['closure']
if 'gtd' not in options.keys():
gtd = g_Sk.dot(d)
else:
gtd = options['gtd']
if 'current_loss' not in options.keys():
F_k = closure()
closure_eval += 1
else:
F_k = options['current_loss']
if 'eta' not in options.keys():
eta = 2
elif options['eta'] <= 0:
raise(ValueError('Invalid eta; must be positive.'))
else:
eta = options['eta']
if 'c1' not in options.keys():
c1 = 1e-4
elif options['c1'] >= 1 or options['c1'] <= 0:
raise(ValueError('Invalid c1; must be strictly between 0 and 1.'))
else:
c1 = options['c1']
if 'max_ls' not in options.keys():
max_ls = 10
elif options['max_ls'] <= 0:
raise(ValueError('Invalid max_ls; must be positive.'))
else:
max_ls = options['max_ls']
if 'interpolate' not in options.keys():
interpolate = True
else:
interpolate = options['interpolate']
if 'inplace' not in options.keys():
inplace = True
else:
inplace = options['inplace']
if 'ls_debug' not in options.keys():
ls_debug = False
else:
ls_debug = options['ls_debug']
else:
raise(ValueError('Options are not specified; need closure evaluating function.'))
# initialize values
if interpolate:
if torch.cuda.is_available():
F_prev = torch.tensor(np.nan, dtype=dtype).cuda()
else:
F_prev = torch.tensor(np.nan, dtype=dtype)
ls_step = 0
t_prev = 0 # old steplength
fail = False # failure flag
# begin print for debug mode
if ls_debug:
print('==================================== Begin Armijo line search ===================================')
print('F(x): %.8e g*d: %.8e' % (F_k, gtd))
# check if search direction is descent direction
if gtd >= 0:
desc_dir = False
if debug:
print('Not a descent direction!')
else:
desc_dir = True
# store values if not in-place
if not inplace:
current_params = self._copy_params()
# update and evaluate at new point
self._add_update(t, d)
F_new = closure()
closure_eval += 1
# print info if debugging
if ls_debug:
print('LS Step: %d t: %.8e F(x+td): %.8e F-c1*t*g*d: %.8e F(x): %.8e'
% (ls_step, t, F_new, F_k + c1 * t * gtd, F_k))
# check Armijo condition
while F_new > F_k + c1*t*gtd or not is_legal(F_new):
# check if maximum number of iterations reached
if ls_step >= max_ls:
if inplace:
self._add_update(-t, d)
else:
self._load_params(current_params)
t = 0
F_new = closure()
closure_eval += 1
fail = True
break
else:
# store current steplength
t_new = t
# compute new steplength
# if first step or not interpolating, then multiply by factor
if ls_step == 0 or not interpolate or not is_legal(F_new):
t = t/eta
# if second step, use function value at new point along with
# gradient and function at current iterate
elif ls_step == 1 or not is_legal(F_prev):
t = polyinterp(np.array([[0, F_k.item(), gtd.item()], [t_new, F_new.item(), np.nan]]))
# otherwise, use function values at new point, previous point,
# and gradient and function at current iterate
else:
t = polyinterp(np.array([[0, F_k.item(), gtd.item()], [t_new, F_new.item(), np.nan],
[t_prev, F_prev.item(), np.nan]]))
# if values are too extreme, adjust t
if interpolate:
if t < 1e-3 * t_new:
t = 1e-3 * t_new
elif t > 0.6 * t_new:
t = 0.6 * t_new
# store old point
F_prev = F_new
t_prev = t_new
# update iterate and reevaluate
if inplace:
self._add_update(t - t_new, d)
else:
self._load_params(current_params)
self._add_update(t, d)
F_new = closure()
closure_eval += 1
ls_step += 1 # iterate
# print info if debugging
if ls_debug:
print('LS Step: %d t: %.8e F(x+td): %.8e F-c1*t*g*d: %.8e F(x): %.8e'
% (ls_step, t, F_new, F_k + c1 * t * gtd, F_k))
# store Bs
if Bs is None:
Bs = (g_Sk.mul(-t)).clone()
else:
Bs.copy_(g_Sk.mul(-t))
# print final steplength
if ls_debug:
print('Final Steplength:', t)
print('===================================== End Armijo line search ====================================')
state['d'] = d
state['prev_flat_grad'] = prev_flat_grad
state['t'] = t
state['Bs'] = Bs
state['fail'] = fail
return F_new, t, ls_step, closure_eval, desc_dir, fail
# perform weak Wolfe line search
elif line_search == 'Wolfe':
# load options
if options:
if 'closure' not in options.keys():
raise(ValueError('closure option not specified.'))
else:
closure = options['closure']
if 'current_loss' not in options.keys():
F_k = closure()
closure_eval += 1
else:
F_k = options['current_loss']
if 'gtd' not in options.keys():
gtd = g_Sk.dot(d)
else:
gtd = options['gtd']
if 'eta' not in options.keys():
eta = 2
elif options['eta'] <= 1:
raise(ValueError('Invalid eta; must be greater than 1.'))
else:
eta = options['eta']
if 'c1' not in options.keys():
c1 = 1e-4
elif options['c1'] >= 1 or options['c1'] <= 0:
raise(ValueError('Invalid c1; must be strictly between 0 and 1.'))
else:
c1 = options['c1']
if 'c2' not in options.keys():
c2 = 0.9
elif options['c2'] >= 1 or options['c2'] <= 0:
raise(ValueError('Invalid c2; must be strictly between 0 and 1.'))
elif options['c2'] <= c1:
raise(ValueError('Invalid c2; must be strictly larger than c1.'))
else:
c2 = options['c2']
if 'max_ls' not in options.keys():
max_ls = 10
elif options['max_ls'] <= 0:
raise(ValueError('Invalid max_ls; must be positive.'))
else:
max_ls = options['max_ls']
if 'interpolate' not in options.keys():
interpolate = True
else:
interpolate = options['interpolate']
if 'inplace' not in options.keys():
inplace = True
else:
inplace = options['inplace']
if 'ls_debug' not in options.keys():
ls_debug = False
else:
ls_debug = options['ls_debug']
else:
raise(ValueError('Options are not specified; need closure evaluating function.'))
# initialize counters
ls_step = 0
grad_eval = 0 # tracks gradient evaluations
t_prev = 0 # old steplength
# initialize bracketing variables and flag
alpha = 0
beta = float('Inf')
fail = False
# initialize values for line search
if(interpolate):
F_a = F_k
g_a = gtd
if(torch.cuda.is_available()):
F_b = torch.tensor(np.nan, dtype=dtype).cuda()
g_b = torch.tensor(np.nan, dtype=dtype).cuda()
else:
F_b = torch.tensor(np.nan, dtype=dtype)
g_b = torch.tensor(np.nan, dtype=dtype)
# begin print for debug mode
if ls_debug:
print('==================================== Begin Wolfe line search ====================================')
print('F(x): %.8e g*d: %.8e' % (F_k, gtd))
# check if search direction is descent direction
if gtd >= 0:
desc_dir = False
if debug:
print('Not a descent direction!')
else:
desc_dir = True
# store values if not in-place
if not inplace:
current_params = self._copy_params()
# update and evaluate at new point
self._add_update(t, d)
F_new = closure()
closure_eval += 1
# main loop
while True:
# check if maximum number of line search steps have been reached
if ls_step >= max_ls:
if inplace:
self._add_update(-t, d)
else:
self._load_params(current_params)
t = 0
F_new = closure()
F_new.backward()
g_new = self._gather_flat_grad()
closure_eval += 1
grad_eval += 1
fail = True
break
# print info if debugging
if ls_debug:
print('LS Step: %d t: %.8e alpha: %.8e beta: %.8e'
% (ls_step, t, alpha, beta))
print('Armijo: F(x+td): %.8e F-c1*t*g*d: %.8e F(x): %.8e'
% (F_new, F_k + c1 * t * gtd, F_k))
# check Armijo condition
if F_new > F_k + c1 * t * gtd:
# set upper bound
beta = t
t_prev = t
# update interpolation quantities
if interpolate:
F_b = F_new
if torch.cuda.is_available():
g_b = torch.tensor(np.nan, dtype=dtype).cuda()
else:
g_b = torch.tensor(np.nan, dtype=dtype)
else:
# compute gradient
F_new.backward()
g_new = self._gather_flat_grad()
grad_eval += 1
gtd_new = g_new.dot(d)
# print info if debugging
if ls_debug:
print('Wolfe: g(x+td)*d: %.8e c2*g*d: %.8e gtd: %.8e'
% (gtd_new, c2 * gtd, gtd))
# check curvature condition
if gtd_new < c2 * gtd:
# set lower bound
alpha = t
t_prev = t
# update interpolation quantities
if interpolate:
F_a = F_new
g_a = gtd_new
else:
break
# compute new steplength
# if first step or not interpolating, then bisect or multiply by factor
if not interpolate or not is_legal(F_b):
if beta == float('Inf'):
t = eta*t
else:
t = (alpha + beta)/2.0
# otherwise interpolate between a and b
else:
t = polyinterp(np.array([[alpha, F_a.item(), g_a.item()], [beta, F_b.item(), g_b.item()]]))
# if values are too extreme, adjust t
if beta == float('Inf'):
if t > 2 * eta * t_prev:
t = 2 * eta * t_prev
elif t < eta * t_prev:
t = eta * t_prev
else:
if t < alpha + 0.2 * (beta - alpha):
t = alpha + 0.2 * (beta - alpha)
elif t > (beta - alpha) / 2.0:
t = (beta - alpha) / 2.0
# if we obtain nonsensical value from interpolation
if t <= 0:
t = (beta - alpha) / 2.0
# update parameters
if inplace:
self._add_update(t - t_prev, d)
else:
self._load_params(current_params)
self._add_update(t, d)
# evaluate closure
F_new = closure()
closure_eval += 1
ls_step += 1
# store Bs
if Bs is None:
Bs = (g_Sk.mul(-t)).clone()
else:
Bs.copy_(g_Sk.mul(-t))
# print final steplength
if ls_debug:
print('Final Steplength:', t)
print('===================================== End Wolfe line search =====================================')
state['d'] = d
state['prev_flat_grad'] = prev_flat_grad
state['t'] = t
state['Bs'] = Bs
state['fail'] = fail
return F_new, g_new, t, ls_step, closure_eval, grad_eval, desc_dir, fail
else:
# perform update
self._add_update(t, d)
# store Bs
if Bs is None:
Bs = (g_Sk.mul(-t)).clone()
else:
Bs.copy_(g_Sk.mul(-t))
state['d'] = d
state['prev_flat_grad'] = prev_flat_grad
state['t'] = t
state['Bs'] = Bs
state['fail'] = False
return t
def step(self, p_k, g_Ok, g_Sk=None, options={}):
return self._step(p_k, g_Ok, g_Sk, options)
class FullBatchLBFGS(LBFGS):
"""
Implements full-batch or deterministic L-BFGS algorithm. Compatible with
Powell damping. Can be used when evaluating a deterministic function and
gradient. Wraps the LBFGS optimizer. Performs the two-loop recursion,
updating, and curvature updating in a single step.
Implemented by: Hao-Jun Michael Shi and Dheevatsa Mudigere
Last edited 11/15/18.
Warnings:
. Does not support per-parameter options and parameter groups.
. All parameters have to be on a single device.
Inputs:
lr (float): steplength or learning rate (default: 1)
history_size (int): update history size (default: 10)
line_search (str): designates line search to use (default: 'Wolfe')
Options:
'None': uses steplength designated in algorithm
'Armijo': uses Armijo backtracking line search
'Wolfe': uses Armijo-Wolfe bracketing line search