Skip to content

Commit

Permalink
Set aggregate=False under t5x.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 643142781
  • Loading branch information
liangyaning33 authored and t5-copybara committed Jun 13, 2024
1 parent 9e0093e commit 645af9e
Showing 1 changed file with 2 additions and 8 deletions.
10 changes: 2 additions & 8 deletions t5x/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def latest_step(checkpoints_dir: str) -> Optional[int]:
def get_local_data(x):
"""Get local buffer for input data."""
if isinstance(x, jax.Array) and not isinstance(x, jax.core.Tracer):
return np.asarray(x.addressable_data(0))
return x.addressable_data(0)
else:
return x

Expand Down Expand Up @@ -2101,7 +2101,7 @@ def _construct_save_args(
"""Create SaveArgs for Orbax saving."""
if param_info.name.split('.')[0] != 'target':
dtype = None
return ocp.SaveArgs(aggregate=param_info.mesh_axes is None, dtype=dtype)
return ocp.SaveArgs(aggregate=False, dtype=dtype)


def _construct_restore_args(
Expand Down Expand Up @@ -2419,12 +2419,6 @@ def save(
functools.partial(_construct_save_args, dtype=self._save_dtype),
param_infos,
)
# If the params are to be aggregated, then get locally addressable data.
state_dict = jax.tree_util.tree_map(
lambda v, arg: get_local_data(v) if arg.aggregate else v,
state_dict,
save_args,
)

# Separate savable items.
args = {
Expand Down

0 comments on commit 645af9e

Please sign in to comment.