Skip to content

Commit

Permalink
End support for legacy Pax formats with directory structure:
Browse files Browse the repository at this point in the history
```
checkpoints/
  checkpoint_0  # msgpack file corresponding to step
```

Users can manually restructure to the following:

```
checkpoints/
  checkpoint_0/
    checkpoint  # renamed msgpack file
```

to ensure continued support.

PiperOrigin-RevId: 634413475
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed May 16, 2024
1 parent ba16e95 commit 595839d
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 14 deletions.
7 changes: 6 additions & 1 deletion checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [0.5.12] - 2024-05-15

### Added
- Introduce `should_save_fn` in `OrbaxCheckpointManagerOptions`.
- Introduce `StepAlreadyExistsError` to be raised on save with existing step.

### Changed
- Delegate to BasePyTreeCheckpointHandler rather than inheriting from it.

### Fixed
- Fix empty metadata file error: Expecting value: line 1 column 1 (char 0)


## [0.5.10] - 2024-05-10
## [0.5.11] - 2024-05-10

### Added
- Implement restoration in emergency.CheckpointManager.
Expand Down
3 changes: 2 additions & 1 deletion checkpoint/orbax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from orbax.checkpoint.proto_checkpoint_handler import ProtoCheckpointHandler
from orbax.checkpoint.pytree_checkpoint_handler import ArrayRestoreArgs
from orbax.checkpoint.pytree_checkpoint_handler import PyTreeCheckpointHandler
from orbax.checkpoint.base_pytree_checkpoint_handler import BasePyTreeCheckpointHandler
from orbax.checkpoint.pytree_checkpoint_handler import RestoreArgs
from orbax.checkpoint.pytree_checkpoint_handler import SaveArgs
from orbax.checkpoint.pytree_checkpointer import PyTreeCheckpointer
Expand All @@ -72,4 +73,4 @@

# A new PyPI release will be pushed everytime `__version__` is increased.
# Also modify version and date in CHANGELOG.
__version__ = '0.5.11'
__version__ = '0.5.12'
3 changes: 2 additions & 1 deletion checkpoint/orbax/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from orbax.checkpoint.proto_checkpoint_handler import ProtoCheckpointHandler
from orbax.checkpoint.pytree_checkpoint_handler import ArrayRestoreArgs
from orbax.checkpoint.pytree_checkpoint_handler import PyTreeCheckpointHandler
from orbax.checkpoint.base_pytree_checkpoint_handler import BasePyTreeCheckpointHandler
from orbax.checkpoint.pytree_checkpoint_handler import RestoreArgs
from orbax.checkpoint.pytree_checkpoint_handler import SaveArgs
from orbax.checkpoint.pytree_checkpointer import PyTreeCheckpointer
Expand All @@ -72,4 +73,4 @@

# A new PyPI release will be pushed everytime `__version__` is increased.
# Also modify version and date in CHANGELOG.
__version__ = '0.5.11'
__version__ = '0.5.12'
4 changes: 2 additions & 2 deletions checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def __init__(
'/jax/orbax/pytree_checkpoint_handler/init/ocdbt'
)

def _get_param_names(self, item: PyTree) -> PyTree:
def get_param_names(self, item: PyTree) -> PyTree:
"""Gets parameter names for PyTree elements."""
return get_param_names(item)

Expand Down Expand Up @@ -420,7 +420,7 @@ def _get_param_infos(
"""
if not item:
raise ValueError('Found empty item')
names = self._get_param_names(item)
names = self.get_param_names(item)
all_params_aggregated = True

def _param_info(value, name, args):
Expand Down
68 changes: 59 additions & 9 deletions checkpoint/orbax/checkpoint/pytree_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from etils import epath
import jax
from orbax.checkpoint import aggregate_handlers
from orbax.checkpoint import async_checkpoint_handler
from orbax.checkpoint import base_pytree_checkpoint_handler
from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import future
Expand Down Expand Up @@ -63,6 +64,7 @@
get_byte_limiter = base_pytree_checkpoint_handler.get_byte_limiter
LimitInFlightBytes = base_pytree_checkpoint_handler.LimitInFlightBytes
_InternalValueMetadata = base_pytree_checkpoint_handler._InternalValueMetadata # pylint: disable=protected-access
get_param_names = base_pytree_checkpoint_handler.get_param_names

_CHECKPOINT_FILE = 'checkpoint'
_METADATA_FILE = base_pytree_checkpoint_handler.METADATA_FILE
Expand Down Expand Up @@ -365,9 +367,7 @@ def _get_impl_save_args(
)


# TODO(b/339456651) Avoid inheriting from BasePyTreeCheckpointHandler when
# when possible.
class PyTreeCheckpointHandler(BasePyTreeCheckpointHandler):
class PyTreeCheckpointHandler(async_checkpoint_handler.AsyncCheckpointHandler):
"""A CheckpointHandler implementation for any PyTree structure.
See JAX documentation for more information on what consistutes a "PyTree".
Expand All @@ -393,6 +393,7 @@ def __init__(
use_zarr3: bool = False,
primary_host: Optional[int] = 0,
type_handler_registry: TypeHandlerRegistry = type_handlers.GLOBAL_TYPE_HANDLER_REGISTRY,
handler_impl: Optional[BasePyTreeCheckpointHandler] = None,
):
"""Creates PyTreeCheckpointHandler.
Expand All @@ -410,6 +411,7 @@ def __init__(
the case that all hosts are only working with local storage.
type_handler_registry: a type_handlers.TypeHandlerRegistry. If not
specified, the global type handler registry will be used.
handler_impl: Allows overriding the internal implementation.
"""
self._aggregate_handler = MsgpackHandler(primary_host=primary_host)
if aggregate_filename is None:
Expand All @@ -421,7 +423,7 @@ def __init__(
self._primary_host = primary_host
self._type_handler_registry = type_handler_registry

super().__init__(
self._handler_impl = handler_impl or BasePyTreeCheckpointHandler(
aggregate_filename=aggregate_filename,
concurrent_gb=concurrent_gb,
use_ocdbt=use_ocdbt,
Expand Down Expand Up @@ -486,7 +488,7 @@ async def async_save(
the data from its source will be awaited in this function.
"""
args = _get_impl_save_args(item, save_args, args)
return await super().async_save(directory, args=args)
return await self._handler_impl.async_save(directory, args=args)

def save(
self,
Expand All @@ -497,7 +499,7 @@ def save(
):
"""Saves the provided item. See async_save."""
args = _get_impl_save_args(item, save_args, args)
super().save(directory, args=args)
self._handler_impl.save(directory, args=args)

def restore(
self,
Expand Down Expand Up @@ -630,15 +632,17 @@ class TrainState:
item,
restore_args=restore_args,
)
return super().restore(directory, args=args)
return self._handler_impl.restore(directory, args=args)

logging.debug('directory=%s, restore_args=%s', directory, restore_args)
if not directory.exists():
raise FileNotFoundError(
f'Requested directory for restore does not exist at {directory}'
)
byte_limiter = get_byte_limiter(self._concurrent_gb)
structure, use_zarr3_metadata = self._get_internal_metadata(directory)
structure, use_zarr3_metadata = self._handler_impl._get_internal_metadata( # pylint: disable=protected-access
directory
)
# `checkpoint_restore_args` has a structure relative to the checkpoint,
# while `restore_args` remains structured relative to the output.
param_infos, checkpoint_restore_args = _get_restore_parameters(
Expand Down Expand Up @@ -680,7 +684,9 @@ def _maybe_set_default_restore_types(
)

restored_item = asyncio.run(
self._maybe_deserialize(structure, param_infos, checkpoint_restore_args)
self._handler_impl._maybe_deserialize( # pylint: disable=protected-access
structure, param_infos, checkpoint_restore_args
)
)

if not legacy_transform_fn:
Expand All @@ -705,6 +711,50 @@ def _maybe_set_default_restore_types(

return restored_item

def metadata(self, directory: epath.Path) -> Optional[PyTree]:
"""Returns tree metadata.
The result will be a PyTree matching the structure of the saved checkpoint.
Note that if the item saved was a custom class, the restored metadata will
be returned as a nested dictionary representation.
Example::
{
'layer0': {
'w': ArrayMetadata(dtype=jnp.float32, shape=(8, 8), shards=(1, 2)),
'b': ArrayMetadata(dtype=jnp.float32, shape=(8,), shards=(1,)),
},
'step': ScalarMetadata(dtype=jnp.int64),
}
If the required metadata file is not present, this method will raise an
error.
Args:
directory: checkpoint location.
Returns:
tree containing metadata.
"""
return self._handler_impl.metadata(directory)

def finalize(self, directory: epath.Path) -> None:
"""Finalization step.
Called automatically by the Checkpointer/AsyncCheckpointer just before the
checkpoint is considered "finalized" in the sense of ensuring atomicity. See
documentation for `type_handlers.merge_ocdbt_per_process_files`.
Args:
directory: Path where the checkpoint is located.
"""
self._handler_impl.finalize(directory)

def close(self):
"""Closes the handler. Called automatically by Checkpointer."""
self._handler_impl.close()


@register_with_handler(PyTreeCheckpointHandler, for_save=True)
@dataclasses.dataclass
Expand Down

0 comments on commit 595839d

Please sign in to comment.