Skip to content

Commit

Permalink
Improve blocking save performance by ~50% by moving file open calls (…
Browse files Browse the repository at this point in the history
…via Tensorstore) into the background thread. Note that the level of this improvement may be highly variable, since the baseline time depends on bandwidth of writes to the filesystem.

PiperOrigin-RevId: 656510748
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed Jul 26, 2024
1 parent 8750e13 commit e1703f9
Show file tree
Hide file tree
Showing 7 changed files with 256 additions and 151 deletions.
4 changes: 4 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Changed
- Improvements to blocking save time, as a result of moving file open operations
into the background.

## [0.5.23] - 2024-07-26

### Changed
Expand Down
36 changes: 21 additions & 15 deletions checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,25 @@ def _param_info(name, value):
_param_info, names, item, is_leaf=utils.is_empty_or_leaf
)

async def _maybe_create_param_directories(
self, param_infos: PyTree, save_args: PyTree
):
if not self._use_ocdbt:
if multihost.is_primary_host(self._primary_host):
# Create directories in parallel.
await asyncio.gather(
*jax.tree.flatten(
jax.tree.map(
_create_param_save_dir,
param_infos,
save_args,
)
)[0]
)
multihost.sync_global_processes(
'PyTreeCheckpointHandler:create_param_save_dirs'
)

async def async_save(
self,
directory: epath.Path,
Expand Down Expand Up @@ -382,21 +401,8 @@ async def async_save(
leaf.parent_dir == directory
for leaf in jax.tree.leaves(param_infos)
)
if not self._use_ocdbt:
if multihost.is_primary_host(self._primary_host):
# Create directories in parallel.
await asyncio.gather(
*jax.tree.flatten(
jax.tree.map(
_create_param_save_dir,
param_infos,
save_args,
)
)[0]
)
multihost.sync_global_processes(
'PyTreeCheckpointHandler:create_param_save_dirs'
)
await self._maybe_create_param_directories(param_infos, save_args)

serialize_ops = []
batch_requests = batched_serialization_requests(
item,
Expand Down
9 changes: 6 additions & 3 deletions checkpoint/orbax/checkpoint/composite_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@

from absl import logging
from etils import epath
import jax
import nest_asyncio
from orbax.checkpoint import async_checkpoint_handler
from orbax.checkpoint import checkpoint_args
Expand Down Expand Up @@ -376,7 +377,6 @@ async def async_save(
self, directory: epath.Path, args: 'CompositeArgs'
) -> Optional[List[Future]]:
"""Saves multiple items to individual subdirectories."""
futures = []
# Sort keys to maintain consistent ordering across processes, otherwise
# we may hit timeouts if processes wait at different barriers in per-item
# handlers.
Expand All @@ -399,15 +399,18 @@ async def async_save(
for path in self._current_temporary_paths.values():
path.create()

save_ops = []
for item_name, item_directory in self._current_temporary_paths.items():
arg = args[item_name]
_maybe_raise_reserved_item_error(item_name)
handler = self._get_or_set_handler(item_name, arg)
if isinstance(handler, AsyncCheckpointHandler):
futures.extend(await handler.async_save(item_directory.get(), args=arg))
save_ops.append(handler.async_save(item_directory.get(), args=arg))
else:
# Blocking save.
handler.save(item_directory.get(), args=arg)
return futures

return jax.tree.flatten(await asyncio.gather(*save_ops))[0]

def save(self, *args, **kwargs):
"""Saves synchronously."""
Expand Down
17 changes: 17 additions & 0 deletions checkpoint/orbax/checkpoint/future.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Orbax Future class used for duck typing."""

import threading
from typing import Any, Optional
from typing_extensions import Protocol

Expand Down Expand Up @@ -41,3 +42,19 @@ class NoopFuture:
def result(self, timeout: Optional[int] = None) -> Any:
del timeout
return None


class ThreadRaisingException(threading.Thread):
"""Thread that raises an exception if it encounters an error."""
_exception: Optional[Exception] = None

def run(self):
try:
super().run()
except Exception as e: # pylint: disable=broad-exception-caught
self._exception = e

def join(self, timeout=None):
super().join(timeout=timeout)
if self._exception is not None:
raise self._exception
142 changes: 100 additions & 42 deletions checkpoint/orbax/checkpoint/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import asyncio
from collections.abc import Awaitable
import dataclasses
import functools
import os
import re
Expand Down Expand Up @@ -181,6 +182,14 @@ async def release_bytes(self, requested_bytes):
self._cv.notify_all()


@dataclasses.dataclass
class Shards:
shards: Sequence[jax.Shard]
shape: Shape
dtype: jnp.dtype
sharding: jax.sharding.Sharding


async def transfer_shard_to_host(shard: jax.Shard) -> np.ndarray:
"""Asynchronously transfers a shard to host memory."""
data = shard.data
Expand All @@ -202,6 +211,33 @@ async def transfer_shard_to_host(shard: jax.Shard) -> np.ndarray:
return np.array(data, copy=False)


async def transfer_array_to_host(arr: jax.Array, replica_id: int) -> Shards:
"""Transfers a jax.Array to host memory."""
transfer_ops = []
dedup_shards = [
shard
for shard in arr.addressable_shards
if shard.replica_id == replica_id
]
for shard in dedup_shards:
transfer_ops.append(transfer_shard_to_host(shard))
numpy_arrs = await asyncio.gather(*transfer_ops)
return Shards(
[
jax.Shard(
device=shard.device,
data=data,
sharding=arr.sharding,
global_shape=arr.shape,
)
for data, shard in zip(numpy_arrs, dedup_shards)
],
shape=arr.shape,
dtype=arr.dtype,
sharding=arr.sharding,
)


def _get_copy_future(write_future):
return write_future.copy

Expand All @@ -210,38 +246,29 @@ def _get_commit_future(write_future):
return write_future.commit


async def _write_array(
async def _write_shard(
shard: jax.Shard,
t: ts.TensorStore,
commit_future: Optional[list[Any]],
replica_id: int,
can_reference_source_data_indefinitely: bool,
):
"""Writes a single array using TensorStore."""
if shard.replica_id == replica_id:
data = await transfer_shard_to_host(shard)
write_future = t[shard.index].write(
data,
# Avoid additional copy of input array into the TensorStore chunk
# cache. If `arr_inp` is a jax.Array, the result of converting
# it to a NumPy array, as is done internally by TensorStore, is
# guaranteed to be immutable and therefore it is safe to retain a
# reference indefinitely.
can_reference_source_data_indefinitely=can_reference_source_data_indefinitely,
)
if commit_future is not None:
assert isinstance(commit_future, list)
commit_future.append(_get_commit_future(write_future))
await _get_copy_future(write_future)
else:
await _get_commit_future(write_future)
"""Writes a single array using TensorStore. No copy is performed."""
assert shard.replica_id == replica_id
assert isinstance(shard.data, np.ndarray)
await t[shard.index].write(
shard.data,
# Avoid additional copy of input array into the TensorStore chunk
# cache. If `arr_inp` is a jax.Array, the result of converting
# it to a NumPy array, as is done internally by TensorStore, is
# guaranteed to be immutable and therefore it is safe to retain a
# reference indefinitely.
can_reference_source_data_indefinitely=True,
)


async def async_serialize(
arr_inp,
tensorstore_spec,
commit_future=None,
context=TS_CONTEXT,
arr_inp: jax.Array,
tensorstore_spec: Dict[str, Any],
context: ts.Context = TS_CONTEXT,
primary_host: Optional[int] = 0,
replica_id: int = 0,
transaction: Optional[ts.Transaction] = None,
Expand All @@ -251,9 +278,6 @@ async def async_serialize(
Args:
arr_inp: The array to serialize.
tensorstore_spec: The tensorstore spec to use.
commit_future: A list of futures that will be appended to. The futures can
be awaited asynchronously. If None, the futures will be awaited
synchronously by this method.
context: ts.Context instance.
primary_host: Primary host, which indicates the host that will be treated as
the "leader". If None, all hosts are treated as the primary. DO NOT USE
Expand Down Expand Up @@ -283,22 +307,59 @@ async def async_serialize(
if 'dtype' not in tensorstore_spec:
tensorstore_spec['dtype'] = jnp.dtype(arr_inp.dtype).name

host_shards = await transfer_array_to_host(arr_inp, replica_id)
await async_serialize_shards(
host_shards,
tensorstore_spec,
context=context,
primary_host=primary_host,
replica_id=replica_id,
transaction=transaction,
)


async def async_serialize_shards(
shards: Shards,
tensorstore_spec: Dict[str, Any],
*,
context: ts.Context = TS_CONTEXT,
primary_host: Optional[int] = 0,
replica_id: int = 0,
transaction: Optional[ts.Transaction] = None,
):
"""Serialize a host-local shards using TensorStore.
Args:
shards: A Shards object. Individual shards are expected to be host-local.
tensorstore_spec: The tensorstore spec to use.
context: ts.Context instance.
primary_host: Primary host, which indicates the host that will be treated as
the "leader". If None, all hosts are treated as the primary. DO NOT USE
unless you are sure you know what you are doing.
replica_id: Allows overriding the shard replica id that will be saved. DO
NOT USE unless you are sure you know what you are doing.
transaction: TensorStore transaction to use for opening and writing the
array. If not specified, a non-transactional write will be used.
Raises:
KeyError: If `metadata` or `dtype` is not found in the tensorstore spec.
"""
if not _spec_has_metadata(tensorstore_spec):
raise KeyError('`metadata` not found in tensorstore spec.')
# Set dtype if it's not in spec
if 'dtype' not in tensorstore_spec:
raise KeyError('`dtype` not found in tensorstore spec.')

# If primary_host is None, all hosts will checkpoint. This is used
# for checkpointing to local filesystem.
if primary_host is None or multihost.process_index() == primary_host:
open_future = ts.open(
await ts.open(
ts.Spec(tensorstore_spec),
create=True,
open=True,
context=context,
transaction=transaction,
)
# Asynchronous case.
if commit_future is not None:
assert isinstance(commit_future, list)
commit_future.append(open_future)
else:
await open_future

# `ts.open` runs twice for process `primary_host` because for the first time,
# we just get the future to be awaited upon in the background thread. The
Expand All @@ -313,18 +374,15 @@ async def async_serialize(
context=context,
transaction=transaction,
)
local_shards = arr_inp.addressable_shards
future_write_state = jax.tree_util.tree_map(
write_shard_coros = jax.tree.map(
functools.partial(
_write_array,
_write_shard,
t=t,
commit_future=commit_future,
replica_id=replica_id,
can_reference_source_data_indefinitely=isinstance(arr_inp, jax.Array),
),
local_shards,
shards.shards,
)
await asyncio.gather(*future_write_state)
await asyncio.gather(*write_shard_coros)


def estimate_read_memory_footprint(t: ts.TensorStore,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def handler(self) -> StandardCheckpointHandler:

def test_basic(self):
self.handler.save(self.directory, args=self.save_args_cls(self.pytree))
test_utils.print_directory(self.directory)
self.assertTrue(
(self.directory / type_handlers._OCDBT_MANIFEST_FILE).exists() # pylint: disable=protected-access
)
Expand Down Expand Up @@ -143,17 +142,6 @@ def test_shape_dtype_struct(self):
)
test_utils.assert_tree_equal(self, self.mixed_pytree, restored)

def test_save_aggregate(self):
def _save_args(arr):
return SaveArgs(aggregate=(np.asarray(arr).ndim < 2))

save_args = jax.tree.map(_save_args, self.numpy_pytree)
with self.assertRaisesRegex(ValueError, 'Unsupported option `aggregate`'):
self.handler.save(
self.directory,
args=self.save_args_cls(self.numpy_pytree, save_args=save_args),
)

def test_save_unsupported_type(self):
pytree = {'str_key': 'str_value', **self.pytree}
with self.assertRaisesRegex(ValueError, 'Unsupported type'):
Expand Down
Loading

0 comments on commit e1703f9

Please sign in to comment.