Skip to content

Commit

Permalink
Write tree metadata when checkpointing.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 577883417
  • Loading branch information
jpuigcerver authored and copybara-github committed Oct 30, 2023
1 parent 44906f2 commit a9b35bc
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion vmoe/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,9 @@ def create_checkpoint_manager(
directory,
{
'state': orbax.checkpoint.AsyncCheckpointer(
orbax.checkpoint.PyTreeCheckpointHandler(),
orbax.checkpoint.PyTreeCheckpointHandler(
write_tree_metadata=True,
),
timeout_secs=wait_seconds,
),
'dataset_iterator': orbax.checkpoint.Checkpointer(
Expand Down

0 comments on commit a9b35bc

Please sign in to comment.