Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HandlerRegistry to CheckpointManager. #1023

Merged
merged 1 commit into from
Aug 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 72 additions & 9 deletions checkpoint/orbax/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from orbax.checkpoint import options as options_lib
from orbax.checkpoint import proto_checkpoint_handler
from orbax.checkpoint import utils
from orbax.checkpoint.handlers import handler_registration
from orbax.checkpoint.logging import abstract_logger
from orbax.checkpoint.logging import standard_logger
from orbax.checkpoint.logging import step_statistics
Expand Down Expand Up @@ -69,6 +70,7 @@
CheckpointHandler = checkpoint_handler.CheckpointHandler
CheckpointArgs = checkpoint_args.CheckpointArgs
CheckpointHandlersDict = Mapping[str, CheckpointHandler]
CheckpointHandlerRegistry = handler_registration.CheckpointHandlerRegistry

AsyncOptions = options_lib.AsyncOptions
MultiprocessingOptions = options_lib.MultiprocessingOptions
Expand Down Expand Up @@ -415,6 +417,7 @@ def __init__(
Union[CheckpointHandler, CheckpointHandlersDict]
] = None,
logger: Optional[abstract_logger.AbstractLogger] = None,
handler_registry: Optional[CheckpointHandlerRegistry] = None,
):
"""CheckpointManager constructor.

Expand Down Expand Up @@ -519,6 +522,10 @@ def __init__(
Alternatively, a single CheckpointHandler may be provided, in which case
`save` and `restore` should always be called in a single item context.
logger: A logger to log checkpointing events.
handler_registry: A registry of handlers to use for checkpointing. This
option is mutually exclusive with `checkpointers`,`item_handlers`, and
'item_names'. See :py:class:`CheckpointHandlerRegistry` for more
details.
"""
jax.monitoring.record_event('/jax/orbax/checkpoint_manager/init')

Expand All @@ -529,6 +536,7 @@ def __init__(
raise ValueError('`best_mode` must be one of: "min", "max"')

self._logger = logger or standard_logger.StandardLogger()
self._handler_registry = handler_registry

if checkpointers and item_names:
raise ValueError(
Expand All @@ -545,6 +553,26 @@ def __init__(
'`item_handlers` in single item mode and `item_names` should not be'
' provided together.'
)
if checkpointers is not None and handler_registry is not None:
raise ValueError(
'Deprecated `checkpointers` can not be used with `handler_registry`.'
' Please follow the instructions at'
' https://orbax.readthedocs.io/en/latest/api_refactor.html to'
' migrate.'
)

if item_handlers is not None and handler_registry is not None:
raise ValueError(
'`item_handlers` and `handler_registry` are mutually exclusive -'
' prefer configuring the handler registry.'
)

if item_names is not None and handler_registry is not None:
raise ValueError(
'`item_names` and `handler_registry` are mutually exclusive - prefer'
' configuring the handler registry.'
)

# For async_checkpointer.
self._non_blocking_checkpoint_metadata_store = (
checkpoint.checkpoint_metadata_store(enable_write=True)
Expand All @@ -555,26 +583,42 @@ def __init__(
enable_write=True, blocking_write=True
)
)
if checkpointers:

if checkpointers is not None:
jax.monitoring.record_event(
'/jax/orbax/deprecation/checkpoint_manager_legacy_init'
)
logging.warning(
'Configured `CheckpointManager` using deprecated legacy API. Please'
' follow the instructions at'
' https://orbax.readthedocs.io/en/latest/api_refactor.html to'
' migrate by August 1st, 2024.'
' migrate.'
)
self._single_item = isinstance(checkpointers, AbstractCheckpointer)
self._checkpointer = self._configure_checkpointer_legacy_init(
checkpointers, self._options
)
elif self._handler_registry is not None:
# TODO: b/357913991 - Add support for single item mode with handler
# registry. This could be done using a "lazy" model where the single-item
# mode is determined when the first save/restore/item_metadata, call is
# made.
self._single_item = False
self._checkpointer = self._configure_checkpointer_from_handler_registry(
handler_registry,
self._options,
)
else:
self._single_item = isinstance(item_handlers, CheckpointHandler) or (
item_names is None and item_handlers is None
)
self._checkpointer = self._configure_checkpointer(
item_names, item_handlers, self._options, self._single_item
self._checkpointer = (
self._configure_checkpointer_from_item_names_and_handlers(
item_names,
item_handlers,
self._options,
self._single_item,
)
)

self._directory = epath.Path(directory)
Expand Down Expand Up @@ -750,7 +794,7 @@ def _validate_handler(self, handler):
f'handler[{type(handler)}]={handler._primary_host} ' # pylint: disable=protected-access
)

def _configure_checkpointer(
def _configure_checkpointer_from_item_names_and_handlers(
self,
item_names: Optional[Sequence[str]],
item_handlers: Optional[Union[CheckpointHandler, CheckpointHandlersDict]],
Expand All @@ -772,8 +816,6 @@ def _configure_checkpointer(
if isinstance(item_handlers, CheckpointHandler)
else None
)
if item_handler:
self._validate_handler(item_handler)
all_item_handlers = {DEFAULT_ITEM_NAME: item_handler}
else:
# Initialize all_item_handlers with None or empty.
Expand All @@ -784,10 +826,11 @@ def _configure_checkpointer(
# Update all_item_handlers with provided CheckpointHandlers.
if item_handlers and isinstance(item_handlers, Mapping):
for item_name, handler in item_handlers.items():
self._validate_handler(handler)
all_item_handlers[item_name] = handler

for item_name in all_item_handlers:
for item_name, handler in all_item_handlers.items():
if handler is not None:
self._validate_handler(handler)
if item_name in RESERVED_ITEM_NAMES:
raise ValueError(
f'Found {item_name} in `checkpointers`; this is a reserved key.'
Expand All @@ -811,6 +854,26 @@ def _configure_checkpointer(
options.enable_async_checkpointing,
)

def _configure_checkpointer_from_handler_registry(
self,
handler_registry: CheckpointHandlerRegistry,
options: CheckpointManagerOptions,
) -> Checkpointer:
"""Initializes _CompositeCheckpointer given a `handler_registry`."""

# CompositeCheckpointHandler defers per-item handler creation until
# save/restore time.
return self._configure_checkpointer_common(
CompositeCheckpointHandler(
composite_options=composite_checkpoint_handler.CompositeOptions(
multiprocessing_options=options.multiprocessing_options,
),
handler_registry=handler_registry,
),
options,
options.enable_async_checkpointing,
)

@property
def directory(self) -> epath.Path:
"""See superclass documentation."""
Expand Down
Loading