Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 658514298
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed Aug 1, 2024
1 parent 53250b3 commit fca60dd
Show file tree
Hide file tree
Showing 7 changed files with 278 additions and 184 deletions.
1 change: 1 addition & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed
- Improve logging by adding jax_process, error logs in threads and more...
- Improvements to blocking save time, as a result of moving file open operations into the background.


## [0.5.23] - 2024-07-26
Expand Down
36 changes: 21 additions & 15 deletions checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,25 @@ def _param_info(name, value):
_param_info, names, item, is_leaf=utils.is_empty_or_leaf
)

async def _maybe_create_param_directories(
self, param_infos: PyTree, save_args: PyTree
):
if not self._use_ocdbt:
if multihost.is_primary_host(self._primary_host):
# Create directories in parallel.
await asyncio.gather(
*jax.tree.flatten(
jax.tree.map(
_create_param_save_dir,
param_infos,
save_args,
)
)[0]
)
multihost.sync_global_processes(
'PyTreeCheckpointHandler:create_param_save_dirs'
)

async def async_save(
self,
directory: epath.Path,
Expand Down Expand Up @@ -382,21 +401,8 @@ async def async_save(
leaf.parent_dir == directory
for leaf in jax.tree.leaves(param_infos)
)
if not self._use_ocdbt:
if multihost.is_primary_host(self._primary_host):
# Create directories in parallel.
await asyncio.gather(
*jax.tree.flatten(
jax.tree.map(
_create_param_save_dir,
param_infos,
save_args,
)
)[0]
)
multihost.sync_global_processes(
'PyTreeCheckpointHandler:create_param_save_dirs'
)
await self._maybe_create_param_directories(param_infos, save_args)

serialize_ops = []
batch_requests = batched_serialization_requests(
item,
Expand Down
14 changes: 7 additions & 7 deletions checkpoint/orbax/checkpoint/composite_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@

from absl import logging
from etils import epath
import jax
import nest_asyncio
from orbax.checkpoint import async_checkpoint_handler
from orbax.checkpoint import checkpoint_args
Expand Down Expand Up @@ -376,7 +377,6 @@ async def async_save(
self, directory: epath.Path, args: 'CompositeArgs'
) -> Optional[List[Future]]:
"""Saves multiple items to individual subdirectories."""
futures = []
# Sort keys to maintain consistent ordering across processes, otherwise
# we may hit timeouts if processes wait at different barriers in per-item
# handlers.
Expand All @@ -399,19 +399,19 @@ async def async_save(
for path in self._current_temporary_paths.values():
path.create()

save_ops = []
for item_name, item_directory in self._current_temporary_paths.items():
arg = args[item_name]
_maybe_raise_reserved_item_error(item_name)
handler = self._get_or_set_handler(item_name, arg)
if isinstance(handler, AsyncCheckpointHandler):
commit_futures = await handler.async_save(
item_directory.get(), args=arg
)
if commit_futures is not None:
futures.extend(commit_futures)
save_ops.append(handler.async_save(item_directory.get(), args=arg))
else:
# Blocking save.
handler.save(item_directory.get(), args=arg)
return futures

commit_futures = jax.tree.flatten(await asyncio.gather(*save_ops))[0]
return commit_futures or []

def save(self, *args, **kwargs):
"""Saves synchronously."""
Expand Down
17 changes: 17 additions & 0 deletions checkpoint/orbax/checkpoint/future.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Orbax Future class used for duck typing."""

import threading
from typing import Any, Optional
from typing_extensions import Protocol

Expand Down Expand Up @@ -41,3 +42,19 @@ class NoopFuture:
def result(self, timeout: Optional[int] = None) -> Any:
del timeout
return None


class ThreadRaisingException(threading.Thread):
"""Thread that raises an exception if it encounters an error."""
_exception: Optional[Exception] = None

def run(self):
try:
super().run()
except Exception as e: # pylint: disable=broad-exception-caught
self._exception = e

def join(self, timeout=None):
super().join(timeout=timeout)
if self._exception is not None:
raise self._exception
183 changes: 121 additions & 62 deletions checkpoint/orbax/checkpoint/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@

import asyncio
from collections.abc import Awaitable
import dataclasses
import functools
import os
import re
from typing import Any, Callable, Dict, Optional, Sequence, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Union

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -181,8 +182,30 @@ async def release_bytes(self, requested_bytes):
self._cv.notify_all()


async def transfer_shard_to_host(shard: jax.Shard) -> np.ndarray:
"""Asynchronously transfers a shard to host memory."""
@dataclasses.dataclass
class Shards:
"""Basic representation of host-local set of shards."""
shards: List[jax.Shard]
shape: Shape
dtype: jnp.dtype
sharding: jax.sharding.Sharding

def block_until_ready(self):
jax.block_until_ready([shard.data for shard in self.shards])
# Ensure that jax.Array's internal numpy array can be zero-copied. This
# guards against consumers like tensorstore that would otherwise copy
# silently.
for i, shard in enumerate(self.shards):
self.shards[i] = jax.Shard(
device=shard.device,
data=np.array(shard.data, copy=False),
sharding=self.sharding,
global_shape=self.shape,
)


def _transfer_shard_to_host(shard: jax.Shard) -> jax.Array:
"""Asynchronously transfers a shard to host memory. Does not block."""
data = shard.data
has_pinned_host = any(
m.kind == 'pinned_host' for m in shard.device.addressable_memories()
Expand All @@ -195,11 +218,34 @@ async def transfer_shard_to_host(shard: jax.Shard) -> np.ndarray:
data = jax.device_put(data, sharding)
else:
data.copy_to_host_async()
# Allow other transfers to be scheduled simultaneously.
await asyncio.sleep(0)
# Ensure that jax.Array's internal numpy array can be zero-copied. This guards
# against consumers like tensorstore that would otherwise copy silently.
return np.array(data, copy=False)
return data


def transfer_array_to_host(arr: jax.Array, replica_id: int) -> Shards:
"""Transfers a jax.Array to host memory."""
shard_data = []
dedup_shards = [
shard
for shard in arr.addressable_shards
if shard.replica_id == replica_id
]
for shard in dedup_shards:
shard_data.append(_transfer_shard_to_host(shard))

return Shards(
[
jax.Shard(
device=shard.device,
data=data,
sharding=arr.sharding,
global_shape=arr.shape,
)
for data, shard in zip(shard_data, dedup_shards)
],
shape=arr.shape,
dtype=arr.dtype,
sharding=arr.sharding,
)


def _get_copy_future(write_future):
Expand All @@ -210,50 +256,40 @@ def _get_commit_future(write_future):
return write_future.commit


async def _write_array(
async def _write_shard(
shard: jax.Shard,
t: ts.TensorStore,
commit_future: Optional[list[Any]],
replica_id: int,
can_reference_source_data_indefinitely: bool,
):
"""Writes a single array using TensorStore."""
if shard.replica_id == replica_id:
data = await transfer_shard_to_host(shard)
write_future = t[shard.index].write(
data,
# Avoid additional copy of input array into the TensorStore chunk
# cache. If `arr_inp` is a jax.Array, the result of converting
# it to a NumPy array, as is done internally by TensorStore, is
# guaranteed to be immutable and therefore it is safe to retain a
# reference indefinitely.
can_reference_source_data_indefinitely=can_reference_source_data_indefinitely,
)
if commit_future is not None:
assert isinstance(commit_future, list)
commit_future.append(_get_commit_future(write_future))
await _get_copy_future(write_future)
else:
await _get_commit_future(write_future)
"""Writes a single array using TensorStore. No copy is performed."""
assert shard.replica_id == replica_id
assert isinstance(shard.data, np.ndarray)
await t[shard.index].write(
shard.data,
# Avoid additional copy of input array into the TensorStore chunk
# cache. The data array of a shard is guaranteed to be immutable and
# therefore it is safe to retain a reference indefinitely.
can_reference_source_data_indefinitely=True,
)


async def async_serialize(
arr_inp,
tensorstore_spec,
commit_future=None,
context=TS_CONTEXT,
arr_inp: jax.Array,
tensorstore_spec: Dict[str, Any],
context: ts.Context = TS_CONTEXT,
primary_host: Optional[int] = 0,
replica_id: int = 0,
transaction: Optional[ts.Transaction] = None,
):
"""Serialize an array using TensorStore.
Performs a D2H transfer of the array. Prefer to use `async_serialize_shards`
by separately performing a D2H transfer, and then starting the serialization
in a background thread.
Args:
arr_inp: The array to serialize.
tensorstore_spec: The tensorstore spec to use.
commit_future: A list of futures that will be appended to. The futures can
be awaited asynchronously. If None, the futures will be awaited
synchronously by this method.
context: ts.Context instance.
primary_host: Primary host, which indicates the host that will be treated as
the "leader". If None, all hosts are treated as the primary. DO NOT USE
Expand All @@ -263,42 +299,68 @@ async def async_serialize(
transaction: TensorStore transaction to use for opening and writing the
array. If not specified, a non-transactional write will be used.
"""
if (
isinstance(arr_inp, jax.Array)
and jax.process_count() > 1
and arr_inp.is_fully_addressable
):
raise ValueError(
f'Passing fully addressable arrays to a multiprocess '
f'serialization is not allowed, as this may lead to a race condition '
f'between processes. Serialization have failed for the array with '
f'the path "{tensorstore_spec["kvstore"]["path"]}".')

# 'metadata' may not be present at the top level (for example, if we are using
# a 'cast' driver).
if not _spec_has_metadata(tensorstore_spec):
tensorstore_spec['metadata'] = _get_metadata(arr_inp)

# Set dtype if it's not in spec
if 'dtype' not in tensorstore_spec:
tensorstore_spec['dtype'] = jnp.dtype(arr_inp.dtype).name
# Start D2H transfer in parallel for each array.
host_shards = transfer_array_to_host(arr_inp, replica_id)
host_shards.block_until_ready()
await async_serialize_shards(
host_shards,
tensorstore_spec,
context=context,
primary_host=primary_host,
replica_id=replica_id,
transaction=transaction,
)


async def async_serialize_shards(
shards: Shards,
tensorstore_spec: Dict[str, Any],
*,
context: ts.Context = TS_CONTEXT,
primary_host: Optional[int] = 0,
replica_id: int = 0,
transaction: Optional[ts.Transaction] = None,
):
"""Serialize a host-local shards using TensorStore.
Args:
shards: A Shards object. Individual shards are expected to be host-local.
tensorstore_spec: The tensorstore spec to use.
context: ts.Context instance.
primary_host: Primary host, which indicates the host that will be treated as
the "leader". If None, all hosts are treated as the primary. DO NOT USE
unless you are sure you know what you are doing.
replica_id: Allows overriding the shard replica id that will be saved. DO
NOT USE unless you are sure you know what you are doing.
transaction: TensorStore transaction to use for opening and writing the
array. If not specified, a non-transactional write will be used.
Raises:
KeyError: If `metadata` or `dtype` is not found in the tensorstore spec.
"""
if not _spec_has_metadata(tensorstore_spec):
raise KeyError('`metadata` not found in tensorstore spec.')
# Set dtype if it's not in spec
if 'dtype' not in tensorstore_spec:
raise KeyError('`dtype` not found in tensorstore spec.')

# If primary_host is None, all hosts will checkpoint. This is used
# for checkpointing to local filesystem.
if primary_host is None or multihost.process_index() == primary_host:
open_future = ts.open(
await ts.open(
ts.Spec(tensorstore_spec),
create=True,
open=True,
context=context,
transaction=transaction,
)
# Asynchronous case.
if commit_future is not None:
assert isinstance(commit_future, list)
commit_future.append(open_future)
else:
await open_future

# `ts.open` runs twice for process `primary_host` because for the first time,
# we just get the future to be awaited upon in the background thread. The
Expand All @@ -313,18 +375,15 @@ async def async_serialize(
context=context,
transaction=transaction,
)
local_shards = arr_inp.addressable_shards
future_write_state = jax.tree_util.tree_map(
write_shard_coros = jax.tree.map(
functools.partial(
_write_array,
_write_shard,
t=t,
commit_future=commit_future,
replica_id=replica_id,
can_reference_source_data_indefinitely=isinstance(arr_inp, jax.Array),
),
local_shards,
shards.shards,
)
await asyncio.gather(*future_write_state)
await asyncio.gather(*write_shard_coros)


def estimate_read_memory_footprint(t: ts.TensorStore,
Expand Down
Loading

0 comments on commit fca60dd

Please sign in to comment.