Skip to content

Commit

Permalink
DO NOT SUBMIT
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 654697025
  • Loading branch information
Orbax Authors committed Aug 5, 2024
1 parent f84f454 commit 07bc9db
Show file tree
Hide file tree
Showing 4 changed files with 662 additions and 7 deletions.
72 changes: 65 additions & 7 deletions checkpoint/orbax/checkpoint/composite_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import checkpoint_handler
from orbax.checkpoint import future
from orbax.checkpoint import handler_registration
from orbax.checkpoint import options as options_lib
from orbax.checkpoint import proto_checkpoint_handler
from orbax.checkpoint.path import atomicity
Expand Down Expand Up @@ -273,6 +274,9 @@ def __init__(
self,
*item_names: str,
composite_options: CompositeOptions = CompositeOptions(),
handler_registry: Optional[
handler_registration.CheckpointHandlerRegistry
] = None,
**items_and_handlers: CheckpointHandler,
):
"""Constructor.
Expand All @@ -282,19 +286,55 @@ def __init__(
Args:
*item_names: A list of string item names that this handler will manage.
composite_options: Options.
handler_registry: A `CheckpointHandlerRegistry` instance. If provided, the
`CompositeCheckpointHandler` will use this registry to determine the
`CheckpointHandler` for each item.
**items_and_handlers: A mapping of item name to `CheckpointHandler`
instance, which will be used as the handler for objects of the
corresponding name.
"""
self._known_handlers: Dict[str, Optional[CheckpointHandler]] = (
items_and_handlers
)
if handler_registry is not None and items_and_handlers:
raise ValueError(
'Both `handler_registry` and `items_and_handlers` were provided. '
'Please specify only one of the two.'
)

if handler_registry is not None:
self._handler_registry = handler_registry
self._known_handlers = {
item: handler
for (
item,
unused_args_type,
), handler in handler_registry.get_all_entries().items()
# The mapping is from item to handlers, so only include items that
# have handlers.
if item is not None
}
elif items_and_handlers:
logging.info(
'Prefer using `handler_registry` instead of `items_and_handlers`.'
)
self._handler_registry = None
self._known_handlers = items_and_handlers
else:
# If no handler registry or items_and_handlers are provided, we will
# default to the global registry.
self._handler_registry = None
self._known_handlers = {}

for item in item_names:
_maybe_raise_reserved_item_error(item)
if item not in self._known_handlers:
self._known_handlers[item] = None

for item_name, handler in self._known_handlers.items():
_maybe_raise_reserved_item_error(item_name)
if handler and not checkpoint_args.has_registered_args(handler):
if self._handler_registry is not None:
raise ValueError(
'Handler registry has been provided, but no registered'
f' `CheckpointArgs` found for handler type: {type(handler)}.'
)
logging.warning(
'No registered CheckpointArgs found for handler type: %s',
type(handler),
Expand All @@ -316,6 +356,23 @@ def _get_or_set_handler(
item_name: str,
args: Optional[CheckpointArgs],
) -> CheckpointHandler:
if self._handler_registry is not None:
if args is not None:
try:
# PyType assumes that the handler is always `None` for some reason,
# despite the fact that we check if it is not `None` above.
handler = self._handler_registry.get(item_name, type(args)) # pytype: disable=attribute-error
if self._known_handlers[item_name] is None:
self._known_handlers[item_name] = handler
return handler
except handler_registration.NoEntryError:
logging.info(
'No entry found in handler registry for item: %s and args: %s.'
' Falling back to global handler registry',
item_name,
args,
)

if item_name not in self._known_handlers:
raise ValueError(
f'Unknown key "{item_name}". Please make sure that this key was'
Expand Down Expand Up @@ -344,9 +401,10 @@ def _get_or_set_handler(
self._known_handlers[item_name] = handler
if not isinstance(handler, registered_handler_cls_for_args):
raise ValueError(
f'For item, "{item_name}", CheckpointHandler {type(handler)} does not'
f' match with registered handler {registered_handler_cls_for_args}'
f' for provided args of type: {type(args)}'
f'For item, "{item_name}", CheckpointHandler {type(handler)} does'
' not match with registered handler'
f' {registered_handler_cls_for_args} for provided args of type:'
f' {type(args)}'
)
return handler

Expand Down
184 changes: 184 additions & 0 deletions checkpoint/orbax/checkpoint/composite_checkpoint_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from jax import numpy as jnp
from orbax.checkpoint import args as args_lib
from orbax.checkpoint import composite_checkpoint_handler
from orbax.checkpoint import handler_registration
from orbax.checkpoint import json_checkpoint_handler
from orbax.checkpoint import multihost
from orbax.checkpoint import proto_checkpoint_handler
Expand All @@ -40,6 +41,14 @@
CompositeOptions = composite_checkpoint_handler.CompositeOptions


class _TestSaveArgs(standard_checkpoint_handler.StandardSaveArgs):
...


class _TestRestoreArgs(standard_checkpoint_handler.StandardRestoreArgs):
...


class CompositeArgsTest(absltest.TestCase):

def test_args(self):
Expand Down Expand Up @@ -153,6 +162,58 @@ def test_save_restore(self):
self.assertDictEqual(restored.state, state)
self.assertDictEqual(restored.metadata, metadata)

def test_save_restore_handler_registry(self):
handler_registry = handler_registration.DefaultCheckpointHandlerRegistry()
handler_registry.add(
'state',
_TestSaveArgs,
StandardCheckpointHandler(),
)
handler_registry.add(
'state',
_TestRestoreArgs,
StandardCheckpointHandler(),
)
handler_registry.add(
'metadata',
json_checkpoint_handler.JsonSaveArgs,
JsonCheckpointHandler(),
)
handler_registry.add(
'metadata',
json_checkpoint_handler.JsonRestoreArgs,
JsonCheckpointHandler(),
)

handler = CompositeCheckpointHandler(
'state',
'metadata',
handler_registry=handler_registry,
)

state = {'a': 1, 'b': 2}
dummy_state = {'a': 0, 'b': 0}
metadata = {'lang': 'en', 'version': 1.0}
self.save(
handler,
self.directory,
CompositeArgs(
state=_TestSaveArgs(state),
metadata=args_lib.JsonSave(metadata),
),
)
self.assertTrue((self.directory / 'state').exists())
self.assertTrue((self.directory / 'metadata').exists())
restored = handler.restore(
self.directory,
CompositeArgs(
state=_TestRestoreArgs(dummy_state),
metadata=args_lib.JsonRestore(),
),
)
self.assertDictEqual(restored.state, state)
self.assertDictEqual(restored.metadata, metadata)

def test_save_restore_partial(self):
handler = CompositeCheckpointHandler('state', 'opt_state', 'metadata')
state = {'a': 1, 'b': 2}
Expand Down Expand Up @@ -279,6 +340,40 @@ def test_no_restore_args_partial_save(self):
self.assertDictEqual(restored.state, state)
self.assertIsNone(restored.metadata)

def test_no_restore_args_partial_save_handler_registry(self):
handler_registry = handler_registration.DefaultCheckpointHandlerRegistry()
handler_registry.add(
'metadata',
standard_checkpoint_handler.StandardSaveArgs,
JsonCheckpointHandler(),
)
handler = CompositeCheckpointHandler(
'state', handler_registry=handler_registry
)

state = {'a': 1, 'b': 2}
dummy_state = {'a': 0, 'b': 0}
self.save(
handler,
self.directory,
CompositeArgs(
state=args_lib.StandardSave(state),
),
)
self.assertTrue((self.directory / 'state').exists())
self.assertFalse((self.directory / 'metadata').exists())

restored = handler.restore(self.directory)
self.assertDictEqual(restored.state, state)
self.assertIsNone(restored.metadata)

restored = handler.restore(
self.directory,
CompositeArgs(),
)
self.assertDictEqual(restored.state, state)
self.assertIsNone(restored.metadata)

def test_no_restore_args_handler_unspecified(self):
handler = CompositeCheckpointHandler('state', 'metadata')
state = {'a': 1, 'b': 2}
Expand All @@ -304,6 +399,36 @@ def test_no_restore_args_handler_unspecified(self):
CompositeArgs(),
)

def test_no_restore_args_handler_unspecified_handler_registry(self):
handler_registry = handler_registration.DefaultCheckpointHandlerRegistry()
handler = CompositeCheckpointHandler(
'state',
'metadata',
handler_registry=handler_registry,
)
state = {'a': 1, 'b': 2}
dummy_state = {'a': 0, 'b': 0}
metadata = {'lang': 'en', 'version': 1.0}
self.save(
handler,
self.directory,
CompositeArgs(
state=args_lib.StandardSave(state),
metadata=args_lib.JsonSave(metadata),
),
)
self.assertTrue((self.directory / 'state').exists())
self.assertTrue((self.directory / 'metadata').exists())

handler = CompositeCheckpointHandler('state', 'metadata')
with self.assertRaises(ValueError):
handler.restore(self.directory)
with self.assertRaises(ValueError):
handler.restore(
self.directory,
CompositeArgs(),
)

def test_metadata(self):
handler = CompositeCheckpointHandler(
'extra',
Expand Down Expand Up @@ -338,6 +463,51 @@ def test_metadata(self):
self.assertIsNone(metadata.metadata)
self.assertNotIn('extra', metadata.items())

def test_metadata_handler_registry(self):
handler_registry = handler_registration.DefaultCheckpointHandlerRegistry()
handler_registry.add(
'state',
standard_checkpoint_handler.StandardSaveArgs,
StandardCheckpointHandler(),
)
handler_registry.add(
'metadata',
standard_checkpoint_handler.StandardSaveArgs,
JsonCheckpointHandler(),
)
handler = CompositeCheckpointHandler(
'extra',
handler_registry=handler_registry,
)

metadata = handler.metadata(self.directory)
self.assertIsNone(metadata.state)
self.assertIsNone(metadata.metadata)
self.assertNotIn('extra', metadata.items())

state = {'a': 1, 'b': 2}
self.save(
handler,
self.directory,
CompositeArgs(
state=args_lib.StandardSave(state),
),
)
metadata = handler.metadata(self.directory)
self.assertDictEqual(
metadata.state,
{
'a': value_metadata.ScalarMetadata(
name='a', directory=self.directory / 'state', dtype=jnp.int64
),
'b': value_metadata.ScalarMetadata(
name='b', directory=self.directory / 'state', dtype=jnp.int64
),
},
)
self.assertIsNone(metadata.metadata)
self.assertNotIn('extra', metadata.items())

def test_finalize(self):
state_handler = mock.create_autospec(StandardCheckpointHandler)
metadata_handler = mock.create_autospec(JsonCheckpointHandler)
Expand Down Expand Up @@ -452,6 +622,20 @@ def test_items_exist_temp(self):
tmp_dirs['metadata'].get().as_posix(),
)

def test_handler_registry_and_items_and_handlers_raises_error(self):

with self.assertRaisesRegex(ValueError, 'Both'):
handler_registry = handler_registration.DefaultCheckpointHandlerRegistry()
CompositeCheckpointHandler(
'state',
'metadata',
handler_registry=handler_registry,
items_and_handlers={
'state': StandardCheckpointHandler(),
'metadata': JsonCheckpointHandler(),
},
)


if __name__ == '__main__':
absltest.main()
Loading

0 comments on commit 07bc9db

Please sign in to comment.