Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move all work for _write_metadata_file into a background thread to avoid building metadata in the main thread. This is not all that costly, but it is O(n) where n is the number of arrays in the tree, so it can start to add up for trees with a lot of parameters. #1001

Merged
merged 1 commit into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ entries will not return information about array properties).
### Changed
- Allow one directory creation request per item rather than 1 per item per host.
- Make atomicity logic configurable, and encapsulate it within a class.
- Move all work for `_write_metadata_file` into a background thread to avoid
O(n) computation in building metadata.

### Fixed
- Refactor ts.Context usage to be per-operation (save/restore) rather than a
Expand Down
3 changes: 2 additions & 1 deletion checkpoint/orbax/checkpoint/array_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,14 @@ async def async_save(
)
]

type_handler = type_handlers.get_type_handler(type(item))
info = type_handlers.ParamInfo(
name=self._checkpoint_name,
path=directory / self._checkpoint_name,
parent_dir=directory,
is_ocdbt_checkpoint=False,
value_typestr=type_handler.typestr(),
)
type_handler = type_handlers.get_type_handler(type(item))
futures = await type_handler.serialize([item], [info], args=[save_args])
return list(futures)

Expand Down
53 changes: 26 additions & 27 deletions checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
"""

import asyncio
from concurrent import futures
import dataclasses
import json
import os
import time
from typing import Any, List, Optional, Tuple, Union

Expand Down Expand Up @@ -273,6 +273,8 @@ def __init__(
'/jax/orbax/pytree_checkpoint_handler/init/ocdbt'
)

self._thread_pool = futures.ThreadPoolExecutor(max_workers=1)

def get_param_names(self, item: PyTree) -> PyTree:
"""Gets parameter names for PyTree elements."""
return get_param_names(item)
Expand Down Expand Up @@ -304,8 +306,6 @@ def _get_param_infos(
Returns:
A PyTree matching `item` of ParamInfo.
"""
if not item:
raise ValueError('Found empty item')
if use_zarr3 is None:
use_zarr3 = self._use_zarr3
names = self.get_param_names(item)
Expand All @@ -325,6 +325,9 @@ def _param_info(name, value):
ocdbt_target_data_file_size=ocdbt_target_data_file_size,
byte_limiter=byte_limiter,
ts_context=ts_context,
value_typestr=type_handlers.get_param_typestr(
value, self._type_handler_registry
),
)

return jax.tree.map(
Expand Down Expand Up @@ -373,6 +376,8 @@ async def async_save(
the data from its source will be awaited in this function.
"""
item = args.item
if not item:
raise ValueError('Found empty item.')
save_args = args.save_args
ocdbt_target_data_file_size = args.ocdbt_target_data_file_size
if ocdbt_target_data_file_size is not None and not self._use_zarr3:
Expand Down Expand Up @@ -425,8 +430,8 @@ async def async_save(

if multihost.is_primary_host(self._primary_host):
metadata_write_start_time = time.time()
metadata_future = await self._write_metadata_file(
directory, item, save_args, self._use_zarr3
metadata_future = self._write_metadata_file(
directory, param_infos, save_args, self._use_zarr3
)
commit_futures += [metadata_future]
jax.monitoring.record_event_duration_secs(
Expand Down Expand Up @@ -609,6 +614,8 @@ class TrainState:
if use_zarr3_metadata is not None
else self._use_zarr3
)
if not metadata:
raise ValueError('Found empty metadata.')
param_infos = self._get_param_infos(
metadata,
directory,
Expand All @@ -633,33 +640,25 @@ class TrainState:

return restored_item

async def _write_metadata_file(
def _write_metadata_file(
self,
directory: epath.Path,
item: PyTree,
param_infos: PyTree,
save_args: PyTree,
use_zarr3: bool = False,
) -> future.Future:
tspec = type_handlers._get_tensorstore_spec( # pylint: disable=protected-access
os.fspath(directory), name=METADATA_FILE, use_ocdbt=False
)['kvstore']
txn = ts.Transaction()
metadata_ts_context = type_handlers.get_ts_context()
t = await ts.KvStore.open(
tspec, context=metadata_ts_context
)
metadata_content = tree_metadata.TreeMetadata.build(
item,
save_args=save_args,
type_handler_registry=self._type_handler_registry,
use_zarr3=use_zarr3,
)
write_future = t.with_transaction(txn).write(
'', json.dumps(metadata_content.to_json())
)
await write_future
commit_future = txn.commit_async()
return commit_future
def _save_fn():
if utils.is_primary_host(self._primary_host):
path = directory / METADATA_FILE
metadata_content = tree_metadata.TreeMetadata.build(
param_infos,
save_args=save_args,
use_zarr3=use_zarr3,
)
path.write_text(json.dumps(metadata_content.to_json()))
return 0

return self._thread_pool.submit(_save_fn)

def _read_metadata_file(
self, directory: epath.Path
Expand Down
48 changes: 14 additions & 34 deletions checkpoint/orbax/checkpoint/metadata/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,32 +150,18 @@ def from_json(cls, json_dict: Dict[str, Any]) -> 'ValueMetadataEntry':
@classmethod
def build(
cls,
value: Any,
info: type_handlers.ParamInfo,
save_arg: type_handlers.SaveArgs,
registry: Optional[type_handlers.TypeHandlerRegistry],
) -> 'ValueMetadataEntry':
"""Builds a ValueMetadataEntry."""
if type_handlers.is_supported_empty_aggregation_type(value):
typestr = type_handlers.get_empty_value_typestr(value)
skip_deserialize = True
elif registry is None:
return ValueMetadataEntry(
value_type=type_handlers.RESTORE_TYPE_UNKNOWN, skip_deserialize=False
del save_arg
if info.value_typestr is None:
raise AssertionError(
'Must set `value_typestr` in `ParamInfo` when saving.'
)
else:
try:
handler = registry.get(type(value))
typestr = handler.typestr()
skip_deserialize = save_arg.aggregate
except ValueError:
# Not an error because users' training states often have a bunch of
# random unserializable objects in them (empty states, optimizer
# objects, etc.). An error occurring due to a missing TypeHandler
# will be surfaced elsewhere.
typestr = type_handlers.RESTORE_TYPE_NONE
skip_deserialize = True
skip_deserialize = type_handlers.is_empty_typestr(info.value_typestr)
return ValueMetadataEntry(
value_type=typestr, skip_deserialize=skip_deserialize
value_type=info.value_typestr, skip_deserialize=skip_deserialize
)


Expand Down Expand Up @@ -208,15 +194,12 @@ def from_json(
def build(
cls,
keypath: KeyPath,
value: Any,
info: type_handlers.ParamInfo,
save_arg: type_handlers.SaveArgs,
type_handler_registry: Optional[type_handlers.TypeHandlerRegistry],
) -> 'TreeMetadataEntry':
"""Builds a TreeMetadataEntry."""
key_metadata_entry = KeyMetadataEntry.build(keypath)
value_metadata_entry = ValueMetadataEntry.build(
value, save_arg, type_handler_registry
)
value_metadata_entry = ValueMetadataEntry.build(info, save_arg)
return TreeMetadataEntry(
str(tuple([str(tree_utils.get_key_name(k)) for k in keypath])),
key_metadata_entry,
Expand All @@ -242,33 +225,30 @@ class TreeMetadata:
@classmethod
def build(
cls,
tree: PyTree,
param_infos: PyTree,
*,
type_handler_registry: Optional[type_handlers.TypeHandlerRegistry] = None,
save_args: Optional[PyTree] = None,
use_zarr3: bool = False,
) -> 'TreeMetadata':
"""Builds the tree metadata."""
if save_args is None:
save_args = jax.tree.map(
lambda _: type_handlers.SaveArgs(),
tree,
param_infos,
is_leaf=tree_utils.is_empty_or_leaf,
)
flat_with_keys, _ = jax.tree_util.tree_flatten_with_path(
tree, is_leaf=tree_utils.is_empty_or_leaf
param_infos, is_leaf=tree_utils.is_empty_or_leaf
)
flat_save_args_with_keys, _ = jax.tree_util.tree_flatten_with_path(
save_args, is_leaf=tree_utils.is_empty_or_leaf
)
tree_metadata_entries = []
for (keypath, value), (_, save_arg) in zip(
for (keypath, info), (_, save_arg) in zip(
flat_with_keys, flat_save_args_with_keys
):
tree_metadata_entries.append(
TreeMetadataEntry.build(
keypath, value, save_arg, type_handler_registry
)
TreeMetadataEntry.build(keypath, info, save_arg)
)
return TreeMetadata(tree_metadata_entries, use_zarr3)

Expand Down
70 changes: 48 additions & 22 deletions checkpoint/orbax/checkpoint/metadata/tree_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import jax
from jax import numpy as jnp
import numpy as np
from orbax.checkpoint import tree as tree_utils
from orbax.checkpoint import type_handlers
from orbax.checkpoint.metadata import tree as tree_metadata

Expand All @@ -26,7 +27,15 @@ def setUp(self):
super().setUp()
arr = jnp.arange(8)
assert isinstance(arr, jax.Array)
self.tree = {'a': 1, 'b': {'c': 'hi', 'd': 3.4}, 'e': [np.arange(8), arr]}
self.tree = {
'a': 1,
'b': {'c': 'hi', 'd': 3.4},
'e': [np.arange(8), arr],
'f': None,
'g': {},
'h': [],
'i': tuple([]),
}
self.tree_json = {
'tree_metadata': {
"('a',)": {
Expand Down Expand Up @@ -76,36 +85,53 @@ def setUp(self):
'skip_deserialize': False,
},
},
"('f',)": {
'key_metadata': ({'key': 'f', 'key_type': 2},),
'value_metadata': {
'value_type': 'None',
'skip_deserialize': True,
},
},
"('g',)": {
'key_metadata': ({'key': 'g', 'key_type': 2},),
'value_metadata': {
'value_type': 'Dict',
'skip_deserialize': True,
},
},
"('h',)": {
'key_metadata': ({'key': 'h', 'key_type': 2},),
'value_metadata': {
'value_type': 'List',
'skip_deserialize': True,
},
},
"('i',)": {
'key_metadata': ({'key': 'i', 'key_type': 2},),
'value_metadata': {
'value_type': 'None',
'skip_deserialize': True,
},
},
},
'use_zarr3': True,
}

def test_json_conversion(self):
metadata = tree_metadata.TreeMetadata.build(
self.param_infos = jax.tree.map(
# Other properties are not relevant.
lambda x: type_handlers.ParamInfo(
value_typestr=type_handlers.get_param_typestr(
x, type_handlers.GLOBAL_TYPE_HANDLER_REGISTRY
)
),
self.tree,
type_handler_registry=type_handlers.GLOBAL_TYPE_HANDLER_REGISTRY,
use_zarr3=True,
)
self.assertDictEqual(self.tree_json, metadata.to_json())
self.assertEqual(
metadata, tree_metadata.TreeMetadata.from_json(self.tree_json)
is_leaf=tree_utils.is_empty_or_leaf,
)

def test_aggregate_option(self):
save_args = jax.tree.map(
lambda x: type_handlers.SaveArgs(), self.tree
)
save_args['a'] = type_handlers.SaveArgs(aggregate=True)
def test_json_conversion(self):
metadata = tree_metadata.TreeMetadata.build(
self.tree,
type_handler_registry=type_handlers.GLOBAL_TYPE_HANDLER_REGISTRY,
save_args=save_args,
self.param_infos,
use_zarr3=True,
)
tree_json = dict(**self.tree_json)
tree_json['tree_metadata']["('a',)"]['value_metadata'][
'skip_deserialize'
] = True
self.assertDictEqual(self.tree_json, metadata.to_json())
self.assertEqual(
metadata, tree_metadata.TreeMetadata.from_json(self.tree_json)
Expand Down
20 changes: 20 additions & 0 deletions checkpoint/orbax/checkpoint/type_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ class ParamInfo:
Specifies the target size (in bytes) of each OCDBT data file.
ts_context:
Tensorstore context to use for reading/writing.
value_typestr: stores the original value's typestr (from TypeHandler).
Only required when saving.
"""

name: Optional[str] = None
Expand All @@ -193,6 +195,7 @@ class ParamInfo:
use_zarr3: Optional[bool] = False
ocdbt_target_data_file_size: Optional[int] = None
ts_context: Optional[ts.Context] = None
value_typestr: Optional[str] = None


@dataclasses.dataclass
Expand Down Expand Up @@ -2056,3 +2059,20 @@ def default_restore_type(args: RestoreArgs) -> Any:
return np.ndarray
else:
raise ValueError(f'Unsupported restore_args type: {type(args)}')


def get_param_typestr(value: Any, registry: TypeHandlerRegistry) -> str:
"""Retrieves the typestr for a given value."""
if is_supported_empty_aggregation_type(value):
typestr = get_empty_value_typestr(value)
else:
try:
handler = registry.get(type(value))
typestr = handler.typestr()
except ValueError:
# Not an error because users' training states often have a bunch of
# random unserializable objects in them (empty states, optimizer
# objects, etc.). An error occurring due to a missing TypeHandler
# will be surfaced elsewhere.
typestr = RESTORE_TYPE_NONE
return typestr
Loading