Skip to content

Commit

Permalink
Add memory limiting functionality for save, in addition to restore.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 643168487
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed Aug 5, 2024
1 parent f84f454 commit cb61318
Show file tree
Hide file tree
Showing 10 changed files with 334 additions and 120 deletions.
5 changes: 4 additions & 1 deletion checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Improve logging by adding jax_process, error logs in threads and more...
- Improvements to blocking save time, as a result of moving file open operations into the background.

### Added
- Add memory-based rate limiting support during save.


## [0.5.23] - 2024-07-26

Expand All @@ -39,7 +42,6 @@ entries will not return information about array properties).
### Added
- Rolled forward change to improve TensorStore I/O efficiency.
- Memory efficient broadcasting from one model replica to others.
- Ability to check if a checkpoint save is in progress.

### Changed
- Allow one directory creation request per item rather than 1 per item per host.
Expand Down Expand Up @@ -76,6 +78,7 @@ entries will not return information about array properties).
### Fixed
- Rolled back change in previous release to improve TensorStore I/O efficiency.
This change caused some unexpected failures on certain storage systems.
- Add memory-based rate limiting support during save.

## [0.5.16] - 2024-06-11

Expand Down
1 change: 1 addition & 0 deletions checkpoint/orbax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from orbax.checkpoint import msgpack_utils
from orbax.checkpoint import multihost
from orbax.checkpoint import path
from orbax.checkpoint import serialization
from orbax.checkpoint import test_utils
from orbax.checkpoint import transform_utils
from orbax.checkpoint import tree
Expand Down
1 change: 1 addition & 0 deletions checkpoint/orbax/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from orbax.checkpoint import msgpack_utils
from orbax.checkpoint import multihost
from orbax.checkpoint import path
from orbax.checkpoint import serialization
from orbax.checkpoint import test_utils
from orbax.checkpoint import transform_utils
from orbax.checkpoint import tree
Expand Down
45 changes: 24 additions & 21 deletions checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import future
from orbax.checkpoint import multihost
from orbax.checkpoint import serialization
from orbax.checkpoint import tree as tree_utils
from orbax.checkpoint import type_handlers
from orbax.checkpoint import utils
Expand All @@ -58,22 +59,10 @@
get_param_names = tree_utils.get_param_names

METADATA_FILE = '_METADATA'
DEFAULT_CONCURRENT_GB = 96



def get_byte_limiter(concurrent_gb: int):
async def _create_byte_limiter():
# Wrap creation in async function to avoid issues on python<=3.9.
concurrent_bytes = concurrent_gb * 10**9
# Construction must take place here so that it is within the same async
# method, to prevent errors resulting from different event loops, and
# cannot be created below this level because there must be a single object
# for the entire restore call.
return LimitInFlightBytes(concurrent_bytes) # pylint: disable=protected-access

return asyncio.run(_create_byte_limiter())


async def _create_param_save_dir(param_info: ParamInfo, args: SaveArgs):
# Directory will be unused.
path = param_info.path
Expand Down Expand Up @@ -234,7 +223,9 @@ class BasePyTreeCheckpointHandler(

def __init__(
self,
concurrent_gb: int = 96,
*,
save_concurrent_bytes: Optional[int] = None,
restore_concurrent_bytes: Optional[int] = None,
use_ocdbt: bool = True,
use_zarr3: bool = False,
primary_host: Optional[int] = 0,
Expand All @@ -243,8 +234,12 @@ def __init__(
"""Creates BasePyTreeCheckpointHandler.
Args:
concurrent_gb: max concurrent GB that are allowed to be read. Can help to
reduce the possibility of OOM's when large checkpoints are restored.
save_concurrent_bytes: max concurrent bytes that are allowed to be
written. Can help to reduce the possibility of OOM's when large
checkpoints are saved.
restore_concurrent_bytes: max concurrent bytes that are allowed to be
restored. Can help to reduce the possibility of OOM's when large
checkpoints are restored.
use_ocdbt: Whether to use OCDBT format for saving.
use_zarr3: If True, use Zarr ver3 otherwise Zarr ver2.
primary_host: the host id of the primary host. Default to 0. If it's set
Expand All @@ -254,7 +249,8 @@ def __init__(
specified, the global type handler registry will be used. # BEGIN
enable_descriptor: If True, logs a Descriptor proto that contains lineage
"""
self._concurrent_gb = concurrent_gb
self._save_concurrent_bytes = save_concurrent_bytes
self._restore_concurrent_bytes = restore_concurrent_bytes
self._use_ocdbt = use_ocdbt
self._use_zarr3 = use_zarr3
self._primary_host = primary_host
Expand All @@ -279,7 +275,7 @@ def _get_param_infos(
use_ocdbt: bool = True,
use_zarr3: Optional[bool] = None,
ocdbt_target_data_file_size: Optional[int] = None,
byte_limiter: Optional[LimitInFlightBytes] = None,
byte_limiter: Optional[serialization.ByteLimiter] = None,
) -> PyTree:
"""Returns parameter information for elements in `item`.
Expand All @@ -293,7 +289,7 @@ def _get_param_infos(
use_zarr3: Whether to use zarr3.
ocdbt_target_data_file_size: Specifies the target size (in bytes) of each
OCDBT data file.
byte_limiter: LimitInFlightBytes object.
byte_limiter: ByteLimiter object.
Returns:
A PyTree matching `item` of ParamInfo.
Expand Down Expand Up @@ -395,11 +391,13 @@ async def async_save(
raise ValueError('`ocdbt_target_data_file_size` only works with Zarr3')

save_args = _fill_missing_save_or_restore_args(item, save_args, mode='save')
byte_limiter = serialization.get_byte_limiter(self._save_concurrent_bytes)
param_infos = self._get_param_infos(
item,
directory,
use_ocdbt=self._use_ocdbt,
ocdbt_target_data_file_size=ocdbt_target_data_file_size,
byte_limiter=byte_limiter,
)
assert all(
leaf.parent_dir == directory
Expand Down Expand Up @@ -471,6 +469,13 @@ async def _maybe_deserialize(
) -> PyTree:
"""Deserializes values or skips."""
flat_metadata = tree_utils.to_flat_dict(metadata)
byte_limiter = serialization.get_byte_limiter(
self._restore_concurrent_bytes
)
param_infos = jax.tree.map(
lambda info: dataclasses.replace(info, byte_limiter=byte_limiter),
param_infos,
)
batch_requests = batched_serialization_requests(
metadata,
param_infos,
Expand Down Expand Up @@ -596,7 +601,6 @@ class TrainState:
raise FileNotFoundError(
f'Requested directory for restore does not exist at {directory}'
)
byte_limiter = get_byte_limiter(self._concurrent_gb)
metadata = self._read_metadata_file(directory)
use_zarr3_metadata = metadata.use_zarr3
metadata = metadata.as_nested_tree(keep_empty_nodes=True)
Expand All @@ -620,7 +624,6 @@ class TrainState:
directory,
use_ocdbt=type_handlers.is_ocdbt_checkpoint(directory),
use_zarr3=use_zarr3,
byte_limiter=byte_limiter,
)
restored_item = asyncio.run(
self._maybe_deserialize(item, metadata, param_infos, restore_args)
Expand Down
43 changes: 30 additions & 13 deletions checkpoint/orbax/checkpoint/pytree_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from orbax.checkpoint import base_pytree_checkpoint_handler
from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import future
from orbax.checkpoint import serialization
from orbax.checkpoint import transform_utils
from orbax.checkpoint import tree as tree_utils
from orbax.checkpoint import type_handlers
Expand Down Expand Up @@ -63,13 +64,12 @@
)
BasePyTreeSaveArgs = base_pytree_checkpoint_handler.BasePyTreeSaveArgs
BasePyTreeRestoreArgs = base_pytree_checkpoint_handler.BasePyTreeRestoreArgs
get_byte_limiter = base_pytree_checkpoint_handler.get_byte_limiter
LimitInFlightBytes = base_pytree_checkpoint_handler.LimitInFlightBytes
get_param_names = base_pytree_checkpoint_handler.get_param_names

_CHECKPOINT_FILE = 'checkpoint'
_METADATA_FILE = base_pytree_checkpoint_handler.METADATA_FILE
_DEFAULT_CONCURRENT_GB = 96
DEFAULT_CONCURRENT_GB = base_pytree_checkpoint_handler.DEFAULT_CONCURRENT_GB


def _maybe_set_default_restore_args(args):
Expand Down Expand Up @@ -434,6 +434,13 @@ def _overwrite_aggregate(sa: SaveArgs) -> SaveArgs:
)


def _concurrent_bytes(concurrent_gb: Optional[int]) -> int:
if concurrent_gb is None:
return DEFAULT_CONCURRENT_GB * 10**9
else:
return concurrent_gb * 10**9


class PyTreeCheckpointHandler(async_checkpoint_handler.AsyncCheckpointHandler):
"""A CheckpointHandler implementation for any PyTree structure.
Expand All @@ -450,14 +457,16 @@ class PyTreeCheckpointHandler(async_checkpoint_handler.AsyncCheckpointHandler):
Example::
ckptr = Checkpointer(PyTreeCheckpointHandler())
# TODO(cpgaffney) Cut down on the protected methods accessed by this class.
"""

def __init__(
self,
aggregate_filename: Optional[str] = None,
concurrent_gb: Optional[int] = None,
*,
save_concurrent_gb: Optional[int] = None,
restore_concurrent_gb: Optional[int] = None,
use_ocdbt: bool = True,
use_zarr3: bool = False,
primary_host: Optional[int] = 0,
Expand All @@ -469,8 +478,12 @@ def __init__(
Args:
aggregate_filename: name that the aggregated checkpoint should be saved
as.
concurrent_gb: max concurrent GB that are allowed to be read. Can help to
reduce the possibility of OOM's when large checkpoints are restored.
save_concurrent_gb: max concurrent GB that are allowed for writing. Can
help to reduce the possibility of OOM's when large checkpoints are
saved.
restore_concurrent_gb: max concurrent GB that are allowed for writing. Can
help to reduce the possibility of OOM's when large checkpoints are
restored.
use_ocdbt: enables Tensorstore OCDBT driver. This option allows using a
different checkpoint format which is faster to read and write, as well
as more space efficient.
Expand All @@ -486,16 +499,15 @@ def __init__(
if aggregate_filename is None:
aggregate_filename = _CHECKPOINT_FILE
self._aggregate_filename = aggregate_filename
if concurrent_gb is None:
concurrent_gb = _DEFAULT_CONCURRENT_GB
self._concurrent_gb = concurrent_gb
self._use_ocdbt = use_ocdbt
self._use_zarr3 = use_zarr3
self._primary_host = primary_host
self._type_handler_registry = type_handler_registry

self._save_concurrent_bytes = _concurrent_bytes(save_concurrent_gb)
self._restore_concurrent_bytes = _concurrent_bytes(restore_concurrent_gb)
self._handler_impl = handler_impl or BasePyTreeCheckpointHandler(
concurrent_gb=concurrent_gb,
save_concurrent_bytes=self._save_concurrent_bytes,
restore_concurrent_bytes=self._restore_concurrent_bytes,
use_ocdbt=use_ocdbt,
use_zarr3=use_zarr3,
primary_host=primary_host,
Expand Down Expand Up @@ -579,6 +591,13 @@ async def _maybe_deserialize(
restore_args: PyTree,
) -> PyTree:
"""Deserializes values or gets them from the aggregate file."""
byte_limiter = serialization.get_byte_limiter(
self._restore_concurrent_bytes
)
param_infos = jax.tree.map(
lambda info: dataclasses.replace(info, byte_limiter=byte_limiter),
param_infos,
)

# Handle parameters from aggregate file.
def _process_aggregated_value(meta_or_value, args):
Expand Down Expand Up @@ -766,7 +785,6 @@ class TrainState:
raise FileNotFoundError(
f'Requested directory for restore does not exist at {directory}'
)
byte_limiter = get_byte_limiter(self._concurrent_gb)
structure, use_zarr3_metadata = self._get_internal_metadata(directory)
# `checkpoint_restore_args` has a structure relative to the checkpoint,
# while `restore_args` remains structured relative to the output.
Expand All @@ -778,7 +796,6 @@ class TrainState:
self._handler_impl.get_param_names(structure),
transforms,
restore_args,
byte_limiter=byte_limiter,
transforms_default_to_original=transforms_default_to_original,
use_zarr3=use_zarr3_metadata
if use_zarr3_metadata is not None
Expand Down
Loading

0 comments on commit cb61318

Please sign in to comment.