diff --git a/checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py index 07efbee08..3a2e5e1e1 100644 --- a/checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py @@ -26,6 +26,7 @@ import os import time from typing import Any, Callable, List, Optional, Tuple, Union +import uuid from absl import logging from etils import epath @@ -35,6 +36,7 @@ from orbax.checkpoint import async_checkpoint_handler from orbax.checkpoint import checkpoint_args from orbax.checkpoint import future +from orbax.checkpoint import proto_checkpoint_handler from orbax.checkpoint import transform_utils from orbax.checkpoint import tree as tree_utils from orbax.checkpoint import type_handlers @@ -43,6 +45,7 @@ from orbax.checkpoint.metadata import value as value_metadata import tensorstore as ts + PyTree = Any TupleKey = Tuple[str, ...] RestoreArgs = type_handlers.RestoreArgs @@ -66,6 +69,7 @@ _CHECKPOINT_FILE = 'checkpoint' + def get_byte_limiter(concurrent_gb: int): async def _create_byte_limiter(): # Wrap creation in async function to avoid issues on python<=3.9. @@ -362,7 +366,8 @@ def __init__( to None, then all hosts will be considered as primary. It's useful in the case that all hosts are only working with local storage. type_handler_registry: a type_handlers.TypeHandlerRegistry. If not - specified, the global type handler registry will be used. + specified, the global type handler registry will be used. # BEGIN + enable_descriptor: If True, logs a Descriptor proto that contains lineage """ self._aggregate_handler = MsgpackHandler(primary_host=primary_host) if aggregate_filename is None: @@ -582,30 +587,28 @@ def _maybe_set_default_save_args(value, args_): logging.debug('param_info: %s', param_infos) logging.debug('save_args: %s', save_args) - metadata_future = None - if utils.is_primary_host(self._primary_host): - metadata_write_start_time = time.time() - metadata_future = await self._write_metadata_file( - directory, item, save_args, self._use_zarr3 - ) - jax.monitoring.record_event_duration_secs( - '/jax/checkpoint/write/async/metadata_write_duration_secs', - time.time() - metadata_write_start_time, - ) + metadata_write_start_time = time.time() + commit_futures.append( + await self._write_metadata_file( + directory, item, save_args, self._use_zarr3 + ) + ) + jax.monitoring.record_event_duration_secs( + '/jax/checkpoint/write/async/metadata_write_duration_secs', + time.time() - metadata_write_start_time, + ) aggregate_file_write_start_time = time.time() - aggregate_commit_future = await self._write_aggregate_file( - directory, item, param_infos, save_args + commit_futures.append( + await self._write_aggregate_file( + directory, item, param_infos, save_args + ) ) jax.monitoring.record_event_duration_secs( '/jax/checkpoint/write/async/aggregate_write_duration_secs', time.time() - aggregate_file_write_start_time, ) - return ( - commit_futures + [aggregate_commit_future] + [metadata_future] - if metadata_future is not None - else commit_futures + [aggregate_commit_future] - ) + return commit_futures def save(self, directory: epath.Path, *args, **kwargs): """Saves the provided item. @@ -840,6 +843,7 @@ def _read_aggregate_file(self, directory: epath.Path) -> PyTree: else: return utils.pytree_structure(directory) + async def _write_metadata_file( self, directory: epath.Path, @@ -847,6 +851,9 @@ async def _write_metadata_file( save_args: PyTree, use_zarr3: bool = False, ) -> future.Future: + if not utils.is_primary_host(self._primary_host): + return future.NoopFuture() + tspec = type_handlers._get_tensorstore_spec( # pylint: disable=protected-access os.fspath(directory), name=METADATA_FILE, use_ocdbt=False )['kvstore']