Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal change. #994

Merged
merged 1 commit into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import json
import time
from typing import Any, List, Optional, Tuple, Union
import uuid

from absl import logging
from etils import epath
Expand All @@ -40,6 +41,7 @@
import tensorstore as ts



PyTree = Any
TupleKey = Tuple[str, ...]
RestoreArgs = type_handlers.RestoreArgs
Expand All @@ -58,6 +60,7 @@
METADATA_FILE = '_METADATA'



def get_byte_limiter(concurrent_gb: int):
async def _create_byte_limiter():
# Wrap creation in async function to avoid issues on python<=3.9.
Expand Down Expand Up @@ -248,7 +251,8 @@ def __init__(
to None, then all hosts will be considered as primary. It's useful in
the case that all hosts are only working with local storage.
type_handler_registry: a type_handlers.TypeHandlerRegistry. If not
specified, the global type handler registry will be used.
specified, the global type handler registry will be used. # BEGIN
enable_descriptor: If True, logs a Descriptor proto that contains lineage
"""
self._concurrent_gb = concurrent_gb
self._use_ocdbt = use_ocdbt
Expand All @@ -261,7 +265,7 @@ def __init__(
'/jax/orbax/pytree_checkpoint_handler/init/ocdbt'
)

self._thread_pool = futures.ThreadPoolExecutor(max_workers=1)
self._thread_pool = futures.ThreadPoolExecutor(max_workers=2)

def get_param_names(self, item: PyTree) -> PyTree:
"""Gets parameter names for PyTree elements."""
Expand Down Expand Up @@ -418,10 +422,11 @@ async def async_save(

if multihost.is_primary_host(self._primary_host):
metadata_write_start_time = time.time()
metadata_future = self._write_metadata_file(
directory, param_infos, save_args, self._use_zarr3
commit_futures.append(
self._write_metadata_file(
directory, param_infos, save_args, self._use_zarr3
)
)
commit_futures += [metadata_future]
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/async/metadata_write_duration_secs',
time.time() - metadata_write_start_time,
Expand Down Expand Up @@ -723,7 +728,7 @@ def finalize(self, directory: epath.Path) -> None:

def close(self):
"""Closes the handler. Called automatically by Checkpointer."""
pass
self._thread_pool.shutdown()


@register_with_handler(BasePyTreeCheckpointHandler, for_save=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,18 +230,22 @@ def make_params():
)
test_utils.assert_tree_equal(self, params, restored)

def test_empty_pytrees(self):
def test_empty_error(self):
"""Test case."""
with self.assertRaises(ValueError):
self.handler.save(self.directory, args=self.save_args_cls({}))

def test_empty_dict_node(self):
"""Test case."""
item = {'a': {}, 'b': 3}
self.handler.save(self.directory, args=self.save_args_cls(item))
restored = self.handler.restore(
self.directory, args=self.restore_args_cls(item)
)
self.assertDictEqual(restored, item)

def test_empty_none_node(self):
"""Test case."""
item = {'c': None, 'd': 2}
self.handler.save(self.directory, args=self.save_args_cls(item))
restored = self.handler.restore(
Expand Down
Loading