diff --git a/checkpoint/orbax/checkpoint/checkpoint_manager.py b/checkpoint/orbax/checkpoint/checkpoint_manager.py index 9faf0f4f..1b697620 100644 --- a/checkpoint/orbax/checkpoint/checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/checkpoint_manager.py @@ -40,6 +40,7 @@ from orbax.checkpoint import options as options_lib from orbax.checkpoint import proto_checkpoint_handler from orbax.checkpoint import utils +from orbax.checkpoint.handlers import handler_registration from orbax.checkpoint.logging import abstract_logger from orbax.checkpoint.logging import standard_logger from orbax.checkpoint.logging import step_statistics @@ -69,6 +70,7 @@ CheckpointHandler = checkpoint_handler.CheckpointHandler CheckpointArgs = checkpoint_args.CheckpointArgs CheckpointHandlersDict = Mapping[str, CheckpointHandler] +CheckpointHandlerRegistry = handler_registration.CheckpointHandlerRegistry AsyncOptions = options_lib.AsyncOptions MultiprocessingOptions = options_lib.MultiprocessingOptions @@ -415,6 +417,7 @@ def __init__( Union[CheckpointHandler, CheckpointHandlersDict] ] = None, logger: Optional[abstract_logger.AbstractLogger] = None, + handler_registry: Optional[CheckpointHandlerRegistry] = None, ): """CheckpointManager constructor. @@ -519,6 +522,10 @@ def __init__( Alternatively, a single CheckpointHandler may be provided, in which case `save` and `restore` should always be called in a single item context. logger: A logger to log checkpointing events. + handler_registry: A registry of handlers to use for checkpointing. This + option is mutually exclusive with `checkpointers`,`item_handlers`, and + 'item_names'. See :py:class:`CheckpointHandlerRegistry` for more + details. """ jax.monitoring.record_event('/jax/orbax/checkpoint_manager/init') @@ -529,6 +536,7 @@ def __init__( raise ValueError('`best_mode` must be one of: "min", "max"') self._logger = logger or standard_logger.StandardLogger() + self._handler_registry = handler_registry if checkpointers and item_names: raise ValueError( @@ -545,6 +553,26 @@ def __init__( '`item_handlers` in single item mode and `item_names` should not be' ' provided together.' ) + if checkpointers is not None and handler_registry is not None: + raise ValueError( + 'Deprecated `checkpointers` can not be used with `handler_registry`.' + ' Please follow the instructions at' + ' https://orbax.readthedocs.io/en/latest/api_refactor.html to' + ' migrate.' + ) + + if item_handlers is not None and handler_registry is not None: + raise ValueError( + '`item_handlers` and `handler_registry` are mutually exclusive -' + ' prefer configuring the handler registry.' + ) + + if item_names is not None and handler_registry is not None: + raise ValueError( + '`item_names` and `handler_registry` are mutually exclusive - prefer' + ' configuring the handler registry.' + ) + # For async_checkpointer. self._non_blocking_checkpoint_metadata_store = ( checkpoint.checkpoint_metadata_store(enable_write=True) @@ -555,7 +583,8 @@ def __init__( enable_write=True, blocking_write=True ) ) - if checkpointers: + + if checkpointers is not None: jax.monitoring.record_event( '/jax/orbax/deprecation/checkpoint_manager_legacy_init' ) @@ -563,18 +592,33 @@ def __init__( 'Configured `CheckpointManager` using deprecated legacy API. Please' ' follow the instructions at' ' https://orbax.readthedocs.io/en/latest/api_refactor.html to' - ' migrate by August 1st, 2024.' + ' migrate.' ) self._single_item = isinstance(checkpointers, AbstractCheckpointer) self._checkpointer = self._configure_checkpointer_legacy_init( checkpointers, self._options ) + elif self._handler_registry is not None: + # TODO: b/357913991 - Add support for single item mode with handler + # registry. This could be done using a "lazy" model where the single-item + # mode is determined when the first save/restore/item_metadata, call is + # made. + self._single_item = False + self._checkpointer = self._configure_checkpointer_from_handler_registry( + handler_registry, + self._options, + ) else: self._single_item = isinstance(item_handlers, CheckpointHandler) or ( item_names is None and item_handlers is None ) - self._checkpointer = self._configure_checkpointer( - item_names, item_handlers, self._options, self._single_item + self._checkpointer = ( + self._configure_checkpointer_from_item_names_and_handlers( + item_names, + item_handlers, + self._options, + self._single_item, + ) ) self._directory = epath.Path(directory) @@ -750,7 +794,7 @@ def _validate_handler(self, handler): f'handler[{type(handler)}]={handler._primary_host} ' # pylint: disable=protected-access ) - def _configure_checkpointer( + def _configure_checkpointer_from_item_names_and_handlers( self, item_names: Optional[Sequence[str]], item_handlers: Optional[Union[CheckpointHandler, CheckpointHandlersDict]], @@ -772,8 +816,6 @@ def _configure_checkpointer( if isinstance(item_handlers, CheckpointHandler) else None ) - if item_handler: - self._validate_handler(item_handler) all_item_handlers = {DEFAULT_ITEM_NAME: item_handler} else: # Initialize all_item_handlers with None or empty. @@ -784,10 +826,11 @@ def _configure_checkpointer( # Update all_item_handlers with provided CheckpointHandlers. if item_handlers and isinstance(item_handlers, Mapping): for item_name, handler in item_handlers.items(): - self._validate_handler(handler) all_item_handlers[item_name] = handler - for item_name in all_item_handlers: + for item_name, handler in all_item_handlers.items(): + if handler is not None: + self._validate_handler(handler) if item_name in RESERVED_ITEM_NAMES: raise ValueError( f'Found {item_name} in `checkpointers`; this is a reserved key.' @@ -811,6 +854,26 @@ def _configure_checkpointer( options.enable_async_checkpointing, ) + def _configure_checkpointer_from_handler_registry( + self, + handler_registry: CheckpointHandlerRegistry, + options: CheckpointManagerOptions, + ) -> Checkpointer: + """Initializes _CompositeCheckpointer given a `handler_registry`.""" + + # CompositeCheckpointHandler defers per-item handler creation until + # save/restore time. + return self._configure_checkpointer_common( + CompositeCheckpointHandler( + composite_options=composite_checkpoint_handler.CompositeOptions( + multiprocessing_options=options.multiprocessing_options, + ), + handler_registry=handler_registry, + ), + options, + options.enable_async_checkpointing, + ) + @property def directory(self) -> epath.Path: """See superclass documentation."""