Skip to content

Commit

Permalink
DO NOT SUBMIT
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 654697025
  • Loading branch information
Orbax Authors committed Jul 26, 2024
1 parent 8750e13 commit 1bc2272
Show file tree
Hide file tree
Showing 4 changed files with 667 additions and 7 deletions.
74 changes: 67 additions & 7 deletions checkpoint/orbax/checkpoint/composite_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import checkpoint_handler
from orbax.checkpoint import future
from orbax.checkpoint import handler_registration
from orbax.checkpoint import options as options_lib
from orbax.checkpoint import proto_checkpoint_handler
from orbax.checkpoint.path import atomicity
Expand Down Expand Up @@ -272,6 +273,9 @@ def __init__(
self,
*item_names: str,
composite_options: CompositeOptions = CompositeOptions(),
handler_registry: Optional[
handler_registration.CheckpointHandlerRegistry
] = None,
**items_and_handlers: CheckpointHandler,
):
"""Constructor.
Expand All @@ -281,19 +285,55 @@ def __init__(
Args:
*item_names: A list of string item names that this handler will manage.
composite_options: Options.
handler_registry: A `CheckpointHandlerRegistry` instance. If provided, the
`CompositeCheckpointHandler` will use this registry to determine the
`CheckpointHandler` for each item.
**items_and_handlers: A mapping of item name to `CheckpointHandler`
instance, which will be used as the handler for objects of the
corresponding name.
"""
self._known_handlers: Dict[str, Optional[CheckpointHandler]] = (
items_and_handlers
)
if handler_registry is not None and items_and_handlers:
raise ValueError(
'Both `handler_registry` and `items_and_handlers` were provided. '
'Please specify only one of the two.'
)

if handler_registry is not None:
self._handler_registry = handler_registry
self._known_handlers = {
item: handler
for (
item,
unused_args_type,
), handler in handler_registry.get_all_entries().items()
# The mapping is from item to handlers, so only include items that
# have handlers.
if item is not None
}
elif items_and_handlers:
logging.info(
'Prefer using `handler_registry` instead of `items_and_handlers`.'
)
self._handler_registry = None
self._known_handlers = items_and_handlers
else:
# If no handler registry or items_and_handlers are provided, we will
# default to the global registry.
self._handler_registry = None
self._known_handlers = {}

for item in item_names:
_maybe_raise_reserved_item_error(item)
if item not in self._known_handlers:
self._known_handlers[item] = None

for item_name, handler in self._known_handlers.items():
_maybe_raise_reserved_item_error(item_name)
if handler and not checkpoint_args.has_registered_args(handler):
if self._handler_registry is not None:
raise ValueError(
'Handler registry has been provided, but no registered'
f' `CheckpointArgs` found for handler type: {type(handler)}.'
)
logging.warning(
'No registered CheckpointArgs found for handler type: %s',
type(handler),
Expand All @@ -315,6 +355,17 @@ def _get_or_set_handler(
item_name: str,
args: Optional[CheckpointArgs],
) -> CheckpointHandler:
# Use the handler registry if available.
if self._handler_registry is not None:
if args is not None:
if self._known_handlers[item_name] is None:
# PyType assumes that the handler is always `None` for some reason,
# despite the fact that we check if it is not `None` above.
self._known_handlers[item_name] = self._handler_registry.get( # pytype: disable=attribute-error
item_name, type(args)
)
return self._handler_registry.get(item_name, type(args))

if item_name not in self._known_handlers:
raise ValueError(
f'Unknown key "{item_name}". Please make sure that this key was'
Expand All @@ -335,6 +386,14 @@ def _get_or_set_handler(
)
return handler

if self._handler_registry is not None:
# Do not fall back to the global registry if a handler registry was
# provided.
raise ValueError(
'Handler registry has been provided, but no registered'
f' `CheckpointArgs` found for handler type: {type(handler)}.'
)

registered_handler_cls_for_args = (
checkpoint_args.get_registered_handler_cls(args)
)
Expand All @@ -343,9 +402,10 @@ def _get_or_set_handler(
self._known_handlers[item_name] = handler
if not isinstance(handler, registered_handler_cls_for_args):
raise ValueError(
f'For item, "{item_name}", CheckpointHandler {type(handler)} does not'
f' match with registered handler {registered_handler_cls_for_args}'
f' for provided args of type: {type(args)}'
f'For item, "{item_name}", CheckpointHandler {type(handler)} does'
' not match with registered handler'
f' {registered_handler_cls_for_args} for provided args of type:'
f' {type(args)}'
)
return handler

Expand Down
124 changes: 124 additions & 0 deletions checkpoint/orbax/checkpoint/composite_checkpoint_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from jax import numpy as jnp
from orbax.checkpoint import args as args_lib
from orbax.checkpoint import composite_checkpoint_handler
from orbax.checkpoint import handler_registration
from orbax.checkpoint import json_checkpoint_handler
from orbax.checkpoint import multihost
from orbax.checkpoint import proto_checkpoint_handler
Expand Down Expand Up @@ -279,6 +280,40 @@ def test_no_restore_args_partial_save(self):
self.assertDictEqual(restored.state, state)
self.assertIsNone(restored.metadata)

def test_no_restore_args_partial_save_handler_registry(self):
handler_registry = handler_registration.DefaultCheckpointHandlerRegistry()
handler_registry.add(
'metadata',
standard_checkpoint_handler.StandardSaveArgs,
JsonCheckpointHandler(),
)
handler = CompositeCheckpointHandler(
'state', handler_registry=handler_registry
)

state = {'a': 1, 'b': 2}
dummy_state = {'a': 0, 'b': 0}
self.save(
handler,
self.directory,
CompositeArgs(
state=args_lib.StandardSave(state),
),
)
self.assertTrue((self.directory / 'state').exists())
self.assertFalse((self.directory / 'metadata').exists())

restored = handler.restore(self.directory)
self.assertDictEqual(restored.state, state)
self.assertIsNone(restored.metadata)

restored = handler.restore(
self.directory,
CompositeArgs(),
)
self.assertDictEqual(restored.state, state)
self.assertIsNone(restored.metadata)

def test_no_restore_args_handler_unspecified(self):
handler = CompositeCheckpointHandler('state', 'metadata')
state = {'a': 1, 'b': 2}
Expand All @@ -304,6 +339,36 @@ def test_no_restore_args_handler_unspecified(self):
CompositeArgs(),
)

def test_no_restore_args_handler_unspecified_handler_registry(self):
handler_registry = handler_registration.DefaultCheckpointHandlerRegistry()
handler = CompositeCheckpointHandler(
'state',
'metadata',
handler_registry=handler_registry,
)
state = {'a': 1, 'b': 2}
dummy_state = {'a': 0, 'b': 0}
metadata = {'lang': 'en', 'version': 1.0}
self.save(
handler,
self.directory,
CompositeArgs(
state=args_lib.StandardSave(state),
metadata=args_lib.JsonSave(metadata),
),
)
self.assertTrue((self.directory / 'state').exists())
self.assertTrue((self.directory / 'metadata').exists())

handler = CompositeCheckpointHandler('state', 'metadata')
with self.assertRaises(ValueError):
handler.restore(self.directory)
with self.assertRaises(ValueError):
handler.restore(
self.directory,
CompositeArgs(),
)

def test_metadata(self):
handler = CompositeCheckpointHandler(
'extra',
Expand Down Expand Up @@ -338,6 +403,51 @@ def test_metadata(self):
self.assertIsNone(metadata.metadata)
self.assertNotIn('extra', metadata.items())

def test_metadata_handler_registry(self):
handler_registry = handler_registration.DefaultCheckpointHandlerRegistry()
handler_registry.add(
'state',
standard_checkpoint_handler.StandardSaveArgs,
StandardCheckpointHandler(),
)
handler_registry.add(
'metadata',
standard_checkpoint_handler.StandardSaveArgs,
JsonCheckpointHandler(),
)
handler = CompositeCheckpointHandler(
'extra',
handler_registry=handler_registry,
)

metadata = handler.metadata(self.directory)
self.assertIsNone(metadata.state)
self.assertIsNone(metadata.metadata)
self.assertNotIn('extra', metadata.items())

state = {'a': 1, 'b': 2}
self.save(
handler,
self.directory,
CompositeArgs(
state=args_lib.StandardSave(state),
),
)
metadata = handler.metadata(self.directory)
self.assertDictEqual(
metadata.state,
{
'a': value_metadata.ScalarMetadata(
name='a', directory=self.directory / 'state', dtype=jnp.int64
),
'b': value_metadata.ScalarMetadata(
name='b', directory=self.directory / 'state', dtype=jnp.int64
),
},
)
self.assertIsNone(metadata.metadata)
self.assertNotIn('extra', metadata.items())

def test_finalize(self):
state_handler = mock.create_autospec(StandardCheckpointHandler)
metadata_handler = mock.create_autospec(JsonCheckpointHandler)
Expand Down Expand Up @@ -452,6 +562,20 @@ def test_items_exist_temp(self):
tmp_dirs['metadata'].get().as_posix(),
)

def test_handler_registry_and_items_and_handlers_raises_error(self):

with self.assertRaisesRegex(ValueError, 'Both'):
handler_registry = handler_registration.DefaultCheckpointHandlerRegistry()
CompositeCheckpointHandler(
'state',
'metadata',
handler_registry=handler_registry,
items_and_handlers={
'state': StandardCheckpointHandler(),
'metadata': JsonCheckpointHandler(),
},
)


if __name__ == '__main__':
absltest.main()
Loading

0 comments on commit 1bc2272

Please sign in to comment.