|
26 | 26 | import jax.extend as jex |
27 | 27 | import jax.numpy as jnp |
28 | 28 | from jax_tpu_embedding.sparsecore.lib.core.primitives import sparse_dense_matmul_grad_with_adagrad |
| 29 | +from jax_tpu_embedding.sparsecore.lib.core.primitives import sparse_dense_matmul_grad_with_laprop |
29 | 30 | from jax_tpu_embedding.sparsecore.lib.core.primitives import sparse_dense_matmul_grad_with_sgd |
30 | 31 |
|
31 | 32 | HyperParameterType: TypeAlias = Callable[[], jax.Array] | float |
|
45 | 46 | "AdagradSlotVariables", ["accumulator"] |
46 | 47 | ) |
47 | 48 |
|
| 49 | +LaPropSlotVariables = collections.namedtuple( |
| 50 | + "LaPropSlotVariables", ["mu", "nu"] |
| 51 | +) |
| 52 | + |
48 | 53 |
|
49 | 54 | # TODO(b/365975374): Create helper functions for generating OptimizerSpecs. |
50 | 55 | @dataclasses.dataclass(frozen=True, order=True) |
@@ -81,6 +86,12 @@ def get_learning_rate(self, step: int | None = None) -> jax.Array: |
81 | 86 | else: |
82 | 87 | return jnp.array(self.learning_rate, dtype=jnp.float32) |
83 | 88 |
|
| 89 | + def get_hyperparameters( |
| 90 | + self, step: int | None = None |
| 91 | + ) -> tuple[jax.Array, ...]: |
| 92 | + """Returns the hyperparameters for the optimizer.""" |
| 93 | + return (self.get_learning_rate(step),) |
| 94 | + |
84 | 95 | def slot_variables_initializers(self) -> tuple[CallableTableInitializer, ...]: |
85 | 96 | """Slot variables initializers for the optimizer. |
86 | 97 |
|
@@ -199,6 +210,90 @@ def get_optimizer_primitive(self) -> jex.core.Primitive: |
199 | 210 | ) |
200 | 211 |
|
201 | 212 |
|
| 213 | +class LaPropOptimizerSpec(OptimizerSpec): |
| 214 | + """Spec for the LaProp optimizer. |
| 215 | +
|
| 216 | + Laprop decouples momentum and adaptivity in the Adam-style methods, leading to |
| 217 | + improved speed and stability compare to Adam. |
| 218 | + https://arxiv.org/abs/2002.04839 |
| 219 | +
|
| 220 | + Attributes: |
| 221 | + learning_rate: The learning rate for the training variables or embeddings. |
| 222 | + b1: decay rate for the exponentially weighted average of grads. |
| 223 | + b2: decay rate for the exponentially weighted average of squared grads. |
| 224 | + eps: term added to the squared gradient to improve numerical stability. |
| 225 | + rms_clip_threshold: Clipping threshold for RMS. |
| 226 | + initial_slot_value: Initial value for the slot variables. |
| 227 | + """ |
| 228 | + |
| 229 | + def __init__( |
| 230 | + self, |
| 231 | + learning_rate=0.001, |
| 232 | + b1: float = 0.9, |
| 233 | + b2: float = 0.95, |
| 234 | + eps: float = 1e-30, |
| 235 | + rms_clip_threshold: float = 1.0, |
| 236 | + initial_slot_value: float = 0.0, |
| 237 | + ): |
| 238 | + super().__init__( |
| 239 | + learning_rate=learning_rate, |
| 240 | + ) |
| 241 | + self.b1 = b1 |
| 242 | + self.b2 = b2 |
| 243 | + self.eps = eps |
| 244 | + self.rms_clip_threshold = rms_clip_threshold |
| 245 | + self.initial_slot_value = initial_slot_value |
| 246 | + |
| 247 | + def slot_variables_initializers(self) -> tuple[CallableTableInitializer, ...]: |
| 248 | + return LaPropSlotVariables( |
| 249 | + mu=jax.nn.initializers.constant(self.initial_slot_value), |
| 250 | + nu=jax.nn.initializers.constant(self.initial_slot_value), |
| 251 | + ) |
| 252 | + |
| 253 | + def get_decay_rate(self, step: int | None = None) -> jax.Array: |
| 254 | + """Returns the decay rate for the optimizer.""" |
| 255 | + |
| 256 | + if step is None: |
| 257 | + return jnp.array(self.b2, dtype=jnp.float32) |
| 258 | + |
| 259 | + decay_rate = ( |
| 260 | + self.b2 |
| 261 | + * (1.0 - jnp.power(self.b2, step)) |
| 262 | + / ((1.0 - jnp.power(self.b2, step+1.0))) |
| 263 | + ) |
| 264 | + |
| 265 | + return jnp.array(decay_rate, dtype=jnp.float32) |
| 266 | + |
| 267 | + def get_hyperparameters( |
| 268 | + self, step: int | None = None |
| 269 | + ) -> tuple[jax.Array, ...]: |
| 270 | + """Returns the LaProp hyperparameters: (learning_rate, b1, decay_rate, eps).""" |
| 271 | + return ( |
| 272 | + self.get_learning_rate(step), |
| 273 | + jnp.array(self.b1, dtype=jnp.float32), |
| 274 | + self.get_decay_rate(step), |
| 275 | + jnp.array(self.eps, dtype=jnp.float32), |
| 276 | + ) |
| 277 | + |
| 278 | + def __hash__(self) -> int: |
| 279 | + return hash(( |
| 280 | + self.learning_rate, |
| 281 | + self.b1, |
| 282 | + self.b2, |
| 283 | + self.eps, |
| 284 | + self.rms_clip_threshold, |
| 285 | + self.initial_slot_value, |
| 286 | + )) |
| 287 | + |
| 288 | + def short_name(self) -> str: |
| 289 | + return "laprop" |
| 290 | + |
| 291 | + def get_optimizer_primitive(self) -> jex.core.Primitive: |
| 292 | + return ( |
| 293 | + sparse_dense_matmul_grad_with_laprop.tpu_sparse_dense_matmul_grad_with_laprop_primitive |
| 294 | + ) |
| 295 | + |
| 296 | + |
202 | 297 | @dataclasses.dataclass(eq=True, frozen=True, order=True) |
203 | 298 | class FeatureIdTransformation: |
204 | 299 | """Transformation to apply to the input feature ids.""" |
|
0 commit comments