From b28852d36d42e15a34ef894584645acb6a383196 Mon Sep 17 00:00:00 2001 From: Orbax Authors Date: Mon, 8 Jul 2024 10:24:40 -0700 Subject: [PATCH] Sort devices explicitly by process index, then id (as opposed to IDs alone). IDs may be randomly generated, and are not guaranteed to be ordered based on their process index. PiperOrigin-RevId: 650294623 --- checkpoint/orbax/checkpoint/utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/checkpoint/orbax/checkpoint/utils.py b/checkpoint/orbax/checkpoint/utils.py index c694366ae..1170be26c 100644 --- a/checkpoint/orbax/checkpoint/utils.py +++ b/checkpoint/orbax/checkpoint/utils.py @@ -158,6 +158,11 @@ def is_scalar(x): return isinstance(x, (int, float, np.number)) +def _shard_sort_device_key(shard_data): + d = list(shard_data.devices())[0] + return (d.process_index, d.id) + + def fully_replicated_host_local_array_to_global_array( arr: jax.Array, ) -> jax.Array: @@ -184,7 +189,7 @@ def fully_replicated_host_local_array_to_global_array( # pmap-produced Array has a "scrambled" device order. dbs = sorted( [shard.data for shard in arr.addressable_shards], - key=lambda x: list(x.devices())[0].id, + key=_shard_sort_device_key ) return jax.make_array_from_single_device_arrays( global_shape, jax.sharding.NamedSharding(mesh, partition_spec), dbs