diff --git a/checkpoint/orbax/checkpoint/utils.py b/checkpoint/orbax/checkpoint/utils.py index c694366ae..1170be26c 100644 --- a/checkpoint/orbax/checkpoint/utils.py +++ b/checkpoint/orbax/checkpoint/utils.py @@ -158,6 +158,11 @@ def is_scalar(x): return isinstance(x, (int, float, np.number)) +def _shard_sort_device_key(shard_data): + d = list(shard_data.devices())[0] + return (d.process_index, d.id) + + def fully_replicated_host_local_array_to_global_array( arr: jax.Array, ) -> jax.Array: @@ -184,7 +189,7 @@ def fully_replicated_host_local_array_to_global_array( # pmap-produced Array has a "scrambled" device order. dbs = sorted( [shard.data for shard in arr.addressable_shards], - key=lambda x: list(x.devices())[0].id, + key=_shard_sort_device_key ) return jax.make_array_from_single_device_arrays( global_shape, jax.sharding.NamedSharding(mesh, partition_spec), dbs