Skip to content

Commit

Permalink
Enable ocdbt in t5x by default.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 652712150
  • Loading branch information
liangyaning33 authored and t5-copybara committed Jul 16, 2024
1 parent 29ebb8e commit dcdcda9
Showing 1 changed file with 1 addition and 2 deletions.
3 changes: 1 addition & 2 deletions t5x/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2323,8 +2323,7 @@ def __init__(
self._should_write_dataset_ckpt = (
self._dataset_iterator and data_layout.is_first_host_in_replica_set
)
# TODO(b/273803615) Enable OCDBT.
self._state_handler = ocp.PyTreeCheckpointHandler(use_ocdbt=False)
self._state_handler = ocp.PyTreeCheckpointHandler(use_ocdbt=True)
item_handlers = {
_STATE_KEY: self._state_handler,
_DATASET_KEY: DatasetCheckpointHandler(
Expand Down

1 comment on commit dcdcda9

@stefan-it
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @liangyaning33 ,

I have one question about the new Orbax checkpointing format.

Unfortunately, it is no longer possible to use the load_t5x_checkpoint() method to load checkpoints.

Do you have some code snippet to show how it is possible with the new Orbax format?

Many thanks in advance!

Please sign in to comment.