Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675673632
  • Loading branch information
hawkinsp authored and copybara-github committed Sep 17, 2024
1 parent a13fe02 commit 2826f11
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion vmoe/initialization/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ def _get_shape_dtype_struct(item):
tree=structure, axis_resources_regexes=axis_resources_regexes or ())

def _array_restore_args(value, spec):
if value is None:
return None
if isinstance(value, jax.ShapeDtypeStruct):
sharding = NamedSharding(mesh, spec)
return orbax_checkpoint.ArrayRestoreArgs(
Expand All @@ -135,7 +137,11 @@ def _array_restore_args(value, spec):
return orbax_checkpoint.RestoreArgs()

restore_args = jax.tree_util.tree_map(
_array_restore_args, structure, axis_resources)
_array_restore_args,
structure,
axis_resources,
is_leaf=lambda x: x is None,
)
ckpt = ckptr.restore(directory, restore_args=restore_args)
return mapping.map_state_dict(ckpt, target, rules, **map_state_dict_kwargs)

Expand Down

0 comments on commit 2826f11

Please sign in to comment.