diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 0d2f69362..dcb874624 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Allow one directory creation request per item rather than 1 per item per host. - Make atomicity logic configurable, and encapsulate it within a class. +- Move all work for `_write_metadata_file` into a background thread to avoid +O(n) computation in building metadata. ### Fixed - Refactor ts.Context usage to be per-operation (save/restore) rather than a diff --git a/checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py index 07efbee08..db3b16c1c 100644 --- a/checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py @@ -21,9 +21,9 @@ import asyncio import collections +from concurrent import futures import dataclasses import json -import os import time from typing import Any, Callable, List, Optional, Tuple, Union @@ -380,6 +380,8 @@ def __init__( '/jax/orbax/pytree_checkpoint_handler/init/ocdbt' ) + self._thread_pool = futures.ThreadPoolExecutor(max_workers=1) + def get_param_names(self, item: PyTree) -> PyTree: """Gets parameter names for PyTree elements.""" return get_param_names(item) @@ -585,7 +587,7 @@ def _maybe_set_default_save_args(value, args_): metadata_future = None if utils.is_primary_host(self._primary_host): metadata_write_start_time = time.time() - metadata_future = await self._write_metadata_file( + metadata_future = self._write_metadata_file( directory, item, save_args, self._use_zarr3 ) jax.monitoring.record_event_duration_secs( @@ -840,33 +842,26 @@ def _read_aggregate_file(self, directory: epath.Path) -> PyTree: else: return utils.pytree_structure(directory) - async def _write_metadata_file( + def _write_metadata_file( self, directory: epath.Path, item: PyTree, save_args: PyTree, use_zarr3: bool = False, ) -> future.Future: - tspec = type_handlers._get_tensorstore_spec( # pylint: disable=protected-access - os.fspath(directory), name=METADATA_FILE, use_ocdbt=False - )['kvstore'] - txn = ts.Transaction() - metadata_ts_context = type_handlers.get_ts_context() - t = await ts.KvStore.open( - tspec, context=metadata_ts_context - ) - metadata_content = tree_metadata.TreeMetadata.build( - item, - save_args=save_args, - type_handler_registry=self._type_handler_registry, - use_zarr3=use_zarr3, - ) - write_future = t.with_transaction(txn).write( - '', json.dumps(metadata_content.to_json()) - ) - await write_future - commit_future = txn.commit_async() - return commit_future + def _save_fn(): + if utils.is_primary_host(self._primary_host): + path = directory / METADATA_FILE + metadata_content = tree_metadata.TreeMetadata.build( + item, + save_args=save_args, + type_handler_registry=self._type_handler_registry, + use_zarr3=use_zarr3, + ) + path.write_text(json.dumps(metadata_content.to_json())) + return 0 + + return self._thread_pool.submit(_save_fn) def _read_metadata_file( self, directory: epath.Path