Skip to content

Commit

Permalink
Fix issue with restoring orbax checkpoints through transformation_fn
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 576571694
  • Loading branch information
liangyaning33 authored and t5-copybara committed Oct 25, 2023
1 parent 0a56776 commit 8a294fb
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion t5x/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2361,7 +2361,11 @@ def restore(
self.directory,
restore_dtype=self._restore_dtype,
)
return legacy_checkpointer.restore(path=path)
return legacy_checkpointer.restore(
path=path,
fallback_state=fallback_state,
state_transformation_fns=state_transformation_fns,
)

state_dict = self._train_state.state_dict()
# Returns a state dict rather than a train state.
Expand Down

0 comments on commit 8a294fb

Please sign in to comment.