From 2826f11c4b1b3f9f091c3900f1f3405f42a33b1d Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 17 Sep 2024 12:49:52 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 675673632 --- vmoe/initialization/initialization.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/vmoe/initialization/initialization.py b/vmoe/initialization/initialization.py index b778b9c..0bac15e 100644 --- a/vmoe/initialization/initialization.py +++ b/vmoe/initialization/initialization.py @@ -127,6 +127,8 @@ def _get_shape_dtype_struct(item): tree=structure, axis_resources_regexes=axis_resources_regexes or ()) def _array_restore_args(value, spec): + if value is None: + return None if isinstance(value, jax.ShapeDtypeStruct): sharding = NamedSharding(mesh, spec) return orbax_checkpoint.ArrayRestoreArgs( @@ -135,7 +137,11 @@ def _array_restore_args(value, spec): return orbax_checkpoint.RestoreArgs() restore_args = jax.tree_util.tree_map( - _array_restore_args, structure, axis_resources) + _array_restore_args, + structure, + axis_resources, + is_leaf=lambda x: x is None, + ) ckpt = ckptr.restore(directory, restore_args=restore_args) return mapping.map_state_dict(ckpt, target, rules, **map_state_dict_kwargs)