diff --git a/checkpoint/CHANGELOG.md b/checkpoint/CHANGELOG.md index 3a3c752f..f5441b49 100644 --- a/checkpoint/CHANGELOG.md +++ b/checkpoint/CHANGELOG.md @@ -22,6 +22,8 @@ entries will not return information about array properties). ### Changed - Allow one directory creation request per item rather than 1 per item per host. - Make atomicity logic configurable, and encapsulate it within a class. +- Move all work for `_write_metadata_file` into a background thread to avoid +O(n) computation in building metadata. ### Fixed - Refactor ts.Context usage to be per-operation (save/restore) rather than a diff --git a/checkpoint/orbax/checkpoint/array_checkpoint_handler.py b/checkpoint/orbax/checkpoint/array_checkpoint_handler.py index da60db75..ad53cd4b 100644 --- a/checkpoint/orbax/checkpoint/array_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/array_checkpoint_handler.py @@ -88,13 +88,14 @@ async def async_save( ) ] + type_handler = type_handlers.get_type_handler(type(item)) info = type_handlers.ParamInfo( name=self._checkpoint_name, path=directory / self._checkpoint_name, parent_dir=directory, is_ocdbt_checkpoint=False, + value_typestr=type_handler.typestr(), ) - type_handler = type_handlers.get_type_handler(type(item)) futures = await type_handler.serialize([item], [info], args=[save_args]) return list(futures) diff --git a/checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py b/checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py index 441e2c24..74c559dd 100644 --- a/checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py +++ b/checkpoint/orbax/checkpoint/base_pytree_checkpoint_handler.py @@ -20,9 +20,9 @@ """ import asyncio +from concurrent import futures import dataclasses import json -import os import time from typing import Any, List, Optional, Tuple, Union @@ -273,6 +273,8 @@ def __init__( '/jax/orbax/pytree_checkpoint_handler/init/ocdbt' ) + self._thread_pool = futures.ThreadPoolExecutor(max_workers=1) + def get_param_names(self, item: PyTree) -> PyTree: """Gets parameter names for PyTree elements.""" return get_param_names(item) @@ -304,8 +306,6 @@ def _get_param_infos( Returns: A PyTree matching `item` of ParamInfo. """ - if not item: - raise ValueError('Found empty item') if use_zarr3 is None: use_zarr3 = self._use_zarr3 names = self.get_param_names(item) @@ -325,6 +325,9 @@ def _param_info(name, value): ocdbt_target_data_file_size=ocdbt_target_data_file_size, byte_limiter=byte_limiter, ts_context=ts_context, + value_typestr=type_handlers.get_param_typestr( + value, self._type_handler_registry + ), ) return jax.tree.map( @@ -373,6 +376,8 @@ async def async_save( the data from its source will be awaited in this function. """ item = args.item + if not item: + raise ValueError('Found empty item.') save_args = args.save_args ocdbt_target_data_file_size = args.ocdbt_target_data_file_size if ocdbt_target_data_file_size is not None and not self._use_zarr3: @@ -425,8 +430,8 @@ async def async_save( if multihost.is_primary_host(self._primary_host): metadata_write_start_time = time.time() - metadata_future = await self._write_metadata_file( - directory, item, save_args, self._use_zarr3 + metadata_future = self._write_metadata_file( + directory, param_infos, save_args, self._use_zarr3 ) commit_futures += [metadata_future] jax.monitoring.record_event_duration_secs( @@ -609,6 +614,8 @@ class TrainState: if use_zarr3_metadata is not None else self._use_zarr3 ) + if not metadata: + raise ValueError('Found empty metadata.') param_infos = self._get_param_infos( metadata, directory, @@ -633,33 +640,25 @@ class TrainState: return restored_item - async def _write_metadata_file( + def _write_metadata_file( self, directory: epath.Path, - item: PyTree, + param_infos: PyTree, save_args: PyTree, use_zarr3: bool = False, ) -> future.Future: - tspec = type_handlers._get_tensorstore_spec( # pylint: disable=protected-access - os.fspath(directory), name=METADATA_FILE, use_ocdbt=False - )['kvstore'] - txn = ts.Transaction() - metadata_ts_context = type_handlers.get_ts_context() - t = await ts.KvStore.open( - tspec, context=metadata_ts_context - ) - metadata_content = tree_metadata.TreeMetadata.build( - item, - save_args=save_args, - type_handler_registry=self._type_handler_registry, - use_zarr3=use_zarr3, - ) - write_future = t.with_transaction(txn).write( - '', json.dumps(metadata_content.to_json()) - ) - await write_future - commit_future = txn.commit_async() - return commit_future + def _save_fn(): + if utils.is_primary_host(self._primary_host): + path = directory / METADATA_FILE + metadata_content = tree_metadata.TreeMetadata.build( + param_infos, + save_args=save_args, + use_zarr3=use_zarr3, + ) + path.write_text(json.dumps(metadata_content.to_json())) + return 0 + + return self._thread_pool.submit(_save_fn) def _read_metadata_file( self, directory: epath.Path diff --git a/checkpoint/orbax/checkpoint/metadata/tree.py b/checkpoint/orbax/checkpoint/metadata/tree.py index c3b51b6d..e87d5fa5 100644 --- a/checkpoint/orbax/checkpoint/metadata/tree.py +++ b/checkpoint/orbax/checkpoint/metadata/tree.py @@ -150,32 +150,18 @@ def from_json(cls, json_dict: Dict[str, Any]) -> 'ValueMetadataEntry': @classmethod def build( cls, - value: Any, + info: type_handlers.ParamInfo, save_arg: type_handlers.SaveArgs, - registry: Optional[type_handlers.TypeHandlerRegistry], ) -> 'ValueMetadataEntry': """Builds a ValueMetadataEntry.""" - if type_handlers.is_supported_empty_aggregation_type(value): - typestr = type_handlers.get_empty_value_typestr(value) - skip_deserialize = True - elif registry is None: - return ValueMetadataEntry( - value_type=type_handlers.RESTORE_TYPE_UNKNOWN, skip_deserialize=False + del save_arg + if info.value_typestr is None: + raise AssertionError( + 'Must set `value_typestr` in `ParamInfo` when saving.' ) - else: - try: - handler = registry.get(type(value)) - typestr = handler.typestr() - skip_deserialize = save_arg.aggregate - except ValueError: - # Not an error because users' training states often have a bunch of - # random unserializable objects in them (empty states, optimizer - # objects, etc.). An error occurring due to a missing TypeHandler - # will be surfaced elsewhere. - typestr = type_handlers.RESTORE_TYPE_NONE - skip_deserialize = True + skip_deserialize = type_handlers.is_empty_typestr(info.value_typestr) return ValueMetadataEntry( - value_type=typestr, skip_deserialize=skip_deserialize + value_type=info.value_typestr, skip_deserialize=skip_deserialize ) @@ -208,15 +194,12 @@ def from_json( def build( cls, keypath: KeyPath, - value: Any, + info: type_handlers.ParamInfo, save_arg: type_handlers.SaveArgs, - type_handler_registry: Optional[type_handlers.TypeHandlerRegistry], ) -> 'TreeMetadataEntry': """Builds a TreeMetadataEntry.""" key_metadata_entry = KeyMetadataEntry.build(keypath) - value_metadata_entry = ValueMetadataEntry.build( - value, save_arg, type_handler_registry - ) + value_metadata_entry = ValueMetadataEntry.build(info, save_arg) return TreeMetadataEntry( str(tuple([str(tree_utils.get_key_name(k)) for k in keypath])), key_metadata_entry, @@ -242,9 +225,8 @@ class TreeMetadata: @classmethod def build( cls, - tree: PyTree, + param_infos: PyTree, *, - type_handler_registry: Optional[type_handlers.TypeHandlerRegistry] = None, save_args: Optional[PyTree] = None, use_zarr3: bool = False, ) -> 'TreeMetadata': @@ -252,23 +234,21 @@ def build( if save_args is None: save_args = jax.tree.map( lambda _: type_handlers.SaveArgs(), - tree, + param_infos, is_leaf=tree_utils.is_empty_or_leaf, ) flat_with_keys, _ = jax.tree_util.tree_flatten_with_path( - tree, is_leaf=tree_utils.is_empty_or_leaf + param_infos, is_leaf=tree_utils.is_empty_or_leaf ) flat_save_args_with_keys, _ = jax.tree_util.tree_flatten_with_path( save_args, is_leaf=tree_utils.is_empty_or_leaf ) tree_metadata_entries = [] - for (keypath, value), (_, save_arg) in zip( + for (keypath, info), (_, save_arg) in zip( flat_with_keys, flat_save_args_with_keys ): tree_metadata_entries.append( - TreeMetadataEntry.build( - keypath, value, save_arg, type_handler_registry - ) + TreeMetadataEntry.build(keypath, info, save_arg) ) return TreeMetadata(tree_metadata_entries, use_zarr3) diff --git a/checkpoint/orbax/checkpoint/metadata/tree_test.py b/checkpoint/orbax/checkpoint/metadata/tree_test.py index 52f1779a..896b0ad6 100644 --- a/checkpoint/orbax/checkpoint/metadata/tree_test.py +++ b/checkpoint/orbax/checkpoint/metadata/tree_test.py @@ -16,6 +16,7 @@ import jax from jax import numpy as jnp import numpy as np +from orbax.checkpoint import tree as tree_utils from orbax.checkpoint import type_handlers from orbax.checkpoint.metadata import tree as tree_metadata @@ -26,7 +27,15 @@ def setUp(self): super().setUp() arr = jnp.arange(8) assert isinstance(arr, jax.Array) - self.tree = {'a': 1, 'b': {'c': 'hi', 'd': 3.4}, 'e': [np.arange(8), arr]} + self.tree = { + 'a': 1, + 'b': {'c': 'hi', 'd': 3.4}, + 'e': [np.arange(8), arr], + 'f': None, + 'g': {}, + 'h': [], + 'i': tuple([]), + } self.tree_json = { 'tree_metadata': { "('a',)": { @@ -76,36 +85,53 @@ def setUp(self): 'skip_deserialize': False, }, }, + "('f',)": { + 'key_metadata': ({'key': 'f', 'key_type': 2},), + 'value_metadata': { + 'value_type': 'None', + 'skip_deserialize': True, + }, + }, + "('g',)": { + 'key_metadata': ({'key': 'g', 'key_type': 2},), + 'value_metadata': { + 'value_type': 'Dict', + 'skip_deserialize': True, + }, + }, + "('h',)": { + 'key_metadata': ({'key': 'h', 'key_type': 2},), + 'value_metadata': { + 'value_type': 'List', + 'skip_deserialize': True, + }, + }, + "('i',)": { + 'key_metadata': ({'key': 'i', 'key_type': 2},), + 'value_metadata': { + 'value_type': 'None', + 'skip_deserialize': True, + }, + }, }, 'use_zarr3': True, } - - def test_json_conversion(self): - metadata = tree_metadata.TreeMetadata.build( + self.param_infos = jax.tree.map( + # Other properties are not relevant. + lambda x: type_handlers.ParamInfo( + value_typestr=type_handlers.get_param_typestr( + x, type_handlers.GLOBAL_TYPE_HANDLER_REGISTRY + ) + ), self.tree, - type_handler_registry=type_handlers.GLOBAL_TYPE_HANDLER_REGISTRY, - use_zarr3=True, - ) - self.assertDictEqual(self.tree_json, metadata.to_json()) - self.assertEqual( - metadata, tree_metadata.TreeMetadata.from_json(self.tree_json) + is_leaf=tree_utils.is_empty_or_leaf, ) - def test_aggregate_option(self): - save_args = jax.tree.map( - lambda x: type_handlers.SaveArgs(), self.tree - ) - save_args['a'] = type_handlers.SaveArgs(aggregate=True) + def test_json_conversion(self): metadata = tree_metadata.TreeMetadata.build( - self.tree, - type_handler_registry=type_handlers.GLOBAL_TYPE_HANDLER_REGISTRY, - save_args=save_args, + self.param_infos, use_zarr3=True, ) - tree_json = dict(**self.tree_json) - tree_json['tree_metadata']["('a',)"]['value_metadata'][ - 'skip_deserialize' - ] = True self.assertDictEqual(self.tree_json, metadata.to_json()) self.assertEqual( metadata, tree_metadata.TreeMetadata.from_json(self.tree_json) diff --git a/checkpoint/orbax/checkpoint/type_handlers.py b/checkpoint/orbax/checkpoint/type_handlers.py index 4fe73ac2..04a01997 100644 --- a/checkpoint/orbax/checkpoint/type_handlers.py +++ b/checkpoint/orbax/checkpoint/type_handlers.py @@ -182,6 +182,8 @@ class ParamInfo: Specifies the target size (in bytes) of each OCDBT data file. ts_context: Tensorstore context to use for reading/writing. + value_typestr: stores the original value's typestr (from TypeHandler). + Only required when saving. """ name: Optional[str] = None @@ -193,6 +195,7 @@ class ParamInfo: use_zarr3: Optional[bool] = False ocdbt_target_data_file_size: Optional[int] = None ts_context: Optional[ts.Context] = None + value_typestr: Optional[str] = None @dataclasses.dataclass @@ -2056,3 +2059,20 @@ def default_restore_type(args: RestoreArgs) -> Any: return np.ndarray else: raise ValueError(f'Unsupported restore_args type: {type(args)}') + + +def get_param_typestr(value: Any, registry: TypeHandlerRegistry) -> str: + """Retrieves the typestr for a given value.""" + if is_supported_empty_aggregation_type(value): + typestr = get_empty_value_typestr(value) + else: + try: + handler = registry.get(type(value)) + typestr = handler.typestr() + except ValueError: + # Not an error because users' training states often have a bunch of + # random unserializable objects in them (empty states, optimizer + # objects, etc.). An error occurring due to a missing TypeHandler + # will be surfaced elsewhere. + typestr = RESTORE_TYPE_NONE + return typestr