Skip to content

Commit

Permalink
Propagate use_orbax_format to _get_optimizer_state_dict.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 611509017
  • Loading branch information
liangyaning33 authored and t5-copybara committed Feb 29, 2024
1 parent 65ce7e9 commit dcbe3e7
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions t5x/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,7 +1119,9 @@ def restore(
ckpt_contents = _maybe_update_ts_from_gcs_to_file(ckpt_contents)

ckpt_state_dict = self._get_optimizer_state_dict(
ckpt_contents, state_transformation_fns
ckpt_contents,
state_transformation_fns,
use_orbax_format=ckpt_type is checkpoint_utils.CheckpointTypes.ORBAX,
)

# The state dict may contain TensorStore specs that need to be read.
Expand Down Expand Up @@ -1363,9 +1365,13 @@ def _get_optimizer_state_dict(
self,
ckpt_contents: PyTree,
state_transformation_fns: Sequence[RestoreStateTransformationFn],
use_orbax_format: bool = False,
):
return _get_optimizer_state_dict(
ckpt_contents, self._train_state.state_dict(), state_transformation_fns
ckpt_contents,
self._train_state.state_dict(),
state_transformation_fns,
use_orbax_format,
)


Expand Down

0 comments on commit dcbe3e7

Please sign in to comment.