Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 657266226
  • Loading branch information
Orbax Authors committed Jul 29, 2024
1 parent 4bc8308 commit 0d0893d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 16 deletions.
17 changes: 6 additions & 11 deletions checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import json
import time
from typing import Any, List, Optional, Tuple, Union
import uuid

from absl import logging
from etils import epath
Expand All @@ -41,7 +40,6 @@
import tensorstore as ts



PyTree = Any
TupleKey = Tuple[str, ...]
RestoreArgs = type_handlers.RestoreArgs
Expand All @@ -60,7 +58,6 @@
METADATA_FILE = '_METADATA'



def get_byte_limiter(concurrent_gb: int):
async def _create_byte_limiter():
# Wrap creation in async function to avoid issues on python<=3.9.
Expand Down Expand Up @@ -251,8 +248,7 @@ def __init__(
to None, then all hosts will be considered as primary. It's useful in
the case that all hosts are only working with local storage.
type_handler_registry: a type_handlers.TypeHandlerRegistry. If not
specified, the global type handler registry will be used. # BEGIN
enable_descriptor: If True, logs a Descriptor proto that contains lineage
specified, the global type handler registry will be used.
"""
self._concurrent_gb = concurrent_gb
self._use_ocdbt = use_ocdbt
Expand All @@ -265,7 +261,7 @@ def __init__(
'/jax/orbax/pytree_checkpoint_handler/init/ocdbt'
)

self._thread_pool = futures.ThreadPoolExecutor(max_workers=2)
self._thread_pool = futures.ThreadPoolExecutor(max_workers=1)

def get_param_names(self, item: PyTree) -> PyTree:
"""Gets parameter names for PyTree elements."""
Expand Down Expand Up @@ -422,11 +418,10 @@ async def async_save(

if multihost.is_primary_host(self._primary_host):
metadata_write_start_time = time.time()
commit_futures.append(
self._write_metadata_file(
directory, param_infos, save_args, self._use_zarr3
)
metadata_future = self._write_metadata_file(
directory, param_infos, save_args, self._use_zarr3
)
commit_futures += [metadata_future]
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/async/metadata_write_duration_secs',
time.time() - metadata_write_start_time,
Expand Down Expand Up @@ -728,7 +723,7 @@ def finalize(self, directory: epath.Path) -> None:

def close(self):
"""Closes the handler. Called automatically by Checkpointer."""
self._thread_pool.shutdown()
pass


@register_with_handler(BasePyTreeCheckpointHandler, for_save=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,22 +230,18 @@ def make_params():
)
test_utils.assert_tree_equal(self, params, restored)

def test_empty_error(self):
def test_empty_pytrees(self):
"""Test case."""
with self.assertRaises(ValueError):
self.handler.save(self.directory, args=self.save_args_cls({}))

def test_empty_dict_node(self):
"""Test case."""
item = {'a': {}, 'b': 3}
self.handler.save(self.directory, args=self.save_args_cls(item))
restored = self.handler.restore(
self.directory, args=self.restore_args_cls(item)
)
self.assertDictEqual(restored, item)

def test_empty_none_node(self):
"""Test case."""
item = {'c': None, 'd': 2}
self.handler.save(self.directory, args=self.save_args_cls(item))
restored = self.handler.restore(
Expand Down

0 comments on commit 0d0893d

Please sign in to comment.