Skip to content

Commit be49522

Browse files
authored
checkpointing with orbax (#122)
1 parent 7c96b83 commit be49522

File tree

2 files changed

+238
-4
lines changed

2 files changed

+238
-4
lines changed
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
"""A Wrapper over orbax CheckpointManager for Keras3 Jax TPU Embeddings."""
2+
3+
from typing import Any
4+
5+
import keras
6+
import orbax.checkpoint as ocp
7+
from etils import epath
8+
9+
10+
class JaxKeras3CheckpointManager(ocp.CheckpointManager):
11+
"""A wrapper over orbax CheckpointManager for Keras3 Jax TPU Embeddings."""
12+
13+
def __init__(
14+
self,
15+
model: keras.Model,
16+
checkpoint_dir: epath.PathLike,
17+
max_to_keep: int,
18+
steps_per_epoch: int = 1,
19+
**kwargs: Any,
20+
):
21+
options = ocp.CheckpointManagerOptions(
22+
max_to_keep=max_to_keep, enable_async_checkpointing=False, **kwargs
23+
)
24+
self._model = model
25+
self._steps_per_epoch = steps_per_epoch
26+
self._checkpoint_dir = checkpoint_dir
27+
super().__init__(checkpoint_dir, options=options)
28+
29+
def _get_state(self) -> tuple[dict[str, Any], Any | None]:
30+
"""Gets the model state and metrics"""
31+
model_state = self._model.get_state_tree()
32+
state = {}
33+
metrics = None
34+
for k, v in model_state.items():
35+
if k == "metrics_variables":
36+
metrics = v
37+
else:
38+
state[k] = v
39+
return state, metrics
40+
41+
def save_state(self, epoch: int) -> None:
42+
"""Saves the model to the checkpoint directory.
43+
44+
Args:
45+
epoch: The epoch number at which the state is saved.
46+
"""
47+
state, metrics_value = self._get_state()
48+
self.save(
49+
epoch * self._steps_per_epoch,
50+
args=ocp.args.StandardSave(item=state),
51+
metrics=metrics_value,
52+
)
53+
54+
def restore_state(self, step: int | None = None) -> None:
55+
"""Restores the model from the checkpoint directory.
56+
57+
Args:
58+
step: The step .number to restore the state from. Default=None
59+
restores the latest step.
60+
"""
61+
if step is None:
62+
step = self.latest_step()
63+
# Restore the model state only, not metrics.
64+
state, _ = self._get_state()
65+
restored_state = self.restore(
66+
step, args=ocp.args.StandardRestore(item=state)
67+
)
68+
self._model.set_state_tree(restored_state)
69+
70+
71+
class JaxKeras3CheckpointCallback(keras.callbacks.Callback):
72+
"""A callback for checkpointing and restoring state using Orbax."""
73+
74+
def __init__(
75+
self,
76+
model: keras.Model,
77+
checkpoint_dir: epath.PathLike,
78+
max_to_keep: int,
79+
steps_per_epoch: int = 1,
80+
**kwargs: Any,
81+
):
82+
if keras.backend.backend() != "jax":
83+
raise ValueError(
84+
"`JaxKeras3CheckpointCallback` is only supported on a "
85+
"`jax` backend."
86+
)
87+
self._checkpoint_manager = JaxKeras3CheckpointManager(
88+
model, checkpoint_dir, max_to_keep, steps_per_epoch, **kwargs
89+
)
90+
91+
def on_train_begin(self, logs: dict[str, Any] | None = None) -> None:
92+
if not self.model.built or not self.model.optimizer.built:
93+
raise ValueError(
94+
"To use `JaxKeras3CheckpointCallback`, your model and "
95+
"optimizer must be built before you call `fit()`."
96+
)
97+
latest_epoch = self._checkpoint_manager.latest_step()
98+
if latest_epoch is not None:
99+
self._checkpoint_manager.restore_state(step=latest_epoch)
100+
101+
def on_epoch_end(
102+
self, epoch: int, logs: dict[str, Any] | None = None
103+
) -> None:
104+
self._checkpoint_manager.save_state(epoch)

keras_rs/src/layers/embedding/jax/distributed_embedding_test.py

Lines changed: 134 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import dataclasses
2+
import os
3+
import tempfile
14
import typing
25
from typing import Any
36

@@ -10,10 +13,13 @@
1013
from absl.testing import parameterized
1114
from jax.experimental import layout as jax_layout
1215
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
13-
from jax_tpu_embedding.sparsecore.lib.nn import table_stacking
16+
from jax_tpu_embedding.sparsecore.lib.nn import (
17+
table_stacking as table_stacking_lib,
18+
)
1419
from jax_tpu_embedding.sparsecore.utils import utils as jte_utils
1520

1621
from keras_rs.src.layers.embedding import test_utils as keras_test_utils
22+
from keras_rs.src.layers.embedding.jax import checkpoint_utils
1723
from keras_rs.src.layers.embedding.jax import config_conversion
1824
from keras_rs.src.layers.embedding.jax import (
1925
distributed_embedding as jax_distributed_embedding,
@@ -131,7 +137,7 @@ def test_sharded_matches_unsharded(self):
131137
feature_spec.table_spec.name: feature_spec.table_spec
132138
for feature_spec in feature_specs
133139
}
134-
table_stacking.stack_tables(
140+
table_stacking_lib.stack_tables(
135141
feature_specs,
136142
table_names=[table_config.name for table_config in table_configs],
137143
global_device_count=device_count,
@@ -198,7 +204,7 @@ def test_random_shards(self):
198204
num_sc_per_device = _num_sparsecores_per_device()
199205
num_table_shards = device_count * num_sc_per_device
200206

201-
table_stacking.stack_tables(
207+
table_stacking_lib.stack_tables(
202208
feature_specs,
203209
table_names=[
204210
table_spec.name for table_spec in table_specs.values()
@@ -257,7 +263,7 @@ def test_compilability(self):
257263
num_sc_per_device = _num_sparsecores_per_device()
258264
num_table_shards = device_count * num_sc_per_device
259265

260-
table_stacking.stack_tables(
266+
table_stacking_lib.stack_tables(
261267
feature_specs,
262268
table_names=[
263269
table_spec.name for table_spec in table_specs.values()
@@ -458,6 +464,130 @@ def loss_fn(y_true, y_pred):
458464
loss_after = model.evaluate(evaluation_dataset)
459465
np.testing.assert_array_less(loss_after, loss_before)
460466

467+
@parameterized.product(
468+
ragged=[False, True],
469+
target_stacking=[
470+
"auto",
471+
[["table:0", "table:1", "table:2"]],
472+
],
473+
)
474+
def test_save_and_restore(
475+
self,
476+
ragged: bool,
477+
target_stacking: str | list[str] | list[list[str]],
478+
):
479+
keras.distribution.set_distribution(keras.distribution.DataParallel())
480+
481+
table_configs = keras_test_utils.create_random_table_configs(
482+
max_vocabulary_size=64,
483+
max_embedding_dim=8,
484+
optimizer=keras.optimizers.SGD(learning_rate=0.1),
485+
seed=10,
486+
)
487+
feature_configs = keras_test_utils.create_random_feature_configs(
488+
table_configs=table_configs,
489+
batch_size=16,
490+
seed=20,
491+
)
492+
feature_configs_dict = {
493+
feature_config.name: feature_config
494+
for feature_config in feature_configs
495+
}
496+
497+
# Create tables for generating labels.
498+
seed = keras.random.SeedGenerator(40)
499+
tables = {
500+
table_config.name: keras.random.uniform(
501+
shape=(
502+
table_config.vocabulary_size,
503+
table_config.embedding_dim,
504+
),
505+
minval=-5,
506+
maxval=5,
507+
dtype="float32",
508+
seed=seed,
509+
)
510+
for table_config in table_configs
511+
}
512+
513+
# Fit and evaluate.
514+
def loss_fn(y_true, y_pred):
515+
return jnp.mean(jnp.square(y_true - y_pred))
516+
517+
embedding_layer_name = "distributed_embedding_chkpt_test"
518+
layer = jax_distributed_embedding.DistributedEmbedding(
519+
feature_configs_dict,
520+
table_stacking=target_stacking,
521+
name=embedding_layer_name,
522+
)
523+
model = keras.Sequential([layer])
524+
model.compile(jit_compile=True, loss=loss_fn)
525+
526+
# Fit model to different dataset.
527+
training_dataset = keras_test_utils.RandomInputSampleDataset(
528+
feature_configs_dict,
529+
tables,
530+
ragged=ragged,
531+
num_batches=100,
532+
seed=42,
533+
preprocessor=lambda inputs, weights: layer.preprocess(
534+
inputs, weights, training=True
535+
),
536+
)
537+
with tempfile.TemporaryDirectory() as tmp_dir:
538+
chkpt_path = os.path.join(tmp_dir, "checkpoint")
539+
540+
model.fit(
541+
training_dataset,
542+
epochs=2,
543+
steps_per_epoch=1,
544+
callbacks=[
545+
checkpoint_utils.JaxKeras3CheckpointCallback(
546+
model,
547+
chkpt_path,
548+
max_to_keep=1,
549+
steps_per_epoch=1,
550+
)
551+
],
552+
)
553+
# Setup a model with a zero initializer but otherwise the same
554+
# feature configs to test restore. Keep the same embedding layer name to
555+
# ensure the correct weights are restored.
556+
feature_configs_with_zero_init = {
557+
feature_config.name: dataclasses.replace(
558+
feature_config,
559+
table=dataclasses.replace(
560+
feature_config.table, initializer="zeros"
561+
),
562+
)
563+
for feature_config in feature_configs
564+
}
565+
layer_for_restore = jax_distributed_embedding.DistributedEmbedding(
566+
feature_configs_with_zero_init,
567+
table_stacking=target_stacking,
568+
name=embedding_layer_name,
569+
)
570+
input_shapes = jax.tree.map(
571+
lambda f: f.input_shape, feature_configs_with_zero_init
572+
)
573+
layer_for_restore.build(input_shapes)
574+
model_for_restore = keras.Sequential([layer_for_restore])
575+
manager_for_restore = checkpoint_utils.JaxKeras3CheckpointManager(
576+
model_for_restore,
577+
chkpt_path,
578+
max_to_keep=1,
579+
steps_per_epoch=1,
580+
)
581+
model_for_restore.compile(jit_compile=True, loss=loss_fn)
582+
model_for_restore.build()
583+
model_for_restore.optimizer.build(model_for_restore.trainable_variables)
584+
manager_for_restore.restore_state()
585+
jax.tree.map(
586+
np.testing.assert_array_equal,
587+
model.trainable_variables,
588+
model_for_restore.trainable_variables,
589+
)
590+
461591

462592
if __name__ == "__main__":
463593
absltest.main()

0 commit comments

Comments
 (0)