forked from fmeirinhos/pytorch-hessianfree
-
Notifications
You must be signed in to change notification settings - Fork 0
/
hessianfree.py
348 lines (276 loc) · 11.5 KB
/
hessianfree.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
import torch
from torch.nn.utils.convert_parameters import vector_to_parameters, parameters_to_vector
from functools import reduce
class HessianFree(torch.optim.Optimizer):
"""
Implements the Hessian-free algorithm presented in `Training Deep and
Recurrent Networks with Hessian-Free Optimization`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1)
delta_decay (float, optional): Decay of the previous result of
computing delta with conjugate gradient method for the
initialization of the next conjugate gradient iteration
damping (float, optional): Initial value of the Tikhonov damping
coefficient. (default: 0.5)
max_iter (int, optional): Maximum number of Conjugate-Gradient
iterations (default: 50)
use_gnm (bool, optional): Use the generalized Gauss-Newton matrix:
probably solves the indefiniteness of the Hessian (Section 20.6)
verbose (bool, optional): Print statements (debugging)
.. _Training Deep and Recurrent Networks with Hessian-Free Optimization:
https://doi.org/10.1007/978-3-642-35289-8_27
"""
def __init__(self, params,
lr=1,
damping=0.5,
delta_decay=0.95,
cg_max_iter=100,
use_gnm=True,
verbose=False):
if not (0.0 < lr <= 1):
raise ValueError("Invalid lr: {}".format(lr))
if not (0.0 < damping <= 1):
raise ValueError("Invalid damping: {}".format(damping))
if not cg_max_iter > 0:
raise ValueError("Invalid cg_max_iter: {}".format(cg_max_iter))
defaults = dict(alpha=lr,
damping=damping,
delta_decay=delta_decay,
cg_max_iter=cg_max_iter,
use_gnm=use_gnm,
verbose=verbose)
super(HessianFree, self).__init__(params, defaults)
if len(self.param_groups) != 1:
raise ValueError(
"HessianFree doesn't support per-parameter options (parameter groups)")
self._params = self.param_groups[0]['params']
def _gather_flat_grad(self):
views = list()
for p in self._params:
if p.grad is None:
view = p.data.new(p.data.numel()).zero_()
elif p.grad.data.is_sparse:
view = p.grad.data.to_dense().view(-1)
else:
view = p.grad.contiguous().view(-1)
views.append(view)
return torch.cat(views, 0)
def step(self, closure, b=None, M_inv=None):
"""
Performs a single optimization step.
Arguments:
closure (callable): A closure that re-evaluates the model
and returns a tuple of the loss and the output.
b (callable, optional): A closure that calculates the vector b in
the minimization problem x^T . A . x + x^T b.
M (callable, optional): The INVERSE preconditioner of A
"""
assert len(self.param_groups) == 1
group = self.param_groups[0]
alpha = group['alpha']
delta_decay = group['delta_decay']
cg_max_iter = group['cg_max_iter']
damping = group['damping']
use_gnm = group['use_gnm']
verbose = group['verbose']
state = self.state[self._params[0]]
state.setdefault('func_evals', 0)
state.setdefault('n_iter', 0)
loss_before, output = closure()
current_evals = 1
state['func_evals'] += 1
# Gather current parameters and respective gradients
flat_params = parameters_to_vector(self._params)
flat_grad = self._gather_flat_grad()
# Define linear operator
if use_gnm:
# Generalized Gauss-Newton vector product
def A(x):
return self._Gv(loss_before, output, x, damping)
else:
# Hessian-vector product
def A(x):
return self._Hv(flat_grad, x, damping)
if M_inv is not None:
m_inv = M_inv()
# Preconditioner recipe (Section 20.13)
if m_inv.dim() == 1:
m = (m_inv + damping) ** (-0.85)
def M(x):
return m * x
else:
m = torch.inverse(m_inv + damping * torch.eye(*m_inv.shape))
def M(x):
return m @ x
else:
M = None
b = flat_grad.detach() if b is None else b().detach().flatten()
# Initializing Conjugate-Gradient (Section 20.10)
if state.get('init_delta') is not None:
init_delta = delta_decay * state.get('init_delta')
else:
init_delta = torch.zeros_like(flat_params)
eps = torch.finfo(b.dtype).eps
# Conjugate-Gradient
deltas, Ms = self._CG(A=A, b=b.neg(), x0=init_delta,
M=M, max_iter=cg_max_iter,
tol=1e1 * eps, eps=eps, martens=True)
# Update parameters
delta = state['init_delta'] = deltas[-1]
M = Ms[-1]
vector_to_parameters(flat_params + delta, self._params)
loss_now = closure()[0]
current_evals += 1
state['func_evals'] += 1
# Conjugate-Gradient backtracking (Section 20.8.7)
if verbose:
print("Loss before CG: {}".format(float(loss_before)))
print("Loss before BT: {}".format(float(loss_now)))
for (d, m) in zip(reversed(deltas[:-1][::2]), reversed(Ms[:-1][::2])):
vector_to_parameters(flat_params + d, self._params)
loss_prev = closure()[0]
if float(loss_prev) > float(loss_now):
break
delta = d
M = m
loss_now = loss_prev
if verbose:
print("Loss after BT: {}".format(float(loss_now)))
# The Levenberg-Marquardt Heuristic (Section 20.8.5)
reduction_ratio = (float(loss_now) -
float(loss_before)) / M if M != 0 else 1
if reduction_ratio < 0.25:
group['damping'] *= 3 / 2
elif reduction_ratio > 0.75:
group['damping'] *= 2 / 3
if reduction_ratio < 0:
group['init_delta'] = 0
# Line Searching (Section 20.8.8)
beta = 0.8
c = 1e-2
min_improv = min(c * torch.dot(b, delta), 0)
for _ in range(60):
if float(loss_now) <= float(loss_before) + alpha * min_improv:
break
alpha *= beta
vector_to_parameters(flat_params + alpha * delta, self._params)
loss_now = closure()[0]
else: # No good update found
alpha = 0.0
loss_now = loss_before
# Update the parameters (this time fo real)
vector_to_parameters(flat_params + alpha * delta, self._params)
if verbose:
print("Loss after LS: {0} (lr: {1:.3f})".format(
float(loss_now), alpha))
print("Tikhonov damping: {0:.3f} (reduction ratio: {1:.3f})".format(
group['damping'], reduction_ratio), end='\n\n')
return loss_now
def _CG(self, A, b, x0, M=None, max_iter=50, tol=1.2e-6, eps=1.2e-7,
martens=False):
"""
Minimizes the linear system x^T.A.x - x^T b using the conjugate
gradient method
Arguments:
A (callable): An abstract linear operator implementing the
product A.x. A must represent a hermitian, positive definite
matrix.
b (torch.Tensor): The vector b.
x0 (torch.Tensor): An initial guess for x.
M (callable, optional): An abstract linear operator implementing
the product of the preconditioner (for A) matrix with a vector.
tol (float, optional): Tolerance for convergence.
martens (bool, optional): Flag for Martens' convergence criterion.
"""
x = [x0]
r = A(x[0]) - b
if M is not None:
y = M(r)
p = -y
else:
p = -r
res_i_norm = r @ r
if martens:
m = [0.5 * (r - b) @ x0]
for i in range(max_iter):
Ap = A(p)
alpha = res_i_norm / ((p @ Ap) + eps)
x.append(x[i] + alpha * p)
r = r + alpha * Ap
if M is not None:
y = M(r)
res_ip1_norm = y @ r
else:
res_ip1_norm = r @ r
beta = res_ip1_norm / (res_i_norm + eps)
res_i_norm = res_ip1_norm
# Martens' Relative Progress stopping condition (Section 20.4)
if martens:
m.append(0.5 * A(x[i + 1]) @ x[i + 1] - b @ x[i + 1])
k = max(10, int(i / 10))
if i > k:
stop = (m[i] - m[i - k]) / (m[i] + eps)
if stop < 1e-4:
break
if res_i_norm < tol or torch.isnan(res_i_norm):
break
if M is not None:
p = - y + beta * p
else:
p = - r + beta * p
return (x, m) if martens else (x, None)
def _Hv(self, gradient, vec, damping):
"""
Computes the Hessian vector product.
"""
Hv = self._Rop(gradient, self._params, vec)
# Tikhonov damping (Section 20.8.1)
return Hv.detach() + damping * vec
def _Gv(self, loss, output, vec, damping):
"""
Computes the generalized Gauss-Newton vector product.
"""
Jv = self._Rop(output, self._params, vec)
gradient = torch.autograd.grad(loss, output, create_graph=True)
HJv = self._Rop(gradient, output, Jv)
JHJv = torch.autograd.grad(
output, self._params, grad_outputs=HJv.reshape_as(output), retain_graph=True)
# Tikhonov damping (Section 20.8.1)
return parameters_to_vector(JHJv).detach() + damping * vec
@staticmethod
def _Rop(y, x, v, create_graph=False):
"""
Computes the product (dy_i/dx_j) v_j: R-operator
"""
if isinstance(y, tuple):
ws = [torch.zeros_like(y_i, requires_grad=True) for y_i in y]
else:
ws = torch.zeros_like(y, requires_grad=True)
jacobian = torch.autograd.grad(
y, x, grad_outputs=ws, create_graph=True)
Jv = torch.autograd.grad(parameters_to_vector(
jacobian), ws, grad_outputs=v, create_graph=create_graph)
return parameters_to_vector(Jv)
# The empirical Fisher diagonal (Section 20.11.3)
def empirical_fisher_diagonal(net, xs, ys, criterion):
grads = list()
for (x, y) in zip(xs, ys):
fi = criterion(net(x), y)
grads.append(torch.autograd.grad(fi, net.parameters(),
retain_graph=False))
vec = torch.cat([(torch.stack(p) ** 2).mean(0).detach().flatten()
for p in zip(*grads)])
return vec
# The empirical Fisher matrix (Section 20.11.3)
def empirical_fisher_matrix(net, xs, ys, criterion):
grads = list()
for (x, y) in zip(xs, ys):
fi = criterion(net(x), y)
grad = torch.autograd.grad(fi, net.parameters(),
retain_graph=False)
grads.append(torch.cat([g.detach().flatten() for g in grad]))
grads = torch.stack(grads)
n_batch = grads.shape[0]
return torch.einsum('ij,ik->jk', grads, grads) / n_batch