diff --git a/checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py index a8dac3a3a..95592a6d7 100644 --- a/checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py @@ -25,6 +25,7 @@ import json import time from typing import Any, List, Optional, Tuple, Union +import uuid from absl import logging from etils import epath @@ -40,6 +41,7 @@ import tensorstore as ts + PyTree = Any TupleKey = Tuple[str, ...] RestoreArgs = type_handlers.RestoreArgs @@ -58,6 +60,7 @@ METADATA_FILE = '_METADATA' + def get_byte_limiter(concurrent_gb: int): async def _create_byte_limiter(): # Wrap creation in async function to avoid issues on python<=3.9. @@ -248,7 +251,8 @@ def __init__( to None, then all hosts will be considered as primary. It's useful in the case that all hosts are only working with local storage. type_handler_registry: a type_handlers.TypeHandlerRegistry. If not - specified, the global type handler registry will be used. + specified, the global type handler registry will be used. # BEGIN + enable_descriptor: If True, logs a Descriptor proto that contains lineage """ self._concurrent_gb = concurrent_gb self._use_ocdbt = use_ocdbt @@ -261,7 +265,7 @@ def __init__( '/jax/orbax/pytree_checkpoint_handler/init/ocdbt' ) - self._thread_pool = futures.ThreadPoolExecutor(max_workers=1) + self._thread_pool = futures.ThreadPoolExecutor(max_workers=2) def get_param_names(self, item: PyTree) -> PyTree: """Gets parameter names for PyTree elements.""" @@ -418,10 +422,11 @@ async def async_save( if multihost.is_primary_host(self._primary_host): metadata_write_start_time = time.time() - metadata_future = self._write_metadata_file( - directory, param_infos, save_args, self._use_zarr3 + commit_futures.append( + self._write_metadata_file( + directory, param_infos, save_args, self._use_zarr3 + ) ) - commit_futures += [metadata_future] jax.monitoring.record_event_duration_secs( '/jax/checkpoint/write/async/metadata_write_duration_secs', time.time() - metadata_write_start_time, @@ -723,7 +728,7 @@ def finalize(self, directory: epath.Path) -> None: def close(self): """Closes the handler. Called automatically by Checkpointer.""" - pass + self._thread_pool.shutdown() @register_with_handler(BasePyTreeCheckpointHandler, for_save=True) diff --git a/checkpoint/orbax/checkpoint/standard_checkpoint_handler_test_utils.py b/checkpoint/orbax/checkpoint/standard_checkpoint_handler_test_utils.py index 1a3c32d5b..c783d6e53 100644 --- a/checkpoint/orbax/checkpoint/standard_checkpoint_handler_test_utils.py +++ b/checkpoint/orbax/checkpoint/standard_checkpoint_handler_test_utils.py @@ -230,11 +230,13 @@ def make_params(): ) test_utils.assert_tree_equal(self, params, restored) - def test_empty_pytrees(self): + def test_empty_error(self): """Test case.""" with self.assertRaises(ValueError): self.handler.save(self.directory, args=self.save_args_cls({})) + def test_empty_dict_node(self): + """Test case.""" item = {'a': {}, 'b': 3} self.handler.save(self.directory, args=self.save_args_cls(item)) restored = self.handler.restore( @@ -242,6 +244,8 @@ def test_empty_pytrees(self): ) self.assertDictEqual(restored, item) + def test_empty_none_node(self): + """Test case.""" item = {'c': None, 'd': 2} self.handler.save(self.directory, args=self.save_args_cls(item)) restored = self.handler.restore(