Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 653631904
  • Loading branch information
cpgaffney1 authored and Orbax Authors committed Jul 18, 2024
1 parent f517a2f commit 097a23b
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 84 deletions.
2 changes: 2 additions & 0 deletions checkpoint/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ entries will not return information about array properties).
### Changed
- Allow one directory creation request per item rather than 1 per item per host.
- Make atomicity logic configurable, and encapsulate it within a class.
- Move all work for `_write_metadata_file` into a background thread to avoid
O(n) computation in building metadata.

### Fixed
- Refactor ts.Context usage to be per-operation (save/restore) rather than a
Expand Down
3 changes: 2 additions & 1 deletion checkpoint/orbax/checkpoint/array_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,14 @@ async def async_save(
)
]

type_handler = type_handlers.get_type_handler(type(item))
info = type_handlers.ParamInfo(
name=self._checkpoint_name,
path=directory / self._checkpoint_name,
parent_dir=directory,
is_ocdbt_checkpoint=False,
value_typestr=type_handler.typestr(),
)
type_handler = type_handlers.get_type_handler(type(item))
futures = await type_handler.serialize([item], [info], args=[save_args])
return list(futures)

Expand Down
53 changes: 26 additions & 27 deletions checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
"""

import asyncio
from concurrent import futures
import dataclasses
import json
import os
import time
from typing import Any, List, Optional, Tuple, Union

Expand Down Expand Up @@ -273,6 +273,8 @@ def __init__(
'/jax/orbax/pytree_checkpoint_handler/init/ocdbt'
)

self._thread_pool = futures.ThreadPoolExecutor(max_workers=1)

def get_param_names(self, item: PyTree) -> PyTree:
"""Gets parameter names for PyTree elements."""
return get_param_names(item)
Expand Down Expand Up @@ -304,8 +306,6 @@ def _get_param_infos(
Returns:
A PyTree matching `item` of ParamInfo.
"""
if not item:
raise ValueError('Found empty item')
if use_zarr3 is None:
use_zarr3 = self._use_zarr3
names = self.get_param_names(item)
Expand All @@ -325,6 +325,9 @@ def _param_info(name, value):
ocdbt_target_data_file_size=ocdbt_target_data_file_size,
byte_limiter=byte_limiter,
ts_context=ts_context,
value_typestr=type_handlers.get_param_typestr(
value, self._type_handler_registry
),
)

return jax.tree.map(
Expand Down Expand Up @@ -373,6 +376,8 @@ async def async_save(
the data from its source will be awaited in this function.
"""
item = args.item
if not item:
raise ValueError('Found empty item.')
save_args = args.save_args
ocdbt_target_data_file_size = args.ocdbt_target_data_file_size
if ocdbt_target_data_file_size is not None and not self._use_zarr3:
Expand Down Expand Up @@ -425,8 +430,8 @@ async def async_save(

if multihost.is_primary_host(self._primary_host):
metadata_write_start_time = time.time()
metadata_future = await self._write_metadata_file(
directory, item, save_args, self._use_zarr3
metadata_future = self._write_metadata_file(
directory, param_infos, save_args, self._use_zarr3
)
commit_futures += [metadata_future]
jax.monitoring.record_event_duration_secs(
Expand Down Expand Up @@ -609,6 +614,8 @@ class TrainState:
if use_zarr3_metadata is not None
else self._use_zarr3
)
if not metadata:
raise ValueError('Found empty metadata.')
param_infos = self._get_param_infos(
metadata,
directory,
Expand All @@ -633,33 +640,25 @@ class TrainState:

return restored_item

async def _write_metadata_file(
def _write_metadata_file(
self,
directory: epath.Path,
item: PyTree,
param_infos: PyTree,
save_args: PyTree,
use_zarr3: bool = False,
) -> future.Future:
tspec = type_handlers._get_tensorstore_spec( # pylint: disable=protected-access
os.fspath(directory), name=METADATA_FILE, use_ocdbt=False
)['kvstore']
txn = ts.Transaction()
metadata_ts_context = type_handlers.get_ts_context()
t = await ts.KvStore.open(
tspec, context=metadata_ts_context
)
metadata_content = tree_metadata.TreeMetadata.build(
item,
save_args=save_args,
type_handler_registry=self._type_handler_registry,
use_zarr3=use_zarr3,
)
write_future = t.with_transaction(txn).write(
'', json.dumps(metadata_content.to_json())
)
await write_future
commit_future = txn.commit_async()
return commit_future
def _save_fn():
if utils.is_primary_host(self._primary_host):
path = directory / METADATA_FILE
metadata_content = tree_metadata.TreeMetadata.build(
param_infos,
save_args=save_args,
use_zarr3=use_zarr3,
)
path.write_text(json.dumps(metadata_content.to_json()))
return 0

return self._thread_pool.submit(_save_fn)

def _read_metadata_file(
self, directory: epath.Path
Expand Down
48 changes: 14 additions & 34 deletions checkpoint/orbax/checkpoint/metadata/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,32 +150,18 @@ def from_json(cls, json_dict: Dict[str, Any]) -> 'ValueMetadataEntry':
@classmethod
def build(
cls,
value: Any,
info: type_handlers.ParamInfo,
save_arg: type_handlers.SaveArgs,
registry: Optional[type_handlers.TypeHandlerRegistry],
) -> 'ValueMetadataEntry':
"""Builds a ValueMetadataEntry."""
if type_handlers.is_supported_empty_aggregation_type(value):
typestr = type_handlers.get_empty_value_typestr(value)
skip_deserialize = True
elif registry is None:
return ValueMetadataEntry(
value_type=type_handlers.RESTORE_TYPE_UNKNOWN, skip_deserialize=False
del save_arg
if info.value_typestr is None:
raise AssertionError(
'Must set `value_typestr` in `ParamInfo` when saving.'
)
else:
try:
handler = registry.get(type(value))
typestr = handler.typestr()
skip_deserialize = save_arg.aggregate
except ValueError:
# Not an error because users' training states often have a bunch of
# random unserializable objects in them (empty states, optimizer
# objects, etc.). An error occurring due to a missing TypeHandler
# will be surfaced elsewhere.
typestr = type_handlers.RESTORE_TYPE_NONE
skip_deserialize = True
skip_deserialize = type_handlers.is_empty_typestr(info.value_typestr)
return ValueMetadataEntry(
value_type=typestr, skip_deserialize=skip_deserialize
value_type=info.value_typestr, skip_deserialize=skip_deserialize
)


Expand Down Expand Up @@ -208,15 +194,12 @@ def from_json(
def build(
cls,
keypath: KeyPath,
value: Any,
info: type_handlers.ParamInfo,
save_arg: type_handlers.SaveArgs,
type_handler_registry: Optional[type_handlers.TypeHandlerRegistry],
) -> 'TreeMetadataEntry':
"""Builds a TreeMetadataEntry."""
key_metadata_entry = KeyMetadataEntry.build(keypath)
value_metadata_entry = ValueMetadataEntry.build(
value, save_arg, type_handler_registry
)
value_metadata_entry = ValueMetadataEntry.build(info, save_arg)
return TreeMetadataEntry(
str(tuple([str(tree_utils.get_key_name(k)) for k in keypath])),
key_metadata_entry,
Expand All @@ -242,33 +225,30 @@ class TreeMetadata:
@classmethod
def build(
cls,
tree: PyTree,
param_infos: PyTree,
*,
type_handler_registry: Optional[type_handlers.TypeHandlerRegistry] = None,
save_args: Optional[PyTree] = None,
use_zarr3: bool = False,
) -> 'TreeMetadata':
"""Builds the tree metadata."""
if save_args is None:
save_args = jax.tree.map(
lambda _: type_handlers.SaveArgs(),
tree,
param_infos,
is_leaf=tree_utils.is_empty_or_leaf,
)
flat_with_keys, _ = jax.tree_util.tree_flatten_with_path(
tree, is_leaf=tree_utils.is_empty_or_leaf
param_infos, is_leaf=tree_utils.is_empty_or_leaf
)
flat_save_args_with_keys, _ = jax.tree_util.tree_flatten_with_path(
save_args, is_leaf=tree_utils.is_empty_or_leaf
)
tree_metadata_entries = []
for (keypath, value), (_, save_arg) in zip(
for (keypath, info), (_, save_arg) in zip(
flat_with_keys, flat_save_args_with_keys
):
tree_metadata_entries.append(
TreeMetadataEntry.build(
keypath, value, save_arg, type_handler_registry
)
TreeMetadataEntry.build(keypath, info, save_arg)
)
return TreeMetadata(tree_metadata_entries, use_zarr3)

Expand Down
70 changes: 48 additions & 22 deletions checkpoint/orbax/checkpoint/metadata/tree_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import jax
from jax import numpy as jnp
import numpy as np
from orbax.checkpoint import tree as tree_utils
from orbax.checkpoint import type_handlers
from orbax.checkpoint.metadata import tree as tree_metadata

Expand All @@ -26,7 +27,15 @@ def setUp(self):
super().setUp()
arr = jnp.arange(8)
assert isinstance(arr, jax.Array)
self.tree = {'a': 1, 'b': {'c': 'hi', 'd': 3.4}, 'e': [np.arange(8), arr]}
self.tree = {
'a': 1,
'b': {'c': 'hi', 'd': 3.4},
'e': [np.arange(8), arr],
'f': None,
'g': {},
'h': [],
'i': tuple([]),
}
self.tree_json = {
'tree_metadata': {
"('a',)": {
Expand Down Expand Up @@ -76,36 +85,53 @@ def setUp(self):
'skip_deserialize': False,
},
},
"('f',)": {
'key_metadata': ({'key': 'f', 'key_type': 2},),
'value_metadata': {
'value_type': 'None',
'skip_deserialize': True,
},
},
"('g',)": {
'key_metadata': ({'key': 'g', 'key_type': 2},),
'value_metadata': {
'value_type': 'Dict',
'skip_deserialize': True,
},
},
"('h',)": {
'key_metadata': ({'key': 'h', 'key_type': 2},),
'value_metadata': {
'value_type': 'List',
'skip_deserialize': True,
},
},
"('i',)": {
'key_metadata': ({'key': 'i', 'key_type': 2},),
'value_metadata': {
'value_type': 'None',
'skip_deserialize': True,
},
},
},
'use_zarr3': True,
}

def test_json_conversion(self):
metadata = tree_metadata.TreeMetadata.build(
self.param_infos = jax.tree.map(
# Other properties are not relevant.
lambda x: type_handlers.ParamInfo(
value_typestr=type_handlers.get_param_typestr(
x, type_handlers.GLOBAL_TYPE_HANDLER_REGISTRY
)
),
self.tree,
type_handler_registry=type_handlers.GLOBAL_TYPE_HANDLER_REGISTRY,
use_zarr3=True,
)
self.assertDictEqual(self.tree_json, metadata.to_json())
self.assertEqual(
metadata, tree_metadata.TreeMetadata.from_json(self.tree_json)
is_leaf=tree_utils.is_empty_or_leaf,
)

def test_aggregate_option(self):
save_args = jax.tree.map(
lambda x: type_handlers.SaveArgs(), self.tree
)
save_args['a'] = type_handlers.SaveArgs(aggregate=True)
def test_json_conversion(self):
metadata = tree_metadata.TreeMetadata.build(
self.tree,
type_handler_registry=type_handlers.GLOBAL_TYPE_HANDLER_REGISTRY,
save_args=save_args,
self.param_infos,
use_zarr3=True,
)
tree_json = dict(**self.tree_json)
tree_json['tree_metadata']["('a',)"]['value_metadata'][
'skip_deserialize'
] = True
self.assertDictEqual(self.tree_json, metadata.to_json())
self.assertEqual(
metadata, tree_metadata.TreeMetadata.from_json(self.tree_json)
Expand Down
20 changes: 20 additions & 0 deletions checkpoint/orbax/checkpoint/type_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ class ParamInfo:
Specifies the target size (in bytes) of each OCDBT data file.
ts_context:
Tensorstore context to use for reading/writing.
value_typestr: stores the original value's typestr (from TypeHandler).
Only required when saving.
"""

name: Optional[str] = None
Expand All @@ -193,6 +195,7 @@ class ParamInfo:
use_zarr3: Optional[bool] = False
ocdbt_target_data_file_size: Optional[int] = None
ts_context: Optional[ts.Context] = None
value_typestr: Optional[str] = None


@dataclasses.dataclass
Expand Down Expand Up @@ -2056,3 +2059,20 @@ def default_restore_type(args: RestoreArgs) -> Any:
return np.ndarray
else:
raise ValueError(f'Unsupported restore_args type: {type(args)}')


def get_param_typestr(value: Any, registry: TypeHandlerRegistry) -> str:
"""Retrieves the typestr for a given value."""
if is_supported_empty_aggregation_type(value):
typestr = get_empty_value_typestr(value)
else:
try:
handler = registry.get(type(value))
typestr = handler.typestr()
except ValueError:
# Not an error because users' training states often have a bunch of
# random unserializable objects in them (empty states, optimizer
# objects, etc.). An error occurring due to a missing TypeHandler
# will be surfaced elsewhere.
typestr = RESTORE_TYPE_NONE
return typestr

0 comments on commit 097a23b

Please sign in to comment.