diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 0d2f69362..117a9e281 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.5.21] - 2024-07-12 + ### Added - Rolled forward change to improve TensorStore I/O efficiency. - Memory efficient broadcasting from one model replica to others. @@ -24,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Add logic to the barrier-compatible test fixture that allows each test case to have its own module-level counter, to avoid problems that arise when multiprocess tests run in inconsistent orders. +- Ensure D2H transfers are parallelized. ## [0.5.20] - 2024-06-20 diff --git a/checkpoint/orbax/__init__.py b/checkpoint/orbax/__init__.py index 600c6ca40..1b6ddd412 100644 --- a/checkpoint/orbax/__init__.py +++ b/checkpoint/orbax/__init__.py @@ -74,4 +74,4 @@ # A new PyPI release will be pushed everytime `__version__` is increased. # Also modify version and date in CHANGELOG. -__version__ = '0.5.20' +__version__ = '0.5.21' diff --git a/checkpoint/orbax/checkpoint/__init__.py b/checkpoint/orbax/checkpoint/__init__.py index 600c6ca40..1b6ddd412 100644 --- a/checkpoint/orbax/checkpoint/__init__.py +++ b/checkpoint/orbax/checkpoint/__init__.py @@ -74,4 +74,4 @@ # A new PyPI release will be pushed everytime `__version__` is increased. # Also modify version and date in CHANGELOG. -__version__ = '0.5.20' +__version__ = '0.5.21' diff --git a/checkpoint/orbax/checkpoint/serialization.py b/checkpoint/orbax/checkpoint/serialization.py index 023483df9..4d436b43f 100644 --- a/checkpoint/orbax/checkpoint/serialization.py +++ b/checkpoint/orbax/checkpoint/serialization.py @@ -19,6 +19,7 @@ import asyncio from collections.abc import Awaitable +import functools import os import re from typing import Any, Callable, Dict, Optional, Sequence, Union @@ -180,6 +181,62 @@ async def release_bytes(self, requested_bytes): self._cv.notify_all() +async def transfer_shard_to_host(shard: jax.Shard) -> np.ndarray: + """Asynchronously transfers a shard to host memory.""" + data = shard.data + has_pinned_host = any( + m.kind == 'pinned_host' for m in shard.device.addressable_memories() + ) + if jax._src.config.enable_memories.value and has_pinned_host: # pylint: disable=protected-access + # If available, transfer to pinned host memory + sharding = jax.sharding.SingleDeviceSharding( + shard.device, memory_kind='pinned_host' + ) + data = jax.device_put(data, sharding) + else: + data.copy_to_host_async() + # Allow other transfers to be scheduled simultaneously. + await asyncio.sleep(0) + # Ensure that jax.Array's internal numpy array can be zero-copied. This guards + # against consumers like tensorstore that would otherwise copy silently. + return np.array(data, copy=False) + + +def _get_copy_future(write_future): + return write_future.copy + + +def _get_commit_future(write_future): + return write_future.commit + + +async def _write_array( + shard: jax.Shard, + t: ts.TensorStore, + commit_future: Optional[list[Any]], + replica_id: int, + can_reference_source_data_indefinitely: bool, +): + """Writes a single array using TensorStore.""" + if shard.replica_id == replica_id: + data = await transfer_shard_to_host(shard) + write_future = t[shard.index].write( + data, + # Avoid additional copy of input array into the TensorStore chunk + # cache. If `arr_inp` is a jax.Array, the result of converting + # it to a NumPy array, as is done internally by TensorStore, is + # guaranteed to be immutable and therefore it is safe to retain a + # reference indefinitely. + can_reference_source_data_indefinitely=can_reference_source_data_indefinitely, + ) + if commit_future is not None: + assert isinstance(commit_future, list) + commit_future.append(_get_commit_future(write_future)) + await _get_copy_future(write_future) + else: + await _get_commit_future(write_future) + + async def async_serialize( arr_inp, tensorstore_spec, @@ -256,27 +313,17 @@ async def async_serialize( context=context, transaction=transaction, ) - - async def _write_array(shard): - if shard.replica_id == replica_id: - write_future = t[shard.index].write( - shard.data, - # Avoid additional copy of input array into the TensorStore chunk - # cache. If `arr_inp` is a jax.Array, the result of converting - # it to a NumPy array, as is done internally by TensorStore, is - # guaranteed to be immutable and therefore it is safe to retain a - # reference indefinitely. - can_reference_source_data_indefinitely=isinstance(arr_inp, jax.Array), - ) - if commit_future is not None: - assert isinstance(commit_future, list) - commit_future.append(write_future.commit) - await write_future.copy - else: - await write_future.commit - local_shards = arr_inp.addressable_shards - future_write_state = jax.tree_util.tree_map(_write_array, local_shards) + future_write_state = jax.tree_util.tree_map( + functools.partial( + _write_array, + t=t, + commit_future=commit_future, + replica_id=replica_id, + can_reference_source_data_indefinitely=isinstance(arr_inp, jax.Array), + ), + local_shards, + ) await asyncio.gather(*future_write_state)