Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update lion_tf_2 #1184

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 106 additions & 84 deletions lion/lion_tf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,91 +14,113 @@
# ==============================================================================
"""TF2 implementation of the Lion optimizer."""

import tensorflow.compat.v2 as tf
import tensorflow as tf
from packaging.version import parse

if parse(tf.__version__) > parse('2.11.0'):
from tensorflow.keras.optimizers.legacy import Optimizer as keras_opt
else:
from tensorflow.keras.optimizers import Optimizer as keras_opt


class Lion(tf.keras.optimizers.legacy.Optimizer):
r"""Optimizer that implements the Lion algorithm."""
class Lion(keras_opt):
def __init__(
self,
learning_rate=0.0001,
beta_1=0.9,
beta_2=0.99,
wd=0,
name='lion',
**kwargs
):
super().__init__(name, **kwargs)
self._set_hyper('learning_rate', kwargs.get('lr', learning_rate))
self._set_hyper('beta_1', beta_1)
self._set_hyper('beta_2', beta_2)
self._set_hyper('wd', wd)

def _create_slots(self, var_list):
# Create slots for the first and second moments.
# Separate for-loops to respect the ordering of slot variables from v1.
for var in var_list:
self.add_slot(var, 'm')

def _prepare_local(self, var_device, var_dtype, apply_state):
super(Lion, self)._prepare_local(var_device, var_dtype, apply_state)
beta_1_t = tf.identity(self._get_hyper('beta_1', var_dtype))
beta_2_t = tf.identity(self._get_hyper('beta_2', var_dtype))
wd_t = tf.identity(self._get_hyper('wd', var_dtype))
lr = apply_state[(var_device, var_dtype)]['lr_t']
apply_state[(var_device, var_dtype)].update(
dict(
lr=lr,
beta_1_t=beta_1_t,
one_minus_beta_1_t=1 - beta_1_t,
beta_2_t=beta_2_t,
one_minus_beta_2_t=1 - beta_2_t,
wd_t=wd_t
)
)

@tf.function(jit_compile=True)
def _resource_apply_dense(self, grad, var, apply_state=None):
var_device, var_dtype = var.device, var.dtype.base_dtype
coefficients = (
(apply_state or {}).get(
(
var_device, var_dtype
)
) or self._fallback_apply_state(var_device, var_dtype)
)

m = self.get_slot(var, 'm')
var_t = var.assign_sub(
coefficients['lr_t'] * (
tf.math.sign(
m * coefficients['beta_1_t'] +
grad * coefficients['one_minus_beta_1_t']
) + var * coefficients['wd_t'])
)

with tf.control_dependencies([var_t]):
m.assign(
m * coefficients['beta_2_t'] + grad * coefficients['one_minus_beta_2_t']
)

@tf.function(jit_compile=True)
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
var_device, var_dtype = var.device, var.dtype.base_dtype
coefficients = (
(apply_state or {}).get(
(
var_device, var_dtype
)
) or self._fallback_apply_state(var_device, var_dtype)
)

def __init__(self,
learning_rate=0.0001,
beta_1=0.9,
beta_2=0.99,
wd=0,
name='lion',
**kwargs):
"""Construct a new Lion optimizer."""
m = self.get_slot(var, 'm')
m_t = m.assign(m * coefficients['beta_1_t'])
m_scaled_g_values = grad * coefficients['one_minus_beta_1_t']
m_t = m_t.scatter_add(tf.IndexedSlices(m_scaled_g_values, indices))
var_t = var.assign_sub(
coefficients['lr'] * (
tf.math.sign(m_t) + var * coefficients['wd_t'])
)

super(Lion, self).__init__(name, **kwargs)
self._set_hyper('learning_rate', kwargs.get('lr', learning_rate))
self._set_hyper('beta_1', beta_1)
self._set_hyper('beta_2', beta_2)
self._set_hyper('wd', wd)

def _create_slots(self, var_list):
# Create slots for the first and second moments.
# Separate for-loops to respect the ordering of slot variables from v1.
for var in var_list:
self.add_slot(var, 'm')

def _prepare_local(self, var_device, var_dtype, apply_state):
super(Lion, self)._prepare_local(var_device, var_dtype, apply_state)

beta_1_t = tf.identity(self._get_hyper('beta_1', var_dtype))
beta_2_t = tf.identity(self._get_hyper('beta_2', var_dtype))
wd_t = tf.identity(self._get_hyper('wd', var_dtype))
lr = apply_state[(var_device, var_dtype)]['lr_t']
apply_state[(var_device, var_dtype)].update(
dict(
lr=lr,
beta_1_t=beta_1_t,
one_minus_beta_1_t=1 - beta_1_t,
beta_2_t=beta_2_t,
one_minus_beta_2_t=1 - beta_2_t,
wd_t=wd_t))

@tf.function(jit_compile=True)
def _resource_apply_dense(self, grad, var, apply_state=None):
var_device, var_dtype = var.device, var.dtype.base_dtype
coefficients = ((apply_state or {}).get((var_device, var_dtype)) or
self._fallback_apply_state(var_device, var_dtype))

m = self.get_slot(var, 'm')
var_t = var.assign_sub(
coefficients['lr_t'] *
(tf.math.sign(m * coefficients['beta_1_t'] +
grad * coefficients['one_minus_beta_1_t']) +
var * coefficients['wd_t']))
with tf.control_dependencies([var_t]):
m.assign(m * coefficients['beta_2_t'] +
grad * coefficients['one_minus_beta_2_t'])

@tf.function(jit_compile=True)
def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
var_device, var_dtype = var.device, var.dtype.base_dtype
coefficients = ((apply_state or {}).get((var_device, var_dtype)) or
self._fallback_apply_state(var_device, var_dtype))

m = self.get_slot(var, 'm')
m_t = m.assign(m * coefficients['beta_1_t'])
m_scaled_g_values = grad * coefficients['one_minus_beta_1_t']
m_t = m_t.scatter_add(tf.IndexedSlices(m_scaled_g_values, indices))
var_t = var.assign_sub(coefficients['lr'] *
(tf.math.sign(m_t) + var * coefficients['wd_t']))

with tf.control_dependencies([var_t]):
m_t = m_t.scatter_add(tf.IndexedSlices(-m_scaled_g_values, indices))
m_t = m_t.assign(m_t * coefficients['beta_2_t'] /
coefficients['beta_1_t'])
m_scaled_g_values = grad * coefficients['one_minus_beta_2_t']
m_t.scatter_add(tf.IndexedSlices(m_scaled_g_values, indices))

def get_config(self):
config = super(Lion, self).get_config()
config.update({
'learning_rate': self._serialize_hyperparameter('learning_rate'),
'beta_1': self._serialize_hyperparameter('beta_1'),
'beta_2': self._serialize_hyperparameter('beta_2'),
'wd': self._serialize_hyperparameter('wd'),
})
return config
with tf.control_dependencies([var_t]):
m_t = m_t.scatter_add(tf.IndexedSlices(-m_scaled_g_values, indices))
m_t = m_t.assign(
m_t * coefficients['beta_2_t'] / coefficients['beta_1_t']
)
m_scaled_g_values = grad * coefficients['one_minus_beta_2_t']
m_t.scatter_add(tf.IndexedSlices(m_scaled_g_values, indices))

def get_config(self):
config = super(Lion, self).get_config()
config.update({
'learning_rate': self._serialize_hyperparameter('learning_rate'),
'beta_1': self._serialize_hyperparameter('beta_1'),
'beta_2': self._serialize_hyperparameter('beta_2'),
'wd': self._serialize_hyperparameter('wd'),
})
return config