Skip to content

Commit

Permalink
Launch Orbax Checkpointing by default in t5x.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 610796739
  • Loading branch information
liangyaning33 authored and t5-copybara committed Feb 27, 2024
1 parent ecb126e commit 65ce7e9
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions t5x/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,13 +236,13 @@ def detect_checkpoint_type(
)
return checkpoint_type
else:
checkpoint_type = CheckpointTypes.ORBAX
checkpoint_type = CheckpointTypes.T5X
_warn_if_unexpected_type(
checkpoint_path,
checkpoint_type,
expected,
'Did not detect ts.Spec nor the {"version", "optimizer"} keys in the'
'checkpoint msgpack file, so the checkpoint was assumed to be '
'written with Orbax.',
'written with T5X.',
)
return checkpoint_type
4 changes: 2 additions & 2 deletions t5x/checkpoint_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def test_detect_checkpoint_type(self):
ret = checkpoint_utils.detect_checkpoint_type(
orbax_ckpt, expected=checkpoint_utils.CheckpointTypes.ORBAX
)
self.assertEqual(ret, checkpoint_utils.CheckpointTypes.ORBAX)
self.assertEqual(ret, checkpoint_utils.CheckpointTypes.T5X)

with self.assertLogs(level="WARN") as log_output:
checkpoint_utils.detect_checkpoint_type(
Expand All @@ -185,7 +185,7 @@ def test_detect_checkpoint_type(self):
self.assertRegex(
log_output[0][0].message,
".*to be CheckpointTypes.T5X_TF format, but the actual detected format"
" was CheckpointTypes.ORBAX.*",
" was CheckpointTypes.T5X.*",
)


Expand Down
2 changes: 1 addition & 1 deletion t5x/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def evaluate(
] = utils.TrainStateInitializer,
train_eval_get_dataset_fn: utils.GetEvalDatasetCallable = utils.get_training_eval_datasets,
fallback_init_rng: Optional[int] = None,
use_orbax: bool = False,
use_orbax: bool = True,
):
"""Evaluation function.
Expand Down
2 changes: 1 addition & 1 deletion t5x/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,7 @@ def infer(
output_vocab_feature_name: str = 'targets',
file_extension: str = 'jsonl',
keep_aux_as_numpy: bool = False,
use_orbax: bool = False,
use_orbax: bool = True,
):
"""Infer function.
Expand Down
2 changes: 1 addition & 1 deletion t5x/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def train(
train_state_initializer_cls: Type[
utils.TrainStateInitializer
] = utils.TrainStateInitializer,
use_orbax: bool = False,
use_orbax: bool = True,
verify_matching_vocabs_fn: Optional[
Callable[[utils.DatasetConfig, models.BaseTransformerModel], None]
] = utils.verify_matching_vocabs,
Expand Down

0 comments on commit 65ce7e9

Please sign in to comment.