Skip to content

Commit

Permalink
Parallelize.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 650747775
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed Jul 11, 2024
1 parent 1e06498 commit a7c32e2
Showing 1 changed file with 67 additions and 20 deletions.
87 changes: 67 additions & 20 deletions checkpoint/orbax/checkpoint/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import asyncio
from collections.abc import Awaitable
import functools
import os
import re
from typing import Any, Callable, Dict, Optional, Sequence, Union
Expand Down Expand Up @@ -180,6 +181,62 @@ async def release_bytes(self, requested_bytes):
self._cv.notify_all()


async def transfer_shard_to_host(shard: jax.Shard) -> np.ndarray:
"""Asynchronously transfers a shard to host memory."""
data = shard.data
has_pinned_host = any(
m.kind == 'pinned_host' for m in shard.device.addressable_memories()
)
if jax._src.config.enable_memories.value and has_pinned_host: # pylint: disable=protected-access
# If available, transfer to pinned host memory
sharding = jax.sharding.SingleDeviceSharding(
shard.device, memory_kind='pinned_host'
)
data = jax.device_put(data, sharding)
else:
data.copy_to_host_async()
# Allow other transfers to be scheduled simultaneously.
await asyncio.sleep(0)
# Ensure that jax.Array's internal numpy array can be zero-copied. This guards
# against consumers like tensorstore that would otherwise copy silently.
return np.array(data, copy=False)


def _get_copy_future(write_future):
return write_future.copy


def _get_commit_future(write_future):
return write_future.commit


async def _write_array(
shard: jax.Shard,
t: ts.TensorStore,
commit_future: Optional[list[Any]],
replica_id: int,
can_reference_source_data_indefinitely: bool,
):
"""Writes a single array using TensorStore."""
if shard.replica_id == replica_id:
data = await transfer_shard_to_host(shard)
write_future = t[shard.index].write(
data,
# Avoid additional copy of input array into the TensorStore chunk
# cache. If `arr_inp` is a jax.Array, the result of converting
# it to a NumPy array, as is done internally by TensorStore, is
# guaranteed to be immutable and therefore it is safe to retain a
# reference indefinitely.
can_reference_source_data_indefinitely=can_reference_source_data_indefinitely,
)
if commit_future is not None:
assert isinstance(commit_future, list)
commit_future.append(_get_commit_future(write_future))
await _get_copy_future(write_future)
else:
await _get_commit_future(write_future)


async def async_serialize(
arr_inp,
tensorstore_spec,
Expand Down Expand Up @@ -256,27 +313,17 @@ async def async_serialize(
context=context,
transaction=transaction,
)

async def _write_array(shard):
if shard.replica_id == replica_id:
write_future = t[shard.index].write(
shard.data,
# Avoid additional copy of input array into the TensorStore chunk
# cache. If `arr_inp` is a jax.Array, the result of converting
# it to a NumPy array, as is done internally by TensorStore, is
# guaranteed to be immutable and therefore it is safe to retain a
# reference indefinitely.
can_reference_source_data_indefinitely=isinstance(arr_inp, jax.Array),
)
if commit_future is not None:
assert isinstance(commit_future, list)
commit_future.append(write_future.commit)
await write_future.copy
else:
await write_future.commit

local_shards = arr_inp.addressable_shards
future_write_state = jax.tree_util.tree_map(_write_array, local_shards)
future_write_state = jax.tree_util.tree_map(
functools.partial(
_write_array,
t=t,
commit_future=commit_future,
replica_id=replica_id,
can_reference_source_data_indefinitely=isinstance(arr_inp, jax.Array),
),
local_shards,
)
await asyncio.gather(*future_write_state)


Expand Down

0 comments on commit a7c32e2

Please sign in to comment.