Skip to content

Commit

Permalink
Enable AsyncCheckpointer by default in t5x.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 627123534
  • Loading branch information
liangyaning33 authored and t5-copybara committed Apr 22, 2024
1 parent a81874c commit 6b02d25
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 2 deletions.
10 changes: 8 additions & 2 deletions t5x/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def latest_step(checkpoints_dir: str) -> Optional[int]:
def get_local_data(x):
"""Get local buffer for input data."""
if isinstance(x, jax.Array) and not isinstance(x, jax.core.Tracer):
return x.addressable_data(0)
return np.asarray(x.addressable_data(0))
else:
return x

Expand Down Expand Up @@ -2310,7 +2310,7 @@ def __init__(
)
# TODO(b/273803615) Enable OCDBT.
self._state_handler = ocp.PyTreeCheckpointHandler(use_ocdbt=False)
checkpointers = {_STATE_KEY: ocp.Checkpointer(self._state_handler)}
checkpointers = {_STATE_KEY: ocp.AsyncCheckpointer(self._state_handler)}
if self._should_write_dataset_ckpt:
checkpointers[_DATASET_KEY] = ocp.Checkpointer(
DatasetCheckpointHandler(checkpoint_filename=dataset_ckpt_name)
Expand Down Expand Up @@ -2354,6 +2354,12 @@ def latest_step(self) -> Optional[int]:
def should_save(self, step: int) -> bool:
return self._manager.should_save(step)

def wait_until_finished(self):
return self._manager.wait_until_finished()

def close(self):
return self._manager.close()

def save(
self,
train_state: train_state_lib.TrainState,
Expand Down
5 changes: 5 additions & 0 deletions t5x/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,6 +753,7 @@ def _as_gda(spec):
trainer.train_state,
checkpoint_cfg.save.state_transformation_fns, # pytype: disable=attribute-error
)
checkpoint_manager.wait_until_finished()
logging.info('Saving emergency checkpoint done.')
raise e

Expand All @@ -773,6 +774,9 @@ def _as_gda(spec):
trainer.train_state,
checkpoint_cfg.save.state_transformation_fns, # pytype: disable=attribute-error
)
# `_run_training_eval`` depends upon the result of the checkpoint,
# thus calling `wait_until_finished()`` here.
checkpoint_manager.wait_until_finished()
checkpoint_tock = time.time()
train_metrics.write_scalar(
'timing/checkpoint_seconds',
Expand All @@ -793,6 +797,7 @@ def _as_gda(spec):
# Inference Evaluation (i.e., with decoding or scoring).
if is_eval_epoch and evaluator is not None:
_run_inference_eval()
checkpoint_manager.close()

# Wait until computations are done before exiting
_cleanup()
Expand Down
6 changes: 6 additions & 0 deletions t5x/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,12 @@ def __init__(
strict=strict,
)

def wait_until_finished(self):
pass

def close(self):
pass

def save(
self,
train_state: train_state_lib.TrainState,
Expand Down

0 comments on commit 6b02d25

Please sign in to comment.