Skip to content

Commit

Permalink
Sort devices explicitly by process index, then id (as opposed to IDs …
Browse files Browse the repository at this point in the history
…alone). IDs may be randomly generated, and are not guaranteed to be ordered based on their process index.

PiperOrigin-RevId: 650294623
  • Loading branch information
Orbax Authors committed Jul 12, 2024
1 parent a704408 commit 667a46a
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion checkpoint/orbax/checkpoint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ def is_scalar(x):
return isinstance(x, (int, float, np.number))


def _shard_sort_device_key(shard_data):
d = list(shard_data.devices())[0]
return (d.process_index, d.id)


def fully_replicated_host_local_array_to_global_array(
arr: jax.Array,
) -> jax.Array:
Expand All @@ -184,7 +189,7 @@ def fully_replicated_host_local_array_to_global_array(
# pmap-produced Array has a "scrambled" device order.
dbs = sorted(
[shard.data for shard in arr.addressable_shards],
key=lambda x: list(x.devices())[0].id,
key=_shard_sort_device_key
)
return jax.make_array_from_single_device_arrays(
global_shape, jax.sharding.NamedSharding(mesh, partition_spec), dbs
Expand Down

0 comments on commit 667a46a

Please sign in to comment.