Skip to content

Commit

Permalink
Fix initialization from orbax.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 556785886
  • Loading branch information
jpuigcerver authored and copybara-github committed Aug 14, 2023
1 parent 9505e61 commit 95e34dd
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
22 changes: 20 additions & 2 deletions vmoe/initialization/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
from typing import Any, Optional, Union

from etils import epath
import flax.serialization
import flax.traverse_util
import jax
Expand Down Expand Up @@ -46,6 +47,24 @@
]


class AsyncCheckpointerWithStructure(orbax_checkpoint.AsyncCheckpointer):

def structure(self, directory: epath.PathLike) -> Optional[Any]:
"""See superclass documentation."""
directory = epath.Path(directory)
try:
return self._handler.structure(directory) # pytype: disable=attribute-error
except NotImplementedError:
return


class PyTreeCheckpointHandlerWithStructure(
orbax_checkpoint.PyTreeCheckpointHandler):

def structure(self, directory):
return super()._read_aggregate_file(directory)


def initialize_from_orbax(
*,
target: PyTree,
Expand Down Expand Up @@ -89,8 +108,7 @@ def _get_shape_dtype_struct(item):
return k, jax.ShapeDtypeStruct(shape, jax.numpy.dtype(dtype))
return k, v

ckptr = orbax_checkpoint.AsyncCheckpointer(
orbax_checkpoint.PyTreeCheckpointHandler())
ckptr = AsyncCheckpointerWithStructure(PyTreeCheckpointHandlerWithStructure())

# Restore the structure of the checkpoint.
structure = ckptr.structure(directory)
Expand Down
3 changes: 1 addition & 2 deletions vmoe/initialization/initialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,7 @@ def setUp(self):
'dtype': self.ckpt['foo'][name].dtype.str}, fp)
# Mock the AsyncCheckpointer class.
self.mock_async_checkpointer = self.enter_context(
mock.patch.object(initialization.orbax_checkpoint,
'AsyncCheckpointer'))
mock.patch.object(initialization, 'AsyncCheckpointerWithStructure'))
# Mock the structure() method.
self.mock_structure = self.mock_async_checkpointer.return_value.structure
self.mock_structure.return_value = {
Expand Down

0 comments on commit 95e34dd

Please sign in to comment.