diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py b/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py index 551e5f2c..a2998c53 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/checkpoint_manager.py @@ -47,6 +47,7 @@ from orbax.checkpoint import pytree_checkpoint_handler from orbax.checkpoint import type_handlers from orbax.checkpoint import utils +from orbax.checkpoint.multihost import multislice from orbax.checkpoint.path import step as step_lib from typing_extensions import Self # for Python version < 3.11 @@ -141,50 +142,6 @@ class CheckpointManagerOptions: async_options: Optional[checkpoint_manager.AsyncOptions] = None -def _process_slice_id( - process_index: int, global_mesh: jax.sharding.Mesh -) -> int: - """Returns the slice id that the process_index belongs to.""" - for slice_id, device_slice in enumerate(global_mesh.devices): - if process_index in _pid_in_slice(device_slice): - return slice_id - - return -1 - - -def _pid_in_slice(device_slice: np.ndarray) -> np.ndarray: - pid = np.vectorize(lambda d: d.process_index) - return pid(device_slice) - - -def in_slice(process_index: int, device_slice: np.ndarray) -> bool: - return process_index in _pid_in_slice(device_slice) - - -def in_primary_slice( - process_index: int, global_mesh: jax.sharding.Mesh -) -> bool: - """Returns true if host is in primary slice (the first slice).""" - primary_slice = global_mesh.devices[0] - - return in_slice(process_index, primary_slice) - - -def _unique_processes_from_devices(device_array: np.ndarray) -> Set[int]: - pid = np.vectorize(lambda d: d.process_index) - return set(pid(device_array).flat) - - -def _local_slice_devices(devices_array: np.ndarray) -> np.ndarray: - for device_slice in devices_array: - if in_slice(multihost.process_index(), device_slice): - return device_slice - raise ValueError( - f'process_index {multihost.process_index()} does not exist in provided' - ' `global_mesh`' - ) - - def _pad_steps(steps, target): return steps + [-1] * (target - len(steps)) @@ -260,7 +217,9 @@ def __init__( async_options=options.async_options, multiprocessing_options=checkpoint_manager.MultiprocessingOptions( primary_host=None, - active_processes=_unique_processes_from_devices(device_array), + active_processes=multihost.unique_processes_from_devices( + device_array + ), barrier_sync_key_prefix='local', ), enable_async_checkpointing=options.enable_async_checkpointing, @@ -279,10 +238,7 @@ def __init__( def _global_list_union_interslice(self, steps: Sequence[int]) -> Set[int]: barrier_processes = self._options.multiprocessing_options.active_processes - barrier_processes = [ - multihost.utils._runtime_to_distributed_process_id(runtime_id) # pylint: disable=protected-access - for runtime_id in barrier_processes - ] + barrier_processes = list(barrier_processes) client = multihost.utils._get_jax_distributed_client() # pylint: disable=protected-access dir_key = f'steps/{next(_module_unique_count)}/' @@ -317,7 +273,6 @@ def _common_steps_global(self, steps: Sequence[int]) -> np.ndarray: steps: a list of steps known to all hosts on a slice """ unioned_steps = self._global_list_union_interslice(steps) - logging.info('After inter-slice broadcast, found steps: %s.', unioned_steps) return np.asarray(list(unioned_steps)) def common_steps_within_slice(self, steps: Sequence[int]) -> np.ndarray: @@ -335,7 +290,7 @@ def common_steps_within_slice(self, steps: Sequence[int]) -> np.ndarray: steps: a list of known steps on host """ - devices = _local_slice_devices(self._device_array) + devices = multislice.local_slice_devices(self._device_array) slice_device_count = devices.size unioned_steps = _global_list_union(steps, devices) @@ -387,11 +342,8 @@ def all_steps(self, read: bool = False) -> Sequence[int]: A sequence of steps (integers) """ local_steps = self.local_host_steps(read) - common_steps_on_slice = self.common_steps_within_slice(local_steps) - steps = self._common_steps_global(common_steps_on_slice) - return [x for x in steps if x != -1] def latest_step(self) -> Optional[int]: @@ -441,7 +393,7 @@ def __init__( options = options or CheckpointManagerOptions() self._global_mesh = global_mesh self._abstract_state = abstract_state - self._slice_id = _process_slice_id( + self._slice_id = multislice.process_slice_id( multihost.process_index(), self._global_mesh ) @@ -450,19 +402,21 @@ def __init__( self._persistent_directory = persistent_directory self._options = options self._metadata = metadata - self._persistent_primary_host = global_mesh.devices[0].flat[0].process_index - self._local_primary_host = ( - global_mesh.devices[1].flat[0].process_index - if global_mesh.devices.shape[0] > 1 - else None + self._persistent_primary_host = multihost.runtime_to_distributed_process_id( + global_mesh.devices[0].flat[0].process_index ) + self._local_primary_host = None + if global_mesh.devices.shape[0] > 1: + self._local_primary_host = multihost.runtime_to_distributed_process_id( + global_mesh.devices[1].flat[0].process_index + ) if self._local_primary_host is None: raise AssertionError( - 'to use this CheckpointManager, at least 3 data-parallel slices are' + 'To use this CheckpointManager, at least 2 data-parallel slices are' ' needed.' ) - self.in_primary_slice = in_primary_slice( + self.in_primary_slice = multislice.in_primary_slice( multihost.process_index(), global_mesh ) self._persistent_max_to_keep = self._options.persistent.max_to_keep @@ -471,7 +425,7 @@ def __init__( persistent_multiprocessing_options = ( checkpoint_manager.MultiprocessingOptions( primary_host=self._persistent_primary_host, - active_processes=_unique_processes_from_devices( + active_processes=multihost.unique_processes_from_devices( self._global_mesh.devices[0] ), barrier_sync_key_prefix='persistent', @@ -496,6 +450,7 @@ def __init__( item_handlers=PyTreeCheckpointHandler( use_ocdbt=True, use_zarr3=True, + primary_host=self._persistent_primary_host, ), ) ) @@ -664,7 +619,7 @@ def should_save(self, step: int) -> bool: Returns: True if the checkpoint should be saved. """ - logging.info('Checking should_save.') + logging.info('Checking should_save at step: %d.', step) if self.in_primary_slice: should_save = self._persistent_checkpoint_manager.should_save(step) else: diff --git a/checkpoint/orbax/checkpoint/multihost/__init__.py b/checkpoint/orbax/checkpoint/multihost/__init__.py index b5bf13f5..aacdc003 100644 --- a/checkpoint/orbax/checkpoint/multihost/__init__.py +++ b/checkpoint/orbax/checkpoint/multihost/__init__.py @@ -21,6 +21,8 @@ from orbax.checkpoint.multihost.utils import reached_preemption from orbax.checkpoint.multihost.utils import sync_global_processes from orbax.checkpoint.multihost.utils import process_index +from orbax.checkpoint.multihost.utils import unique_processes_from_devices +from orbax.checkpoint.multihost.utils import runtime_to_distributed_process_id from orbax.checkpoint.multihost.utils import BarrierSyncFn from orbax.checkpoint.multihost.utils import get_barrier_sync_fn @@ -29,3 +31,5 @@ from orbax.checkpoint.multihost.utils import DIRECTORY_CREATION_TIMEOUT from orbax.checkpoint.multihost.utils import DIRECTORY_DELETION_TIMEOUT + +from orbax.checkpoint.multihost import multislice_utils as multislice diff --git a/checkpoint/orbax/checkpoint/multihost/multislice_utils.py b/checkpoint/orbax/checkpoint/multihost/multislice_utils.py new file mode 100644 index 00000000..02e58b89 --- /dev/null +++ b/checkpoint/orbax/checkpoint/multihost/multislice_utils.py @@ -0,0 +1,58 @@ +# Copyright 2024 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multislice utils.""" + +import jax +import numpy as np +from orbax.checkpoint.multihost import utils + + +def process_slice_id(process_index: int, global_mesh: jax.sharding.Mesh) -> int: + """Returns the slice id that the process_index belongs to.""" + for slice_id, device_slice in enumerate(global_mesh.devices): + if process_index in _pid_in_slice(device_slice): + return slice_id + + return -1 + + +def _pid_in_slice(device_slice: np.ndarray) -> np.ndarray: + pid = np.vectorize( + lambda d: utils.runtime_to_distributed_process_id(d.process_index) + ) + return pid(device_slice) + + +def in_slice(process_index: int, device_slice: np.ndarray) -> bool: + return process_index in _pid_in_slice(device_slice) + + +def in_primary_slice( + process_index: int, global_mesh: jax.sharding.Mesh +) -> bool: + """Returns true if host is in primary slice (the first slice).""" + primary_slice = global_mesh.devices[0] + + return in_slice(process_index, primary_slice) + + +def local_slice_devices(devices_array: np.ndarray) -> np.ndarray: + for device_slice in devices_array: + if in_slice(utils.process_index(), device_slice): + return device_slice + raise ValueError( + f'process_index {utils.process_index()} does not exist in provided' + ' `global_mesh`' + ) diff --git a/checkpoint/orbax/checkpoint/multihost/utils.py b/checkpoint/orbax/checkpoint/multihost/utils.py index d4d2d658..b048b35f 100644 --- a/checkpoint/orbax/checkpoint/multihost/utils.py +++ b/checkpoint/orbax/checkpoint/multihost/utils.py @@ -19,6 +19,7 @@ from absl import logging import jax from jax.experimental import multihost_utils +import numpy as np # Default timeout in seconds. _DEFAULT_BARRIER_TIMEOUT = 1200 @@ -62,7 +63,7 @@ def initialize_runtime_to_distributed_ids(): logging.info('runtime_to_distributed_id: %s', _RUNTIME_TO_DISTRIBUTED_ID) -def _runtime_to_distributed_process_id(pid: int) -> int: +def runtime_to_distributed_process_id(pid: int) -> int: """Converts a distributed process index to a runtime process index.""" if _RUNTIME_TO_DISTRIBUTED_ID is None: raise ValueError('Please call initialize_runtime_to_distributed_ids()') @@ -137,10 +138,8 @@ def get_barrier_sync_fn( if processes is None: barrier_processes = None else: - barrier_processes = [ - _runtime_to_distributed_process_id(runtime_id) - for runtime_id in barrier_processes - ] + # Don't map ids anymore if we are using distributed ids. + barrier_processes = list(barrier_processes) def _fn(*, key: str, timeout_ms: int) -> None: key = _unique_barrier_key(key) @@ -230,4 +229,11 @@ def is_primary_host(primary_host: Optional[int]): def process_index() -> int: - return jax.process_index() + return jax._src.distributed.global_state.process_id # pylint: disable=protected-access + + +def unique_processes_from_devices(device_array: np.ndarray) -> Set[int]: + pid = np.vectorize( + lambda d: runtime_to_distributed_process_id(d.process_index) + ) + return set(pid(device_array).flat) diff --git a/checkpoint/orbax/checkpoint/test_utils.py b/checkpoint/orbax/checkpoint/test_utils.py index bb1847a6..7fef8b86 100644 --- a/checkpoint/orbax/checkpoint/test_utils.py +++ b/checkpoint/orbax/checkpoint/test_utils.py @@ -21,7 +21,7 @@ import string import time import typing -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Set, Tuple from unittest import mock from absl import logging @@ -38,6 +38,7 @@ from orbax.checkpoint import tree as tree_utils from orbax.checkpoint import type_handlers from orbax.checkpoint import utils +from orbax.checkpoint.multihost import multislice def sync_global_processes(name: str): @@ -207,11 +208,46 @@ def setup_replica_sharded_arrays( return sharded_arrs, mesh, mesh_axes +def get_fake_global_mesh_for_slices( + slice_processes: List[Set[int]], +) -> jax.sharding.Mesh: + """Creates a "multi-slice" global mesh for testing. + + Args: + slice_processes: List of sets of process indices, where each element in the + list is a set of processes that are active in a single slice. + + Returns: + A global mesh. + """ + devices = jax.devices() + slice_devices = [] + devices_per_slices = None + all_processes = set() + for processes in slice_processes: + all_processes |= processes + slice_devices.append([ + d + for d in devices + if multihost.utils.runtime_to_distributed_process_id(d.process_index) + in processes + ]) + devices_per_slices = devices_per_slices or len(slice_devices[-1]) + if len(slice_devices[-1]) != devices_per_slices: + raise ValueError('All slices must have the same number of devices.') + if len(all_processes) != jax.process_count(): + raise ValueError('All processes must be accounted for.') + + slice_devices = np.asarray(slice_devices) + return jax.sharding.Mesh(slice_devices, ('replica', 'data')) + + def select_single_replica( arrays: List[jax.Array], global_mesh: jax.sharding.Mesh ) -> List[jax.Array]: """Returns arrays sharded over single slice.""" - slice_devices = global_mesh.devices[0] + slice_devices = multislice.local_slice_devices(global_mesh.devices) + # slice_devices = global_mesh.devices[0] single_slice_mesh = jax.sharding.Mesh( slice_devices, global_mesh.axis_names[1:] )