Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 650366520
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed Jul 11, 2024
1 parent a704408 commit c23e9ad
Showing 1 changed file with 25 additions and 18 deletions.
43 changes: 25 additions & 18 deletions checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import os
import time
from typing import Any, Callable, List, Optional, Tuple, Union
import uuid

from absl import logging
from etils import epath
Expand All @@ -35,6 +36,7 @@
from orbax.checkpoint import async_checkpoint_handler
from orbax.checkpoint import checkpoint_args
from orbax.checkpoint import future
from orbax.checkpoint import proto_checkpoint_handler
from orbax.checkpoint import transform_utils
from orbax.checkpoint import tree as tree_utils
from orbax.checkpoint import type_handlers
Expand All @@ -43,6 +45,7 @@
from orbax.checkpoint.metadata import value as value_metadata
import tensorstore as ts


PyTree = Any
TupleKey = Tuple[str, ...]
RestoreArgs = type_handlers.RestoreArgs
Expand All @@ -66,6 +69,7 @@
_CHECKPOINT_FILE = 'checkpoint'



def get_byte_limiter(concurrent_gb: int):
async def _create_byte_limiter():
# Wrap creation in async function to avoid issues on python<=3.9.
Expand Down Expand Up @@ -362,7 +366,8 @@ def __init__(
to None, then all hosts will be considered as primary. It's useful in
the case that all hosts are only working with local storage.
type_handler_registry: a type_handlers.TypeHandlerRegistry. If not
specified, the global type handler registry will be used.
specified, the global type handler registry will be used. # BEGIN
enable_descriptor: If True, logs a Descriptor proto that contains lineage
"""
self._aggregate_handler = MsgpackHandler(primary_host=primary_host)
if aggregate_filename is None:
Expand Down Expand Up @@ -582,30 +587,28 @@ def _maybe_set_default_save_args(value, args_):
logging.debug('param_info: %s', param_infos)
logging.debug('save_args: %s', save_args)

metadata_future = None
if utils.is_primary_host(self._primary_host):
metadata_write_start_time = time.time()
metadata_future = await self._write_metadata_file(
directory, item, save_args, self._use_zarr3
)
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/async/metadata_write_duration_secs',
time.time() - metadata_write_start_time,
)
metadata_write_start_time = time.time()
commit_futures.append(
await self._write_metadata_file(
directory, item, save_args, self._use_zarr3
)
)
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/async/metadata_write_duration_secs',
time.time() - metadata_write_start_time,
)

aggregate_file_write_start_time = time.time()
aggregate_commit_future = await self._write_aggregate_file(
directory, item, param_infos, save_args
commit_futures.append(
await self._write_aggregate_file(
directory, item, param_infos, save_args
)
)
jax.monitoring.record_event_duration_secs(
'/jax/checkpoint/write/async/aggregate_write_duration_secs',
time.time() - aggregate_file_write_start_time,
)
return (
commit_futures + [aggregate_commit_future] + [metadata_future]
if metadata_future is not None
else commit_futures + [aggregate_commit_future]
)
return commit_futures

def save(self, directory: epath.Path, *args, **kwargs):
"""Saves the provided item.
Expand Down Expand Up @@ -840,13 +843,17 @@ def _read_aggregate_file(self, directory: epath.Path) -> PyTree:
else:
return utils.pytree_structure(directory)


async def _write_metadata_file(
self,
directory: epath.Path,
item: PyTree,
save_args: PyTree,
use_zarr3: bool = False,
) -> future.Future:
if not utils.is_primary_host(self._primary_host):
return future.NoopFuture()

tspec = type_handlers._get_tensorstore_spec( # pylint: disable=protected-access
os.fspath(directory), name=METADATA_FILE, use_ocdbt=False
)['kvstore']
Expand Down

0 comments on commit c23e9ad

Please sign in to comment.