diff --git a/checkpoint/orbax/checkpoint/utils.py b/checkpoint/orbax/checkpoint/utils.py index c694366ae..c46b3aaca 100644 --- a/checkpoint/orbax/checkpoint/utils.py +++ b/checkpoint/orbax/checkpoint/utils.py @@ -184,7 +184,7 @@ def fully_replicated_host_local_array_to_global_array( # pmap-produced Array has a "scrambled" device order. dbs = sorted( [shard.data for shard in arr.addressable_shards], - key=lambda x: list(x.devices())[0].id, + key=lambda x: list(x.devices())[0], ) return jax.make_array_from_single_device_arrays( global_shape, jax.sharding.NamedSharding(mesh, partition_spec), dbs