Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 658538282
  • Loading branch information
Orbax Authors committed Aug 2, 2024
1 parent cbd9d57 commit 0fe2162
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def _write_process_metadata(path: epath.Path, mesh: jax.sharding.Mesh):

if multihost.process_index() == 0:
path.mkdir(parents=False, exist_ok=False)
runtime_to_distributed_ids = multihost.utils.runtime_to_distributed_ids()
runtime_to_distributed_ids = (
emergency_multihost.runtime_to_distributed_ids()
)
(path / _GLOBAL_PROCESS_METADATA_FILE_NAME).write_text(
json.dumps(runtime_to_distributed_ids)
)
Expand Down
50 changes: 38 additions & 12 deletions checkpoint/orbax/checkpoint/experimental/emergency/multihost.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _get_runtime_id_across_restarts(
ValueError:
"""
current_dist_to_runtime_id = _int_list_flip_index_and_value(
multihost_utils.runtime_to_distributed_ids()
runtime_to_distributed_ids()
)
previous_dist_to_runtime_id = _int_list_flip_index_and_value(
previous_runtime_to_dist_id
Expand All @@ -70,23 +70,25 @@ def _get_runtime_id_across_restarts(
return result


def _process_index_from_device_id(device_id: int) -> int:
def process_index_from_device_id(device_id: int) -> int:
"""Get process index from device id."""
if jax.devices()[0].platform == 'gpu':
return device_id // jax.local_device_count()
elif jax.devices()[0].platform == 'tpu':
# Multi-slice TPU workload.
if hasattr(jax.devices()[0], 'slice_index'):
# Note that it is possible for single slice TPU devices to have a slice
# index.
num_slices = max([d.slice_index for d in jax.devices()]) + 1
num_processes_per_slice = jax.process_count() // num_slices
# This is based on how Megascale device ids are assigned.
# See platforms/xla/megascale/runtime/common/multi_slice_topology.h.
slice_id = device_id // 100000 - 1
local_process_id = device_id % 100000 // jax.local_device_count()
return slice_id * num_processes_per_slice + local_process_id
# Multi-slice TPU workload.
if num_slices > 1:
num_processes_per_slice = jax.process_count() // num_slices
# This is based on how Megascale device ids are assigned.
# See platforms/xla/megascale/runtime/common/multi_slice_topology.h.
slice_id = device_id // 100000 - 1
local_process_id = device_id % 100000 // jax.local_device_count()
return slice_id * num_processes_per_slice + local_process_id
# Single slice TPU workload.
else:
return device_id // jax.local_device_count()
return device_id // jax.local_device_count()
# CPU workload.
else:
# This is based on how CPU device ids are assigned.
Expand Down Expand Up @@ -119,7 +121,7 @@ def consistent_restore_mesh(
previous_runtime_to_dist_id
)
new_flattened_mesh_device_ids = [
runtime_id_across_restarts[_process_index_from_device_id(raw_id)]
runtime_id_across_restarts[process_index_from_device_id(raw_id)]
* jax.local_device_count()
+ raw_id % jax.local_device_count()
for raw_id in previous_flattened_mesh_device_ids
Expand All @@ -131,3 +133,27 @@ def consistent_restore_mesh(
user_mesh.devices.shape
)
return jax.sharding.Mesh(new_mesh_devices, user_mesh.axis_names)


def runtime_to_distributed_ids() -> List[int]:
"""Returns the runtime to distributed process id mapping."""
# TODO(b/325293150): Deprecate this after jaxlib contains the fix.
result = multihost_utils.runtime_to_distributed_ids()
runtime_and_distributed_ids_are_the_same = all([
result[i] == i for i in range(len(result))
])

# JAX may choose to overwrite the device process index with the distributed
# process index. In that case, we have to use the device id to infer the real
# device process index. This is a hack, the intent is to remove it once this
# workaround is no longer needed.
if runtime_and_distributed_ids_are_the_same:
result = [-1 for _ in range(jax.process_count())]
devices = jax.devices()
for i in range(0, jax.device_count(), jax.local_device_count()):
result[process_index_from_device_id(devices[i].id)] = devices[
i
].process_index
assert -1 not in result
return result

0 comments on commit 0fe2162

Please sign in to comment.