diff --git a/lion/lion_tf2.py b/lion/lion_tf2.py index 82432a77..cb5e3737 100644 --- a/lion/lion_tf2.py +++ b/lion/lion_tf2.py @@ -14,91 +14,113 @@ # ============================================================================== """TF2 implementation of the Lion optimizer.""" -import tensorflow.compat.v2 as tf +import tensorflow as tf +from packaging.version import parse +if parse(tf.__version__) > parse('2.11.0'): + from tensorflow.keras.optimizers.legacy import Optimizer as keras_opt +else: + from tensorflow.keras.optimizers import Optimizer as keras_opt + -class Lion(tf.keras.optimizers.legacy.Optimizer): - r"""Optimizer that implements the Lion algorithm.""" +class Lion(keras_opt): + def __init__( + self, + learning_rate=0.0001, + beta_1=0.9, + beta_2=0.99, + wd=0, + name='lion', + **kwargs + ): + super().__init__(name, **kwargs) + self._set_hyper('learning_rate', kwargs.get('lr', learning_rate)) + self._set_hyper('beta_1', beta_1) + self._set_hyper('beta_2', beta_2) + self._set_hyper('wd', wd) + + def _create_slots(self, var_list): + # Create slots for the first and second moments. + # Separate for-loops to respect the ordering of slot variables from v1. + for var in var_list: + self.add_slot(var, 'm') + + def _prepare_local(self, var_device, var_dtype, apply_state): + super(Lion, self)._prepare_local(var_device, var_dtype, apply_state) + beta_1_t = tf.identity(self._get_hyper('beta_1', var_dtype)) + beta_2_t = tf.identity(self._get_hyper('beta_2', var_dtype)) + wd_t = tf.identity(self._get_hyper('wd', var_dtype)) + lr = apply_state[(var_device, var_dtype)]['lr_t'] + apply_state[(var_device, var_dtype)].update( + dict( + lr=lr, + beta_1_t=beta_1_t, + one_minus_beta_1_t=1 - beta_1_t, + beta_2_t=beta_2_t, + one_minus_beta_2_t=1 - beta_2_t, + wd_t=wd_t + ) + ) + + @tf.function(jit_compile=True) + def _resource_apply_dense(self, grad, var, apply_state=None): + var_device, var_dtype = var.device, var.dtype.base_dtype + coefficients = ( + (apply_state or {}).get( + ( + var_device, var_dtype + ) + ) or self._fallback_apply_state(var_device, var_dtype) + ) + + m = self.get_slot(var, 'm') + var_t = var.assign_sub( + coefficients['lr_t'] * ( + tf.math.sign( + m * coefficients['beta_1_t'] + + grad * coefficients['one_minus_beta_1_t'] + ) + var * coefficients['wd_t']) + ) + + with tf.control_dependencies([var_t]): + m.assign( + m * coefficients['beta_2_t'] + grad * coefficients['one_minus_beta_2_t'] + ) + + @tf.function(jit_compile=True) + def _resource_apply_sparse(self, grad, var, indices, apply_state=None): + var_device, var_dtype = var.device, var.dtype.base_dtype + coefficients = ( + (apply_state or {}).get( + ( + var_device, var_dtype + ) + ) or self._fallback_apply_state(var_device, var_dtype) + ) - def __init__(self, - learning_rate=0.0001, - beta_1=0.9, - beta_2=0.99, - wd=0, - name='lion', - **kwargs): - """Construct a new Lion optimizer.""" + m = self.get_slot(var, 'm') + m_t = m.assign(m * coefficients['beta_1_t']) + m_scaled_g_values = grad * coefficients['one_minus_beta_1_t'] + m_t = m_t.scatter_add(tf.IndexedSlices(m_scaled_g_values, indices)) + var_t = var.assign_sub( + coefficients['lr'] * ( + tf.math.sign(m_t) + var * coefficients['wd_t']) + ) - super(Lion, self).__init__(name, **kwargs) - self._set_hyper('learning_rate', kwargs.get('lr', learning_rate)) - self._set_hyper('beta_1', beta_1) - self._set_hyper('beta_2', beta_2) - self._set_hyper('wd', wd) - - def _create_slots(self, var_list): - # Create slots for the first and second moments. - # Separate for-loops to respect the ordering of slot variables from v1. - for var in var_list: - self.add_slot(var, 'm') - - def _prepare_local(self, var_device, var_dtype, apply_state): - super(Lion, self)._prepare_local(var_device, var_dtype, apply_state) - - beta_1_t = tf.identity(self._get_hyper('beta_1', var_dtype)) - beta_2_t = tf.identity(self._get_hyper('beta_2', var_dtype)) - wd_t = tf.identity(self._get_hyper('wd', var_dtype)) - lr = apply_state[(var_device, var_dtype)]['lr_t'] - apply_state[(var_device, var_dtype)].update( - dict( - lr=lr, - beta_1_t=beta_1_t, - one_minus_beta_1_t=1 - beta_1_t, - beta_2_t=beta_2_t, - one_minus_beta_2_t=1 - beta_2_t, - wd_t=wd_t)) - - @tf.function(jit_compile=True) - def _resource_apply_dense(self, grad, var, apply_state=None): - var_device, var_dtype = var.device, var.dtype.base_dtype - coefficients = ((apply_state or {}).get((var_device, var_dtype)) or - self._fallback_apply_state(var_device, var_dtype)) - - m = self.get_slot(var, 'm') - var_t = var.assign_sub( - coefficients['lr_t'] * - (tf.math.sign(m * coefficients['beta_1_t'] + - grad * coefficients['one_minus_beta_1_t']) + - var * coefficients['wd_t'])) - with tf.control_dependencies([var_t]): - m.assign(m * coefficients['beta_2_t'] + - grad * coefficients['one_minus_beta_2_t']) - - @tf.function(jit_compile=True) - def _resource_apply_sparse(self, grad, var, indices, apply_state=None): - var_device, var_dtype = var.device, var.dtype.base_dtype - coefficients = ((apply_state or {}).get((var_device, var_dtype)) or - self._fallback_apply_state(var_device, var_dtype)) - - m = self.get_slot(var, 'm') - m_t = m.assign(m * coefficients['beta_1_t']) - m_scaled_g_values = grad * coefficients['one_minus_beta_1_t'] - m_t = m_t.scatter_add(tf.IndexedSlices(m_scaled_g_values, indices)) - var_t = var.assign_sub(coefficients['lr'] * - (tf.math.sign(m_t) + var * coefficients['wd_t'])) - - with tf.control_dependencies([var_t]): - m_t = m_t.scatter_add(tf.IndexedSlices(-m_scaled_g_values, indices)) - m_t = m_t.assign(m_t * coefficients['beta_2_t'] / - coefficients['beta_1_t']) - m_scaled_g_values = grad * coefficients['one_minus_beta_2_t'] - m_t.scatter_add(tf.IndexedSlices(m_scaled_g_values, indices)) - - def get_config(self): - config = super(Lion, self).get_config() - config.update({ - 'learning_rate': self._serialize_hyperparameter('learning_rate'), - 'beta_1': self._serialize_hyperparameter('beta_1'), - 'beta_2': self._serialize_hyperparameter('beta_2'), - 'wd': self._serialize_hyperparameter('wd'), - }) - return config + with tf.control_dependencies([var_t]): + m_t = m_t.scatter_add(tf.IndexedSlices(-m_scaled_g_values, indices)) + m_t = m_t.assign( + m_t * coefficients['beta_2_t'] / coefficients['beta_1_t'] + ) + m_scaled_g_values = grad * coefficients['one_minus_beta_2_t'] + m_t.scatter_add(tf.IndexedSlices(m_scaled_g_values, indices)) + + def get_config(self): + config = super(Lion, self).get_config() + config.update({ + 'learning_rate': self._serialize_hyperparameter('learning_rate'), + 'beta_1': self._serialize_hyperparameter('beta_1'), + 'beta_2': self._serialize_hyperparameter('beta_2'), + 'wd': self._serialize_hyperparameter('wd'), + }) + return config