Skip to content

Commit

Permalink
DO NOT SUBMIT
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 654736313
  • Loading branch information
Orbax Authors committed Aug 5, 2024
1 parent f84f454 commit f4b6b3c
Show file tree
Hide file tree
Showing 5 changed files with 734 additions and 16 deletions.
81 changes: 72 additions & 9 deletions checkpoint/orbax/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from orbax.checkpoint import checkpoint_handler
from orbax.checkpoint import checkpointer as checkpointer_lib
from orbax.checkpoint import composite_checkpoint_handler
from orbax.checkpoint import handler_registration
from orbax.checkpoint import json_checkpoint_handler
from orbax.checkpoint import multihost
from orbax.checkpoint import options as options_lib
Expand Down Expand Up @@ -69,6 +70,7 @@
CheckpointHandler = checkpoint_handler.CheckpointHandler
CheckpointArgs = checkpoint_args.CheckpointArgs
CheckpointHandlersDict = Mapping[str, CheckpointHandler]
CheckpointHandlerRegistry = handler_registration.CheckpointHandlerRegistry

AsyncOptions = options_lib.AsyncOptions
MultiprocessingOptions = options_lib.MultiprocessingOptions
Expand Down Expand Up @@ -393,6 +395,7 @@ def __init__(
Union[CheckpointHandler, CheckpointHandlersDict]
] = None,
logger: Optional[abstract_logger.AbstractLogger] = None,
handler_registry: Optional[CheckpointHandlerRegistry] = None,
):
"""CheckpointManager constructor.
Expand Down Expand Up @@ -488,15 +491,18 @@ def __init__(
metadata.
item_names: Names of distinct items that may be saved/restored with this
`CheckpointManager`. `item_names` and `checkpointers` are mutually
exclusive - do not use together. Also see `item_handlers` below.
exclusive - do not use together. Also see `item_handlers` below. Prefer
using the `handler_registry`.
item_handlers: A mapping of item name to `CheckpointHandler`. The mapped
CheckpointHandler must be registered against the `CheckpointArgs` input
in save/restore operations. Please don't use `checkpointers` and
`item_handlers` together. It can be used with or without `item_names`.
The item name key may or may not be present in `item_names`.
Alternatively, a single CheckpointHandler may be provided, in which case
`save` and `restore` should always be called in a single item context.
Prefer using the `handler_registry`.
logger: A logger to log checkpointing events.
handler_registry: A registry of handlers to use for checkpointing.
"""
jax.monitoring.record_event('/jax/orbax/checkpoint_manager/init')

Expand All @@ -508,6 +514,7 @@ def __init__(
self._options.multiprocessing_options or MultiprocessingOptions()
)
self._logger = logger or standard_logger.StandardLogger()
self._handler_registry = handler_registry

if checkpointers and item_names:
raise ValueError(
Expand All @@ -524,6 +531,20 @@ def __init__(
'`item_handlers` in single item mode and `item_names` should not be'
' provided together.'
)
if checkpointers is not None and handler_registry is not None:
raise ValueError(
'Deprecated `checkpointers` can not be used with `handler_registry`.'
' Please follow the instructions at'
' https://orbax.readthedocs.io/en/latest/api_refactor.html to'
' migrate by August 1st, 2024.'
)

if item_handlers is not None and handler_registry is not None:
raise ValueError(
'`item_handlers` and `handler_registry` are mutually exclusive -'
' prefer configuring the handler registry.'
)

# For async_checkpointer.
self._non_blocking_checkpoint_metadata_store = (
checkpoint.checkpoint_metadata_store(enable_write=True)
Expand All @@ -534,7 +555,8 @@ def __init__(
enable_write=True, blocking_write=True
)
)
if checkpointers:

if checkpointers is not None:
logging.warning(
'Configured `CheckpointManager` using deprecated legacy API. Please'
' follow the instructions at'
Expand All @@ -545,12 +567,28 @@ def __init__(
self._checkpointer = self._configure_checkpointer_legacy_init(
checkpointers, self._options
)
elif self._handler_registry is not None:
self._single_item = item_names is None
self._checkpointer = self._configure_checkpointer_from_handler_registry(
item_names,
handler_registry,
self._options,
)
else:
logging.warning(
'Prefer using the `handler_registry` instead of `item_names` and'
' `item_handlers`.'
)
self._single_item = isinstance(item_handlers, CheckpointHandler) or (
item_names is None and item_handlers is None
)
self._checkpointer = self._configure_checkpointer(
item_names, item_handlers, self._options, self._single_item
self._checkpointer = (
self._configure_checkpointer_from_item_names_and_handlers(
item_names,
item_handlers,
self._options,
self._single_item,
)
)

self._directory = epath.Path(directory)
Expand Down Expand Up @@ -745,7 +783,7 @@ def _validate_handler(self, handler):
f'handler[{type(handler)}]={handler._primary_host} ' # pylint: disable=protected-access
)

def _configure_checkpointer(
def _configure_checkpointer_from_item_names_and_handlers(
self,
item_names: Optional[Sequence[str]],
item_handlers: Optional[Union[CheckpointHandler, CheckpointHandlersDict]],
Expand All @@ -767,8 +805,6 @@ def _configure_checkpointer(
if isinstance(item_handlers, CheckpointHandler)
else None
)
if item_handler:
self._validate_handler(item_handler)
all_item_handlers = {DEFAULT_ITEM_NAME: item_handler}
else:
# Initialize all_item_handlers with None or empty.
Expand All @@ -779,10 +815,11 @@ def _configure_checkpointer(
# Update all_item_handlers with provided CheckpointHandlers.
if item_handlers and isinstance(item_handlers, Mapping):
for item_name, handler in item_handlers.items():
self._validate_handler(handler)
all_item_handlers[item_name] = handler

for item_name in all_item_handlers:
for item_name, handler in all_item_handlers.items():
if handler is not None:
self._validate_handler(handler)
if item_name in RESERVED_ITEM_NAMES:
raise ValueError(
f'Found {item_name} in `checkpointers`; this is a reserved key.'
Expand All @@ -808,6 +845,32 @@ def _configure_checkpointer(
options.enable_async_checkpointing,
)

def _configure_checkpointer_from_handler_registry(
self,
item_names: Optional[Sequence[str]],
handler_registry: CheckpointHandlerRegistry,
options: CheckpointManagerOptions,
) -> Checkpointer:
"""Initializes _CompositeCheckpointer given a `handler_registry`."""

# If `item_names`` is None, we will default to a single item.
items = [DEFAULT_ITEM_NAME] if item_names is None else item_names
# CompositeCheckpointHandler defers per-item handler creation until
# save/restore time.
return self._configure_checkpointer_common(
CompositeCheckpointHandler(
*items,
composite_options=composite_checkpoint_handler.CompositeOptions(
primary_host=self._multiprocessing_options.primary_host,
active_processes=self._multiprocessing_options.active_processes,
barrier_sync_key_prefix=self._multiprocessing_options.barrier_sync_key_prefix,
),
handler_registry=handler_registry,
),
options,
options.enable_async_checkpointing,
)

@property
def directory(self) -> epath.Path:
"""See superclass documentation."""
Expand Down
72 changes: 65 additions & 7 deletions checkpoint/orbax/checkpoint/composite_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import checkpoint_handler
from orbax.checkpoint import future
from orbax.checkpoint import handler_registration
from orbax.checkpoint import options as options_lib
from orbax.checkpoint import proto_checkpoint_handler
from orbax.checkpoint.path import atomicity
Expand Down Expand Up @@ -273,6 +274,9 @@ def __init__(
self,
*item_names: str,
composite_options: CompositeOptions = CompositeOptions(),
handler_registry: Optional[
handler_registration.CheckpointHandlerRegistry
] = None,
**items_and_handlers: CheckpointHandler,
):
"""Constructor.
Expand All @@ -282,19 +286,55 @@ def __init__(
Args:
*item_names: A list of string item names that this handler will manage.
composite_options: Options.
handler_registry: A `CheckpointHandlerRegistry` instance. If provided, the
`CompositeCheckpointHandler` will use this registry to determine the
`CheckpointHandler` for each item.
**items_and_handlers: A mapping of item name to `CheckpointHandler`
instance, which will be used as the handler for objects of the
corresponding name.
"""
self._known_handlers: Dict[str, Optional[CheckpointHandler]] = (
items_and_handlers
)
if handler_registry is not None and items_and_handlers:
raise ValueError(
'Both `handler_registry` and `items_and_handlers` were provided. '
'Please specify only one of the two.'
)

if handler_registry is not None:
self._handler_registry = handler_registry
self._known_handlers = {
item: handler
for (
item,
unused_args_type,
), handler in handler_registry.get_all_entries().items()
# The mapping is from item to handlers, so only include items that
# have handlers.
if item is not None
}
elif items_and_handlers:
logging.info(
'Prefer using `handler_registry` instead of `items_and_handlers`.'
)
self._handler_registry = None
self._known_handlers = items_and_handlers
else:
# If no handler registry or items_and_handlers are provided, we will
# default to the global registry.
self._handler_registry = None
self._known_handlers = {}

for item in item_names:
_maybe_raise_reserved_item_error(item)
if item not in self._known_handlers:
self._known_handlers[item] = None

for item_name, handler in self._known_handlers.items():
_maybe_raise_reserved_item_error(item_name)
if handler and not checkpoint_args.has_registered_args(handler):
if self._handler_registry is not None:
raise ValueError(
'Handler registry has been provided, but no registered'
f' `CheckpointArgs` found for handler type: {type(handler)}.'
)
logging.warning(
'No registered CheckpointArgs found for handler type: %s',
type(handler),
Expand All @@ -316,6 +356,23 @@ def _get_or_set_handler(
item_name: str,
args: Optional[CheckpointArgs],
) -> CheckpointHandler:
if self._handler_registry is not None:
if args is not None:
try:
# PyType assumes that the handler is always `None` for some reason,
# despite the fact that we check if it is not `None` above.
handler = self._handler_registry.get(item_name, type(args)) # pytype: disable=attribute-error
if self._known_handlers[item_name] is None:
self._known_handlers[item_name] = handler
return handler
except handler_registration.NoEntryError:
logging.info(
'No entry found in handler registry for item: %s and args: %s.'
' Falling back to global handler registry',
item_name,
args,
)

if item_name not in self._known_handlers:
raise ValueError(
f'Unknown key "{item_name}". Please make sure that this key was'
Expand Down Expand Up @@ -344,9 +401,10 @@ def _get_or_set_handler(
self._known_handlers[item_name] = handler
if not isinstance(handler, registered_handler_cls_for_args):
raise ValueError(
f'For item, "{item_name}", CheckpointHandler {type(handler)} does not'
f' match with registered handler {registered_handler_cls_for_args}'
f' for provided args of type: {type(args)}'
f'For item, "{item_name}", CheckpointHandler {type(handler)} does'
' not match with registered handler'
f' {registered_handler_cls_for_args} for provided args of type:'
f' {type(args)}'
)
return handler

Expand Down
Loading

0 comments on commit f4b6b3c

Please sign in to comment.