-
Notifications
You must be signed in to change notification settings - Fork 0
/
optimizers.py
38 lines (34 loc) · 1.51 KB
/
optimizers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from jax.example_libraries.optimizers import optimizer, make_schedule
import jax.numpy as jnp
@optimizer
def adamw(step_size, b1=0.9, b2=0.999, eps=1e-8, wd=1e-6):
"""Construct optimizer triple for Adamw.
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to a positive scalar.
b1: optional, a positive scalar value for beta_1, the exponential decay rate
for the first moment estimates (default 0.9).
b2: optional, a positive scalar value for beta_2, the exponential decay rate
for the second moment estimates (default 0.999).
eps: optional, a positive scalar value for epsilon, a small constant for
numerical stability (default 1e-8).
Returns:
An (init_fun, update_fun, get_params) triple.
"""
step_size = make_schedule(step_size)
def init(x0):
m0 = jnp.zeros_like(x0)
v0 = jnp.zeros_like(x0)
return x0, m0, v0
def update(i, g, state):
x, m, v = state
m = (1 - b1) * g + b1 * m # First moment estimate.
v = (1 - b2) * jnp.square(g) + b2 * v # Second moment estimate.
mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction.
vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
x = x - step_size(i) * (mhat / (jnp.sqrt(vhat) + eps)) - wd * x
return x, m, v
def get_params(state):
x, _, _ = state
return x
return init, update, get_params