Skip to content

Commit

Permalink
switch to create_checkpoint_manager_and_restore for export_lib.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 613349654
  • Loading branch information
liangyaning33 authored and t5-copybara committed Mar 6, 2024
1 parent 6ec92a8 commit 20e3e03
Showing 1 changed file with 21 additions and 3 deletions.
24 changes: 21 additions & 3 deletions t5x/export_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def preproc_func(self, *args):

def get_train_state_initializer(
model: models.BaseTransformerModel,
partitioner: Optional[partitioning.BasePartitioner],
partitioner: partitioning.BasePartitioner,
task_feature_lengths: Mapping[str, int],
batch_size: Optional[int],
trailing_shapes: Optional[Mapping[str, Tuple[int, ...]]] = None,
Expand Down Expand Up @@ -415,12 +415,29 @@ def inference_fn(
def load_params_from_checkpoint(
restore_checkpoint_cfg: utils.RestoreCheckpointConfig,
train_state_initializer: Optional[utils.TrainStateInitializer],
partitioner: partitioning.BasePartitioner,
) -> frozen_dict.FrozenDict:
"""Loads the checkpoint and casts the variable."""
if train_state_initializer is not None:
train_state = train_state_initializer.from_checkpoint(
restore_cfg, ckpt_paths = utils.get_first_valid_restore_config_and_paths(
[restore_checkpoint_cfg]
)
if len(ckpt_paths) != 1:
raise ValueError(
f'Expected only 1 checkpoint but got {len(ckpt_paths)} for '
f'config(s): {restore_cfg}'
)
train_state, _ = utils.create_checkpoint_manager_and_restore(
train_state_initializer=train_state_initializer,
partitioner=partitioner,
restore_checkpoint_cfg=restore_cfg,
restore_path=ckpt_paths[0],
fallback_init_rng=jax.random.PRNGKey(0),
save_checkpoint_cfg=None,
model_dir=None,
ds_iter=None,
use_orbax=True,
)
return train_state.params # pytype:disable=attribute-error
else:
if restore_checkpoint_cfg.mode != 'specific':
Expand Down Expand Up @@ -1361,7 +1378,7 @@ def save(
..., InferenceFn
] = create_inference_function,
create_postprocessor_fn: CreatePostprocessorFn = create_postprocessor,
partitioner: Optional[partitioning.BasePartitioner] = None,
partitioner: Optional[partitioning.BasePartitioner],
create_decoding_state_callback_fn: Optional[
CreateDecodingStateCallbackFn
] = None,
Expand Down Expand Up @@ -1511,6 +1528,7 @@ def save(
params = load_params_from_checkpoint(
restore_checkpoint_cfg=restore_checkpoint_cfg,
train_state_initializer=train_state_initializer,
partitioner=partitioner,
)

logging.info('Preparing Module to save...')
Expand Down

0 comments on commit 20e3e03

Please sign in to comment.