-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathwarmup_lr.py
42 lines (36 loc) · 1.38 KB
/
warmup_lr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# taken from https://nlp.seas.harvard.edu/2018/04/03/attention.html#optimizer
class WarmupWrapper:
"Optim wrapper that implements rate."
def __init__(self, hidden_dim, warmup, optimizer):
self.optimizer = optimizer
self._step = 0
self.warmup = warmup
self.hidden_dim = hidden_dim
self._rate = 0
def state_dict(self):
"""Returns the state of the warmup scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
is not the optimizer.
"""
return {key: value for key, value in self.__dict__.items() if key != 'optimizer'}
def load_state_dict(self, state_dict):
"""Loads the warmup scheduler's state.
Arguments:
state_dict (dict): warmup scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
def step(self):
"Update parameters and rate"
self._step += 1
rate = self.rate()
for p in self.optimizer.param_groups:
p['lr'] = rate
self._rate = rate
self.optimizer.step()
def rate(self, step = None):
"Implement `lrate` above"
if step is None:
step = self._step
return (self.hidden_dim ** (-0.5) *
min(step ** (-0.5), step * self.warmup ** (-1.5)))