Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure D2H copies are parallelized, following https://github.com/google/jax/pull/22169. #997

Merged
merged 1 commit into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [0.5.21] - 2024-07-12

### Added
- Rolled forward change to improve TensorStore I/O efficiency.
- Memory efficient broadcasting from one model replica to others.
Expand All @@ -24,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
Add logic to the barrier-compatible test fixture that allows each test case to
have its own module-level counter, to avoid problems that arise when
multiprocess tests run in inconsistent orders.
- Ensure D2H transfers are parallelized.

## [0.5.20] - 2024-06-20

Expand Down
2 changes: 1 addition & 1 deletion checkpoint/orbax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,4 @@

# A new PyPI release will be pushed everytime `__version__` is increased.
# Also modify version and date in CHANGELOG.
__version__ = '0.5.20'
__version__ = '0.5.21'
2 changes: 1 addition & 1 deletion checkpoint/orbax/checkpoint/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,4 @@

# A new PyPI release will be pushed everytime `__version__` is increased.
# Also modify version and date in CHANGELOG.
__version__ = '0.5.20'
__version__ = '0.5.21'
87 changes: 67 additions & 20 deletions checkpoint/orbax/checkpoint/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import asyncio
from collections.abc import Awaitable
import functools
import os
import re
from typing import Any, Callable, Dict, Optional, Sequence, Union
Expand Down Expand Up @@ -180,6 +181,62 @@ async def release_bytes(self, requested_bytes):
self._cv.notify_all()


async def transfer_shard_to_host(shard: jax.Shard) -> np.ndarray:
"""Asynchronously transfers a shard to host memory."""
data = shard.data
has_pinned_host = any(
m.kind == 'pinned_host' for m in shard.device.addressable_memories()
)
if jax._src.config.enable_memories.value and has_pinned_host: # pylint: disable=protected-access
# If available, transfer to pinned host memory
sharding = jax.sharding.SingleDeviceSharding(
shard.device, memory_kind='pinned_host'
)
data = jax.device_put(data, sharding)
else:
data.copy_to_host_async()
# Allow other transfers to be scheduled simultaneously.
await asyncio.sleep(0)
# Ensure that jax.Array's internal numpy array can be zero-copied. This guards
# against consumers like tensorstore that would otherwise copy silently.
return np.array(data, copy=False)


def _get_copy_future(write_future):
return write_future.copy


def _get_commit_future(write_future):
return write_future.commit


async def _write_array(
shard: jax.Shard,
t: ts.TensorStore,
commit_future: Optional[list[Any]],
replica_id: int,
can_reference_source_data_indefinitely: bool,
):
"""Writes a single array using TensorStore."""
if shard.replica_id == replica_id:
data = await transfer_shard_to_host(shard)
write_future = t[shard.index].write(
data,
# Avoid additional copy of input array into the TensorStore chunk
# cache. If `arr_inp` is a jax.Array, the result of converting
# it to a NumPy array, as is done internally by TensorStore, is
# guaranteed to be immutable and therefore it is safe to retain a
# reference indefinitely.
can_reference_source_data_indefinitely=can_reference_source_data_indefinitely,
)
if commit_future is not None:
assert isinstance(commit_future, list)
commit_future.append(_get_commit_future(write_future))
await _get_copy_future(write_future)
else:
await _get_commit_future(write_future)


async def async_serialize(
arr_inp,
tensorstore_spec,
Expand Down Expand Up @@ -256,27 +313,17 @@ async def async_serialize(
context=context,
transaction=transaction,
)

async def _write_array(shard):
if shard.replica_id == replica_id:
write_future = t[shard.index].write(
shard.data,
# Avoid additional copy of input array into the TensorStore chunk
# cache. If `arr_inp` is a jax.Array, the result of converting
# it to a NumPy array, as is done internally by TensorStore, is
# guaranteed to be immutable and therefore it is safe to retain a
# reference indefinitely.
can_reference_source_data_indefinitely=isinstance(arr_inp, jax.Array),
)
if commit_future is not None:
assert isinstance(commit_future, list)
commit_future.append(write_future.commit)
await write_future.copy
else:
await write_future.commit

local_shards = arr_inp.addressable_shards
future_write_state = jax.tree_util.tree_map(_write_array, local_shards)
future_write_state = jax.tree_util.tree_map(
functools.partial(
_write_array,
t=t,
commit_future=commit_future,
replica_id=replica_id,
can_reference_source_data_indefinitely=isinstance(arr_inp, jax.Array),
),
local_shards,
)
await asyncio.gather(*future_write_state)


Expand Down
Loading