diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 6c7bc936..c66c451d 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -7,15 +7,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.5.12] - 2024-05-15 + ### Added - Introduce `should_save_fn` in `OrbaxCheckpointManagerOptions`. - Introduce `StepAlreadyExistsError` to be raised on save with existing step. +### Changed +- Delegate to BasePyTreeCheckpointHandler rather than inheriting from it. + ### Fixed - Fix empty metadata file error: Expecting value: line 1 column 1 (char 0) -## [0.5.10] - 2024-05-10 +## [0.5.11] - 2024-05-10 ### Added - Implement restoration in emergency.CheckpointManager. diff --git a/checkpoint/orbax/__init__.py b/checkpoint/orbax/__init__.py index b6be08d3..6b5093fb 100644 --- a/checkpoint/orbax/__init__.py +++ b/checkpoint/orbax/__init__.py @@ -50,6 +50,7 @@ from orbax.checkpoint.proto_checkpoint_handler import ProtoCheckpointHandler from orbax.checkpoint.pytree_checkpoint_handler import ArrayRestoreArgs from orbax.checkpoint.pytree_checkpoint_handler import PyTreeCheckpointHandler +from orbax.checkpoint.base_pytree_checkpoint_handler import BasePyTreeCheckpointHandler from orbax.checkpoint.pytree_checkpoint_handler import RestoreArgs from orbax.checkpoint.pytree_checkpoint_handler import SaveArgs from orbax.checkpoint.pytree_checkpointer import PyTreeCheckpointer @@ -72,4 +73,4 @@ # A new PyPI release will be pushed everytime `__version__` is increased. # Also modify version and date in CHANGELOG. -__version__ = '0.5.11' +__version__ = '0.5.12' diff --git a/checkpoint/orbax/checkpoint/__init__.py b/checkpoint/orbax/checkpoint/__init__.py index b6be08d3..6b5093fb 100644 --- a/checkpoint/orbax/checkpoint/__init__.py +++ b/checkpoint/orbax/checkpoint/__init__.py @@ -50,6 +50,7 @@ from orbax.checkpoint.proto_checkpoint_handler import ProtoCheckpointHandler from orbax.checkpoint.pytree_checkpoint_handler import ArrayRestoreArgs from orbax.checkpoint.pytree_checkpoint_handler import PyTreeCheckpointHandler +from orbax.checkpoint.base_pytree_checkpoint_handler import BasePyTreeCheckpointHandler from orbax.checkpoint.pytree_checkpoint_handler import RestoreArgs from orbax.checkpoint.pytree_checkpoint_handler import SaveArgs from orbax.checkpoint.pytree_checkpointer import PyTreeCheckpointer @@ -72,4 +73,4 @@ # A new PyPI release will be pushed everytime `__version__` is increased. # Also modify version and date in CHANGELOG. -__version__ = '0.5.11' +__version__ = '0.5.12' diff --git a/checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py index a00ab746..d8fac1ca 100644 --- a/checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py @@ -376,7 +376,7 @@ def __init__( '/jax/orbax/pytree_checkpoint_handler/init/ocdbt' ) - def _get_param_names(self, item: PyTree) -> PyTree: + def get_param_names(self, item: PyTree) -> PyTree: """Gets parameter names for PyTree elements.""" return get_param_names(item) @@ -420,7 +420,7 @@ def _get_param_infos( """ if not item: raise ValueError('Found empty item') - names = self._get_param_names(item) + names = self.get_param_names(item) all_params_aggregated = True def _param_info(value, name, args): diff --git a/checkpoint/orbax/checkpoint/pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/pytree_checkpoint_handler.py index 89b31f3d..4766f00c 100644 --- a/checkpoint/orbax/checkpoint/pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/pytree_checkpoint_handler.py @@ -30,6 +30,7 @@ from etils import epath import jax from orbax.checkpoint import aggregate_handlers +from orbax.checkpoint import async_checkpoint_handler from orbax.checkpoint import base_pytree_checkpoint_handler from orbax.checkpoint import checkpoint_args from orbax.checkpoint import future @@ -63,6 +64,7 @@ get_byte_limiter = base_pytree_checkpoint_handler.get_byte_limiter LimitInFlightBytes = base_pytree_checkpoint_handler.LimitInFlightBytes _InternalValueMetadata = base_pytree_checkpoint_handler._InternalValueMetadata # pylint: disable=protected-access +get_param_names = base_pytree_checkpoint_handler.get_param_names _CHECKPOINT_FILE = 'checkpoint' _METADATA_FILE = base_pytree_checkpoint_handler.METADATA_FILE @@ -365,9 +367,7 @@ def _get_impl_save_args( ) -# TODO(b/339456651) Avoid inheriting from BasePyTreeCheckpointHandler when -# when possible. -class PyTreeCheckpointHandler(BasePyTreeCheckpointHandler): +class PyTreeCheckpointHandler(async_checkpoint_handler.AsyncCheckpointHandler): """A CheckpointHandler implementation for any PyTree structure. See JAX documentation for more information on what consistutes a "PyTree". @@ -393,6 +393,7 @@ def __init__( use_zarr3: bool = False, primary_host: Optional[int] = 0, type_handler_registry: TypeHandlerRegistry = type_handlers.GLOBAL_TYPE_HANDLER_REGISTRY, + handler_impl: Optional[BasePyTreeCheckpointHandler] = None, ): """Creates PyTreeCheckpointHandler. @@ -410,6 +411,7 @@ def __init__( the case that all hosts are only working with local storage. type_handler_registry: a type_handlers.TypeHandlerRegistry. If not specified, the global type handler registry will be used. + handler_impl: Allows overriding the internal implementation. """ self._aggregate_handler = MsgpackHandler(primary_host=primary_host) if aggregate_filename is None: @@ -421,7 +423,7 @@ def __init__( self._primary_host = primary_host self._type_handler_registry = type_handler_registry - super().__init__( + self._handler_impl = handler_impl or BasePyTreeCheckpointHandler( aggregate_filename=aggregate_filename, concurrent_gb=concurrent_gb, use_ocdbt=use_ocdbt, @@ -486,7 +488,7 @@ async def async_save( the data from its source will be awaited in this function. """ args = _get_impl_save_args(item, save_args, args) - return await super().async_save(directory, args=args) + return await self._handler_impl.async_save(directory, args=args) def save( self, @@ -497,7 +499,7 @@ def save( ): """Saves the provided item. See async_save.""" args = _get_impl_save_args(item, save_args, args) - super().save(directory, args=args) + self._handler_impl.save(directory, args=args) def restore( self, @@ -630,7 +632,7 @@ class TrainState: item, restore_args=restore_args, ) - return super().restore(directory, args=args) + return self._handler_impl.restore(directory, args=args) logging.debug('directory=%s, restore_args=%s', directory, restore_args) if not directory.exists(): @@ -638,7 +640,9 @@ class TrainState: f'Requested directory for restore does not exist at {directory}' ) byte_limiter = get_byte_limiter(self._concurrent_gb) - structure, use_zarr3_metadata = self._get_internal_metadata(directory) + structure, use_zarr3_metadata = self._handler_impl._get_internal_metadata( # pylint: disable=protected-access + directory + ) # `checkpoint_restore_args` has a structure relative to the checkpoint, # while `restore_args` remains structured relative to the output. param_infos, checkpoint_restore_args = _get_restore_parameters( @@ -680,7 +684,9 @@ def _maybe_set_default_restore_types( ) restored_item = asyncio.run( - self._maybe_deserialize(structure, param_infos, checkpoint_restore_args) + self._handler_impl._maybe_deserialize( # pylint: disable=protected-access + structure, param_infos, checkpoint_restore_args + ) ) if not legacy_transform_fn: @@ -705,6 +711,50 @@ def _maybe_set_default_restore_types( return restored_item + def metadata(self, directory: epath.Path) -> Optional[PyTree]: + """Returns tree metadata. + + The result will be a PyTree matching the structure of the saved checkpoint. + Note that if the item saved was a custom class, the restored metadata will + be returned as a nested dictionary representation. + + Example:: + + { + 'layer0': { + 'w': ArrayMetadata(dtype=jnp.float32, shape=(8, 8), shards=(1, 2)), + 'b': ArrayMetadata(dtype=jnp.float32, shape=(8,), shards=(1,)), + }, + 'step': ScalarMetadata(dtype=jnp.int64), + } + + If the required metadata file is not present, this method will raise an + error. + + Args: + directory: checkpoint location. + + Returns: + tree containing metadata. + """ + return self._handler_impl.metadata(directory) + + def finalize(self, directory: epath.Path) -> None: + """Finalization step. + + Called automatically by the Checkpointer/AsyncCheckpointer just before the + checkpoint is considered "finalized" in the sense of ensuring atomicity. See + documentation for `type_handlers.merge_ocdbt_per_process_files`. + + Args: + directory: Path where the checkpoint is located. + """ + self._handler_impl.finalize(directory) + + def close(self): + """Closes the handler. Called automatically by Checkpointer.""" + self._handler_impl.close() + @register_with_handler(PyTreeCheckpointHandler, for_save=True) @dataclasses.dataclass