Skip to content

Commit 5542696

Browse files
nicolagpGoogle-ML-Automation
authored andcommitted
Add LaProp to JAX SC optimizer specs and modify base optimizer to return general hyperparameter array instead of learning_rate only.
PiperOrigin-RevId: 746093821
1 parent bc8655d commit 5542696

File tree

5 files changed

+163
-2
lines changed

5 files changed

+163
-2
lines changed

jax_tpu_embedding/sparsecore/lib/nn/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ pytype_strict_library(
3030
],
3131
deps = [
3232
"//jax_tpu_embedding/sparsecore/lib/core/primitives:sparse_dense_matmul_grad_with_adagrad",
33+
"//jax_tpu_embedding/sparsecore/lib/core/primitives:sparse_dense_matmul_grad_with_laprop",
3334
"//jax_tpu_embedding/sparsecore/lib/core/primitives:sparse_dense_matmul_grad_with_sgd",
3435
pypi_requirement("jax"),
3536
pypi_requirement("jax/extend"),

jax_tpu_embedding/sparsecore/lib/nn/embedding.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -660,8 +660,7 @@ def tpu_sparse_dense_matmul_grad(
660660
embedding_variable = embedding_variables[stacked_table_name]
661661
activation_gradient = gradients[stacked_table_name]
662662
stack_table_spec = stacked_table_specs[stacked_table_name]
663-
learning_rate = stack_table_spec.optimizer.get_learning_rate(step)
664-
hyper_params = [learning_rate]
663+
hyper_params = stack_table_spec.optimizer.get_hyperparameters(step)
665664
# The MLIR computation symbol names need to be different. We attach the
666665
# table name to the symbol name to ensure that.
667666
symbol_name = "{}-{}{}".format(

jax_tpu_embedding/sparsecore/lib/nn/embedding_spec.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import jax.extend as jex
2727
import jax.numpy as jnp
2828
from jax_tpu_embedding.sparsecore.lib.core.primitives import sparse_dense_matmul_grad_with_adagrad
29+
from jax_tpu_embedding.sparsecore.lib.core.primitives import sparse_dense_matmul_grad_with_laprop
2930
from jax_tpu_embedding.sparsecore.lib.core.primitives import sparse_dense_matmul_grad_with_sgd
3031

3132
HyperParameterType: TypeAlias = Callable[[], jax.Array] | float
@@ -45,6 +46,10 @@
4546
"AdagradSlotVariables", ["accumulator"]
4647
)
4748

49+
LaPropSlotVariables = collections.namedtuple(
50+
"LaPropSlotVariables", ["mu", "nu"]
51+
)
52+
4853

4954
# TODO(b/365975374): Create helper functions for generating OptimizerSpecs.
5055
@dataclasses.dataclass(frozen=True, order=True)
@@ -81,6 +86,12 @@ def get_learning_rate(self, step: int | None = None) -> jax.Array:
8186
else:
8287
return jnp.array(self.learning_rate, dtype=jnp.float32)
8388

89+
def get_hyperparameters(
90+
self, step: int | None = None
91+
) -> tuple[jax.Array, ...]:
92+
"""Returns the hyperparameters for the optimizer."""
93+
return (self.get_learning_rate(step),)
94+
8495
def slot_variables_initializers(self) -> tuple[CallableTableInitializer, ...]:
8596
"""Slot variables initializers for the optimizer.
8697
@@ -199,6 +210,90 @@ def get_optimizer_primitive(self) -> jex.core.Primitive:
199210
)
200211

201212

213+
class LaPropOptimizerSpec(OptimizerSpec):
214+
"""Spec for the LaProp optimizer.
215+
216+
Laprop decouples momentum and adaptivity in the Adam-style methods, leading to
217+
improved speed and stability compare to Adam.
218+
https://arxiv.org/abs/2002.04839
219+
220+
Attributes:
221+
learning_rate: The learning rate for the training variables or embeddings.
222+
b1: decay rate for the exponentially weighted average of grads.
223+
b2: decay rate for the exponentially weighted average of squared grads.
224+
eps: term added to the squared gradient to improve numerical stability.
225+
rms_clip_threshold: Clipping threshold for RMS.
226+
initial_slot_value: Initial value for the slot variables.
227+
"""
228+
229+
def __init__(
230+
self,
231+
learning_rate=0.001,
232+
b1: float = 0.9,
233+
b2: float = 0.95,
234+
eps: float = 1e-30,
235+
rms_clip_threshold: float = 1.0,
236+
initial_slot_value: float = 0.0,
237+
):
238+
super().__init__(
239+
learning_rate=learning_rate,
240+
)
241+
self.b1 = b1
242+
self.b2 = b2
243+
self.eps = eps
244+
self.rms_clip_threshold = rms_clip_threshold
245+
self.initial_slot_value = initial_slot_value
246+
247+
def slot_variables_initializers(self) -> tuple[CallableTableInitializer, ...]:
248+
return LaPropSlotVariables(
249+
mu=jax.nn.initializers.constant(self.initial_slot_value),
250+
nu=jax.nn.initializers.constant(self.initial_slot_value),
251+
)
252+
253+
def get_decay_rate(self, step: int | None = None) -> jax.Array:
254+
"""Returns the decay rate for the optimizer."""
255+
256+
if step is None:
257+
return jnp.array(self.b2, dtype=jnp.float32)
258+
259+
decay_rate = (
260+
self.b2
261+
* (1.0 - jnp.power(self.b2, step))
262+
/ ((1.0 - jnp.power(self.b2, step+1.0)))
263+
)
264+
265+
return jnp.array(decay_rate, dtype=jnp.float32)
266+
267+
def get_hyperparameters(
268+
self, step: int | None = None
269+
) -> tuple[jax.Array, ...]:
270+
"""Returns the LaProp hyperparameters: (learning_rate, b1, decay_rate, eps)."""
271+
return (
272+
self.get_learning_rate(step),
273+
jnp.array(self.b1, dtype=jnp.float32),
274+
self.get_decay_rate(step),
275+
jnp.array(self.eps, dtype=jnp.float32),
276+
)
277+
278+
def __hash__(self) -> int:
279+
return hash((
280+
self.learning_rate,
281+
self.b1,
282+
self.b2,
283+
self.eps,
284+
self.rms_clip_threshold,
285+
self.initial_slot_value,
286+
))
287+
288+
def short_name(self) -> str:
289+
return "laprop"
290+
291+
def get_optimizer_primitive(self) -> jex.core.Primitive:
292+
return (
293+
sparse_dense_matmul_grad_with_laprop.tpu_sparse_dense_matmul_grad_with_laprop_primitive
294+
)
295+
296+
202297
@dataclasses.dataclass(eq=True, frozen=True, order=True)
203298
class FeatureIdTransformation:
204299
"""Transformation to apply to the input feature ids."""

jax_tpu_embedding/sparsecore/lib/nn/tests/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ pytype_strict_contrib_test(
184184
deps = [
185185
"//jax_tpu_embedding/sparsecore/lib/nn:embedding_spec",
186186
pypi_requirement("absl/testing:absltest"),
187+
pypi_requirement("jax"),
187188
pypi_requirement("optax/schedules"),
188189
],
189190
)

jax_tpu_embedding/sparsecore/lib/nn/tests/embedding_spec_test.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
"""Tests for embedding spec."""
1515

1616
from absl.testing import absltest
17+
import jax.numpy as jnp
1718
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
1819
from optax import schedules
1920

@@ -72,6 +73,44 @@ def test_compare_adagrad(self):
7273
self.assertEqual(op.learning_rate, 0.1)
7374
self.assertEqual(op.initial_accumulator_value, 0.1)
7475

76+
def test_compare_laprop(self):
77+
self.assertEqual(
78+
embedding_spec.LaPropOptimizerSpec(
79+
learning_rate=0.1,
80+
b1=0.9,
81+
b2=0.95,
82+
eps=1e-30,
83+
rms_clip_threshold=1.0,
84+
initial_slot_value=0.0,
85+
),
86+
embedding_spec.LaPropOptimizerSpec(
87+
learning_rate=0.1,
88+
b1=0.9,
89+
b2=0.95,
90+
eps=1e-30,
91+
rms_clip_threshold=1.0,
92+
initial_slot_value=0.0,
93+
),
94+
)
95+
self.assertNotEqual(
96+
embedding_spec.LaPropOptimizerSpec(
97+
learning_rate=0.1,
98+
b1=0.8,
99+
b2=0.95,
100+
eps=1e-30,
101+
rms_clip_threshold=1.0,
102+
initial_slot_value=0.0,
103+
),
104+
embedding_spec.LaPropOptimizerSpec(
105+
learning_rate=0.1,
106+
b1=0.9,
107+
b2=0.95,
108+
eps=1e-30,
109+
rms_clip_threshold=1.0,
110+
initial_slot_value=0.0,
111+
),
112+
)
113+
75114
def test_learning_rate_callable(self):
76115
def lr():
77116
return 0.1
@@ -90,6 +129,32 @@ def test_learning_rate_schedule(self):
90129
self.assertEqual(op.get_learning_rate(50), 0.55)
91130
self.assertEqual(op.get_learning_rate(100), 0.1)
92131

132+
def test_hyperparameters(self):
133+
op = embedding_spec.AdagradOptimizerSpec(
134+
learning_rate=schedules.linear_schedule(
135+
init_value=1.0, end_value=0.1, transition_steps=100
136+
)
137+
)
138+
self.assertEqual(op.get_hyperparameters(0), (1.0,))
139+
140+
op = embedding_spec.LaPropOptimizerSpec(
141+
learning_rate=schedules.linear_schedule(
142+
init_value=1.0, end_value=0.1, transition_steps=100
143+
),
144+
b1=0.9,
145+
b2=0.95,
146+
eps=1e-30,
147+
rms_clip_threshold=1.0,
148+
initial_slot_value=0.0,
149+
)
150+
expected_hyperparameters = (
151+
jnp.array(1.0, dtype=jnp.float32),
152+
jnp.array(0.9, dtype=jnp.float32),
153+
jnp.array(0.0, dtype=jnp.float32),
154+
jnp.array(1e-30, dtype=jnp.float32),
155+
)
156+
self.assertEqual(op.get_hyperparameters(0), expected_hyperparameters)
157+
93158

94159
if __name__ == "__main__":
95160
absltest.main()

0 commit comments

Comments
 (0)