From 8320329aa5c4d6293730a716015fb2e18ca1cbfc Mon Sep 17 00:00:00 2001 From: Orbax Authors Date: Mon, 8 Jul 2024 10:24:40 -0700 Subject: [PATCH] Sort devices by their implicit order instead of explicitly by id. IDs may be randomly generated, so it's better to rely on the implicit order, which is currently based on (process index, id). PiperOrigin-RevId: 650294623 --- checkpoint/orbax/checkpoint/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/checkpoint/orbax/checkpoint/utils.py b/checkpoint/orbax/checkpoint/utils.py index c694366ae..c46b3aaca 100644 --- a/checkpoint/orbax/checkpoint/utils.py +++ b/checkpoint/orbax/checkpoint/utils.py @@ -184,7 +184,7 @@ def fully_replicated_host_local_array_to_global_array( # pmap-produced Array has a "scrambled" device order. dbs = sorted( [shard.data for shard in arr.addressable_shards], - key=lambda x: list(x.devices())[0].id, + key=lambda x: list(x.devices())[0], ) return jax.make_array_from_single_device_arrays( global_shape, jax.sharding.NamedSharding(mesh, partition_spec), dbs