Skip to content

Commit

Permalink
Move all work for _write_metadata_file into a background thread to …
Browse files Browse the repository at this point in the history
…avoid building metadata in the main thread. This is not all that costly, but it is O(n) where n is the number of arrays in the tree, so it can start to add up for trees with a lot of parameters.

PiperOrigin-RevId: 651805233
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed Jul 12, 2024
1 parent a704408 commit 729c2ac
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 23 deletions.
2 changes: 2 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Changed
- Allow one directory creation request per item rather than 1 per item per host.
- Make atomicity logic configurable, and encapsulate it within a class.
- Move all work for `_write_metadata_file` into a background thread to avoid
O(n) computation in building metadata.

### Fixed
- Refactor ts.Context usage to be per-operation (save/restore) rather than a
Expand Down
41 changes: 18 additions & 23 deletions checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@

import asyncio
import collections
from concurrent import futures
import dataclasses
import json
import os
import time
from typing import Any, Callable, List, Optional, Tuple, Union

Expand Down Expand Up @@ -380,6 +380,8 @@ def __init__(
'/jax/orbax/pytree_checkpoint_handler/init/ocdbt'
)

self._thread_pool = futures.ThreadPoolExecutor(max_workers=1)

def get_param_names(self, item: PyTree) -> PyTree:
"""Gets parameter names for PyTree elements."""
return get_param_names(item)
Expand Down Expand Up @@ -585,7 +587,7 @@ def _maybe_set_default_save_args(value, args_):
metadata_future = None
if utils.is_primary_host(self._primary_host):
metadata_write_start_time = time.time()
metadata_future = await self._write_metadata_file(
metadata_future = self._write_metadata_file(
directory, item, save_args, self._use_zarr3
)
jax.monitoring.record_event_duration_secs(
Expand Down Expand Up @@ -840,33 +842,26 @@ def _read_aggregate_file(self, directory: epath.Path) -> PyTree:
else:
return utils.pytree_structure(directory)

async def _write_metadata_file(
def _write_metadata_file(
self,
directory: epath.Path,
item: PyTree,
save_args: PyTree,
use_zarr3: bool = False,
) -> future.Future:
tspec = type_handlers._get_tensorstore_spec( # pylint: disable=protected-access
os.fspath(directory), name=METADATA_FILE, use_ocdbt=False
)['kvstore']
txn = ts.Transaction()
metadata_ts_context = type_handlers.get_ts_context()
t = await ts.KvStore.open(
tspec, context=metadata_ts_context
)
metadata_content = tree_metadata.TreeMetadata.build(
item,
save_args=save_args,
type_handler_registry=self._type_handler_registry,
use_zarr3=use_zarr3,
)
write_future = t.with_transaction(txn).write(
'', json.dumps(metadata_content.to_json())
)
await write_future
commit_future = txn.commit_async()
return commit_future
def _save_fn():
if utils.is_primary_host(self._primary_host):
path = directory / METADATA_FILE
metadata_content = tree_metadata.TreeMetadata.build(
item,
save_args=save_args,
type_handler_registry=self._type_handler_registry,
use_zarr3=use_zarr3,
)
path.write_text(json.dumps(metadata_content.to_json()))
return 0

return self._thread_pool.submit(_save_fn)

def _read_metadata_file(
self, directory: epath.Path
Expand Down

0 comments on commit 729c2ac

Please sign in to comment.