diff --git a/t5x/checkpoint_utils.py b/t5x/checkpoint_utils.py index d817b61af..4ed726df3 100644 --- a/t5x/checkpoint_utils.py +++ b/t5x/checkpoint_utils.py @@ -20,11 +20,10 @@ import enum import os -from typing import Any, BinaryIO, Optional, Tuple, Union +from typing import Any, BinaryIO, Optional, Union from absl import logging from etils import epath -import jax import msgpack import orbax.checkpoint as ocp from tensorflow.io import gfile @@ -257,27 +256,21 @@ def _is_supported_empty_value(value: Any) -> bool: return ocp.type_handlers.is_supported_empty_value(value) -def get_restore_parameters( - directory: epath.Path, - structure: PyTree, -) -> Tuple[PyTree, PyTree]: - """Construct parameters needed for restoration. +def get_restore_parameters(directory: epath.Path, structure: PyTree) -> PyTree: + """Construct ParamInfos tree needed for restoration. - ParamInfos are - constructed from the structure of the original checkpoint, and restore_args - are serialized to a tree structure compatible with param_infos and structure. + ParamInfos are constructed from the structure of the original checkpoint. Args: directory: Checkpoint directory. structure: The structure of the original checkpoint. Returns: - Tuple of param_infos, and restore_args. + PyTree of `ParamInfo`. """ flat_structure = ocp.tree.to_flat_dict(structure, keep_empty_nodes=True) param_names = ocp.tree.get_param_names(structure) flat_param_names = ocp.tree.to_flat_dict(param_names, keep_empty_nodes=True) - restore_args = jax.tree.map(lambda x: ocp.RestoreArgs(), structure) flat_param_infos = {} is_ocdbt_checkpoint = ocp.type_handlers.is_ocdbt_checkpoint(directory) ts_context = ocp.type_handlers.get_ts_context() @@ -305,9 +298,5 @@ def _get_param_info( for key, meta in flat_structure.items(): flat_param_infos[key] = _get_param_info(flat_param_names[key], meta) - restore_args = ocp.tree.serialize_tree(restore_args, keep_empty_nodes=True) - return ( - ocp.tree.from_flat_dict(flat_param_infos, target=structure), - restore_args, - ) + return ocp.tree.from_flat_dict(flat_param_infos, target=structure) diff --git a/t5x/checkpoints.py b/t5x/checkpoints.py index 315604c87..ada1b4e7a 100644 --- a/t5x/checkpoints.py +++ b/t5x/checkpoints.py @@ -2166,7 +2166,7 @@ def _modify_orbax_param_info(info, value): return info item_ = jax.tree.map(_make_orbax_internal_metadata, item_, restore_args) - param_infos_, _ = checkpoint_utils.get_restore_parameters(directory_, item_) + param_infos_ = checkpoint_utils.get_restore_parameters(directory_, item_) param_infos_ = jax.tree.map( _modify_orbax_param_info, param_infos_, state_dict_to_restore )