Skip to content

Commit

Permalink
Set aggregate to False under SaveArgs within t5x.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 642417224
  • Loading branch information
liangyaning33 authored and t5-copybara committed Jun 11, 2024
1 parent 7f5ed21 commit 9e0093e
Showing 1 changed file with 11 additions and 26 deletions.
37 changes: 11 additions & 26 deletions t5x/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def latest_step(checkpoints_dir: str) -> Optional[int]:
def get_local_data(x):
"""Get local buffer for input data."""
if isinstance(x, jax.Array) and not isinstance(x, jax.core.Tracer):
return x.addressable_data(0)
return np.asarray(x.addressable_data(0))
else:
return x

Expand Down Expand Up @@ -2101,7 +2101,7 @@ def _construct_save_args(
"""Create SaveArgs for Orbax saving."""
if param_info.name.split('.')[0] != 'target':
dtype = None
return ocp.SaveArgs(aggregate=False, dtype=dtype)
return ocp.SaveArgs(aggregate=param_info.mesh_axes is None, dtype=dtype)


def _construct_restore_args(
Expand Down Expand Up @@ -2154,7 +2154,7 @@ def _construct_orbax_restoration_transforms(
)
assert state_subdir.is_dir(), state_subdir
use_orbax_format = state_subdir.stem == _STATE_KEY # Standard Orbax format
structure, _ = state_handler._handler_impl._get_internal_metadata( # pylint: disable=protected-access
structure = state_handler._handler_impl._read_aggregate_file( # pylint: disable=protected-access
state_subdir
)
# Note: Ideally we would use Orbax's `transform_fn` to do this logic, but
Expand Down Expand Up @@ -2187,13 +2187,7 @@ def _transform_fn(
del structure_, param_infos_

def _make_orbax_internal_metadata(value: Any, args: ocp.RestoreArgs):
if isinstance(
value, ocp.pytree_checkpoint_handler._InternalValueMetadata # pylint: disable=protected-access
):
if value.restore_type == 'scalar':
return ocp.pytree_checkpoint_handler._InternalValueMetadata( # pylint: disable=protected-access
restore_type='scalar'
)
if ocp.utils.leaf_is_placeholder(value):
if isinstance(args, ocp.ArrayRestoreArgs):
restore_type = 'jax.Array'
else:
Expand Down Expand Up @@ -2425,6 +2419,13 @@ def save(
functools.partial(_construct_save_args, dtype=self._save_dtype),
param_infos,
)
# If the params are to be aggregated, then get locally addressable data.
state_dict = jax.tree_util.tree_map(
lambda v, arg: get_local_data(v) if arg.aggregate else v,
state_dict,
save_args,
)

# Separate savable items.
args = {
_STATE_KEY: ocp.args.PyTreeSave(
Expand Down Expand Up @@ -2578,22 +2579,6 @@ def _maybe_make_sharded_array_helper(arr, info):
_maybe_make_sharded_array_helper, state_dict, param_infos
)

def convert_scalars(v, metadata):
if (
isinstance(
metadata, ocp.pytree_checkpoint_handler._InternalValueMetadata # pylint: disable=protected-access
)
and metadata.restore_type == 'scalar'
):
return np.int32(v)
return v

state_dict = jax.tree_util.tree_map(
convert_scalars,
state_dict,
state_dict_to_restore,
)

train_state = self._train_state.restore_state(state_dict)

end_time = time.time()
Expand Down

0 comments on commit 9e0093e

Please sign in to comment.