|
| 1 | +import dataclasses |
| 2 | +import os |
| 3 | +import tempfile |
1 | 4 | import typing |
2 | 5 | from typing import Any |
3 | 6 |
|
|
10 | 13 | from absl.testing import parameterized |
11 | 14 | from jax.experimental import layout as jax_layout |
12 | 15 | from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec |
13 | | -from jax_tpu_embedding.sparsecore.lib.nn import table_stacking |
| 16 | +from jax_tpu_embedding.sparsecore.lib.nn import ( |
| 17 | + table_stacking as table_stacking_lib, |
| 18 | +) |
14 | 19 | from jax_tpu_embedding.sparsecore.utils import utils as jte_utils |
15 | 20 |
|
16 | 21 | from keras_rs.src.layers.embedding import test_utils as keras_test_utils |
| 22 | +from keras_rs.src.layers.embedding.jax import checkpoint_utils |
17 | 23 | from keras_rs.src.layers.embedding.jax import config_conversion |
18 | 24 | from keras_rs.src.layers.embedding.jax import ( |
19 | 25 | distributed_embedding as jax_distributed_embedding, |
@@ -131,7 +137,7 @@ def test_sharded_matches_unsharded(self): |
131 | 137 | feature_spec.table_spec.name: feature_spec.table_spec |
132 | 138 | for feature_spec in feature_specs |
133 | 139 | } |
134 | | - table_stacking.stack_tables( |
| 140 | + table_stacking_lib.stack_tables( |
135 | 141 | feature_specs, |
136 | 142 | table_names=[table_config.name for table_config in table_configs], |
137 | 143 | global_device_count=device_count, |
@@ -198,7 +204,7 @@ def test_random_shards(self): |
198 | 204 | num_sc_per_device = _num_sparsecores_per_device() |
199 | 205 | num_table_shards = device_count * num_sc_per_device |
200 | 206 |
|
201 | | - table_stacking.stack_tables( |
| 207 | + table_stacking_lib.stack_tables( |
202 | 208 | feature_specs, |
203 | 209 | table_names=[ |
204 | 210 | table_spec.name for table_spec in table_specs.values() |
@@ -257,7 +263,7 @@ def test_compilability(self): |
257 | 263 | num_sc_per_device = _num_sparsecores_per_device() |
258 | 264 | num_table_shards = device_count * num_sc_per_device |
259 | 265 |
|
260 | | - table_stacking.stack_tables( |
| 266 | + table_stacking_lib.stack_tables( |
261 | 267 | feature_specs, |
262 | 268 | table_names=[ |
263 | 269 | table_spec.name for table_spec in table_specs.values() |
@@ -458,6 +464,130 @@ def loss_fn(y_true, y_pred): |
458 | 464 | loss_after = model.evaluate(evaluation_dataset) |
459 | 465 | np.testing.assert_array_less(loss_after, loss_before) |
460 | 466 |
|
| 467 | + @parameterized.product( |
| 468 | + ragged=[False, True], |
| 469 | + target_stacking=[ |
| 470 | + "auto", |
| 471 | + [["table:0", "table:1", "table:2"]], |
| 472 | + ], |
| 473 | + ) |
| 474 | + def test_save_and_restore( |
| 475 | + self, |
| 476 | + ragged: bool, |
| 477 | + target_stacking: str | list[str] | list[list[str]], |
| 478 | + ): |
| 479 | + keras.distribution.set_distribution(keras.distribution.DataParallel()) |
| 480 | + |
| 481 | + table_configs = keras_test_utils.create_random_table_configs( |
| 482 | + max_vocabulary_size=64, |
| 483 | + max_embedding_dim=8, |
| 484 | + optimizer=keras.optimizers.SGD(learning_rate=0.1), |
| 485 | + seed=10, |
| 486 | + ) |
| 487 | + feature_configs = keras_test_utils.create_random_feature_configs( |
| 488 | + table_configs=table_configs, |
| 489 | + batch_size=16, |
| 490 | + seed=20, |
| 491 | + ) |
| 492 | + feature_configs_dict = { |
| 493 | + feature_config.name: feature_config |
| 494 | + for feature_config in feature_configs |
| 495 | + } |
| 496 | + |
| 497 | + # Create tables for generating labels. |
| 498 | + seed = keras.random.SeedGenerator(40) |
| 499 | + tables = { |
| 500 | + table_config.name: keras.random.uniform( |
| 501 | + shape=( |
| 502 | + table_config.vocabulary_size, |
| 503 | + table_config.embedding_dim, |
| 504 | + ), |
| 505 | + minval=-5, |
| 506 | + maxval=5, |
| 507 | + dtype="float32", |
| 508 | + seed=seed, |
| 509 | + ) |
| 510 | + for table_config in table_configs |
| 511 | + } |
| 512 | + |
| 513 | + # Fit and evaluate. |
| 514 | + def loss_fn(y_true, y_pred): |
| 515 | + return jnp.mean(jnp.square(y_true - y_pred)) |
| 516 | + |
| 517 | + embedding_layer_name = "distributed_embedding_chkpt_test" |
| 518 | + layer = jax_distributed_embedding.DistributedEmbedding( |
| 519 | + feature_configs_dict, |
| 520 | + table_stacking=target_stacking, |
| 521 | + name=embedding_layer_name, |
| 522 | + ) |
| 523 | + model = keras.Sequential([layer]) |
| 524 | + model.compile(jit_compile=True, loss=loss_fn) |
| 525 | + |
| 526 | + # Fit model to different dataset. |
| 527 | + training_dataset = keras_test_utils.RandomInputSampleDataset( |
| 528 | + feature_configs_dict, |
| 529 | + tables, |
| 530 | + ragged=ragged, |
| 531 | + num_batches=100, |
| 532 | + seed=42, |
| 533 | + preprocessor=lambda inputs, weights: layer.preprocess( |
| 534 | + inputs, weights, training=True |
| 535 | + ), |
| 536 | + ) |
| 537 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 538 | + chkpt_path = os.path.join(tmp_dir, "checkpoint") |
| 539 | + |
| 540 | + model.fit( |
| 541 | + training_dataset, |
| 542 | + epochs=2, |
| 543 | + steps_per_epoch=1, |
| 544 | + callbacks=[ |
| 545 | + checkpoint_utils.JaxKeras3CheckpointCallback( |
| 546 | + model, |
| 547 | + chkpt_path, |
| 548 | + max_to_keep=1, |
| 549 | + steps_per_epoch=1, |
| 550 | + ) |
| 551 | + ], |
| 552 | + ) |
| 553 | + # Setup a model with a zero initializer but otherwise the same |
| 554 | + # feature configs to test restore. Keep the same embedding layer name to |
| 555 | + # ensure the correct weights are restored. |
| 556 | + feature_configs_with_zero_init = { |
| 557 | + feature_config.name: dataclasses.replace( |
| 558 | + feature_config, |
| 559 | + table=dataclasses.replace( |
| 560 | + feature_config.table, initializer="zeros" |
| 561 | + ), |
| 562 | + ) |
| 563 | + for feature_config in feature_configs |
| 564 | + } |
| 565 | + layer_for_restore = jax_distributed_embedding.DistributedEmbedding( |
| 566 | + feature_configs_with_zero_init, |
| 567 | + table_stacking=target_stacking, |
| 568 | + name=embedding_layer_name, |
| 569 | + ) |
| 570 | + input_shapes = jax.tree.map( |
| 571 | + lambda f: f.input_shape, feature_configs_with_zero_init |
| 572 | + ) |
| 573 | + layer_for_restore.build(input_shapes) |
| 574 | + model_for_restore = keras.Sequential([layer_for_restore]) |
| 575 | + manager_for_restore = checkpoint_utils.JaxKeras3CheckpointManager( |
| 576 | + model_for_restore, |
| 577 | + chkpt_path, |
| 578 | + max_to_keep=1, |
| 579 | + steps_per_epoch=1, |
| 580 | + ) |
| 581 | + model_for_restore.compile(jit_compile=True, loss=loss_fn) |
| 582 | + model_for_restore.build() |
| 583 | + model_for_restore.optimizer.build(model_for_restore.trainable_variables) |
| 584 | + manager_for_restore.restore_state() |
| 585 | + jax.tree.map( |
| 586 | + np.testing.assert_array_equal, |
| 587 | + model.trainable_variables, |
| 588 | + model_for_restore.trainable_variables, |
| 589 | + ) |
| 590 | + |
461 | 591 |
|
462 | 592 | if __name__ == "__main__": |
463 | 593 | absltest.main() |
0 commit comments