From ac2b3bbc181ce59ff42b96a716714481e9308bfd Mon Sep 17 00:00:00 2001 From: Orbax Authors Date: Tue, 23 Jul 2024 07:16:39 -0700 Subject: [PATCH] DO NOT SUBMIT PiperOrigin-RevId: 655153651 --- checkpoint/orbax/__init__.py | 2 + checkpoint/orbax/checkpoint/__init__.py | 2 + .../orbax/checkpoint/checkpoint_manager.py | 81 +++++- .../composite_checkpoint_handler.py | 74 ++++- .../composite_checkpoint_handler_test.py | 124 ++++++++ .../orbax/checkpoint/handler_registration.py | 203 +++++++++++++ .../checkpoint/handler_registration_test.py | 273 ++++++++++++++++++ 7 files changed, 743 insertions(+), 16 deletions(-) create mode 100644 checkpoint/orbax/checkpoint/handler_registration.py create mode 100644 checkpoint/orbax/checkpoint/handler_registration_test.py diff --git a/checkpoint/orbax/__init__.py b/checkpoint/orbax/__init__.py index 99afddddc..a2b7b081e 100644 --- a/checkpoint/orbax/__init__.py +++ b/checkpoint/orbax/__init__.py @@ -47,6 +47,8 @@ from orbax.checkpoint.checkpointer import Checkpointer from orbax.checkpoint.composite_checkpoint_handler import CompositeCheckpointHandler from orbax.checkpoint.future import Future +from orbax.checkpoint.handler_registration import CheckpointHandlerRegistry +from orbax.checkpoint.handler_registration import DefaultCheckpointHandlerRegistry from orbax.checkpoint.json_checkpoint_handler import JsonCheckpointHandler from orbax.checkpoint.proto_checkpoint_handler import ProtoCheckpointHandler from orbax.checkpoint.pytree_checkpoint_handler import ArrayRestoreArgs diff --git a/checkpoint/orbax/checkpoint/__init__.py b/checkpoint/orbax/checkpoint/__init__.py index 99afddddc..a2b7b081e 100644 --- a/checkpoint/orbax/checkpoint/__init__.py +++ b/checkpoint/orbax/checkpoint/__init__.py @@ -47,6 +47,8 @@ from orbax.checkpoint.checkpointer import Checkpointer from orbax.checkpoint.composite_checkpoint_handler import CompositeCheckpointHandler from orbax.checkpoint.future import Future +from orbax.checkpoint.handler_registration import CheckpointHandlerRegistry +from orbax.checkpoint.handler_registration import DefaultCheckpointHandlerRegistry from orbax.checkpoint.json_checkpoint_handler import JsonCheckpointHandler from orbax.checkpoint.proto_checkpoint_handler import ProtoCheckpointHandler from orbax.checkpoint.pytree_checkpoint_handler import ArrayRestoreArgs diff --git a/checkpoint/orbax/checkpoint/checkpoint_manager.py b/checkpoint/orbax/checkpoint/checkpoint_manager.py index db3e3109e..1417f19e1 100644 --- a/checkpoint/orbax/checkpoint/checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/checkpoint_manager.py @@ -35,6 +35,7 @@ from orbax.checkpoint import checkpoint_handler from orbax.checkpoint import checkpointer as checkpointer_lib from orbax.checkpoint import composite_checkpoint_handler +from orbax.checkpoint import handler_registration from orbax.checkpoint import json_checkpoint_handler from orbax.checkpoint import multihost from orbax.checkpoint import options as options_lib @@ -69,6 +70,7 @@ CheckpointHandler = checkpoint_handler.CheckpointHandler CheckpointArgs = checkpoint_args.CheckpointArgs CheckpointHandlersDict = Mapping[str, CheckpointHandler] +CheckpointHandlerRegistry = handler_registration.CheckpointHandlerRegistry AsyncOptions = options_lib.AsyncOptions MultiprocessingOptions = options_lib.MultiprocessingOptions @@ -393,6 +395,7 @@ def __init__( Union[CheckpointHandler, CheckpointHandlersDict] ] = None, logger: Optional[abstract_logger.AbstractLogger] = None, + handler_registry: Optional[CheckpointHandlerRegistry] = None, ): """CheckpointManager constructor. @@ -488,7 +491,8 @@ def __init__( metadata. item_names: Names of distinct items that may be saved/restored with this `CheckpointManager`. `item_names` and `checkpointers` are mutually - exclusive - do not use together. Also see `item_handlers` below. + exclusive - do not use together. Also see `item_handlers` below. Prefer + using the `handler_registry`. item_handlers: A mapping of item name to `CheckpointHandler`. The mapped CheckpointHandler must be registered against the `CheckpointArgs` input in save/restore operations. Please don't use `checkpointers` and @@ -496,7 +500,9 @@ def __init__( The item name key may or may not be present in `item_names`. Alternatively, a single CheckpointHandler may be provided, in which case `save` and `restore` should always be called in a single item context. + Prefer using the `handler_registry`. logger: A logger to log checkpointing events. + handler_registry: A registry of handlers to use for checkpointing. """ jax.monitoring.record_event('/jax/orbax/checkpoint_manager/init') @@ -508,6 +514,7 @@ def __init__( self._options.multiprocessing_options or MultiprocessingOptions() ) self._logger = logger or standard_logger.StandardLogger() + self._handler_registry = handler_registry if checkpointers and item_names: raise ValueError( @@ -524,6 +531,20 @@ def __init__( '`item_handlers` in single item mode and `item_names` should not be' ' provided together.' ) + if checkpointers is not None and handler_registry is not None: + raise ValueError( + 'Deprecated `checkpointers` can not be used with `handler_registry`.' + ' Please follow the instructions at' + ' https://orbax.readthedocs.io/en/latest/api_refactor.html to' + ' migrate by August 1st, 2024.' + ) + + if item_handlers is not None and handler_registry is not None: + raise ValueError( + '`item_handlers` and `handler_registry` are mutually exclusive -' + ' prefer configuring the handler registry.' + ) + # For async_checkpointer. self._non_blocking_checkpoint_metadata_store = ( checkpoint.checkpoint_metadata_store(enable_write=True) @@ -534,7 +555,8 @@ def __init__( enable_write=True, blocking_write=True ) ) - if checkpointers: + + if checkpointers is not None: logging.warning( 'Configured `CheckpointManager` using deprecated legacy API. Please' ' follow the instructions at' @@ -545,12 +567,28 @@ def __init__( self._checkpointer = self._configure_checkpointer_legacy_init( checkpointers, self._options ) + elif self._handler_registry is not None: + self._single_item = item_names is None + self._checkpointer = self._configure_checkpointer_from_handler_registry( + item_names, + handler_registry, + self._options, + ) else: + logging.warning( + 'Prefer using the `handler_registry` instead of `item_names` and' + ' `item_handlers`.' + ) self._single_item = isinstance(item_handlers, CheckpointHandler) or ( item_names is None and item_handlers is None ) - self._checkpointer = self._configure_checkpointer( - item_names, item_handlers, self._options, self._single_item + self._checkpointer = ( + self._configure_checkpointer_from_item_names_and_handlers( + item_names, + item_handlers, + self._options, + self._single_item, + ) ) self._directory = epath.Path(directory) @@ -745,7 +783,7 @@ def _validate_handler(self, handler): f'handler[{type(handler)}]={handler._primary_host} ' # pylint: disable=protected-access ) - def _configure_checkpointer( + def _configure_checkpointer_from_item_names_and_handlers( self, item_names: Optional[Sequence[str]], item_handlers: Optional[Union[CheckpointHandler, CheckpointHandlersDict]], @@ -767,8 +805,6 @@ def _configure_checkpointer( if isinstance(item_handlers, CheckpointHandler) else None ) - if item_handler: - self._validate_handler(item_handler) all_item_handlers = {DEFAULT_ITEM_NAME: item_handler} else: # Initialize all_item_handlers with None or empty. @@ -779,10 +815,11 @@ def _configure_checkpointer( # Update all_item_handlers with provided CheckpointHandlers. if item_handlers and isinstance(item_handlers, Mapping): for item_name, handler in item_handlers.items(): - self._validate_handler(handler) all_item_handlers[item_name] = handler - for item_name in all_item_handlers: + for item_name, handler in all_item_handlers.items(): + if handler is not None: + self._validate_handler(handler) if item_name in RESERVED_ITEM_NAMES: raise ValueError( f'Found {item_name} in `checkpointers`; this is a reserved key.' @@ -808,6 +845,32 @@ def _configure_checkpointer( options.enable_async_checkpointing, ) + def _configure_checkpointer_from_handler_registry( + self, + item_names: Optional[Sequence[str]], + handler_registry: CheckpointHandlerRegistry, + options: CheckpointManagerOptions, + ) -> Checkpointer: + """Initializes _CompositeCheckpointer given a `handler_registry`.""" + + # If `item_names`` is None, we will default to a single item. + items = [DEFAULT_ITEM_NAME] if item_names is None else item_names + # CompositeCheckpointHandler defers per-item handler creation until + # save/restore time. + return self._configure_checkpointer_common( + CompositeCheckpointHandler( + *items, + composite_options=composite_checkpoint_handler.CompositeOptions( + primary_host=self._multiprocessing_options.primary_host, + active_processes=self._multiprocessing_options.active_processes, + barrier_sync_key_prefix=self._multiprocessing_options.barrier_sync_key_prefix, + ), + handler_registry=handler_registry, + ), + options, + options.enable_async_checkpointing, + ) + @property def directory(self) -> epath.Path: """See superclass documentation.""" diff --git a/checkpoint/orbax/checkpoint/composite_checkpoint_handler.py b/checkpoint/orbax/checkpoint/composite_checkpoint_handler.py index 73d01ca34..e94231a04 100644 --- a/checkpoint/orbax/checkpoint/composite_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/composite_checkpoint_handler.py @@ -61,6 +61,7 @@ from orbax.checkpoint import checkpoint_args from orbax.checkpoint import checkpoint_handler from orbax.checkpoint import future +from orbax.checkpoint import handler_registration from orbax.checkpoint import options as options_lib from orbax.checkpoint import proto_checkpoint_handler from orbax.checkpoint.path import atomicity @@ -272,6 +273,9 @@ def __init__( self, *item_names: str, composite_options: CompositeOptions = CompositeOptions(), + handler_registry: Optional[ + handler_registration.CheckpointHandlerRegistry + ] = None, **items_and_handlers: CheckpointHandler, ): """Constructor. @@ -281,19 +285,55 @@ def __init__( Args: *item_names: A list of string item names that this handler will manage. composite_options: Options. + handler_registry: A `CheckpointHandlerRegistry` instance. If provided, the + `CompositeCheckpointHandler` will use this registry to determine the + `CheckpointHandler` for each item. **items_and_handlers: A mapping of item name to `CheckpointHandler` instance, which will be used as the handler for objects of the corresponding name. """ - self._known_handlers: Dict[str, Optional[CheckpointHandler]] = ( - items_and_handlers - ) + if handler_registry is not None and items_and_handlers: + raise ValueError( + 'Both `handler_registry` and `items_and_handlers` were provided. ' + 'Please specify only one of the two.' + ) + + if handler_registry is not None: + self._handler_registry = handler_registry + self._known_handlers = { + item: handler + for ( + item, + unused_args_type, + ), handler in handler_registry.get_all_entries().items() + # The mapping is from item to handlers, so only include items that + # have handlers. + if item is not None + } + elif items_and_handlers: + logging.info( + 'Prefer using `handler_registry` instead of `items_and_handlers`.' + ) + self._handler_registry = None + self._known_handlers = items_and_handlers + else: + # If no handler registry or items_and_handlers are provided, we will + # default to the global registry. + self._handler_registry = None + self._known_handlers = {} + for item in item_names: + _maybe_raise_reserved_item_error(item) if item not in self._known_handlers: self._known_handlers[item] = None + for item_name, handler in self._known_handlers.items(): - _maybe_raise_reserved_item_error(item_name) if handler and not checkpoint_args.has_registered_args(handler): + if self._handler_registry is not None: + raise ValueError( + 'Handler registry has been provided, but no registered' + f' `CheckpointArgs` found for handler type: {type(handler)}.' + ) logging.warning( 'No registered CheckpointArgs found for handler type: %s', type(handler), @@ -315,6 +355,17 @@ def _get_or_set_handler( item_name: str, args: Optional[CheckpointArgs], ) -> CheckpointHandler: + # Use the handler registry if available. + if self._handler_registry is not None: + if args is not None: + if self._known_handlers[item_name] is None: + # PyType assumes that the handler is always `None` for some reason, + # despite the fact that we check if it is not `None` above. + self._known_handlers[item_name] = self._handler_registry.get( # pytype: disable=attribute-error + item_name, type(args) + ) + return self._handler_registry.get(item_name, type(args)) + if item_name not in self._known_handlers: raise ValueError( f'Unknown key "{item_name}". Please make sure that this key was' @@ -335,6 +386,14 @@ def _get_or_set_handler( ) return handler + if self._handler_registry is not None: + # Do not fall back to the global registry if a handler registry was + # provided. + raise ValueError( + 'Handler registry has been provided, but no registered' + f' `CheckpointArgs` found for handler type: {type(handler)}.' + ) + registered_handler_cls_for_args = ( checkpoint_args.get_registered_handler_cls(args) ) @@ -343,9 +402,10 @@ def _get_or_set_handler( self._known_handlers[item_name] = handler if not isinstance(handler, registered_handler_cls_for_args): raise ValueError( - f'For item, "{item_name}", CheckpointHandler {type(handler)} does not' - f' match with registered handler {registered_handler_cls_for_args}' - f' for provided args of type: {type(args)}' + f'For item, "{item_name}", CheckpointHandler {type(handler)} does' + ' not match with registered handler' + f' {registered_handler_cls_for_args} for provided args of type:' + f' {type(args)}' ) return handler diff --git a/checkpoint/orbax/checkpoint/composite_checkpoint_handler_test.py b/checkpoint/orbax/checkpoint/composite_checkpoint_handler_test.py index 467c5e7af..bcbe0ac86 100644 --- a/checkpoint/orbax/checkpoint/composite_checkpoint_handler_test.py +++ b/checkpoint/orbax/checkpoint/composite_checkpoint_handler_test.py @@ -20,6 +20,7 @@ from jax import numpy as jnp from orbax.checkpoint import args as args_lib from orbax.checkpoint import composite_checkpoint_handler +from orbax.checkpoint import handler_registration from orbax.checkpoint import json_checkpoint_handler from orbax.checkpoint import multihost from orbax.checkpoint import proto_checkpoint_handler @@ -279,6 +280,40 @@ def test_no_restore_args_partial_save(self): self.assertDictEqual(restored.state, state) self.assertIsNone(restored.metadata) + def test_no_restore_args_partial_save_handler_registry(self): + handler_registry = handler_registration.DefaultCheckpointHandlerRegistry() + handler_registry.add( + 'metadata', + standard_checkpoint_handler.StandardSaveArgs, + JsonCheckpointHandler(), + ) + handler = CompositeCheckpointHandler( + 'state', handler_registry=handler_registry + ) + + state = {'a': 1, 'b': 2} + dummy_state = {'a': 0, 'b': 0} + self.save( + handler, + self.directory, + CompositeArgs( + state=args_lib.StandardSave(state), + ), + ) + self.assertTrue((self.directory / 'state').exists()) + self.assertFalse((self.directory / 'metadata').exists()) + + restored = handler.restore(self.directory) + self.assertDictEqual(restored.state, state) + self.assertIsNone(restored.metadata) + + restored = handler.restore( + self.directory, + CompositeArgs(), + ) + self.assertDictEqual(restored.state, state) + self.assertIsNone(restored.metadata) + def test_no_restore_args_handler_unspecified(self): handler = CompositeCheckpointHandler('state', 'metadata') state = {'a': 1, 'b': 2} @@ -304,6 +339,36 @@ def test_no_restore_args_handler_unspecified(self): CompositeArgs(), ) + def test_no_restore_args_handler_unspecified_handler_registry(self): + handler_registry = handler_registration.DefaultCheckpointHandlerRegistry() + handler = CompositeCheckpointHandler( + 'state', + 'metadata', + handler_registry=handler_registry, + ) + state = {'a': 1, 'b': 2} + dummy_state = {'a': 0, 'b': 0} + metadata = {'lang': 'en', 'version': 1.0} + self.save( + handler, + self.directory, + CompositeArgs( + state=args_lib.StandardSave(state), + metadata=args_lib.JsonSave(metadata), + ), + ) + self.assertTrue((self.directory / 'state').exists()) + self.assertTrue((self.directory / 'metadata').exists()) + + handler = CompositeCheckpointHandler('state', 'metadata') + with self.assertRaises(ValueError): + handler.restore(self.directory) + with self.assertRaises(ValueError): + handler.restore( + self.directory, + CompositeArgs(), + ) + def test_metadata(self): handler = CompositeCheckpointHandler( 'extra', @@ -338,6 +403,51 @@ def test_metadata(self): self.assertIsNone(metadata.metadata) self.assertNotIn('extra', metadata.items()) + def test_metadata_handler_registry(self): + handler_registry = handler_registration.DefaultCheckpointHandlerRegistry() + handler_registry.add( + 'state', + standard_checkpoint_handler.StandardSaveArgs, + StandardCheckpointHandler(), + ) + handler_registry.add( + 'metadata', + standard_checkpoint_handler.StandardSaveArgs, + JsonCheckpointHandler(), + ) + handler = CompositeCheckpointHandler( + 'extra', + handler_registry=handler_registry, + ) + + metadata = handler.metadata(self.directory) + self.assertIsNone(metadata.state) + self.assertIsNone(metadata.metadata) + self.assertNotIn('extra', metadata.items()) + + state = {'a': 1, 'b': 2} + self.save( + handler, + self.directory, + CompositeArgs( + state=args_lib.StandardSave(state), + ), + ) + metadata = handler.metadata(self.directory) + self.assertDictEqual( + metadata.state, + { + 'a': value_metadata.ScalarMetadata( + name='a', directory=self.directory / 'state', dtype=jnp.int64 + ), + 'b': value_metadata.ScalarMetadata( + name='b', directory=self.directory / 'state', dtype=jnp.int64 + ), + }, + ) + self.assertIsNone(metadata.metadata) + self.assertNotIn('extra', metadata.items()) + def test_finalize(self): state_handler = mock.create_autospec(StandardCheckpointHandler) metadata_handler = mock.create_autospec(JsonCheckpointHandler) @@ -452,6 +562,20 @@ def test_items_exist_temp(self): tmp_dirs['metadata'].get().as_posix(), ) + def test_handler_registry_and_items_and_handlers_raises_error(self): + + with self.assertRaisesRegex(ValueError, 'Both'): + handler_registry = handler_registration.DefaultCheckpointHandlerRegistry() + CompositeCheckpointHandler( + 'state', + 'metadata', + handler_registry=handler_registry, + items_and_handlers={ + 'state': StandardCheckpointHandler(), + 'metadata': JsonCheckpointHandler(), + }, + ) + if __name__ == '__main__': absltest.main() diff --git a/checkpoint/orbax/checkpoint/handler_registration.py b/checkpoint/orbax/checkpoint/handler_registration.py new file mode 100644 index 000000000..1067d85ad --- /dev/null +++ b/checkpoint/orbax/checkpoint/handler_registration.py @@ -0,0 +1,203 @@ +# Copyright 2024 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Global registry for `CheckpointHandler`s.""" + +import logging +from typing import MutableMapping, Optional, Protocol, Type, Union + +from orbax.checkpoint import checkpoint_args +from orbax.checkpoint import checkpoint_handler + +CheckpointArgs = checkpoint_args.CheckpointArgs +CheckpointHandler = checkpoint_handler.CheckpointHandler +HandlerRegistryMapping = MutableMapping[ + tuple[Optional[str], type[CheckpointArgs]], CheckpointHandler +] + +_SAVE_ARG_TO_HANDLER = checkpoint_args._SAVE_ARG_TO_HANDLER # pylint: disable=protected-access +_RESTORE_ARG_TO_HANDLER = checkpoint_args._RESTORE_ARG_TO_HANDLER # pylint: disable=protected-access + + +class CheckpointHandlerRegistry(Protocol): + """Protocol for `CheckpointHandlerRegistry`.""" + + def add( + self, + item: Optional[str], + args_type: Type[CheckpointArgs], + handler: Union[CheckpointHandler, Type[CheckpointHandler]], + **kwargs, + ): + ... + + def get( + self, + item: Optional[str], + args_type: Type[CheckpointArgs], + ) -> CheckpointHandler: + ... + + def has( + self, + item: Optional[str], + args_type: Type[CheckpointArgs], + ) -> bool: + ... + + def get_all_entries( + self, + ) -> HandlerRegistryMapping: + ... + + +class AlreadyExistsError(ValueError): + """Raised when an entry already exists in the registry.""" + + +class NoEntryError(KeyError): + """Raised when no entry exists in the registry.""" + + +class DefaultCheckpointHandlerRegistry(CheckpointHandlerRegistry): + """Default implementation of `CheckpointHandlerRegistry`. + + Inherits from globally registered `CheckpointHandler`s on construction. + """ + + def __init__( + self, other_registry: Optional[CheckpointHandlerRegistry] = None + ): + self._registry: HandlerRegistryMapping = {} + + # Inherit from globally registered handlers. + for args_type, handler_class in _SAVE_ARG_TO_HANDLER.items(): + self._add_entry(None, args_type, handler_class) + for args_type, handler_class in _RESTORE_ARG_TO_HANDLER.items(): + if not self.has(None, args_type): + self._add_entry(None, args_type, handler_class) + + # Initialize the registry with entries from other registry. + if other_registry: + for ( + item, + args_type, + ), handler in other_registry.get_all_entries().items(): + if self.has(item, args_type): + if isinstance(handler, type(self.get(item, args_type))): + continue + else: + raise AlreadyExistsError( + f'Entry for item={item} and args_type={args_type} already' + ' exists in the registry, but with a different handler from' + ' the other registry.' + ) + + self._add_entry(item, args_type, handler) + + def _add_entry( + self, + item: Optional[str], + args_type: type[CheckpointArgs], + handler: Union[CheckpointHandler, Type[CheckpointHandler]], + ): + """Adds an entry to the registry. + + Args: + item: The item name. If None, the entry will be added as a general + `args_type` entry. + args_type: The args type. + handler: The handler. If a type is provided, an instance of the type will + be added to the registry. + + Raises: + AlreadyExistsError: If an entry for the given item and args type already + exists in the registry. + """ + if self.has(item, args_type): + raise AlreadyExistsError( + f'Entry for item={item} and args_type={args_type} already' + ' exists in the registry.' + ) + else: + try: + handler_instance = handler() if isinstance(handler, type) else handler + except TypeError: + logging.warning( + 'Failed to instantiate handler for item=%s and args_type=%s.', + item, + args_type, + ) + return + self._registry[(item, args_type)] = handler_instance + + def add( + self, + item: Optional[str], + args_type: Type[CheckpointArgs], + handler: Union[CheckpointHandler, Type[CheckpointHandler]], + ): + self._add_entry(item, args_type, handler) + + def get( + self, + item: Optional[str], + args_type: Type[CheckpointArgs], + ) -> CheckpointHandler: + """Returns the handler for the given item and args type. + + Args: + item: The item name. If None, the entry will be added as a general + `args_type` entry. + args_type: The args type. + + If item the item has not been registered, the general `args_type` entry will + be returned if it exists. + + Raises: + NoEntryError: If no entry for the given item and args type exists in the + registry. + """ + + if (item, args_type) in self._registry: + return self._registry[(item, args_type)] + + # Fall back to general `args_type` if item is not specified. + if item is not None: + if (None, args_type) in self._registry: + return self.get(None, args_type) + + raise NoEntryError( + f'No entry for item={item} and args_ty={args_type} in the registry.' + ) + + def has(self, item: Optional[str], args_type: type[CheckpointArgs]) -> bool: + """Returns whether an entry for the given item and args type exists in the registry. + + Args: + item: The item name or None. + args_type: The args type. + + Does not check for fall back to general `args_type` entry. + """ + return ( + item, + args_type, + ) in self._registry + + def get_all_entries( + self, + ) -> HandlerRegistryMapping: + # Return all entries in self._registry + return self._registry diff --git a/checkpoint/orbax/checkpoint/handler_registration_test.py b/checkpoint/orbax/checkpoint/handler_registration_test.py new file mode 100644 index 000000000..60967a3dd --- /dev/null +++ b/checkpoint/orbax/checkpoint/handler_registration_test.py @@ -0,0 +1,273 @@ +# Copyright 2024 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +from typing import Optional, Type, Union + +from absl.testing import absltest +from absl.testing import parameterized +from etils import epath +from orbax.checkpoint import checkpoint_args +from orbax.checkpoint import checkpoint_handler +from orbax.checkpoint import handler_registration + + +CheckpointHandler = checkpoint_handler.CheckpointHandler +DefaultCheckpointHandlerRegistry = ( + handler_registration.DefaultCheckpointHandlerRegistry +) +AlreadyExistsError = handler_registration.AlreadyExistsError +NoEntryError = handler_registration.NoEntryError + + +class _TestCheckpointHandler(CheckpointHandler): + """No-op checkpoint handler for testing.""" + + def save(self, directory: epath.Path, *args, **kwargs) -> None: + del directory, args, kwargs + + def restore(self, directory: epath.Path, *args, **kwargs) -> None: + del directory, args, kwargs + + +@dataclasses.dataclass +class _TestArgs(checkpoint_args.CheckpointArgs): + """No-op checkpoint args for testing.""" + + ... + + +class HandlerRegistryTest(parameterized.TestCase): + + def setUp(self): + # Reset the global handler registry before each test. + self._restore_arg_to_handler_copy = ( + checkpoint_args._RESTORE_ARG_TO_HANDLER.copy() + ) + self._save_arg_to_handler_copy = checkpoint_args._SAVE_ARG_TO_HANDLER.copy() + + checkpoint_args._RESTORE_ARG_TO_HANDLER.clear() + checkpoint_args._SAVE_ARG_TO_HANDLER.clear() + super().setUp() + + def tearDown(self): + # Restore the global handler registry. + for k, v in self._restore_arg_to_handler_copy.items(): + checkpoint_args._RESTORE_ARG_TO_HANDLER[k] = v + + for k, v in self._save_arg_to_handler_copy.items(): + checkpoint_args._SAVE_ARG_TO_HANDLER[k] = v + super().tearDown() + + @parameterized.product( + handler=(_TestCheckpointHandler, _TestCheckpointHandler()), + item=(None, 'item'), + ) + def test_add_and_get_entry( + self, + handler: Union[CheckpointHandler, Type[CheckpointHandler]], + item: Optional[str], + ): + args_type = _TestArgs + registry = DefaultCheckpointHandlerRegistry() + + registry.add( + item, + args_type, + handler, + ) + + # Check that the entry is added to the registry. + self.assertTrue(registry.has(item, args_type)) + # Check that the handler is returned and that it is initialized as an + # object. + self.assertIsInstance( + registry.get(item, args_type), + _TestCheckpointHandler, + ) + + def test_add_entry_with_existing_item_and_args_type_raises_error(self): + item = 'item' + args_type = _TestArgs + registry = DefaultCheckpointHandlerRegistry() + + registry.add(item, args_type, _TestCheckpointHandler) + + with self.assertRaisesRegex( + AlreadyExistsError, r'already exists in the registry' + ): + registry.add(item, args_type, _TestCheckpointHandler) + + def test_get_all_entries(self): + item1 = 'item1' + item2 = 'item2' + args_type = _TestArgs + handler = _TestCheckpointHandler + registry = DefaultCheckpointHandlerRegistry() + + registry.add(item1, args_type, handler) + registry.add(item2, args_type, handler) + + entries = registry.get_all_entries() + self.assertLen(entries, 2) + self.assertIsInstance( + entries[(item1, args_type)], + handler, + ) + self.assertIsInstance( + entries[(item2, args_type)], + handler, + ) + + def test_instanziate_registry_from_another_registry(self): + item1 = 'item1' + item2 = 'item2' + args_type = _TestArgs + handler = _TestCheckpointHandler + + registry1 = DefaultCheckpointHandlerRegistry() + registry1.add(item1, args_type, handler) + registry2 = DefaultCheckpointHandlerRegistry(registry1) + registry2.add(item2, args_type, handler) + + entries = registry2.get_all_entries() + self.assertLen(entries, 2) + self.assertIsInstance( + entries[(item1, args_type)], + handler, + ) + self.assertIsInstance( + entries[(item2, args_type)], + handler, + ) + + @parameterized.product( + item=(None, 'item'), + ) + def test_raise_error_when_no_entry_found(self, item: Optional[str]): + registry = DefaultCheckpointHandlerRegistry() + + with self.assertRaisesRegex( + NoEntryError, + r'No entry for item=.* and args_ty=.* in the registry', + ): + registry.get(item, _TestArgs) + + def test_concrete_item_takes_precedence_over_general_args_type(self): + none_item = None + item = 'item' + args_type = _TestArgs + + class _TestCheckpointHandlerA(_TestCheckpointHandler): + pass + + class _TestCheckpointHandlerB(_TestCheckpointHandler): + pass + + registry = DefaultCheckpointHandlerRegistry() + registry.add(none_item, args_type, _TestCheckpointHandlerA) + registry.add(item, args_type, _TestCheckpointHandlerB) + + self.assertTrue(registry.has(none_item, args_type)) + self.assertTrue(registry.has(item, args_type)) + self.assertIsInstance( + registry.get(none_item, args_type), + _TestCheckpointHandlerA, + ) + self.assertIsInstance( + registry.get(item, args_type), + _TestCheckpointHandlerB, + ) + + def test_falls_back_to_general_args_type(self): + none_item = None + registered_item = 'registered_item' + item_without_registration = 'item_without_registration' + args_type = _TestArgs + + class _TestCheckpointHandlerA(_TestCheckpointHandler): + pass + + class _TestCheckpointHandlerB(_TestCheckpointHandler): + pass + + registry = DefaultCheckpointHandlerRegistry() + registry.add(none_item, args_type, _TestCheckpointHandlerA) + registry.add(registered_item, args_type, _TestCheckpointHandlerB) + + self.assertTrue(registry.has(none_item, args_type)) + self.assertTrue(registry.has(registered_item, args_type)) + self.assertFalse(registry.has(item_without_registration, args_type)) + + self.assertIsInstance( + registry.get(none_item, args_type), + _TestCheckpointHandlerA, + ) + self.assertIsInstance( + registry.get(item_without_registration, args_type), + _TestCheckpointHandlerA, + ) + self.assertIsInstance( + registry.get(registered_item, args_type), + _TestCheckpointHandlerB, + ) + + def test_multiple_hadnlers_for_same_item(self): + item = 'item' + + class _TestArgsA(checkpoint_args.CheckpointArgs): + pass + + class _TestArgsB(checkpoint_args.CheckpointArgs): + pass + + registry = DefaultCheckpointHandlerRegistry() + registry.add(item, _TestArgsA, _TestCheckpointHandler) + registry.add(item, _TestArgsB, _TestCheckpointHandler) + + self.assertIsInstance( + registry.get(item, _TestArgsA), + _TestCheckpointHandler, + ) + self.assertIsInstance( + registry.get(item, _TestArgsB), + _TestCheckpointHandler, + ) + + def test_inherits_from_globally_registered_handlers(self): + + @checkpoint_args.register_with_handler( + _TestCheckpointHandler, for_save=True, for_restore=True + ) + class _GloballyRegisteredTestArgs(checkpoint_args.CheckpointArgs): + pass + + self.assertIn( + _GloballyRegisteredTestArgs, checkpoint_args._SAVE_ARG_TO_HANDLER + ) + self.assertIn( + _GloballyRegisteredTestArgs, checkpoint_args._RESTORE_ARG_TO_HANDLER + ) + + registry = DefaultCheckpointHandlerRegistry() + + self.assertTrue(registry.has(None, _GloballyRegisteredTestArgs)) + self.assertIsInstance( + registry.get(None, _GloballyRegisteredTestArgs), + _TestCheckpointHandler, + ) + + +if __name__ == '__main__': + absltest.main()