Skip to content

Commit

Permalink
Fix issue with dtype being dropped from RestoreArgs.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 615496686
  • Loading branch information
liangyaning33 authored and t5-copybara committed Mar 13, 2024
1 parent f6ec080 commit 962b9d5
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion t5x/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2081,7 +2081,7 @@ def _construct_restore_args(
"""Create RestoreArgs for Orbax restoration."""
if not isinstance(param_info, _OrbaxParamInfo): # from fallback
return ocp.RestoreArgs(dtype=dtype)
if param_info.name.split('.')[0] != 'target':
if param_info.name.split('/')[0] != 'target':
dtype = None
if param_info.mesh_axes is None:
return ocp.RestoreArgs(dtype=dtype)
Expand Down

0 comments on commit 962b9d5

Please sign in to comment.