Skip to content

Commit

Permalink
Consolidate usages of MultiprocessingOptions and AsyncOptions. Fo…
Browse files Browse the repository at this point in the history
…rmalize `StandardCheckpointer` as an `AsyncCheckpointer` that doesn't require a `CheckpointArgs` object, and instead allows directly passing the state and extra args.

PiperOrigin-RevId: 658106830
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed Aug 5, 2024
1 parent f84f454 commit cd13afb
Show file tree
Hide file tree
Showing 19 changed files with 306 additions and 150 deletions.
2 changes: 2 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- Improve logging by adding jax_process, error logs in threads and more...
- Improvements to blocking save time, as a result of moving file open operations into the background.
- Consolidate usages of `MultiprocessingOptions` and `AsyncOptions`.
- Formalize `StandardCheckpointer` as an `AsyncCheckpointer` that doesn't require a `CheckpointArgs` object, and instead allows directly passing the state and extra args.


## [0.5.23] - 2024-07-26
Expand Down
1 change: 1 addition & 0 deletions checkpoint/orbax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from orbax.checkpoint import metadata
from orbax.checkpoint import msgpack_utils
from orbax.checkpoint import multihost
from orbax.checkpoint import options
from orbax.checkpoint import path
from orbax.checkpoint import test_utils
from orbax.checkpoint import transform_utils
Expand Down
1 change: 1 addition & 0 deletions checkpoint/orbax/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from orbax.checkpoint import metadata
from orbax.checkpoint import msgpack_utils
from orbax.checkpoint import multihost
from orbax.checkpoint import options
from orbax.checkpoint import path
from orbax.checkpoint import test_utils
from orbax.checkpoint import transform_utils
Expand Down
31 changes: 16 additions & 15 deletions checkpoint/orbax/checkpoint/async_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import asyncio
import threading
import time
from typing import Any, Callable, Optional, Sequence, Set, Type
from typing import Any, Callable, Optional, Sequence, Type

from absl import logging
from etils import epath
Expand All @@ -27,6 +27,7 @@
from orbax.checkpoint import checkpointer
from orbax.checkpoint import future as future_lib
from orbax.checkpoint import multihost
from orbax.checkpoint import options as options_lib
from orbax.checkpoint import utils
from orbax.checkpoint.metadata import checkpoint
from orbax.checkpoint.path import atomicity
Expand Down Expand Up @@ -221,13 +222,10 @@ class AsyncCheckpointer(checkpointer.Checkpointer):
def __init__(
self,
handler: async_checkpoint_handler.AsyncCheckpointHandler,
timeout_secs: int = 300,
timeout_secs: Optional[int] = None,
*,
primary_host: Optional[int] = 0,
active_processes: Optional[Set[int]] = None,
barrier_sync_fn: Optional[multihost.BarrierSyncFn] = None,
barrier_sync_key_prefix: Optional[str] = None,
post_finalization_callback: Optional[Callable[[], None]] = None,
async_options: options_lib.AsyncOptions = options_lib.AsyncOptions(),
multiprocessing_options: options_lib.MultiprocessingOptions = options_lib.MultiprocessingOptions(),
path_permission_mode: Optional[int] = None,
checkpoint_metadata_store: Optional[
checkpoint.CheckpointMetadataStore
Expand All @@ -245,14 +243,14 @@ def __init__(
handler, async_checkpoint_handler.AsyncCheckpointHandler
)
self._handler = handler
self._primary_host = primary_host
self._active_processes = active_processes
self._post_finalization_callback = post_finalization_callback
self._primary_host = multiprocessing_options.primary_host
self._active_processes = multiprocessing_options.active_processes
self._post_finalization_callback = async_options.post_finalization_callback
unique_class_id = self._unique_operation_id()
barrier_sync_key_prefix = (
f'{unique_class_id}'
if barrier_sync_key_prefix is None
else f'{barrier_sync_key_prefix}.{unique_class_id}'
if multiprocessing_options.barrier_sync_key_prefix is None
else f'{multiprocessing_options.barrier_sync_key_prefix}.{unique_class_id}'
)
self._barrier_sync_key_prefix = barrier_sync_key_prefix
self._path_permission_mode = path_permission_mode # e.g. 0o750
Expand All @@ -261,13 +259,16 @@ def __init__(
or checkpoint.checkpoint_metadata_store(enable_write=True)
)
self._temporary_path_class = temporary_path_class
timeout_secs = timeout_secs or async_options.timeout_secs

# TODO(dicentra): consider folding into AsyncCheckpointer directly.
self._async_manager = _AsyncManager(
barrier_sync_fn=barrier_sync_fn
or multihost.get_barrier_sync_fn(processes=active_processes),
barrier_sync_fn=async_options.barrier_sync_fn
or multihost.get_barrier_sync_fn(
processes=multiprocessing_options.active_processes
),
timeout_secs=timeout_secs,
primary_host=primary_host,
primary_host=multiprocessing_options.primary_host,
barrier_sync_key_prefix=barrier_sync_key_prefix,
)

Expand Down
18 changes: 8 additions & 10 deletions checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import future
from orbax.checkpoint import multihost
from orbax.checkpoint import options as options_lib
from orbax.checkpoint import tree as tree_utils
from orbax.checkpoint import type_handlers
from orbax.checkpoint import utils
Expand Down Expand Up @@ -124,6 +125,7 @@ def batched_serialization_requests(
) -> List[_BatchRequest]:
"""Gets a list of batched serialization or deserialization requests."""
grouped = {}

def _group_value(
keypath: Tuple[Any, ...],
info: ParamInfo,
Expand Down Expand Up @@ -234,10 +236,11 @@ class BasePyTreeCheckpointHandler(

def __init__(
self,
*,
concurrent_gb: int = 96,
use_ocdbt: bool = True,
use_zarr3: bool = False,
primary_host: Optional[int] = 0,
multiprocessing_options: options_lib.MultiprocessingOptions = options_lib.MultiprocessingOptions(),
type_handler_registry: TypeHandlerRegistry = type_handlers.GLOBAL_TYPE_HANDLER_REGISTRY,
):
"""Creates BasePyTreeCheckpointHandler.
Expand All @@ -247,17 +250,15 @@ def __init__(
reduce the possibility of OOM's when large checkpoints are restored.
use_ocdbt: Whether to use OCDBT format for saving.
use_zarr3: If True, use Zarr ver3 otherwise Zarr ver2.
primary_host: the host id of the primary host. Default to 0. If it's set
to None, then all hosts will be considered as primary. It's useful in
the case that all hosts are only working with local storage.
multiprocessing_options: See orbax.checkpoint.options.
type_handler_registry: a type_handlers.TypeHandlerRegistry. If not
specified, the global type handler registry will be used. # BEGIN
enable_descriptor: If True, logs a Descriptor proto that contains lineage
"""
self._concurrent_gb = concurrent_gb
self._use_ocdbt = use_ocdbt
self._use_zarr3 = use_zarr3
self._primary_host = primary_host
self._primary_host = multiprocessing_options.primary_host
self._type_handler_registry = type_handler_registry


Expand Down Expand Up @@ -402,8 +403,7 @@ async def async_save(
ocdbt_target_data_file_size=ocdbt_target_data_file_size,
)
assert all(
leaf.parent_dir == directory
for leaf in jax.tree.leaves(param_infos)
leaf.parent_dir == directory for leaf in jax.tree.leaves(param_infos)
)
await self._maybe_create_param_directories(param_infos, save_args)

Expand Down Expand Up @@ -629,9 +629,7 @@ class TrainState:
if logging.level_debug():
logging.debug('param_infos: %s', param_infos)
logging.debug('checkpoint_restore_args: %s', restore_args)
logging.debug(
'restored_item: %s', jax.tree.structure(restored_item)
)
logging.debug('restored_item: %s', jax.tree.structure(restored_item))
logging.debug(
'ts_metrics: %s',
json.dumps(ts.experimental_collect_matching_metrics('/tensorstore/')),
Expand Down
56 changes: 16 additions & 40 deletions checkpoint/orbax/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,12 +501,11 @@ def __init__(
jax.monitoring.record_event('/jax/orbax/checkpoint_manager/init')

self._options = options or CheckpointManagerOptions()
self._multiprocessing_options = self._options.multiprocessing_options

if self._options.best_mode not in ['min', 'max']:
raise ValueError('`best_mode` must be one of: "min", "max"')

self._multiprocessing_options = (
self._options.multiprocessing_options or MultiprocessingOptions()
)
self._logger = logger or standard_logger.StandardLogger()

if checkpointers and item_names:
Expand Down Expand Up @@ -576,11 +575,9 @@ def __init__(

self._metadata_checkpointer = Checkpointer(
JsonCheckpointHandler(
primary_host=self._multiprocessing_options.primary_host
multiprocessing_options=self._multiprocessing_options
),
primary_host=self._multiprocessing_options.primary_host,
active_processes=self._multiprocessing_options.active_processes,
barrier_sync_key_prefix=self._multiprocessing_options.barrier_sync_key_prefix,
multiprocessing_options=self._options.multiprocessing_options,
path_permission_mode=self._options.file_options.path_permission_mode,
checkpoint_metadata_store=self._blocking_checkpoint_metadata_store,
temporary_path_class=self._options.temporary_path_class,
Expand Down Expand Up @@ -621,35 +618,18 @@ def _configure_checkpointer_common(
use_async: bool,
) -> Checkpointer:
if use_async:
if options.async_options is not None:
return async_checkpointer.AsyncCheckpointer(
handler,
timeout_secs=options.async_options.timeout_secs,
primary_host=self._multiprocessing_options.primary_host,
barrier_sync_fn=options.async_options.barrier_sync_fn,
active_processes=self._multiprocessing_options.active_processes,
barrier_sync_key_prefix=self._multiprocessing_options.barrier_sync_key_prefix,
post_finalization_callback=options.async_options.post_finalization_callback,
path_permission_mode=options.file_options.path_permission_mode,
checkpoint_metadata_store=self._non_blocking_checkpoint_metadata_store,
temporary_path_class=options.temporary_path_class,
)
else:
return async_checkpointer.AsyncCheckpointer(
handler,
primary_host=self._multiprocessing_options.primary_host,
active_processes=self._multiprocessing_options.active_processes,
barrier_sync_key_prefix=self._multiprocessing_options.barrier_sync_key_prefix,
path_permission_mode=options.file_options.path_permission_mode,
checkpoint_metadata_store=self._non_blocking_checkpoint_metadata_store,
temporary_path_class=options.temporary_path_class,
)
return async_checkpointer.AsyncCheckpointer(
handler,
multiprocessing_options=options.multiprocessing_options,
async_options=options.async_options or AsyncOptions(),
path_permission_mode=options.file_options.path_permission_mode,
checkpoint_metadata_store=self._non_blocking_checkpoint_metadata_store,
temporary_path_class=options.temporary_path_class,
)
else:
return Checkpointer(
handler,
primary_host=self._multiprocessing_options.primary_host,
active_processes=self._multiprocessing_options.active_processes,
barrier_sync_key_prefix=self._multiprocessing_options.barrier_sync_key_prefix,
multiprocessing_options=options.multiprocessing_options,
path_permission_mode=options.file_options.path_permission_mode,
checkpoint_metadata_store=self._blocking_checkpoint_metadata_store,
temporary_path_class=options.temporary_path_class,
Expand Down Expand Up @@ -722,9 +702,7 @@ def _configure_checkpointer_legacy_init(
return self._configure_checkpointer_common(
CompositeCheckpointHandler(
composite_options=composite_checkpoint_handler.CompositeOptions(
primary_host=self._multiprocessing_options.primary_host,
active_processes=self._multiprocessing_options.active_processes,
barrier_sync_key_prefix=self._multiprocessing_options.barrier_sync_key_prefix,
multiprocessing_options=options.multiprocessing_options,
file_options=options.file_options,
),
**item_handlers,
Expand Down Expand Up @@ -790,16 +768,14 @@ def _configure_checkpointer(
if options.best_fn:
all_item_handlers[METRIC_ITEM_NAME] = JsonCheckpointHandler(
filename=METRIC_ITEM_NAME,
primary_host=self._multiprocessing_options.primary_host,
multiprocessing_options=self._multiprocessing_options,
)
# CompositeCheckpointHandler defers per-item handler creation until
# save/restore time.
return self._configure_checkpointer_common(
CompositeCheckpointHandler(
composite_options=composite_checkpoint_handler.CompositeOptions(
primary_host=self._multiprocessing_options.primary_host,
active_processes=self._multiprocessing_options.active_processes,
barrier_sync_key_prefix=self._multiprocessing_options.barrier_sync_key_prefix,
multiprocessing_options=options.multiprocessing_options,
file_options=options.file_options,
),
**all_item_handlers,
Expand Down
14 changes: 7 additions & 7 deletions checkpoint/orbax/checkpoint/checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Synchronous Checkpointer implementation."""

import time
from typing import Any, Iterable, Optional, Set, Type
from typing import Any, Iterable, Optional, Type

from absl import logging
from etils import epath
Expand Down Expand Up @@ -101,9 +101,7 @@ def __init__(
self,
handler: checkpoint_handler.CheckpointHandler,
*,
primary_host: Optional[int] = 0,
active_processes: Optional[Set[int]] = None,
barrier_sync_key_prefix: Optional[str] = None,
multiprocessing_options: options_lib.MultiprocessingOptions = options_lib.MultiprocessingOptions(),
path_permission_mode: Optional[int] = None,
checkpoint_metadata_store: Optional[
checkpoint.CheckpointMetadataStore
Expand All @@ -117,9 +115,11 @@ def __init__(
)
handler = get_legacy_handler_wrapper(handler)
self._handler = handler
self._primary_host = primary_host
self._active_processes = active_processes
self._barrier_sync_key_prefix = barrier_sync_key_prefix
self._primary_host = multiprocessing_options.primary_host
self._active_processes = multiprocessing_options.active_processes
self._barrier_sync_key_prefix = (
multiprocessing_options.barrier_sync_key_prefix
)
self._path_permission_mode = path_permission_mode # e.g. 0o750
self._temporary_path_class = temporary_path_class

Expand Down
Loading

0 comments on commit cd13afb

Please sign in to comment.