From c3b3190174484d5112d8c5a2f41f04d7a34d0fc1 Mon Sep 17 00:00:00 2001 From: Orbax Authors Date: Mon, 8 Jul 2024 10:24:40 -0700 Subject: [PATCH] Sort devices explicitly by process index, then id (as opposed to IDs alone). IDs may be randomly generated, and are not guaranteed to be ordered based on their process index. PiperOrigin-RevId: 650294623 --- checkpoint/orbax/checkpoint/utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/checkpoint/orbax/checkpoint/utils.py b/checkpoint/orbax/checkpoint/utils.py index c694366ae..1170be26c 100644 --- a/checkpoint/orbax/checkpoint/utils.py +++ b/checkpoint/orbax/checkpoint/utils.py @@ -158,6 +158,11 @@ def is_scalar(x): return isinstance(x, (int, float, np.number)) +def _shard_sort_device_key(shard_data): + d = list(shard_data.devices())[0] + return (d.process_index, d.id) + + def fully_replicated_host_local_array_to_global_array( arr: jax.Array, ) -> jax.Array: @@ -184,7 +189,7 @@ def fully_replicated_host_local_array_to_global_array( # pmap-produced Array has a "scrambled" device order. dbs = sorted( [shard.data for shard in arr.addressable_shards], - key=lambda x: list(x.devices())[0].id, + key=_shard_sort_device_key ) return jax.make_array_from_single_device_arrays( global_shape, jax.sharding.NamedSharding(mesh, partition_spec), dbs