forked from google-research/simclr
-
Notifications
You must be signed in to change notification settings - Fork 0
/
lars_optimizer.py
169 lines (147 loc) · 6.17 KB
/
lars_optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# coding=utf-8
# Copyright 2020 The SimCLR Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific simclr governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions and classes related to optimization (weight updates)."""
import re
import tensorflow.compat.v2 as tf
EETA_DEFAULT = 0.001
class LARSOptimizer(tf.keras.optimizers.Optimizer):
"""Layer-wise Adaptive Rate Scaling for large batch training.
Introduced by "Large Batch Training of Convolutional Networks" by Y. You,
I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888)
"""
def __init__(self,
learning_rate,
momentum=0.9,
use_nesterov=False,
weight_decay=0.0,
exclude_from_weight_decay=None,
exclude_from_layer_adaptation=None,
classic_momentum=True,
eeta=EETA_DEFAULT,
name="LARSOptimizer"):
"""Constructs a LARSOptimizer.
Args:
learning_rate: A `float` for learning rate.
momentum: A `float` for momentum.
use_nesterov: A 'Boolean' for whether to use nesterov momentum.
weight_decay: A `float` for weight decay.
exclude_from_weight_decay: A list of `string` for variable screening, if
any of the string appears in a variable's name, the variable will be
excluded for computing weight decay. For example, one could specify
the list like ['batch_normalization', 'bias'] to exclude BN and bias
from weight decay.
exclude_from_layer_adaptation: Similar to exclude_from_weight_decay, but
for layer adaptation. If it is None, it will be defaulted the same as
exclude_from_weight_decay.
classic_momentum: A `boolean` for whether to use classic (or popular)
momentum. The learning rate is applied during momeuntum update in
classic momentum, but after momentum for popular momentum.
eeta: A `float` for scaling of learning rate when computing trust ratio.
name: The name for the scope.
"""
super(LARSOptimizer, self).__init__(name)
self._set_hyper("learning_rate", learning_rate)
self.momentum = momentum
self.weight_decay = weight_decay
self.use_nesterov = use_nesterov
self.classic_momentum = classic_momentum
self.eeta = eeta
self.exclude_from_weight_decay = exclude_from_weight_decay
# exclude_from_layer_adaptation is set to exclude_from_weight_decay if the
# arg is None.
if exclude_from_layer_adaptation:
self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
else:
self.exclude_from_layer_adaptation = exclude_from_weight_decay
def _create_slots(self, var_list):
for v in var_list:
self.add_slot(v, "Momentum")
def _resource_apply_dense(self, grad, param, apply_state=None):
if grad is None or param is None:
return tf.no_op()
var_device, var_dtype = param.device, param.dtype.base_dtype
coefficients = ((apply_state or {}).get((var_device, var_dtype)) or
self._fallback_apply_state(var_device, var_dtype))
learning_rate = coefficients["lr_t"]
param_name = param.name
v = self.get_slot(param, "Momentum")
if self._use_weight_decay(param_name):
grad += self.weight_decay * param
if self.classic_momentum:
trust_ratio = 1.0
if self._do_layer_adaptation(param_name):
w_norm = tf.norm(param, ord=2)
g_norm = tf.norm(grad, ord=2)
trust_ratio = tf.where(
tf.greater(w_norm, 0),
tf.where(tf.greater(g_norm, 0), (self.eeta * w_norm / g_norm), 1.0),
1.0)
scaled_lr = learning_rate * trust_ratio
next_v = tf.multiply(self.momentum, v) + scaled_lr * grad
if self.use_nesterov:
update = tf.multiply(self.momentum, next_v) + scaled_lr * grad
else:
update = next_v
next_param = param - update
else:
next_v = tf.multiply(self.momentum, v) + grad
if self.use_nesterov:
update = tf.multiply(self.momentum, next_v) + grad
else:
update = next_v
trust_ratio = 1.0
if self._do_layer_adaptation(param_name):
w_norm = tf.norm(param, ord=2)
v_norm = tf.norm(update, ord=2)
trust_ratio = tf.where(
tf.greater(w_norm, 0),
tf.where(tf.greater(v_norm, 0), (self.eeta * w_norm / v_norm), 1.0),
1.0)
scaled_lr = trust_ratio * learning_rate
next_param = param - scaled_lr * update
return tf.group(*[
param.assign(next_param, use_locking=False),
v.assign(next_v, use_locking=False)
])
def _use_weight_decay(self, param_name):
"""Whether to use L2 weight decay for `param_name`."""
if not self.weight_decay:
return False
if self.exclude_from_weight_decay:
for r in self.exclude_from_weight_decay:
# TODO(srbs): Try to avoid name based filtering.
if re.search(r, param_name) is not None:
return False
return True
def _do_layer_adaptation(self, param_name):
"""Whether to do layer-wise learning rate adaptation for `param_name`."""
if self.exclude_from_layer_adaptation:
for r in self.exclude_from_layer_adaptation:
# TODO(srbs): Try to avoid name based filtering.
if re.search(r, param_name) is not None:
return False
return True
def get_config(self):
config = super(LARSOptimizer, self).get_config()
config.update({
"learning_rate": self._serialize_hyperparameter("learning_rate"),
"momentum": self.momentum,
"classic_momentum": self.classic_momentum,
"weight_decay": self.weight_decay,
"eeta": self.eeta,
"use_nesterov": self.use_nesterov,
})
return config