diff --git a/checkpoint/orbax/checkpoint/serialization.py b/checkpoint/orbax/checkpoint/serialization.py index 023483df9..4d436b43f 100644 --- a/checkpoint/orbax/checkpoint/serialization.py +++ b/checkpoint/orbax/checkpoint/serialization.py @@ -19,6 +19,7 @@ import asyncio from collections.abc import Awaitable +import functools import os import re from typing import Any, Callable, Dict, Optional, Sequence, Union @@ -180,6 +181,62 @@ async def release_bytes(self, requested_bytes): self._cv.notify_all() +async def transfer_shard_to_host(shard: jax.Shard) -> np.ndarray: + """Asynchronously transfers a shard to host memory.""" + data = shard.data + has_pinned_host = any( + m.kind == 'pinned_host' for m in shard.device.addressable_memories() + ) + if jax._src.config.enable_memories.value and has_pinned_host: # pylint: disable=protected-access + # If available, transfer to pinned host memory + sharding = jax.sharding.SingleDeviceSharding( + shard.device, memory_kind='pinned_host' + ) + data = jax.device_put(data, sharding) + else: + data.copy_to_host_async() + # Allow other transfers to be scheduled simultaneously. + await asyncio.sleep(0) + # Ensure that jax.Array's internal numpy array can be zero-copied. This guards + # against consumers like tensorstore that would otherwise copy silently. + return np.array(data, copy=False) + + +def _get_copy_future(write_future): + return write_future.copy + + +def _get_commit_future(write_future): + return write_future.commit + + +async def _write_array( + shard: jax.Shard, + t: ts.TensorStore, + commit_future: Optional[list[Any]], + replica_id: int, + can_reference_source_data_indefinitely: bool, +): + """Writes a single array using TensorStore.""" + if shard.replica_id == replica_id: + data = await transfer_shard_to_host(shard) + write_future = t[shard.index].write( + data, + # Avoid additional copy of input array into the TensorStore chunk + # cache. If `arr_inp` is a jax.Array, the result of converting + # it to a NumPy array, as is done internally by TensorStore, is + # guaranteed to be immutable and therefore it is safe to retain a + # reference indefinitely. + can_reference_source_data_indefinitely=can_reference_source_data_indefinitely, + ) + if commit_future is not None: + assert isinstance(commit_future, list) + commit_future.append(_get_commit_future(write_future)) + await _get_copy_future(write_future) + else: + await _get_commit_future(write_future) + + async def async_serialize( arr_inp, tensorstore_spec, @@ -256,27 +313,17 @@ async def async_serialize( context=context, transaction=transaction, ) - - async def _write_array(shard): - if shard.replica_id == replica_id: - write_future = t[shard.index].write( - shard.data, - # Avoid additional copy of input array into the TensorStore chunk - # cache. If `arr_inp` is a jax.Array, the result of converting - # it to a NumPy array, as is done internally by TensorStore, is - # guaranteed to be immutable and therefore it is safe to retain a - # reference indefinitely. - can_reference_source_data_indefinitely=isinstance(arr_inp, jax.Array), - ) - if commit_future is not None: - assert isinstance(commit_future, list) - commit_future.append(write_future.commit) - await write_future.copy - else: - await write_future.commit - local_shards = arr_inp.addressable_shards - future_write_state = jax.tree_util.tree_map(_write_array, local_shards) + future_write_state = jax.tree_util.tree_map( + functools.partial( + _write_array, + t=t, + commit_future=commit_future, + replica_id=replica_id, + can_reference_source_data_indefinitely=isinstance(arr_inp, jax.Array), + ), + local_shards, + ) await asyncio.gather(*future_write_state) diff --git a/checkpoint/orbax/checkpoint/serialization_test.py b/checkpoint/orbax/checkpoint/serialization_test.py new file mode 100644 index 000000000..96aaf1e7a --- /dev/null +++ b/checkpoint/orbax/checkpoint/serialization_test.py @@ -0,0 +1,517 @@ +# Copyright 2024 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for serialization/deserialization.""" + +import asyncio +import logging +import math +import os +import pathlib +import tracemalloc as tm +from typing import Any +import unittest + +from absl.testing import absltest +from absl.testing import parameterized +import jax +from jax import dtypes as _dtypes +import jax.numpy as jnp +import numpy as np +from orbax.checkpoint import future +from orbax.checkpoint import serialization +from orbax.checkpoint import test_utils +import tensorstore as ts + + +GSPMDSharding = jax.sharding.GSPMDSharding +NamedSharding = jax.sharding.NamedSharding +P = jax.sharding.PartitionSpec + +jax.config.update('jax_enable_x64', True) + + +def _dtype(x): + if hasattr(x, 'dtype'): + return x.dtype + elif type(x) in _dtypes.python_scalar_dtypes: + return np.dtype(_dtypes.python_scalar_dtypes[type(x)]) + else: + return np.asarray(x).dtype + + +def serialize(arrs, tspecs, commit_future=None): + async def _serialize(): + await asyncio.gather(*[ + serialization.async_serialize(arr, tspec, commit_future=commit_future) + for arr, tspec in zip(arrs, tspecs) + ]) + + asyncio.run(_serialize()) + test_utils.sync_global_processes('serialization_complete') + + +def deserialize(shardings, tensorstore_specs, global_shapes=None, dtypes=None): + if global_shapes is None: + global_shapes = [None for _ in tensorstore_specs] + if dtypes is None: + dtypes = [None for _ in tensorstore_specs] + + async def _deserialize(): + return await asyncio.gather(*[ + serialization.async_deserialize(sharding, tspec, shape, dtype) + for sharding, tspec, shape, dtype in zip( + shardings, tensorstore_specs, global_shapes, dtypes + ) + ]) + + result = asyncio.run(_deserialize()) + test_utils.sync_global_processes('deserialization_complete') + return result + + +class FutureWithSpeedbump(future.Future): + + def __init__(self, f, speedbump): + self._f = f + self._speedbump = speedbump + assert self._speedbump >= 0 + + def result(self, timeout: int | None = None) -> Any: + raise NotImplementedError() + + async def _sleep_and_result(self): + await asyncio.sleep(self._speedbump) + return await self._f + + def __await__(self): + return self._sleep_and_result().__await__() + + +def create_global_mesh(mesh_shape, axis_names): + size = math.prod(mesh_shape) + if len(jax.devices()) < size: + raise unittest.SkipTest(f'Test requires {size} global devices.') + devices = sorted(jax.devices(), key=lambda d: d.id) + mesh_devices = np.array(devices[:size]).reshape(mesh_shape) + global_mesh = jax.sharding.Mesh(mesh_devices, axis_names) + return global_mesh + + +class CheckpointTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.ckpt_dir = pathlib.Path(self.create_tempdir('ckpt').full_path) + test_utils.sync_global_processes('CheckpointTest:setup_complete') + + def tearDown(self): + test_utils.sync_global_processes('CheckpointTest:tests_complete') + super().tearDown() + + def assertArraysEqual( + self, + x, + y, + *, + check_dtypes=True, + err_msg='', + allow_object_dtype=False, + verbose=True, + ): + """Assert that x and y arrays are exactly equal.""" + if check_dtypes: + self.assertDtypesMatch(x, y) + x = np.asarray(x) + y = np.asarray(y) + + if (not allow_object_dtype) and (x.dtype == object or y.dtype == object): + # See https://github.com/google/jax/issues/17867 + raise TypeError( + 'assertArraysEqual may be poorly behaved when np.asarray casts to' + ' dtype=object. If comparing PRNG keys, consider' + ' random_test.KeyArrayTest.assertKeysEqual. If comparing collections' + ' of arrays, consider using assertAllClose. To let this test proceed' + ' anyway, pass allow_object_dtype=True.' + ) + + # Work around https://github.com/numpy/numpy/issues/18992 + with np.errstate(over='ignore'): + np.testing.assert_array_equal(x, y, err_msg=err_msg, verbose=verbose) + + def assertDtypesMatch(self, x, y): + self.assertEqual(_dtype(x), _dtype(y)) + + def test_memory_consumption(self): + global_mesh = create_global_mesh((2, 4), ('x', 'y')) + inp_shape = (2_048, 4_096) + pspec = P('x', 'y') + num = math.prod(inp_shape) + sharding = NamedSharding(global_mesh, pspec) + src = jnp.arange(num, dtype=np.int32).reshape(inp_shape) # 8e9 + inp = jax.make_array_from_callback( + inp_shape, sharding, lambda idx: src[idx] + ) + tspec = serialization.get_tensorstore_spec(str(self.ckpt_dir)) + + serialize( + [inp], + [tspec], + ) + + async def deserialize_with_byte_limit(): + r = await serialization.async_deserialize( + sharding, + tspec, + inp_shape, + byte_limiter=serialization._LimitInFlightBytes(4_200_000), + ) + r.block_until_ready() + + tm.start() + asyncio.run(deserialize_with_byte_limit()) + unused_current, peak = tm.get_traced_memory() + # NB: some padding + tensorstore overhead. It should always be + # less than array size (2048 * 4096 * 4 = 32M) + self.assertLess(peak, 10_000_000) + deserialize_wo_limit = serialization.async_deserialize( + sharding, tspec, inp_shape) + tm.clear_traces() + # NB: call block_until_ready() is important here and above + # because otherwise this leads to racing condition and segfault with + # tensorstore attempting to dealloc using tracemalloc which is already + # destroyed. + asyncio.run(deserialize_wo_limit).block_until_ready() + + unused_current, peak = tm.get_traced_memory() + # We load entire array in memory here. + self.assertGreater(peak, 30_000_000) + tm.stop() + + def test_checkpointing_jax_array(self): + global_mesh = create_global_mesh((4, 2), ('x', 'y')) + inp_shape = (8, 2) + pspec = P('x', 'y') + num = math.prod(inp_shape) + + # First Array + global_input_data1 = np.arange(num, dtype=np.int32).reshape(inp_shape) + a1 = jax.make_array_from_callback( + inp_shape, + NamedSharding(global_mesh, pspec), + lambda idx: global_input_data1[idx], + ) + + # Second Array + global_input_data2 = np.arange(num, num + num, dtype=np.int32).reshape( + inp_shape + ) + a2 = jax.make_array_from_callback( + inp_shape, + NamedSharding(global_mesh, pspec), + lambda idx: global_input_data2[idx], + ) + + # Third Array + def cb3(_): + return np.array([], dtype=np.float32) + + global_mesh1d = create_global_mesh((8,), ('x',)) + a3 = jax.make_array_from_callback( + (0,), NamedSharding(global_mesh1d, P(None)), cb3 + ) + + ckpt_paths = [ + self.create_tempdir(f'{self.ckpt_dir}/{i}').full_path for i in range(3) + ] + test_utils.sync_global_processes( + 'test_checkpointing_jax_array:create_arr_paths' + ) + tspecs = jax.tree.map(serialization.get_tensorstore_spec, ckpt_paths) + + serialize([a1, a2, a3], tspecs) + + m1, m2, m3 = deserialize( + [ + NamedSharding(global_mesh, pspec), + NamedSharding(global_mesh, P('x')), + NamedSharding(global_mesh1d, P(None)), + ], + tspecs, + ) + + logging.info(m1.addressable_shards) + logging.info(m2.addressable_shards) + logging.info(m3.addressable_shards) + self.assertIsInstance(m1, jax.Array) + self.assertArraysEqual( + np.asarray(m1.addressable_shards[0].data), + np.array([[0], [2]], dtype=np.int32), + ) + self.assertArraysEqual( + np.asarray(m1.addressable_shards[1].data), + np.array([[1], [3]], dtype=np.int32), + ) + self.assertEqual(m1.addressable_shards[0].data.shape, (2, 1)) + self.assertEqual(m1.dtype, np.int32) + + self.assertIsInstance(m2, jax.Array) + self.assertArraysEqual( + np.asarray(m2.addressable_shards[0].data), + np.array([[16, 17], [18, 19]], dtype=np.int32), + ) + self.assertArraysEqual( + np.asarray(m2.addressable_shards[1].data), + np.array([[16, 17], [18, 19]], dtype=np.int32), + ) + self.assertEqual(m2.addressable_shards[0].data.shape, (2, 2)) + self.assertEqual(m2.dtype, np.int32) + + self.assertIsInstance(m3, jax.Array) + for i, s in enumerate(m3.addressable_shards): + self.assertEqual(s.index, (slice(None),)) + self.assertEqual(s.replica_id, i) + self.assertArraysEqual(np.asarray(s.data), np.array([], dtype=np.float32)) + self.assertEqual(m3.dtype, np.float32) + + @parameterized.product(input_dtype=[np.int32, jnp.bfloat16]) + def test_checkpointing_with_bigger_shape_jax_array(self, input_dtype): + global_mesh = create_global_mesh((2, 2), ('x', 'y')) + global_input_shape = (8, 2) + num = math.prod(global_input_shape) + + global_input_data1 = np.arange(num, dtype=input_dtype).reshape( + global_input_shape + ) + def cb1(index): + return global_input_data1[index] + arr = jax.make_array_from_callback( + global_input_shape, NamedSharding(global_mesh, P('x', 'y')), cb1 + ) + ckpt_paths = [str(self.ckpt_dir)] + tspecs = jax.tree.map(serialization.get_tensorstore_spec, ckpt_paths) + + serialize([arr], tspecs) + + ds = NamedSharding(create_global_mesh((4, 2), ('x', 'y')), P('x', 'y')) + + (m1,) = deserialize([ds], tspecs, [(12, 2)], [np.float32]) + + expected_data = { + 0: np.array([[0], [2], [4]], dtype=np.float32), + 1: np.array([[1], [3], [5]], dtype=np.float32), + 2: np.array([[6], [8], [10]], dtype=np.float32), + 3: np.array([[7], [9], [11]], dtype=np.float32), + 4: np.array([[12], [14], [0]], dtype=np.float32), + 5: np.array([[13], [15], [0]], dtype=np.float32), + 6: np.array([[0], [0], [0]], dtype=np.float32), + 7: np.array([[0], [0], [0]], dtype=np.float32), + } + + for l in m1.addressable_shards: + self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id]) + + new_ds = GSPMDSharding.get_replicated(list(global_mesh.devices.flat)) + (m2,) = deserialize([new_ds], tspecs, [(8, 2)], [np.float32]) + for l in m2.addressable_shards: + self.assertArraysEqual(l.data, global_input_data1.astype('float32')) + + @parameterized.product(input_dtype=[jnp.int4, jnp.int8]) + def test_checkpointing_with_int4(self, input_dtype): + global_mesh = create_global_mesh((2, 2), ('x', 'y')) + global_input_shape = (8, 2) + num = math.prod(global_input_shape) + + global_input_data = np.arange(num, dtype=input_dtype).reshape( + global_input_shape + ) + + def cb(index): + return global_input_data[index] + + arr = jax.make_array_from_callback( + global_input_shape, NamedSharding(global_mesh, P('x', 'y')), cb + ) + ckpt_paths = [str(self.ckpt_dir)] + tspecs = jax.tree.map(serialization.get_tensorstore_spec, ckpt_paths) + + serialize([arr], tspecs) + + ds = NamedSharding(create_global_mesh((4, 2), ('x', 'y')), P('x', 'y')) + + target_dtype = jnp.dtype('int4') + (m1,) = deserialize([ds], tspecs, [(12, 2)], [target_dtype]) + + # values bigger than 7 are converted properly. + expected_data = { + 0: jnp.array([[0], [2], [4]], dtype=target_dtype), + 1: jnp.array([[1], [3], [5]], dtype=target_dtype), + 2: jnp.array([[6], [8], [10]], dtype=target_dtype), + 3: jnp.array([[7], [9], [11]], dtype=target_dtype), + 4: jnp.array([[12], [14], [0]], dtype=target_dtype), + 5: jnp.array([[13], [15], [0]], dtype=target_dtype), + 6: jnp.array([[0], [0], [0]], dtype=target_dtype), + 7: jnp.array([[0], [0], [0]], dtype=target_dtype), + } + + for l in m1.addressable_shards: + self.assertArraysEqual(np.asarray(l.data), expected_data[l.device.id]) + + new_ds = GSPMDSharding.get_replicated(list(global_mesh.devices.flat)) + (m2,) = deserialize([new_ds], tspecs, [(8, 2)], [target_dtype]) + for l in m2.addressable_shards: + self.assertArraysEqual(l.data, global_input_data.astype(target_dtype)) + + def test_checkpointing_scalar_jax_array(self): + global_mesh = create_global_mesh((2,), 'x') + global_input_shape = () + data = np.array(4) + s = NamedSharding(global_mesh, P(None)) + array1 = jax.make_array_from_callback( + global_input_shape, s, lambda idx: data[idx] + ) + ckpt_paths = [str(self.ckpt_dir)] + tspecs = jax.tree.map(serialization.get_tensorstore_spec, ckpt_paths) + + serialize([array1], tspecs) + + ds = NamedSharding(global_mesh, P(None)) + + (m1,) = deserialize([ds], tspecs, [()], [np.float32]) + + for l in m1.addressable_shards: + self.assertArraysEqual(np.asarray(l.data), data.astype(np.float32)) + + def test_deserialize_tensorstore_array_jax_array(self): + global_mesh = create_global_mesh((2,), 'x') + data = np.arange(1024) + tspec = ts.array(data).spec() + (m1,) = deserialize([NamedSharding(global_mesh, P(None))], [tspec]) + for l in m1.addressable_shards: + self.assertArraysEqual(np.asarray(l.data), data) + + def test_spec_has_metadata(self): + spec = { + 'a': { + 'b': 1, + 'c': 2, + }, + 'd': 3, + 'e': { + 'a': 2, + 'metadata': 3 + }, + 'f': 4 + } + self.assertTrue(serialization._spec_has_metadata(spec)) + self.assertTrue( + serialization._spec_has_metadata({ + 'driver': 'zarr', + 'kvstore': 'gfile', + 'metadata': { + 'chunks': 4, + 'shape': (32, 64) + }, + 'one_more': 'thing' + })) + + def test_spec_has_no_metadata(self): + spec = { + 'a': { + 'b': 1, + 'c': 2, + }, + 'd': 3, + 'e': { + 'a': 2, + }, + 'f': 4 + } + self.assertFalse(serialization._spec_has_metadata(spec)) + + def test_empty_spec_has_no_metadata(self): + spec = {} + self.assertFalse(serialization._spec_has_metadata(spec)) + + @parameterized.named_parameters( + ('gcs', 'gs://my/ckpt/dir/path'), + ('file', '/my/ckpt/dir/path') + ) + def test_get_tensorstore_spec_ocdbt(self, path): + spec = serialization.get_tensorstore_spec(path, ocdbt=True) + is_gcs_path = path.startswith('gs://') + if is_gcs_path: + self.assertEqual(spec['kvstore']['base'], os.path.dirname(path)) + else: + self.assertEqual( + spec['kvstore']['base'], + { + 'driver': serialization._DEFAULT_DRIVER, + 'path': os.path.dirname(path), + }, + ) + self.assertEqual(spec['kvstore']['path'], 'path') + + def test_get_tensorstore_spec_not_absolute_path(self): + path = 'my/ckpt/path' + with self.assertRaisesRegex( + ValueError, 'Checkpoint path should be absolute' + ): + serialization.get_tensorstore_spec(path, ocdbt=True) + + def test_maybe_cloud_storage(self): + gs_path = 'gs://some-buck/path' + gs_spec = serialization.get_tensorstore_spec(gs_path, ocdbt=True) + self.assertTrue(serialization.is_remote_storage(gs_spec)) + + local_path = '/tmp/checkpoint' + local_spec = serialization.get_tensorstore_spec(local_path, ocdbt=True) + self.assertFalse(serialization.is_remote_storage(local_spec)) + + nested_tspec = { + 'driver': 'cast', + 'dtype': 'int32', + 'base': { + 'driver': 'zarr', + 'kvstore': {'driver': 'ocdbt', 'base': 's3://some-bucket/path'}, + }, + } + self.assertTrue(serialization.is_remote_storage(nested_tspec)) + + def test_deserialization_with_int4(self): + dtype = jnp.int4 + shape = (8, 2) + arr = jnp.arange(np.prod(shape)).reshape(shape).astype(dtype) + + # Run serialization. + sharding = jax.sharding.GSPMDSharding.get_replicated(jax.devices()) + tspecs = jax.tree.map(serialization.get_tensorstore_spec, [self.ckpt_dir]) + + serialize([arr], tspecs) + + # Run deserialization. + (deserialized_arr,) = deserialize( + shardings=[sharding], + tensorstore_specs=tspecs, + global_shapes=[shape], + dtypes=[dtype], + ) + + out = deserialized_arr.astype(jnp.int8) # doesn't crash + self.assertEqual(out.dtype, jnp.int8) + self.assertArraysEqual(out + out, out * 2) + + +if __name__ == '__main__': + absltest.main()