Skip to content

Commit

Permalink
Add memory limiting functionality for save, in addition to restore.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 643168487
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed Aug 2, 2024
1 parent f84f454 commit 484fbb2
Show file tree
Hide file tree
Showing 10 changed files with 264 additions and 109 deletions.
5 changes: 4 additions & 1 deletion checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Improve logging by adding jax_process, error logs in threads and more...
- Improvements to blocking save time, as a result of moving file open operations into the background.

### Added
- Add memory-based rate limiting support during save.


## [0.5.23] - 2024-07-26

Expand All @@ -39,7 +42,6 @@ entries will not return information about array properties).
### Added
- Rolled forward change to improve TensorStore I/O efficiency.
- Memory efficient broadcasting from one model replica to others.
- Ability to check if a checkpoint save is in progress.

### Changed
- Allow one directory creation request per item rather than 1 per item per host.
Expand Down Expand Up @@ -76,6 +78,7 @@ entries will not return information about array properties).
### Fixed
- Rolled back change in previous release to improve TensorStore I/O efficiency.
This change caused some unexpected failures on certain storage systems.
- Add memory-based rate limiting support during save.

## [0.5.16] - 2024-06-11

Expand Down
1 change: 1 addition & 0 deletions checkpoint/orbax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from orbax.checkpoint import msgpack_utils
from orbax.checkpoint import multihost
from orbax.checkpoint import path
from orbax.checkpoint import serialization
from orbax.checkpoint import test_utils
from orbax.checkpoint import transform_utils
from orbax.checkpoint import tree
Expand Down
1 change: 1 addition & 0 deletions checkpoint/orbax/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from orbax.checkpoint import msgpack_utils
from orbax.checkpoint import multihost
from orbax.checkpoint import path
from orbax.checkpoint import serialization
from orbax.checkpoint import test_utils
from orbax.checkpoint import transform_utils
from orbax.checkpoint import tree
Expand Down
36 changes: 17 additions & 19 deletions checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import future
from orbax.checkpoint import multihost
from orbax.checkpoint import serialization
from orbax.checkpoint import tree as tree_utils
from orbax.checkpoint import type_handlers
from orbax.checkpoint import utils
Expand All @@ -58,22 +59,10 @@
get_param_names = tree_utils.get_param_names

METADATA_FILE = '_METADATA'
DEFAULT_CONCURRENT_GB = 96



def get_byte_limiter(concurrent_gb: int):
async def _create_byte_limiter():
# Wrap creation in async function to avoid issues on python<=3.9.
concurrent_bytes = concurrent_gb * 10**9
# Construction must take place here so that it is within the same async
# method, to prevent errors resulting from different event loops, and
# cannot be created below this level because there must be a single object
# for the entire restore call.
return LimitInFlightBytes(concurrent_bytes) # pylint: disable=protected-access

return asyncio.run(_create_byte_limiter())


async def _create_param_save_dir(param_info: ParamInfo, args: SaveArgs):
# Directory will be unused.
path = param_info.path
Expand Down Expand Up @@ -234,7 +223,8 @@ class BasePyTreeCheckpointHandler(

def __init__(
self,
concurrent_gb: int = 96,
*,
concurrent_bytes: Optional[int] = None,
use_ocdbt: bool = True,
use_zarr3: bool = False,
primary_host: Optional[int] = 0,
Expand All @@ -243,8 +233,9 @@ def __init__(
"""Creates BasePyTreeCheckpointHandler.
Args:
concurrent_gb: max concurrent GB that are allowed to be read. Can help to
reduce the possibility of OOM's when large checkpoints are restored.
concurrent_bytes: max concurrent bytes that are allowed to be written or
read. Can help to reduce the possibility of OOM's when large checkpoints
are restored.
use_ocdbt: Whether to use OCDBT format for saving.
use_zarr3: If True, use Zarr ver3 otherwise Zarr ver2.
primary_host: the host id of the primary host. Default to 0. If it's set
Expand All @@ -254,7 +245,9 @@ def __init__(
specified, the global type handler registry will be used. # BEGIN
enable_descriptor: If True, logs a Descriptor proto that contains lineage
"""
self._concurrent_gb = concurrent_gb
if concurrent_bytes is None:
concurrent_bytes = DEFAULT_CONCURRENT_GB * 10**9
self._concurrent_bytes = concurrent_bytes
self._use_ocdbt = use_ocdbt
self._use_zarr3 = use_zarr3
self._primary_host = primary_host
Expand Down Expand Up @@ -395,11 +388,13 @@ async def async_save(
raise ValueError('`ocdbt_target_data_file_size` only works with Zarr3')

save_args = _fill_missing_save_or_restore_args(item, save_args, mode='save')
byte_limiter = serialization.get_byte_limiter(self._concurrent_bytes)
param_infos = self._get_param_infos(
item,
directory,
use_ocdbt=self._use_ocdbt,
ocdbt_target_data_file_size=ocdbt_target_data_file_size,
byte_limiter=byte_limiter,
)
assert all(
leaf.parent_dir == directory
Expand Down Expand Up @@ -471,6 +466,11 @@ async def _maybe_deserialize(
) -> PyTree:
"""Deserializes values or skips."""
flat_metadata = tree_utils.to_flat_dict(metadata)
byte_limiter = serialization.get_byte_limiter(self._concurrent_bytes)
param_infos = jax.tree.map(
lambda info: dataclasses.replace(info, byte_limiter=byte_limiter),
param_infos,
)
batch_requests = batched_serialization_requests(
metadata,
param_infos,
Expand Down Expand Up @@ -596,7 +596,6 @@ class TrainState:
raise FileNotFoundError(
f'Requested directory for restore does not exist at {directory}'
)
byte_limiter = get_byte_limiter(self._concurrent_gb)
metadata = self._read_metadata_file(directory)
use_zarr3_metadata = metadata.use_zarr3
metadata = metadata.as_nested_tree(keep_empty_nodes=True)
Expand All @@ -620,7 +619,6 @@ class TrainState:
directory,
use_ocdbt=type_handlers.is_ocdbt_checkpoint(directory),
use_zarr3=use_zarr3,
byte_limiter=byte_limiter,
)
restored_item = asyncio.run(
self._maybe_deserialize(item, metadata, param_infos, restore_args)
Expand Down
19 changes: 12 additions & 7 deletions checkpoint/orbax/checkpoint/pytree_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from orbax.checkpoint import base_pytree_checkpoint_handler
from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import future
from orbax.checkpoint import serialization
from orbax.checkpoint import transform_utils
from orbax.checkpoint import tree as tree_utils
from orbax.checkpoint import type_handlers
Expand Down Expand Up @@ -63,13 +64,12 @@
)
BasePyTreeSaveArgs = base_pytree_checkpoint_handler.BasePyTreeSaveArgs
BasePyTreeRestoreArgs = base_pytree_checkpoint_handler.BasePyTreeRestoreArgs
get_byte_limiter = base_pytree_checkpoint_handler.get_byte_limiter
LimitInFlightBytes = base_pytree_checkpoint_handler.LimitInFlightBytes
get_param_names = base_pytree_checkpoint_handler.get_param_names

_CHECKPOINT_FILE = 'checkpoint'
_METADATA_FILE = base_pytree_checkpoint_handler.METADATA_FILE
_DEFAULT_CONCURRENT_GB = 96
DEFAULT_CONCURRENT_GB = base_pytree_checkpoint_handler.DEFAULT_CONCURRENT_GB


def _maybe_set_default_restore_args(args):
Expand Down Expand Up @@ -487,15 +487,17 @@ def __init__(
aggregate_filename = _CHECKPOINT_FILE
self._aggregate_filename = aggregate_filename
if concurrent_gb is None:
concurrent_gb = _DEFAULT_CONCURRENT_GB
self._concurrent_gb = concurrent_gb
concurrent_bytes = DEFAULT_CONCURRENT_GB * 10**9
else:
concurrent_bytes = concurrent_gb * 10**9
self._use_ocdbt = use_ocdbt
self._use_zarr3 = use_zarr3
self._primary_host = primary_host
self._type_handler_registry = type_handler_registry

self._concurrent_bytes = concurrent_bytes
self._handler_impl = handler_impl or BasePyTreeCheckpointHandler(
concurrent_gb=concurrent_gb,
concurrent_bytes=self._concurrent_bytes,
use_ocdbt=use_ocdbt,
use_zarr3=use_zarr3,
primary_host=primary_host,
Expand Down Expand Up @@ -579,6 +581,11 @@ async def _maybe_deserialize(
restore_args: PyTree,
) -> PyTree:
"""Deserializes values or gets them from the aggregate file."""
byte_limiter = serialization.get_byte_limiter(self._concurrent_bytes)
param_infos = jax.tree.map(
lambda info: dataclasses.replace(info, byte_limiter=byte_limiter),
param_infos,
)

# Handle parameters from aggregate file.
def _process_aggregated_value(meta_or_value, args):
Expand Down Expand Up @@ -766,7 +773,6 @@ class TrainState:
raise FileNotFoundError(
f'Requested directory for restore does not exist at {directory}'
)
byte_limiter = get_byte_limiter(self._concurrent_gb)
structure, use_zarr3_metadata = self._get_internal_metadata(directory)
# `checkpoint_restore_args` has a structure relative to the checkpoint,
# while `restore_args` remains structured relative to the output.
Expand All @@ -778,7 +784,6 @@ class TrainState:
self._handler_impl.get_param_names(structure),
transforms,
restore_args,
byte_limiter=byte_limiter,
transforms_default_to_original=transforms_default_to_original,
use_zarr3=use_zarr3_metadata
if use_zarr3_metadata is not None
Expand Down
Loading

0 comments on commit 484fbb2

Please sign in to comment.