Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal change. #935

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from orbax.checkpoint import pytree_checkpoint_handler
from orbax.checkpoint import type_handlers
from orbax.checkpoint import utils
from orbax.checkpoint.multihost import multislice
from orbax.checkpoint.path import step as step_lib
from typing_extensions import Self # for Python version < 3.11

Expand Down Expand Up @@ -141,50 +142,6 @@ class CheckpointManagerOptions:
async_options: Optional[checkpoint_manager.AsyncOptions] = None


def _process_slice_id(
process_index: int, global_mesh: jax.sharding.Mesh
) -> int:
"""Returns the slice id that the process_index belongs to."""
for slice_id, device_slice in enumerate(global_mesh.devices):
if process_index in _pid_in_slice(device_slice):
return slice_id

return -1


def _pid_in_slice(device_slice: np.ndarray) -> np.ndarray:
pid = np.vectorize(lambda d: d.process_index)
return pid(device_slice)


def in_slice(process_index: int, device_slice: np.ndarray) -> bool:
return process_index in _pid_in_slice(device_slice)


def in_primary_slice(
process_index: int, global_mesh: jax.sharding.Mesh
) -> bool:
"""Returns true if host is in primary slice (the first slice)."""
primary_slice = global_mesh.devices[0]

return in_slice(process_index, primary_slice)


def _unique_processes_from_devices(device_array: np.ndarray) -> Set[int]:
pid = np.vectorize(lambda d: d.process_index)
return set(pid(device_array).flat)


def _local_slice_devices(devices_array: np.ndarray) -> np.ndarray:
for device_slice in devices_array:
if in_slice(multihost.process_index(), device_slice):
return device_slice
raise ValueError(
f'process_index {multihost.process_index()} does not exist in provided'
' `global_mesh`'
)


def _pad_steps(steps, target):
return steps + [-1] * (target - len(steps))

Expand Down Expand Up @@ -260,7 +217,9 @@ def __init__(
async_options=options.async_options,
multiprocessing_options=checkpoint_manager.MultiprocessingOptions(
primary_host=None,
active_processes=_unique_processes_from_devices(device_array),
active_processes=multihost.unique_processes_from_devices(
device_array
),
barrier_sync_key_prefix='local',
),
enable_async_checkpointing=options.enable_async_checkpointing,
Expand All @@ -279,10 +238,7 @@ def __init__(

def _global_list_union_interslice(self, steps: Sequence[int]) -> Set[int]:
barrier_processes = self._options.multiprocessing_options.active_processes
barrier_processes = [
multihost.utils._runtime_to_distributed_process_id(runtime_id) # pylint: disable=protected-access
for runtime_id in barrier_processes
]
barrier_processes = list(barrier_processes)

client = multihost.utils._get_jax_distributed_client() # pylint: disable=protected-access
dir_key = f'steps/{next(_module_unique_count)}/'
Expand Down Expand Up @@ -317,7 +273,6 @@ def _common_steps_global(self, steps: Sequence[int]) -> np.ndarray:
steps: a list of steps known to all hosts on a slice
"""
unioned_steps = self._global_list_union_interslice(steps)
logging.info('After inter-slice broadcast, found steps: %s.', unioned_steps)
return np.asarray(list(unioned_steps))

def common_steps_within_slice(self, steps: Sequence[int]) -> np.ndarray:
Expand All @@ -335,7 +290,7 @@ def common_steps_within_slice(self, steps: Sequence[int]) -> np.ndarray:
steps: a list of known steps on host
"""

devices = _local_slice_devices(self._device_array)
devices = multislice.local_slice_devices(self._device_array)
slice_device_count = devices.size
unioned_steps = _global_list_union(steps, devices)

Expand Down Expand Up @@ -387,11 +342,8 @@ def all_steps(self, read: bool = False) -> Sequence[int]:
A sequence of steps (integers)
"""
local_steps = self.local_host_steps(read)

common_steps_on_slice = self.common_steps_within_slice(local_steps)

steps = self._common_steps_global(common_steps_on_slice)

return [x for x in steps if x != -1]

def latest_step(self) -> Optional[int]:
Expand Down Expand Up @@ -441,7 +393,7 @@ def __init__(
options = options or CheckpointManagerOptions()
self._global_mesh = global_mesh
self._abstract_state = abstract_state
self._slice_id = _process_slice_id(
self._slice_id = multislice.process_slice_id(
multihost.process_index(), self._global_mesh
)

Expand All @@ -450,19 +402,21 @@ def __init__(
self._persistent_directory = persistent_directory
self._options = options
self._metadata = metadata
self._persistent_primary_host = global_mesh.devices[0].flat[0].process_index
self._local_primary_host = (
global_mesh.devices[1].flat[0].process_index
if global_mesh.devices.shape[0] > 1
else None
self._persistent_primary_host = multihost.runtime_to_distributed_process_id(
global_mesh.devices[0].flat[0].process_index
)
self._local_primary_host = None
if global_mesh.devices.shape[0] > 1:
self._local_primary_host = multihost.runtime_to_distributed_process_id(
global_mesh.devices[1].flat[0].process_index
)
if self._local_primary_host is None:
raise AssertionError(
'to use this CheckpointManager, at least 3 data-parallel slices are'
'To use this CheckpointManager, at least 2 data-parallel slices are'
' needed.'
)

self.in_primary_slice = in_primary_slice(
self.in_primary_slice = multislice.in_primary_slice(
multihost.process_index(), global_mesh
)
self._persistent_max_to_keep = self._options.persistent.max_to_keep
Expand All @@ -471,7 +425,7 @@ def __init__(
persistent_multiprocessing_options = (
checkpoint_manager.MultiprocessingOptions(
primary_host=self._persistent_primary_host,
active_processes=_unique_processes_from_devices(
active_processes=multihost.unique_processes_from_devices(
self._global_mesh.devices[0]
),
barrier_sync_key_prefix='persistent',
Expand All @@ -496,6 +450,7 @@ def __init__(
item_handlers=PyTreeCheckpointHandler(
use_ocdbt=True,
use_zarr3=True,
primary_host=self._persistent_primary_host,
),
)
)
Expand Down Expand Up @@ -664,7 +619,7 @@ def should_save(self, step: int) -> bool:
Returns:
True if the checkpoint should be saved.
"""
logging.info('Checking should_save.')
logging.info('Checking should_save at step: %d.', step)
if self.in_primary_slice:
should_save = self._persistent_checkpoint_manager.should_save(step)
else:
Expand Down
4 changes: 4 additions & 0 deletions checkpoint/orbax/checkpoint/multihost/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from orbax.checkpoint.multihost.utils import reached_preemption
from orbax.checkpoint.multihost.utils import sync_global_processes
from orbax.checkpoint.multihost.utils import process_index
from orbax.checkpoint.multihost.utils import unique_processes_from_devices
from orbax.checkpoint.multihost.utils import runtime_to_distributed_process_id

from orbax.checkpoint.multihost.utils import BarrierSyncFn
from orbax.checkpoint.multihost.utils import get_barrier_sync_fn
Expand All @@ -29,3 +31,5 @@

from orbax.checkpoint.multihost.utils import DIRECTORY_CREATION_TIMEOUT
from orbax.checkpoint.multihost.utils import DIRECTORY_DELETION_TIMEOUT

from orbax.checkpoint.multihost import multislice_utils as multislice
58 changes: 58 additions & 0 deletions checkpoint/orbax/checkpoint/multihost/multislice_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright 2024 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Multislice utils."""

import jax
import numpy as np
from orbax.checkpoint.multihost import utils


def process_slice_id(process_index: int, global_mesh: jax.sharding.Mesh) -> int:
"""Returns the slice id that the process_index belongs to."""
for slice_id, device_slice in enumerate(global_mesh.devices):
if process_index in _pid_in_slice(device_slice):
return slice_id

return -1


def _pid_in_slice(device_slice: np.ndarray) -> np.ndarray:
pid = np.vectorize(
lambda d: utils.runtime_to_distributed_process_id(d.process_index)
)
return pid(device_slice)


def in_slice(process_index: int, device_slice: np.ndarray) -> bool:
return process_index in _pid_in_slice(device_slice)


def in_primary_slice(
process_index: int, global_mesh: jax.sharding.Mesh
) -> bool:
"""Returns true if host is in primary slice (the first slice)."""
primary_slice = global_mesh.devices[0]

return in_slice(process_index, primary_slice)


def local_slice_devices(devices_array: np.ndarray) -> np.ndarray:
for device_slice in devices_array:
if in_slice(utils.process_index(), device_slice):
return device_slice
raise ValueError(
f'process_index {utils.process_index()} does not exist in provided'
' `global_mesh`'
)
18 changes: 12 additions & 6 deletions checkpoint/orbax/checkpoint/multihost/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from absl import logging
import jax
from jax.experimental import multihost_utils
import numpy as np

# Default timeout in seconds.
_DEFAULT_BARRIER_TIMEOUT = 1200
Expand Down Expand Up @@ -62,7 +63,7 @@ def initialize_runtime_to_distributed_ids():
logging.info('runtime_to_distributed_id: %s', _RUNTIME_TO_DISTRIBUTED_ID)


def _runtime_to_distributed_process_id(pid: int) -> int:
def runtime_to_distributed_process_id(pid: int) -> int:
"""Converts a distributed process index to a runtime process index."""
if _RUNTIME_TO_DISTRIBUTED_ID is None:
raise ValueError('Please call initialize_runtime_to_distributed_ids()')
Expand Down Expand Up @@ -137,10 +138,8 @@ def get_barrier_sync_fn(
if processes is None:
barrier_processes = None
else:
barrier_processes = [
_runtime_to_distributed_process_id(runtime_id)
for runtime_id in barrier_processes
]
# Don't map ids anymore if we are using distributed ids.
barrier_processes = list(barrier_processes)

def _fn(*, key: str, timeout_ms: int) -> None:
key = _unique_barrier_key(key)
Expand Down Expand Up @@ -230,4 +229,11 @@ def is_primary_host(primary_host: Optional[int]):


def process_index() -> int:
return jax.process_index()
return jax._src.distributed.global_state.process_id # pylint: disable=protected-access


def unique_processes_from_devices(device_array: np.ndarray) -> Set[int]:
pid = np.vectorize(
lambda d: runtime_to_distributed_process_id(d.process_index)
)
return set(pid(device_array).flat)
40 changes: 38 additions & 2 deletions checkpoint/orbax/checkpoint/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import string
import time
import typing
from typing import Any, List, Optional, Tuple
from typing import Any, List, Optional, Set, Tuple
from unittest import mock

from absl import logging
Expand All @@ -38,6 +38,7 @@
from orbax.checkpoint import tree as tree_utils
from orbax.checkpoint import type_handlers
from orbax.checkpoint import utils
from orbax.checkpoint.multihost import multislice


def sync_global_processes(name: str):
Expand Down Expand Up @@ -207,11 +208,46 @@ def setup_replica_sharded_arrays(
return sharded_arrs, mesh, mesh_axes


def get_fake_global_mesh_for_slices(
slice_processes: List[Set[int]],
) -> jax.sharding.Mesh:
"""Creates a "multi-slice" global mesh for testing.

Args:
slice_processes: List of sets of process indices, where each element in the
list is a set of processes that are active in a single slice.

Returns:
A global mesh.
"""
devices = jax.devices()
slice_devices = []
devices_per_slices = None
all_processes = set()
for processes in slice_processes:
all_processes |= processes
slice_devices.append([
d
for d in devices
if multihost.utils.runtime_to_distributed_process_id(d.process_index)
in processes
])
devices_per_slices = devices_per_slices or len(slice_devices[-1])
if len(slice_devices[-1]) != devices_per_slices:
raise ValueError('All slices must have the same number of devices.')
if len(all_processes) != jax.process_count():
raise ValueError('All processes must be accounted for.')

slice_devices = np.asarray(slice_devices)
return jax.sharding.Mesh(slice_devices, ('replica', 'data'))


def select_single_replica(
arrays: List[jax.Array], global_mesh: jax.sharding.Mesh
) -> List[jax.Array]:
"""Returns arrays sharded over single slice."""
slice_devices = global_mesh.devices[0]
slice_devices = multislice.local_slice_devices(global_mesh.devices)
# slice_devices = global_mesh.devices[0]
single_slice_mesh = jax.sharding.Mesh(
slice_devices, global_mesh.axis_names[1:]
)
Expand Down
Loading