From 6b8b87b7519f5c68f1595bf8d9c1851235a8e314 Mon Sep 17 00:00:00 2001 From: Colin Gaffney Date: Mon, 8 Jul 2024 12:32:10 -0700 Subject: [PATCH] Stop writing msgpack file for new checkpoints and update empty nodes handling so that it no longer depends on this file. PiperOrigin-RevId: 650338576 --- checkpoint/CHANGELOG.md | 2 + .../base_pytree_checkpoint_handler.py | 686 +++++------------- checkpoint/orbax/checkpoint/checkpointer.py | 2 +- checkpoint/orbax/checkpoint/metadata/tree.py | 94 ++- .../checkpoint/pytree_checkpoint_handler.py | 292 +++++++- .../standard_checkpoint_handler_test_utils.py | 18 +- 6 files changed, 529 insertions(+), 565 deletions(-) diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 0d2f69362..219964ce1 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -59,6 +59,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 transactions for the OCDBT storage format, and specify the new `can_reference_source_data_indefinitely=True` option to avoid a redundant copy when writing into the TensorStore chunk cache. +- Stop writing msgpack file for new checkpoints and update empty nodes handling +so that it no longer depends on this file. ## [0.5.15] - 2024-05-31 diff --git a/checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py index 07efbee08..441e2c247 100644 --- a/checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py @@ -20,29 +20,26 @@ """ import asyncio -import collections import dataclasses import json import os import time -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union from absl import logging from etils import epath import jax -import numpy as np -from orbax.checkpoint import aggregate_handlers from orbax.checkpoint import async_checkpoint_handler from orbax.checkpoint import checkpoint_args from orbax.checkpoint import future -from orbax.checkpoint import transform_utils +from orbax.checkpoint import multihost from orbax.checkpoint import tree as tree_utils from orbax.checkpoint import type_handlers from orbax.checkpoint import utils from orbax.checkpoint.metadata import tree as tree_metadata -from orbax.checkpoint.metadata import value as value_metadata import tensorstore as ts + PyTree = Any TupleKey = Tuple[str, ...] RestoreArgs = type_handlers.RestoreArgs @@ -51,19 +48,13 @@ ParamInfo = type_handlers.ParamInfo TypeHandler = type_handlers.TypeHandler TypeHandlerRegistry = type_handlers.TypeHandlerRegistry -AggregateHandler = aggregate_handlers.AggregateHandler -MsgpackHandler = aggregate_handlers.MsgpackHandler -LegacyTransformFn = Callable[[PyTree, PyTree, PyTree], Tuple[PyTree, PyTree]] -Transform = transform_utils.Transform -RestoreTransform = transform_utils.RestoreTransform + # TODO(b/298487158) Clean up protected access. LimitInFlightBytes = type_handlers.LimitInFlightBytes CheckpointArgs = checkpoint_args.CheckpointArgs register_with_handler = checkpoint_args.register_with_handler - METADATA_FILE = '_METADATA' -_CHECKPOINT_FILE = 'checkpoint' def get_byte_limiter(concurrent_gb: int): @@ -86,39 +77,12 @@ async def _create_param_save_dir(param_info: ParamInfo, args: SaveArgs): return # TODO(b/273803615): Note that keys with slashes ('/', generated by Haiku, # for example) will result in the creation of nested sub-directories, rather - # than flat parameter directories like for a standard nested PyTree. Ensure - # `use_ocdbt=True` in this scenario. + # than flat parameter directories like for a standard nested PyTree. This + # discrepancy, while potentially problematic, will not be addressed since we + # anticipate moving fully to OCDBT within a quarter or two. await utils.async_makedirs(path, parents=True) -def _maybe_set_default_restore_args(args): - if isinstance(args, RestoreArgs): - return args - return RestoreArgs(restore_type=None) - - -def _try_array_cast(arr, dtype): - if dtype is not None: - if utils.is_scalar(arr): - arr = np.asarray(arr).astype(dtype).item() - else: - if hasattr(arr, 'astype'): - arr = arr.astype(dtype) - return arr - - -def _maybe_shard_array(value, args): - if hasattr(value, 'reshape') and isinstance(args, ArrayRestoreArgs): - value = value.reshape(args.global_shape) - sharding = args.sharding or jax.sharding.NamedSharding( - args.mesh, args.mesh_axes - ) - value = jax.make_array_from_callback( - value.shape, sharding, lambda idx: value[idx] - ) - return value - - def get_param_names(item: PyTree) -> PyTree: """Gets parameter names for PyTree elements.""" @@ -132,104 +96,6 @@ def _param_name_from_keypath(keypath: Tuple[Any, ...]) -> str: ) -def _keystr(key: Tuple[Any, ...]) -> str: - return '/'.join(key) - - -@dataclasses.dataclass -class _InternalValueMetadata: - restore_type: Optional[str] - skip_deserialize: bool = False - aggregate_value: Optional[Any] = None - - -def _get_restore_parameters( - directory: epath.Path, - structure: PyTree, - restore_args: Optional[PyTree], - byte_limiter: Optional[LimitInFlightBytes] = None, - use_zarr3: bool = False, -) -> Tuple[PyTree, PyTree]: - """Construct parameters needed for restoration. - - param_infos are - constructed from the structure of the original checkpoint, and restore_args - are serialized to a tree structure compatible with param_infos and structure. - - Args: - directory: Checkpoint directory. - structure: The structure of the original checkpoint. - restore_args: User-provided restoration arguments. If None, they were not - provided. Otherwise, the tree has the same structure as the desired output - tree. - byte_limiter: A _LimitInFlightBytes object. - use_zarr3: If True, use Zarr ver3 otherwise Zarr ver2 - - Returns: - Tuple of param_infos, and restore_args. - """ - flat_structure = tree_utils.to_flat_dict(structure, keep_empty_nodes=True) - if restore_args is None: - restore_args = jax.tree.map(lambda x: RestoreArgs(), structure) - flat_param_infos = {} - is_ocdbt_checkpoint = type_handlers.is_ocdbt_checkpoint(directory) - ts_context = type_handlers.get_ts_context() - - def _get_param_info( - nested_name: Tuple[str, ...], - meta: _InternalValueMetadata, - ) -> Union[ParamInfo, Any]: - if type_handlers.is_supported_empty_aggregation_type(meta): - # Empty node, ParamInfo should not be returned. - return meta - name = '.'.join(nested_name) - return ParamInfo( - name=name, - path=directory / name, - parent_dir=directory, - skip_deserialize=meta.skip_deserialize, - is_ocdbt_checkpoint=is_ocdbt_checkpoint, - byte_limiter=byte_limiter, - use_zarr3=use_zarr3, - ts_context=ts_context, - ) - - for key, meta in flat_structure.items(): - flat_param_infos[key] = _get_param_info(key, meta) - restore_args = tree_utils.serialize_tree(restore_args, keep_empty_nodes=True) - return ( - tree_utils.from_flat_dict(flat_param_infos, target=structure), - restore_args, - ) - - -def _get_tree_for_aggregation(param_infos, save_args, item): - """Get tree for aggregated checkpoint.""" - - # TODO(b/283164080): These type checks result in logic from the lower layer - # (TypeHandler/AggregateHandler) leaking into the upper layer - # (CheckpointHandler). Ideally, AggregateHandler could define its own - # supported values and error conditions. - def _get_leaf_for_aggregation(param_info, arg, value): - if arg.aggregate: # Param was aggregated, return value after cast. - if isinstance(value, jax.Array) and not value.is_fully_replicated: - raise ValueError( - 'jax.Array must be fully replicated to be saved in aggregate file.' - ) - if not type_handlers.is_supported_aggregation_type(value): - # Not an error because users' training states often have a bunch of - # random unserializable objects in them (empty states, optimizer - # objects, etc.). - value = None - return _try_array_cast(value, arg.dtype) - else: # Placeholder string for non-aggregated value. - return utils.leaf_placeholder(param_info.name) - - return jax.tree.map( - _get_leaf_for_aggregation, param_infos, save_args, item - ) - - @dataclasses.dataclass class _BatchRequest: """Represents a a request for batched serialization or deserialization. @@ -259,7 +125,7 @@ def __post_init__(self): raise AssertionError('Found `_BatchRequest` with mismatched parameters.') -def _batched_serialization_requests( +def batched_serialization_requests( tree: PyTree, param_infos: PyTree, args: PyTree, @@ -267,28 +133,37 @@ def _batched_serialization_requests( ) -> List[_BatchRequest]: """Gets a list of batched serialization or deserialization requests.""" grouped = {} - def _group_value( keypath: Tuple[Any, ...], info: ParamInfo, - value: Union[Any, _InternalValueMetadata], - arg: RestoreArgs, + value: Union[Any, tree_metadata.ValueMetadataEntry], + arg: Union[SaveArgs, RestoreArgs], ): nonlocal grouped tuple_key = tree_utils.tuple_path_from_keypath(keypath) - # Exclude from serialize/deserialize with TypeHandler if aggregated. if info.skip_deserialize: return if isinstance(arg, RestoreArgs): - assert isinstance(value, _InternalValueMetadata) - restore_type = value.restore_type - if arg.restore_type is not None: - # Give user the chance to override restore_type if they want. - restore_type = arg.restore_type - type_for_registry_lookup = restore_type - else: + assert isinstance(value, tree_metadata.ValueMetadataEntry), type(value) + metadata_restore_type = value.value_type + requested_restore_type = arg.restore_type or metadata_restore_type + # TODO(cpgaffney): Add a warning message if the requested_restore_type + # is not the same as the metadata_restore_type. + if type_handlers.is_empty_typestr(requested_restore_type): + # Skip deserialization of empty node using TypeHandler. + return + type_for_registry_lookup = requested_restore_type + elif isinstance(arg, SaveArgs): + # Skip serialization of empty node using TypeHandler. + if tree_utils.is_empty_node(value): + return type_for_registry_lookup = type(value) + else: + raise AssertionError( + f'Expected `RestoreArgs` or `SaveArgs`. Got {type(arg)}.' + ) + try: handler = registry.get(type_for_registry_lookup) except ValueError as e: @@ -318,29 +193,56 @@ def _group_value( return list(grouped.values()) +def _fill_missing_save_or_restore_args( + item: PyTree, args: Optional[PyTree], *, mode: str +) -> PyTree: + """Fills in missing values in the tree of SaveArgs or RestoreArgs. + + Values may be "missing" because of empty nodes in `item`. After returning, all + keys in `item`, with empty nodes or not, will have a corresponding value + in the result. + + Args: + item: tree to save or target to restore. + args: tree of SaveArgs or RestoreArgs. May be None, if the user did not + provide it. + mode: 'save' or 'restore'. + + Returns: + A tree of SaveArgs or RestoreArgs with missing values filled in. + """ + + # Because of empty states, the user-provided args may not contain + # all necessary arguments. These should be filled in with default args. + def _maybe_set_default_save_args(_, leaf_args): + if isinstance(leaf_args, (SaveArgs, RestoreArgs)): + return leaf_args + elif mode == 'save': + return SaveArgs() + elif mode == 'restore': + return RestoreArgs() + else: + raise ValueError(f'Unknown mode: {mode}.') + + return jax.tree_util.tree_map( + _maybe_set_default_save_args, + item, + item if args is None else args, + is_leaf=utils.is_empty_or_leaf, + ) + + class BasePyTreeCheckpointHandler( async_checkpoint_handler.AsyncCheckpointHandler ): """A CheckpointHandler implementation for any PyTree structure. - See JAX documentation for more information on what consistutes a "PyTree". - This handler is capable of saving and restoring any leaf object for which a - `TypeHandler` (see documentation) is registered. By default, `TypeHandler`s - for standard types like `np.ndarray`, `jax.Array`, Python scalars, and others - are registered. - - As with all `CheckpointHandler` subclasses, `BasePyTreeCheckpointHandler` - should only be used in conjunction with a `Checkpointer` (or subclass). - By itself, the `CheckpointHandler` is non-atomic. - - Example:: - - ckptr = Checkpointer(BasePyTreeCheckpointHandler()) + Largely serves as the implementation for `PyTreeCheckpointHandler`. Users are + advised not to use this class directly. """ def __init__( self, - aggregate_filename: Optional[str] = None, concurrent_gb: int = 96, use_ocdbt: bool = True, use_zarr3: bool = False, @@ -350,24 +252,16 @@ def __init__( """Creates BasePyTreeCheckpointHandler. Args: - aggregate_filename: name that the aggregated checkpoint should be saved - as. concurrent_gb: max concurrent GB that are allowed to be read. Can help to reduce the possibility of OOM's when large checkpoints are restored. - use_ocdbt: enables Tensorstore OCDBT driver. This option allows using a - different checkpoint format which is faster to read and write, as well - as more space efficient. - use_zarr3: If True, use Zarr ver3 otherwise Zarr ver2 - primary_host: the host id of the primary host. Default to 0. If it's set - to None, then all hosts will be considered as primary. It's useful in + use_ocdbt: Whether to use OCDBT format for saving. + use_zarr3: If True, use Zarr ver3 otherwise Zarr ver2. + primary_host: the host id of the primary host. Default to 0. If it's set + to None, then all hosts will be considered as primary. It's useful in the case that all hosts are only working with local storage. type_handler_registry: a type_handlers.TypeHandlerRegistry. If not specified, the global type handler registry will be used. """ - self._aggregate_handler = MsgpackHandler(primary_host=primary_host) - if aggregate_filename is None: - aggregate_filename = _CHECKPOINT_FILE - self._aggregate_filename = aggregate_filename self._concurrent_gb = concurrent_gb self._use_ocdbt = use_ocdbt self._use_zarr3 = use_zarr3 @@ -375,36 +269,24 @@ def __init__( self._type_handler_registry = type_handler_registry - if self._use_ocdbt: - jax.monitoring.record_event( - '/jax/orbax/pytree_checkpoint_handler/init/ocdbt' - ) + jax.monitoring.record_event( + '/jax/orbax/pytree_checkpoint_handler/init/ocdbt' + ) def get_param_names(self, item: PyTree) -> PyTree: """Gets parameter names for PyTree elements.""" return get_param_names(item) - def _skip_deserialize(self, value: Any, args: SaveArgs) -> bool: - """Returns True if _METADATA write is enabled and value is []/{}/None.""" - if type_handlers.is_supported_empty_aggregation_type(value): - # Skip deser if value is empty ([], {}, None) and _METADATA is enabled. - # We don't want to write TypeHandlers for empty values, so will simply - # identify them in metadata and skip deser. - return True - else: - # Follow aggregate based flow: requires TypeHandler registry if - # aggregate=False. Empty values will raise registry error if - # aggregate=False. Users will be prompted to enable _METADATA to avoid - # these errors for empty values. - return args.aggregate - def _get_param_infos( self, item: PyTree, directory: epath.Path, - save_args: PyTree, + *, + use_ocdbt: bool = True, + use_zarr3: Optional[bool] = None, ocdbt_target_data_file_size: Optional[int] = None, - ) -> Tuple[PyTree, bool]: + byte_limiter: Optional[LimitInFlightBytes] = None, + ) -> PyTree: """Returns parameter information for elements in `item`. At minimum, this method should extract the names of each parameter for @@ -413,51 +295,46 @@ def _get_param_infos( Args: item: a PyTree to extract information from. directory: a directory where checkpoint files are located. - save_args: PyTree matching item containing SaveArgs. + use_ocdbt: Whether to use OCDBT for writing or reading. + use_zarr3: Whether to use zarr3. ocdbt_target_data_file_size: Specifies the target size (in bytes) of each OCDBT data file. + byte_limiter: LimitInFlightBytes object. Returns: - A PyTree matching `item` of ParamInfo, and a bool indicating whether all - parameters were aggregated. The bool can enable us to skip some steps - later, potentially saving time. + A PyTree matching `item` of ParamInfo. """ if not item: raise ValueError('Found empty item') + if use_zarr3 is None: + use_zarr3 = self._use_zarr3 names = self.get_param_names(item) - all_params_aggregated = True ts_context = type_handlers.get_ts_context() - def _param_info(value, name, args): - nonlocal all_params_aggregated - all_params_aggregated &= args.aggregate - skip_deserialize = self._skip_deserialize(value, args) + def _param_info(name, value): + skip_deserialize = False + if isinstance(value, tree_metadata.ValueMetadataEntry): + skip_deserialize = value.skip_deserialize return ParamInfo( name=name, path=(directory / name), parent_dir=directory, skip_deserialize=skip_deserialize, - is_ocdbt_checkpoint=self._use_ocdbt, - use_zarr3=self._use_zarr3, + is_ocdbt_checkpoint=use_ocdbt, + use_zarr3=use_zarr3, ocdbt_target_data_file_size=ocdbt_target_data_file_size, + byte_limiter=byte_limiter, ts_context=ts_context, ) - return ( - jax.tree.map( - _param_info, - item, - names, - save_args, - is_leaf=tree_utils.is_empty_or_leaf, - ), - all_params_aggregated, + return jax.tree.map( + _param_info, names, item, is_leaf=utils.is_empty_or_leaf ) async def async_save( self, directory: epath.Path, - args: Optional['BasePyTreeSaveArgs'] = None, + args: 'BasePyTreeSaveArgs', ) -> Optional[List[future.Future]]: """Saves a PyTree to a given directory. @@ -466,12 +343,8 @@ async def async_save( constructor. Standard supported types include Python scalars, `np.ndarray`, `jax.Array`, and strings. - After saving, all files will be located in "directory/". The exact files - that are saved depend on the specific combination of options, including - `use_ocdbt`. A JSON metadata file will be present to store the - tree structure. - In addition, a msgpack file may be present, allowing users to store - aggregated values (see below). + After saving, all files will be located in "directory/". + A JSON metadata file will be present to store the tree structure. Example usage:: @@ -488,12 +361,6 @@ async def async_save( } # Note: save_args may be None if no customization is desired for saved # parameters. - # In this case, we "aggregate" small parameters into a single file to - # allow for greater file read/write efficiency (and potentially less) - # wasted space). With OCDBT format active, this parameter is obsolete. - save_args = - jax.tree.map( - lambda x: SaveArgs(aggregate=x.size < some_size), item) # Eventually calls through to `async_save`. ckptr.save(path, item, save_args) @@ -505,45 +372,25 @@ async def async_save( A Future that will commit the data to `directory` when awaited. Copying the data from its source will be awaited in this function. """ - args = args or BasePyTreeSaveArgs() item = args.item save_args = args.save_args ocdbt_target_data_file_size = args.ocdbt_target_data_file_size - if ocdbt_target_data_file_size is not None and not self._use_zarr3: raise ValueError('`ocdbt_target_data_file_size` only works with Zarr3') - # Because of empty states, the user-provided args may not contain - # all necessary arguments. These should be filled in with default args. - def _maybe_set_default_save_args(value, args_): - # If already set, return. - if isinstance(args_, SaveArgs): - return args_ - if type_handlers.is_supported_empty_aggregation_type(value): - # If _METADATA is enabled and value is empty ([], {}, None) then stop - # aggregating for a smooth SaveArgs.aggregate deprecation. Otherwise, we - # will need to write TypeHandlers for empty values. - return SaveArgs(aggregate=False) - # Empty values will still raise TypeHandler registry error if _METADATA is - # disabled. We will prompt users to enable _METADATA to avoid this error. - aggregate = not self._type_handler_registry.has(type(value)) - return SaveArgs(aggregate=aggregate) - - save_args = jax.tree.map( - _maybe_set_default_save_args, # pylint: disable=protected-access + save_args = _fill_missing_save_or_restore_args(item, save_args, mode='save') + param_infos = self._get_param_infos( item, - item if save_args is None else save_args, - is_leaf=tree_utils.is_empty_or_leaf, - ) - param_infos, all_params_aggregated = self._get_param_infos( - item, directory, save_args, ocdbt_target_data_file_size + directory, + use_ocdbt=self._use_ocdbt, + ocdbt_target_data_file_size=ocdbt_target_data_file_size, ) assert all( leaf.parent_dir == directory for leaf in jax.tree.leaves(param_infos) ) - if not self._use_ocdbt and not all_params_aggregated: - if utils.is_primary_host(self._primary_host): + if not self._use_ocdbt: + if multihost.is_primary_host(self._primary_host): # Create directories in parallel. await asyncio.gather( *jax.tree.flatten( @@ -554,58 +401,39 @@ def _maybe_set_default_save_args(value, args_): ) )[0] ) - utils.sync_global_processes( - 'BasePyTreeCheckpointHandler:create_param_save_dirs' + multihost.sync_global_processes( + 'PyTreeCheckpointHandler:create_param_save_dirs' ) - - if all_params_aggregated: - commit_futures = [] - else: - serialize_ops = [] - batch_requests = _batched_serialization_requests( - item, - param_infos, - save_args, - self._type_handler_registry, - ) - for request in batch_requests: - serialize_ops += [ - request.handler.serialize( - request.values, request.infos, request.args - ) - ] - # Await copy futures. Returns list of lists. - commit_futures = await asyncio.gather(*serialize_ops) - commit_futures, _ = jax.tree.flatten(commit_futures) + serialize_ops = [] + batch_requests = batched_serialization_requests( + item, + param_infos, + save_args, + self._type_handler_registry, + ) + for request in batch_requests: + serialize_ops += [ + request.handler.serialize(request.values, request.infos, request.args) + ] + # Await copy futures. Returns list of lists. + commit_futures = await asyncio.gather(*serialize_ops) + commit_futures, _ = jax.tree.flatten(commit_futures) if logging.level_debug(): logging.debug('param_info: %s', param_infos) logging.debug('save_args: %s', save_args) - metadata_future = None - if utils.is_primary_host(self._primary_host): + if multihost.is_primary_host(self._primary_host): metadata_write_start_time = time.time() metadata_future = await self._write_metadata_file( directory, item, save_args, self._use_zarr3 ) + commit_futures += [metadata_future] jax.monitoring.record_event_duration_secs( '/jax/checkpoint/write/async/metadata_write_duration_secs', time.time() - metadata_write_start_time, ) - - aggregate_file_write_start_time = time.time() - aggregate_commit_future = await self._write_aggregate_file( - directory, item, param_infos, save_args - ) - jax.monitoring.record_event_duration_secs( - '/jax/checkpoint/write/async/aggregate_write_duration_secs', - time.time() - aggregate_file_write_start_time, - ) - return ( - commit_futures + [aggregate_commit_future] + [metadata_future] - if metadata_future is not None - else commit_futures + [aggregate_commit_future] - ) + return commit_futures def save(self, directory: epath.Path, *args, **kwargs): """Saves the provided item. @@ -631,26 +459,16 @@ async def async_save(*args, **kwargs): asyncio.run(async_save(directory, *args, **kwargs)) async def _maybe_deserialize( - self, structure: PyTree, param_infos: PyTree, restore_args: PyTree + self, + item: PyTree, + metadata: PyTree, + param_infos: PyTree, + restore_args: PyTree, ) -> PyTree: - """Deserializes values or gets them from the aggregate file.""" - - # Handle parameters from aggregate file. - def _process_aggregated_value(info, meta, args): - value = meta.aggregate_value - if info.skip_deserialize: - value = _try_array_cast(value, args.dtype) - value = _maybe_shard_array(value, args) - return value - - flat_aggregate = tree_utils.to_flat_dict( - jax.tree.map( - _process_aggregated_value, param_infos, structure, restore_args - ), - ) - - batch_requests = _batched_serialization_requests( - structure, + """Deserializes values or skips.""" + flat_metadata = tree_utils.to_flat_dict(metadata) + batch_requests = batched_serialization_requests( + metadata, param_infos, restore_args, self._type_handler_registry, @@ -667,11 +485,16 @@ def _process_aggregated_value(info, meta, args): for request, deserialized in zip(batch_requests, deserialized_batches): for key, value in zip(request.keys, deserialized): flat_restored[key] = value - # Add in any values which were not deserialized, coming from aggregate file. - for key in flat_aggregate.keys(): + # Add in empty nodes from the metadata tree. + for key in flat_metadata.keys(): if key not in flat_restored: - flat_restored[key] = flat_aggregate[key] - return tree_utils.from_flat_dict(flat_restored, target=structure) + flat_restored[key] = type_handlers.get_empty_value_from_typestr( + flat_metadata[key].value_type + ) + # Restore using `item` as the target structure. If there are any custom + # nodes (e.g. optax.EmptyState), these will replace None values in + # flat_restored. + return tree_utils.from_flat_dict(flat_restored, target=item) def restore( self, @@ -763,47 +586,43 @@ class TrainState: args = args or BasePyTreeRestoreArgs() item = args.item restore_args = args.restore_args + logging.debug('directory=%s, restore_args=%s', directory, restore_args) if not directory.exists(): raise FileNotFoundError( f'Requested directory for restore does not exist at {directory}' ) byte_limiter = get_byte_limiter(self._concurrent_gb) - structure, use_zarr3_metadata = self._get_internal_metadata(directory) - # `checkpoint_restore_args` has a structure relative to the checkpoint, - # while `restore_args` remains structured relative to the output. - param_infos, checkpoint_restore_args = _get_restore_parameters( - directory, - structure, - restore_args, - byte_limiter=byte_limiter, - use_zarr3=use_zarr3_metadata + metadata = self._read_metadata_file(directory) + use_zarr3_metadata = metadata.use_zarr3 + metadata = metadata.as_nested_tree(keep_empty_nodes=True) + if item is None: + item = metadata + restore_args = _fill_missing_save_or_restore_args( + item, restore_args, mode='restore' + ) + restore_args = tree_utils.serialize_tree( + restore_args, keep_empty_nodes=True + ) + use_zarr3 = ( + use_zarr3_metadata if use_zarr3_metadata is not None - else self._use_zarr3, + else self._use_zarr3 ) - - def _maybe_set_default_restore_types( - meta: _InternalValueMetadata, arg: RestoreArgs - ): - if not meta.skip_deserialize and meta.restore_type is None: - return dataclasses.replace( - meta, restore_type=type_handlers.default_restore_type(arg) - ) - return meta - - # If metadata file was missing in the checkpoint, we need to decide - # restore_type based on RestoreArgs. - structure = jax.tree.map( - _maybe_set_default_restore_types, structure, checkpoint_restore_args + param_infos = self._get_param_infos( + metadata, + directory, + use_ocdbt=type_handlers.is_ocdbt_checkpoint(directory), + use_zarr3=use_zarr3, + byte_limiter=byte_limiter, ) - restored_item = asyncio.run( - self._maybe_deserialize(structure, param_infos, checkpoint_restore_args) + self._maybe_deserialize(item, metadata, param_infos, restore_args) ) if logging.level_debug(): logging.debug('param_infos: %s', param_infos) - logging.debug('checkpoint_restore_args: %s', checkpoint_restore_args) + logging.debug('checkpoint_restore_args: %s', restore_args) logging.debug( 'restored_item: %s', jax.tree.structure(restored_item) ) @@ -812,34 +631,8 @@ def _maybe_set_default_restore_types( json.dumps(ts.experimental_collect_matching_metrics('/tensorstore/')), ) - if item is not None: - return tree_utils.deserialize_tree(restored_item, item) return restored_item - async def _write_aggregate_file( - self, - directory: epath.Path, - item: PyTree, - param_infos: PyTree, - save_args: PyTree, - ) -> future.Future: - ser_item = _get_tree_for_aggregation(param_infos, save_args, item) - return await self._aggregate_handler.serialize( - directory / self._aggregate_filename, ser_item - ) - - def _read_aggregate_file(self, directory: epath.Path) -> PyTree: - """Restores the aggregate file representing PyTree structure.""" - checkpoint_path = directory / self._aggregate_filename - if checkpoint_path.exists(): - return self._aggregate_handler.deserialize(checkpoint_path) - elif self._use_ocdbt: - raise FileNotFoundError( - f'Checkpoint structure file does not exist at {directory}.' - ) - else: - return utils.pytree_structure(directory) - async def _write_metadata_file( self, directory: epath.Path, @@ -877,7 +670,7 @@ def _read_metadata_file( directory: directory Returns: - Tree with _InternalValueMetadata as values. + orbax.checkpoint.metadata.TreeMetadata Raises: FileNotFoundError: if the metadata file is not found. @@ -890,139 +683,6 @@ def _read_metadata_file( ) return tree_metadata.TreeMetadata.from_json(json.loads(path.read_text())) - def _get_internal_metadata( - self, directory: epath.Path - ) -> Tuple[PyTree, Optional[bool]]: - """Gets limited information needed to fully restore the checkpoint. - - This information just consists of the restore type for each leaf, as well - as the aggregated value (from the msgpack file) if present, and determines - whether we need to deserialize the parameter using TypeHandler later. - - Args: - directory: directory - - Returns: - A PyTree of _InternalValueMetadata with the tree structure of the - checkpoint. - """ - aggregate_tree = self._read_aggregate_file(directory) - flat_aggregate = tree_utils.to_flat_dict( - aggregate_tree, keep_empty_nodes=True - ) - try: - metadata = self._read_metadata_file(directory) - metadata_tree = metadata.as_nested_tree(keep_empty_nodes=True) - flat_metadata = tree_utils.to_flat_dict( - metadata_tree, keep_empty_nodes=True - ) - use_zarr3 = metadata.use_zarr3 - except FileNotFoundError: - metadata_tree = None - flat_metadata = None - use_zarr3 = None - if flat_metadata is None: - flat_metadata = jax.tree.map( - lambda _: None, flat_aggregate, is_leaf=tree_utils.is_empty_or_leaf - ) - - def _get_internal_value_metadata(value_meta, value): - if value_meta is None: - if type_handlers.is_supported_empty_aggregation_type(value): - return value - restore_type = None - skip_deserialize = not utils.leaf_is_placeholder(value) - else: - if type_handlers.is_empty_typestr(value_meta.value_type): - return type_handlers.get_empty_value_from_typestr( - value_meta.value_type - ) - restore_type, skip_deserialize = ( - value_meta.value_type, - value_meta.skip_deserialize, - ) - return _InternalValueMetadata( - restore_type=restore_type, - skip_deserialize=skip_deserialize, - aggregate_value=value, - ) - - result = {} - for tuple_key in flat_metadata.keys(): - result[tuple_key] = _get_internal_value_metadata( - flat_metadata[tuple_key], flat_aggregate[tuple_key] - ) - target = metadata_tree if metadata_tree is not None else aggregate_tree - return tree_utils.from_flat_dict(result, target=target), use_zarr3 - - def _get_user_metadata(self, directory: epath.Path) -> PyTree: - """Reads metadata file and constructs user-friendly metadata. - - This will involve more file reads than are necessary for internal metadata. - Typically, we will need to perform extra reads in order to get metadata - about individual arrays. - - Args: - directory: directory - - Returns: - A PyTree of value_metadata.Metadata matching the checkpoint tree - structure. - """ - is_ocdbt_checkpoint = type_handlers.is_ocdbt_checkpoint(directory) - ts_context = type_handlers.get_ts_context() - - flat_param_infos = {} - flat_restore_types = {} - metadata = self._read_metadata_file(directory) - metadata_tree = metadata.as_nested_tree(keep_empty_nodes=False) - for keypath, value_meta in tree_utils.to_flat_dict(metadata_tree).items(): - param_name = '.'.join(keypath) - restore_type, skip_deserialize = ( - value_meta.value_type, - value_meta.skip_deserialize, - ) - flat_param_infos[keypath] = ParamInfo( - name=param_name, - path=directory / param_name, - parent_dir=directory, - skip_deserialize=skip_deserialize, - is_ocdbt_checkpoint=is_ocdbt_checkpoint, - use_zarr3=metadata.use_zarr3, - ts_context=ts_context, - ) - flat_restore_types[keypath] = restore_type - - flat_metadatas = {} - batched_param_infos = collections.defaultdict(list) - batched_keypaths = collections.defaultdict(list) - for keypath in flat_param_infos: - param_info = flat_param_infos[keypath] - restore_type = flat_restore_types[keypath] - if param_info.skip_deserialize: - flat_metadatas[keypath] = value_metadata.Metadata( - name=param_info.name, directory=directory - ) - else: - batched_keypaths[restore_type].append(keypath) - batched_param_infos[restore_type].append(param_info) - - metadata_ops = [] - for restore_type, param_infos in batched_param_infos.items(): - handler = self._type_handler_registry.get(restore_type) - metadata_ops.append(handler.metadata(param_infos)) - - async def _get_metadata(): - return await asyncio.gather(*metadata_ops) - - batched_metadatas = asyncio.run(_get_metadata()) - for keypath_batch, metadata_batch in zip( - batched_keypaths.values(), batched_metadatas - ): - for keypath, value in zip(keypath_batch, metadata_batch): - flat_metadatas[keypath] = value - return tree_utils.from_flat_dict(flat_metadatas, target=metadata_tree) - def metadata(self, directory: epath.Path) -> Optional[PyTree]: """Returns tree metadata. @@ -1049,10 +709,10 @@ def metadata(self, directory: epath.Path) -> Optional[PyTree]: Returns: tree containing metadata. """ - try: - return self._get_user_metadata(directory) - except FileNotFoundError as e: - raise FileNotFoundError('Could not locate metadata file.') from e + is_ocdbt_checkpoint = type_handlers.is_ocdbt_checkpoint(directory) + return self._read_metadata_file(directory).as_user_metadata( + directory, self._type_handler_registry, use_ocdbt=is_ocdbt_checkpoint + ) def finalize(self, directory: epath.Path) -> None: """Finalization step. @@ -1064,8 +724,6 @@ def finalize(self, directory: epath.Path) -> None: Args: directory: Path where the checkpoint is located. """ - if not self._use_ocdbt: - return merge_start_time = time.time() ts_context = type_handlers.get_ts_context() type_handlers.merge_ocdbt_per_process_files( @@ -1078,7 +736,7 @@ def finalize(self, directory: epath.Path) -> None: def close(self): """Closes the handler. Called automatically by Checkpointer.""" - self._aggregate_handler.close() + pass @register_with_handler(BasePyTreeCheckpointHandler, for_save=True) diff --git a/checkpoint/orbax/checkpoint/checkpointer.py b/checkpoint/orbax/checkpoint/checkpointer.py index 831bba109..f834c530e 100644 --- a/checkpoint/orbax/checkpoint/checkpointer.py +++ b/checkpoint/orbax/checkpoint/checkpointer.py @@ -224,7 +224,7 @@ def restore(self, directory: epath.PathLike, *args, **kwargs) -> Any: raise FileNotFoundError(f'Checkpoint at {directory} not found.') if not utils.is_checkpoint_finalized(directory): raise ValueError(f'Found incomplete checkpoint at {directory}.') - logging.info('Restoring item from %s.', directory) + logging.info('Restoring checkpoint from %s.', directory) ckpt_args = construct_checkpoint_args(self._handler, False, *args, **kwargs) restored = self._handler.restore(directory, args=ckpt_args) logging.info('Finished restoring checkpoint from %s.', directory) diff --git a/checkpoint/orbax/checkpoint/metadata/tree.py b/checkpoint/orbax/checkpoint/metadata/tree.py index b6ce56be9..8920e036e 100644 --- a/checkpoint/orbax/checkpoint/metadata/tree.py +++ b/checkpoint/orbax/checkpoint/metadata/tree.py @@ -14,11 +14,14 @@ """Utilities for working with Orbax metadata.""" +import asyncio +import collections import dataclasses import enum import functools import operator from typing import Any, Dict, Hashable, List, Optional, Tuple, TypeVar, Union +from etils import epath import jax from orbax.checkpoint import tree as tree_utils from orbax.checkpoint import type_handlers @@ -40,7 +43,7 @@ KeyPath = tuple[KeyEntry, ...] -class _KeyType(enum.Enum): +class KeyType(enum.Enum): """Enum representing PyTree key type.""" SEQUENCE = 1 @@ -50,25 +53,25 @@ def to_json(self) -> int: return self.value @classmethod - def from_json(cls, value: int) -> '_KeyType': + def from_json(cls, value: int) -> 'KeyType': return cls(value) -def _get_key_metadata_type(key: Any) -> _KeyType: +def _get_key_metadata_type(key: Any) -> KeyType: """Translates the JAX key class into a proto enum.""" if tree_utils.is_sequence_key(key): - return _KeyType.SEQUENCE + return KeyType.SEQUENCE elif tree_utils.is_dict_key(key): - return _KeyType.DICT + return KeyType.DICT else: raise ValueError(f'Unsupported KeyEntry: {type(key)}: "{key}"') -def _keypath_from_key_type(key_name: str, key_type: _KeyType) -> Any: +def _keypath_from_key_type(key_name: str, key_type: KeyType) -> Any: """Converts from Key in TreeMetadata to JAX keypath class.""" - if key_type == _KeyType.SEQUENCE: + if key_type == KeyType.SEQUENCE: return jax.tree_util.SequenceKey(int(key_name)) - elif key_type == _KeyType.DICT: + elif key_type == KeyType.DICT: return jax.tree_util.DictKey(key_name) else: raise ValueError(f'Unsupported KeyEntry: {key_type}') @@ -78,7 +81,7 @@ def _keypath_from_key_type(key_name: str, key_type: _KeyType) -> Any: class NestedKeyMetadataEntry: """Represents a key at a single level of nesting.""" nested_key_name: str - key_type: _KeyType + key_type: KeyType def to_json(self) -> Dict[str, Union[str, int]]: return { @@ -92,7 +95,7 @@ def from_json( ) -> 'NestedKeyMetadataEntry': return NestedKeyMetadataEntry( nested_key_name=json_dict[_KEY_NAME], - key_type=_KeyType.from_json(json_dict[_KEY_TYPE]), + key_type=KeyType.from_json(json_dict[_KEY_TYPE]), ) @@ -148,12 +151,16 @@ def build( cls, value: Any, save_arg: type_handlers.SaveArgs, - registry: type_handlers.TypeHandlerRegistry, + registry: Optional[type_handlers.TypeHandlerRegistry], ) -> 'ValueMetadataEntry': """Builds a ValueMetadataEntry.""" if type_handlers.is_supported_empty_aggregation_type(value): typestr = type_handlers.get_empty_value_typestr(value) skip_deserialize = True + elif registry is None: + return ValueMetadataEntry( + value_type=type_handlers.RESTORE_TYPE_UNKNOWN, skip_deserialize=False + ) else: try: handler = registry.get(type(value)) @@ -202,7 +209,7 @@ def build( keypath: KeyPath, value: Any, save_arg: type_handlers.SaveArgs, - type_handler_registry: type_handlers.TypeHandlerRegistry, + type_handler_registry: Optional[type_handlers.TypeHandlerRegistry], ) -> 'TreeMetadataEntry': """Builds a TreeMetadataEntry.""" key_metadata = KeyMetadataEntry.build(keypath) @@ -236,7 +243,7 @@ def build( cls, tree: PyTree, *, - type_handler_registry: type_handlers.TypeHandlerRegistry, + type_handler_registry: Optional[type_handlers.TypeHandlerRegistry] = None, save_args: Optional[PyTree] = None, use_zarr3: bool = False, ) -> 'TreeMetadata': @@ -272,8 +279,8 @@ def to_json(self) -> Dict[str, Any]: _TREE_METADATA_KEY: { "(top_level_key, lower_level_key)": { _KEY_METADATA_KEY: ( - {_KEY_NAME: "top_level_key", _KEY_TYPE: <_KeyType (int)>}, - {_KEY_NAME: "lower_level_key", _KEY_TYPE: <_KeyType (int)>}, + {_KEY_NAME: "top_level_key", _KEY_TYPE: }, + {_KEY_NAME: "lower_level_key", _KEY_TYPE: }, ) _VALUE_METADATA_KEY: { _VALUE_TYPE: "jax.Array", @@ -330,3 +337,60 @@ def _maybe_as_empty_value(value_metadata: ValueMetadataEntry) -> Any: (entry.jax_keypath(), _maybe_as_empty_value(entry.value_metadata)) for entry in self.tree_metadata_entries ]) + + def as_user_metadata( + self, + directory: epath.Path, + type_handler_registry: type_handlers.TypeHandlerRegistry, + *, + use_ocdbt: bool = True, + ) -> PyTree: + """Delegates to TypeHandlers to create user-facing metadata.""" + flat_param_infos = {} + flat_restore_types = {} + metadata_tree = self.as_nested_tree(keep_empty_nodes=True) + ts_context = type_handlers.get_ts_context() + for keypath, value_meta in tree_utils.to_flat_dict(metadata_tree).items(): + param_name = '.'.join(keypath) + if value_meta.skip_deserialize: + assert type_handlers.is_empty_typestr(value_meta.value_type) + flat_param_infos[keypath] = type_handlers.ParamInfo( + name=param_name, + path=directory / param_name, + parent_dir=directory, + skip_deserialize=value_meta.skip_deserialize, + is_ocdbt_checkpoint=use_ocdbt, + use_zarr3=self.use_zarr3, + ts_context=ts_context, + ) + flat_restore_types[keypath] = value_meta.value_type + + flat_metadatas = {} + batched_param_infos = collections.defaultdict(list) + batched_keypaths = collections.defaultdict(list) + for keypath in flat_param_infos: + param_info = flat_param_infos[keypath] + restore_type = flat_restore_types[keypath] + if type_handlers.is_empty_typestr(restore_type): + flat_metadatas[keypath] = type_handlers.get_empty_value_from_typestr( + restore_type + ) + else: + batched_keypaths[restore_type].append(keypath) + batched_param_infos[restore_type].append(param_info) + + metadata_ops = [] + for restore_type, param_infos in batched_param_infos.items(): + handler = type_handler_registry.get(restore_type) + metadata_ops.append(handler.metadata(param_infos)) + + async def _get_metadata(): + return await asyncio.gather(*metadata_ops) + + batched_metadatas = asyncio.run(_get_metadata()) + for keypath_batch, metadata_batch in zip( + batched_keypaths.values(), batched_metadatas + ): + for keypath, value in zip(keypath_batch, metadata_batch): + flat_metadatas[keypath] = value + return tree_utils.from_flat_dict(flat_metadatas, target=metadata_tree) diff --git a/checkpoint/orbax/checkpoint/pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/pytree_checkpoint_handler.py index f116263ca..a172fa06c 100644 --- a/checkpoint/orbax/checkpoint/pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/pytree_checkpoint_handler.py @@ -23,12 +23,14 @@ import dataclasses import json import re +import traceback import typing from typing import Any, Callable, Dict, List, Optional, Tuple, Union from absl import logging from etils import epath import jax +import numpy as np from orbax.checkpoint import aggregate_handlers from orbax.checkpoint import async_checkpoint_handler from orbax.checkpoint import base_pytree_checkpoint_handler @@ -37,9 +39,10 @@ from orbax.checkpoint import transform_utils from orbax.checkpoint import tree as tree_utils from orbax.checkpoint import type_handlers +from orbax.checkpoint import utils +from orbax.checkpoint.metadata import tree as tree_metadata import tensorstore as ts - PyTree = Any TupleKey = Tuple[str, ...] RestoreArgs = type_handlers.RestoreArgs @@ -63,7 +66,6 @@ BasePyTreeRestoreArgs = base_pytree_checkpoint_handler.BasePyTreeRestoreArgs get_byte_limiter = base_pytree_checkpoint_handler.get_byte_limiter LimitInFlightBytes = base_pytree_checkpoint_handler.LimitInFlightBytes -_InternalValueMetadata = base_pytree_checkpoint_handler._InternalValueMetadata # pylint: disable=protected-access get_param_names = base_pytree_checkpoint_handler.get_param_names _CHECKPOINT_FILE = 'checkpoint' @@ -71,6 +73,34 @@ _DEFAULT_CONCURRENT_GB = 96 +def _maybe_set_default_restore_args(args): + if isinstance(args, RestoreArgs): + return args + return RestoreArgs(restore_type=None) + + +def _try_array_cast(arr, dtype): + if dtype is not None: + if utils.is_scalar(arr): + arr = np.asarray(arr).astype(dtype).item() + else: + if hasattr(arr, 'astype'): + arr = arr.astype(dtype) + return arr + + +def _maybe_shard_array(value, args): + if hasattr(value, 'reshape') and isinstance(args, ArrayRestoreArgs): + value = value.reshape(args.global_shape) + sharding = args.sharding or jax.sharding.NamedSharding( + args.mesh, args.mesh_axes + ) + value = jax.make_array_from_callback( + value.shape, sharding, lambda idx: value[idx] + ) + return value + + def _keystr(key: Tuple[Any, ...]) -> str: return '/'.join(key) @@ -145,6 +175,7 @@ def _get_restore_parameters( directory: epath.Path, item: Optional[PyTree], structure: PyTree, + param_names: Optional[PyTree], transforms: Optional[PyTree], restore_args: Optional[PyTree], byte_limiter: Optional[LimitInFlightBytes] = None, @@ -178,6 +209,7 @@ def _get_restore_parameters( directory: Checkpoint directory. item: Optional reference item. structure: The structure of the original checkpoint. + param_names: Tree of parameter names. transforms: User-provided transformations. If None, they were not provided. Has the structure of the desired output tree. restore_args: User-provided restoration arguments. If None, they were not @@ -191,6 +223,9 @@ def _get_restore_parameters( Tuple of param_infos, and restore_args. """ flat_structure = tree_utils.to_flat_dict(structure, keep_empty_nodes=True) + if param_names is None: + param_names = get_param_names(structure) + flat_param_names = tree_utils.to_flat_dict(param_names, keep_empty_nodes=True) if restore_args is None: restore_args = jax.tree.map(lambda x: RestoreArgs(), structure) flat_restore_args = tree_utils.to_flat_dict( @@ -202,18 +237,22 @@ def _get_restore_parameters( ts_context = type_handlers.get_ts_context() def _get_param_info( - nested_name: Tuple[str, ...], - meta: _InternalValueMetadata, + name: str, + meta_or_value: Union[Any, tree_metadata.ValueMetadataEntry], ) -> Union[ParamInfo, Any]: - if type_handlers.is_supported_empty_aggregation_type(meta): + if type_handlers.is_supported_empty_aggregation_type(meta_or_value): # Empty node, ParamInfo should not be returned. - return meta - name = '.'.join(nested_name) + return meta_or_value + elif not isinstance(meta_or_value, tree_metadata.ValueMetadataEntry): + # Aggregated value. + skip_deserialize = True + else: + skip_deserialize = meta_or_value.skip_deserialize return ParamInfo( name=name, path=directory / name, parent_dir=directory, - skip_deserialize=meta.skip_deserialize, + skip_deserialize=skip_deserialize, is_ocdbt_checkpoint=is_ocdbt_checkpoint, byte_limiter=byte_limiter, use_zarr3=use_zarr3, @@ -222,9 +261,10 @@ def _get_param_info( if transforms is None: for key, meta in flat_structure.items(): - flat_param_infos[key] = _get_param_info(key, meta) + flat_param_infos[key] = _get_param_info(flat_param_names[key], meta) restore_args = tree_utils.serialize_tree( - restore_args, keep_empty_nodes=True + restore_args, + keep_empty_nodes=True, ) else: if item is None: @@ -240,7 +280,9 @@ def _get_param_info( input_key, flat_item, flat_transforms, flat_restore_args ) if maybe_input_args: - flat_param_infos[input_key] = _get_param_info(input_key, meta) + flat_param_infos[input_key] = _get_param_info( + flat_param_names[input_key], meta + ) flat_input_restore_args[input_key] = maybe_input_args elif input_key in flat_item and input_key in flat_structure: # Key is present in both input and output. @@ -256,13 +298,17 @@ def _get_param_info( # Specified `use_fallback`, but `transforms_default_to_original` # is False. This means we draw the value from the user-provided # `item`. - flat_param_infos[input_key] = _get_param_info(input_key, meta) + flat_param_infos[input_key] = _get_param_info( + flat_param_names[input_key], meta + ) flat_input_restore_args[input_key] = flat_restore_args[input_key] else: # Transform not specified. if transforms_default_to_original: # Key/value is carried over from the original unchanged. - flat_param_infos[input_key] = _get_param_info(input_key, meta) + flat_param_infos[input_key] = _get_param_info( + flat_param_names[input_key], meta + ) flat_input_restore_args[input_key] = flat_restore_args[input_key] else: # Take the value from the user-provided `item`, ignoring any value @@ -367,6 +413,21 @@ def _get_impl_save_args( item=item, save_args=save_args, ) + + def _overwrite_aggregate(sa: SaveArgs) -> SaveArgs: + if sa.aggregate: + logging.log_first_n( + logging.WARNING, + 'The `aggregate` option is deprecated and will be ignored.', + 5, + ) + sa.aggregate = False + return sa + + if args.save_args is not None: + args.save_args = jax.tree_util.tree_map( + _overwrite_aggregate, args.save_args + ) return BasePyTreeSaveArgs( item=args.item, save_args=args.save_args, @@ -435,7 +496,6 @@ def __init__( self._type_handler_registry = type_handler_registry self._handler_impl = handler_impl or BasePyTreeCheckpointHandler( - aggregate_filename=aggregate_filename, concurrent_gb=concurrent_gb, use_ocdbt=use_ocdbt, use_zarr3=use_zarr3, @@ -512,6 +572,54 @@ def save( args = _get_impl_save_args(item, save_args, args) self._handler_impl.save(directory, args=args) + async def _maybe_deserialize( + self, + item: PyTree, + metadata: PyTree, + param_infos: PyTree, + restore_args: PyTree, + ) -> PyTree: + """Deserializes values or gets them from the aggregate file.""" + + # Handle parameters from aggregate file. + def _process_aggregated_value(meta_or_value, args): + if not isinstance(meta_or_value, tree_metadata.ValueMetadataEntry): + meta_or_value = _try_array_cast(meta_or_value, args.dtype) + meta_or_value = _maybe_shard_array(meta_or_value, args) + return meta_or_value + + flat_aggregate = tree_utils.to_flat_dict( + jax.tree_util.tree_map( + _process_aggregated_value, metadata, restore_args + ), + ) + + batch_requests = ( + base_pytree_checkpoint_handler.batched_serialization_requests( + metadata, + param_infos, + restore_args, + self._type_handler_registry, + ) + ) + deserialized_batches = [] + deserialized_batches_ops = [] + for request in batch_requests: + deserialized_batches_ops.append( + request.handler.deserialize(request.infos, request.args) + ) + deserialized_batches += await asyncio.gather(*deserialized_batches_ops) + + flat_restored = {} + for request, deserialized in zip(batch_requests, deserialized_batches): + for key, value in zip(request.keys, deserialized): + flat_restored[key] = value + # Add in any values which were not deserialized, coming from aggregate file. + for key in flat_aggregate.keys(): + if key not in flat_restored: + flat_restored[key] = flat_aggregate[key] + return tree_utils.from_flat_dict(flat_restored, target=item) + def restore( self, directory: epath.Path, @@ -609,6 +717,7 @@ class TrainState: ValueError: `transforms` is provided without `item`. ValueError: `transforms` contains elements with `multi_value_fn`. """ + logging.info(traceback.format_exc()) if not directory.exists(): raise FileNotFoundError( f'Requested directory for restore does not exist at {directory}.' @@ -626,16 +735,27 @@ class TrainState: transforms_default_to_original, legacy_transform_fn, ) + logging.info(args) + logging.info(list(directory.iterdir())) item = args.item restore_args = args.restore_args transforms = args.transforms transforms_default_to_original = args.transforms_default_to_original legacy_transform_fn = args.legacy_transform_fn - # Delegate to `TreeCheckpointHandler` as long as transformation options are - # not specified and metadata file exists. + try: + can_ignore_aggregate_file = utils.all_leaves_are_placeholders( + self._read_aggregate_file(directory) + ) + except FileNotFoundError: + can_ignore_aggregate_file = True + + # Delegate to `BasePyTreeCheckpointHandler` as long as transformation + # options are not specified and metadata file exists and we do not need to + # read from aggregate file. if ( (directory / _METADATA_FILE).exists() + and can_ignore_aggregate_file and transforms is None and legacy_transform_fn is None ): @@ -651,15 +771,15 @@ class TrainState: f'Requested directory for restore does not exist at {directory}' ) byte_limiter = get_byte_limiter(self._concurrent_gb) - structure, use_zarr3_metadata = self._handler_impl._get_internal_metadata( # pylint: disable=protected-access - directory - ) + structure, use_zarr3_metadata = self._get_internal_metadata(directory) # `checkpoint_restore_args` has a structure relative to the checkpoint, # while `restore_args` remains structured relative to the output. + param_infos, checkpoint_restore_args = _get_restore_parameters( directory, item, structure, + self._handler_impl.get_param_names(structure), transforms, restore_args, byte_limiter=byte_limiter, @@ -679,14 +799,16 @@ class TrainState: restore_args = jax.tree.map(lambda x: RestoreArgs(), item) checkpoint_restore_args = restore_args - def _maybe_set_default_restore_types( - meta: _InternalValueMetadata, arg: RestoreArgs - ): - if not meta.skip_deserialize and meta.restore_type is None: + def _maybe_set_default_restore_types(value_meta: Any, arg: RestoreArgs): + if ( + isinstance(value_meta, tree_metadata.ValueMetadataEntry) + and not value_meta.skip_deserialize + and value_meta.value_type == type_handlers.RESTORE_TYPE_UNKNOWN + ): return dataclasses.replace( - meta, restore_type=type_handlers.default_restore_type(arg) + value_meta, value_type=type_handlers.default_restore_type(arg) ) - return meta + return value_meta # If metadata file was missing in the checkpoint, we need to decide # restore_type based on RestoreArgs. @@ -695,8 +817,8 @@ def _maybe_set_default_restore_types( ) restored_item = asyncio.run( - self._handler_impl._maybe_deserialize( # pylint: disable=protected-access - structure, param_infos, checkpoint_restore_args + self._maybe_deserialize( + structure, structure, param_infos, checkpoint_restore_args ) ) @@ -723,8 +845,120 @@ def _maybe_set_default_restore_types( return restored_item def _read_aggregate_file(self, directory: epath.Path) -> PyTree: - """Reads the aggregate file.""" - return self._handler_impl._read_aggregate_file(directory) # pylint: disable=protected-access + """Restores the aggregate file representing PyTree structure.""" + checkpoint_path = directory / self._aggregate_filename + if checkpoint_path.exists(): + return self._aggregate_handler.deserialize(checkpoint_path) + elif self._use_ocdbt: + raise FileNotFoundError( + f'Checkpoint structure file does not exist at {directory}.' + ) + else: + return utils.pytree_structure(directory) + + def _get_internal_metadata( + self, directory: epath.Path + ) -> Tuple[PyTree, Optional[bool]]: + """Gets limited information needed to fully restore the checkpoint. + + This information just consists of the restore type for each leaf, as well + as the aggregated value (from the msgpack file) if present, and determines + whether we need to deserialize the parameter using TypeHandler later. + + Args: + directory: directory + + Returns: + A PyTree with leaves of ValueMetadataEntry or real values if restored from + the aggregate file (or if empty nodes). + + Raises: + FileNotFoundError: no structure could be identified for the checkpoint at + `directory`. + """ + # Try reading metadata file. + try: + metadata = self._handler_impl._read_metadata_file(directory) # pylint: disable=protected-access + use_zarr3 = metadata.use_zarr3 + metadata_tree = metadata.as_nested_tree(keep_empty_nodes=True) + flat_metadata = tree_utils.to_flat_dict( + metadata_tree, keep_empty_nodes=True + ) + except FileNotFoundError: + metadata_tree = None + flat_metadata = None + use_zarr3 = None + # Try reading aggregate file. + try: + aggregate_tree = self._read_aggregate_file(directory) + flat_aggregate = tree_utils.to_flat_dict( + aggregate_tree, keep_empty_nodes=True + ) + except FileNotFoundError: + aggregate_tree = None + flat_aggregate = None + + def _is_empty_aggregate_value(value): + return type_handlers.is_supported_empty_aggregation_type( + value + ) or not utils.leaf_is_placeholder(value) + + def _process_aggregate_leaf(value): + if _is_empty_aggregate_value(value): + return value + return tree_metadata.ValueMetadataEntry( + value_type=type_handlers.RESTORE_TYPE_UNKNOWN, + skip_deserialize=False, + ) + + def _process_metadata_and_aggregate_leaves(value_meta, value): + if _is_empty_aggregate_value(value): + return value + if type_handlers.is_empty_typestr(value_meta.value_type): + return type_handlers.get_empty_value_from_typestr(value_meta.value_type) + return value_meta + + # Handle cases of missing metadata and/or aggregate files. + structure_tree = metadata_tree or aggregate_tree + if flat_metadata is None and flat_aggregate is None: + raise FileNotFoundError( + f'No structure could be identified for the checkpoint at {directory}.' + ) + elif flat_metadata is None: + # Metadata file is missing. This is an older checkpoint. + flat_structure = jax.tree_util.tree_map( + _process_aggregate_leaf, + flat_aggregate, + is_leaf=tree_utils.is_empty_or_leaf, + ) + elif flat_aggregate is None: + # Aggregate file is missing, so we can just use the metadata_tree as the + # structure. This is a newer checkpoint. + return metadata_tree, use_zarr3 + else: + # Avoid tree_map because input trees may be mismatched (due to empty + # values missing from msgpack structure). + flat_structure = {} + for tuple_key in flat_metadata.keys(): + value_meta = flat_metadata[tuple_key] + if tuple_key in flat_aggregate: + flat_structure[tuple_key] = _process_metadata_and_aggregate_leaves( + value_meta, flat_aggregate[tuple_key] + ) + else: + if type_handlers.is_empty_typestr(value_meta.value_type): + flat_structure[tuple_key] = ( + type_handlers.get_empty_value_from_typestr( + value_meta.value_type + ) + ) + else: + flat_structure[tuple_key] = value_meta + + return ( + tree_utils.from_flat_dict(flat_structure, target=structure_tree), + use_zarr3, + ) def metadata(self, directory: epath.Path) -> Optional[PyTree]: """Returns tree metadata. diff --git a/checkpoint/orbax/checkpoint/standard_checkpoint_handler_test_utils.py b/checkpoint/orbax/checkpoint/standard_checkpoint_handler_test_utils.py index 25433c402..1a3c32d5b 100644 --- a/checkpoint/orbax/checkpoint/standard_checkpoint_handler_test_utils.py +++ b/checkpoint/orbax/checkpoint/standard_checkpoint_handler_test_utils.py @@ -251,14 +251,20 @@ def test_empty_pytrees(self): def test_none_node_in_restore_args(self): """Test case.""" - - item = {'b': np.array([1, 2, 3])} + devices = np.asarray(jax.devices()) + mesh = jax.sharding.Mesh(devices, ('x',)) + mesh_axes = jax.sharding.PartitionSpec( + 'x', + ) + arr = test_utils.create_sharded_array(np.arange(16), mesh, mesh_axes) + item = {'b': arr} self.handler.save(self.directory, args=self.save_args_cls(item)) - with self.assertRaises(ValueError): - self.handler.restore( - self.directory, args=self.restore_args_cls({'b': None}) - ) + restored = self.handler.restore( + self.directory, + args=self.restore_args_cls({'b': None}), + ) + test_utils.assert_tree_equal(self, restored, {'b': None}) def test_masked_shape_dtype_struct(self): """Test case."""