Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

End support for legacy Pax formats with directory structure: #882

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [0.5.12] - 2024-05-15

### Added
- Introduce `should_save_fn` in `OrbaxCheckpointManagerOptions`.
- Introduce `StepAlreadyExistsError` to be raised on save with existing step.

### Changed
- Delegate to BasePyTreeCheckpointHandler rather than inheriting from it.

### Fixed
- Fix empty metadata file error: Expecting value: line 1 column 1 (char 0)


## [0.5.10] - 2024-05-10
## [0.5.11] - 2024-05-10

### Added
- Implement restoration in emergency.CheckpointManager.
Expand Down
3 changes: 2 additions & 1 deletion checkpoint/orbax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from orbax.checkpoint.proto_checkpoint_handler import ProtoCheckpointHandler
from orbax.checkpoint.pytree_checkpoint_handler import ArrayRestoreArgs
from orbax.checkpoint.pytree_checkpoint_handler import PyTreeCheckpointHandler
from orbax.checkpoint.base_pytree_checkpoint_handler import BasePyTreeCheckpointHandler
from orbax.checkpoint.pytree_checkpoint_handler import RestoreArgs
from orbax.checkpoint.pytree_checkpoint_handler import SaveArgs
from orbax.checkpoint.pytree_checkpointer import PyTreeCheckpointer
Expand All @@ -72,4 +73,4 @@

# A new PyPI release will be pushed everytime `__version__` is increased.
# Also modify version and date in CHANGELOG.
__version__ = '0.5.11'
__version__ = '0.5.12'
3 changes: 2 additions & 1 deletion checkpoint/orbax/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from orbax.checkpoint.proto_checkpoint_handler import ProtoCheckpointHandler
from orbax.checkpoint.pytree_checkpoint_handler import ArrayRestoreArgs
from orbax.checkpoint.pytree_checkpoint_handler import PyTreeCheckpointHandler
from orbax.checkpoint.base_pytree_checkpoint_handler import BasePyTreeCheckpointHandler
from orbax.checkpoint.pytree_checkpoint_handler import RestoreArgs
from orbax.checkpoint.pytree_checkpoint_handler import SaveArgs
from orbax.checkpoint.pytree_checkpointer import PyTreeCheckpointer
Expand All @@ -72,4 +73,4 @@

# A new PyPI release will be pushed everytime `__version__` is increased.
# Also modify version and date in CHANGELOG.
__version__ = '0.5.11'
__version__ = '0.5.12'
4 changes: 2 additions & 2 deletions checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def __init__(
'/jax/orbax/pytree_checkpoint_handler/init/ocdbt'
)

def _get_param_names(self, item: PyTree) -> PyTree:
def get_param_names(self, item: PyTree) -> PyTree:
"""Gets parameter names for PyTree elements."""
return get_param_names(item)

Expand Down Expand Up @@ -420,7 +420,7 @@ def _get_param_infos(
"""
if not item:
raise ValueError('Found empty item')
names = self._get_param_names(item)
names = self.get_param_names(item)
all_params_aggregated = True

def _param_info(value, name, args):
Expand Down
68 changes: 59 additions & 9 deletions checkpoint/orbax/checkpoint/pytree_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from etils import epath
import jax
from orbax.checkpoint import aggregate_handlers
from orbax.checkpoint import async_checkpoint_handler
from orbax.checkpoint import base_pytree_checkpoint_handler
from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import future
Expand Down Expand Up @@ -63,6 +64,7 @@
get_byte_limiter = base_pytree_checkpoint_handler.get_byte_limiter
LimitInFlightBytes = base_pytree_checkpoint_handler.LimitInFlightBytes
_InternalValueMetadata = base_pytree_checkpoint_handler._InternalValueMetadata # pylint: disable=protected-access
get_param_names = base_pytree_checkpoint_handler.get_param_names

_CHECKPOINT_FILE = 'checkpoint'
_METADATA_FILE = base_pytree_checkpoint_handler.METADATA_FILE
Expand Down Expand Up @@ -365,9 +367,7 @@ def _get_impl_save_args(
)


# TODO(b/339456651) Avoid inheriting from BasePyTreeCheckpointHandler when
# when possible.
class PyTreeCheckpointHandler(BasePyTreeCheckpointHandler):
class PyTreeCheckpointHandler(async_checkpoint_handler.AsyncCheckpointHandler):
"""A CheckpointHandler implementation for any PyTree structure.

See JAX documentation for more information on what consistutes a "PyTree".
Expand All @@ -393,6 +393,7 @@ def __init__(
use_zarr3: bool = False,
primary_host: Optional[int] = 0,
type_handler_registry: TypeHandlerRegistry = type_handlers.GLOBAL_TYPE_HANDLER_REGISTRY,
handler_impl: Optional[BasePyTreeCheckpointHandler] = None,
):
"""Creates PyTreeCheckpointHandler.

Expand All @@ -410,6 +411,7 @@ def __init__(
the case that all hosts are only working with local storage.
type_handler_registry: a type_handlers.TypeHandlerRegistry. If not
specified, the global type handler registry will be used.
handler_impl: Allows overriding the internal implementation.
"""
self._aggregate_handler = MsgpackHandler(primary_host=primary_host)
if aggregate_filename is None:
Expand All @@ -421,7 +423,7 @@ def __init__(
self._primary_host = primary_host
self._type_handler_registry = type_handler_registry

super().__init__(
self._handler_impl = handler_impl or BasePyTreeCheckpointHandler(
aggregate_filename=aggregate_filename,
concurrent_gb=concurrent_gb,
use_ocdbt=use_ocdbt,
Expand Down Expand Up @@ -486,7 +488,7 @@ async def async_save(
the data from its source will be awaited in this function.
"""
args = _get_impl_save_args(item, save_args, args)
return await super().async_save(directory, args=args)
return await self._handler_impl.async_save(directory, args=args)

def save(
self,
Expand All @@ -497,7 +499,7 @@ def save(
):
"""Saves the provided item. See async_save."""
args = _get_impl_save_args(item, save_args, args)
super().save(directory, args=args)
self._handler_impl.save(directory, args=args)

def restore(
self,
Expand Down Expand Up @@ -630,15 +632,17 @@ class TrainState:
item,
restore_args=restore_args,
)
return super().restore(directory, args=args)
return self._handler_impl.restore(directory, args=args)

logging.debug('directory=%s, restore_args=%s', directory, restore_args)
if not directory.exists():
raise FileNotFoundError(
f'Requested directory for restore does not exist at {directory}'
)
byte_limiter = get_byte_limiter(self._concurrent_gb)
structure, use_zarr3_metadata = self._get_internal_metadata(directory)
structure, use_zarr3_metadata = self._handler_impl._get_internal_metadata( # pylint: disable=protected-access
directory
)
# `checkpoint_restore_args` has a structure relative to the checkpoint,
# while `restore_args` remains structured relative to the output.
param_infos, checkpoint_restore_args = _get_restore_parameters(
Expand Down Expand Up @@ -680,7 +684,9 @@ def _maybe_set_default_restore_types(
)

restored_item = asyncio.run(
self._maybe_deserialize(structure, param_infos, checkpoint_restore_args)
self._handler_impl._maybe_deserialize( # pylint: disable=protected-access
structure, param_infos, checkpoint_restore_args
)
)

if not legacy_transform_fn:
Expand All @@ -705,6 +711,50 @@ def _maybe_set_default_restore_types(

return restored_item

def metadata(self, directory: epath.Path) -> Optional[PyTree]:
"""Returns tree metadata.

The result will be a PyTree matching the structure of the saved checkpoint.
Note that if the item saved was a custom class, the restored metadata will
be returned as a nested dictionary representation.

Example::

{
'layer0': {
'w': ArrayMetadata(dtype=jnp.float32, shape=(8, 8), shards=(1, 2)),
'b': ArrayMetadata(dtype=jnp.float32, shape=(8,), shards=(1,)),
},
'step': ScalarMetadata(dtype=jnp.int64),
}

If the required metadata file is not present, this method will raise an
error.

Args:
directory: checkpoint location.

Returns:
tree containing metadata.
"""
return self._handler_impl.metadata(directory)

def finalize(self, directory: epath.Path) -> None:
"""Finalization step.

Called automatically by the Checkpointer/AsyncCheckpointer just before the
checkpoint is considered "finalized" in the sense of ensuring atomicity. See
documentation for `type_handlers.merge_ocdbt_per_process_files`.

Args:
directory: Path where the checkpoint is located.
"""
self._handler_impl.finalize(directory)

def close(self):
"""Closes the handler. Called automatically by Checkpointer."""
self._handler_impl.close()


@register_with_handler(PyTreeCheckpointHandler, for_save=True)
@dataclasses.dataclass
Expand Down
Loading