Skip to content

Commit

Permalink
Fix sync_global_devices issue.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 634815978
  • Loading branch information
liangyaning33 authored and t5-copybara committed May 17, 2024
1 parent a42d50b commit dddbfca
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 34 deletions.
73 changes: 39 additions & 34 deletions t5x/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@
LazyArray = checkpoint_importer.LazyArray
LazyAwaitableArray = checkpoint_importer.LazyAwaitableArray
LazyThreadPoolArray = checkpoint_importer.LazyThreadPoolArray
Dataset = Union[tf.data.Iterator, clu.data.dataset_iterator.DatasetIterator]
Dataset = Union[
tf.data.Iterator, clu.data.dataset_iterator.DatasetIterator, None
]

# Version 3 is used since 2021-06-10, compared to version 2 the only change is
# that `bfloat16` arrays are written in Tensorstore using its native `bfloat16`
Expand Down Expand Up @@ -2019,8 +2021,9 @@ class _OrbaxParamInfo:
class DatasetCheckpointHandler(ocp.CheckpointHandler):
"""A CheckpointHandler implementation that handles tf.data.Iterator."""

def __init__(self, checkpoint_filename: str):
def __init__(self, checkpoint_filename: str, should_write_dataset_ckpt: bool):
self._checkpoint_filename = checkpoint_filename
self._should_write_dataset_ckpt = should_write_dataset_ckpt

def save(
self,
Expand All @@ -2033,18 +2036,18 @@ def save(
directory: save location directory.
args: DatasetArgs (see below).
"""
item = args.item
if item is None:
raise ValueError('Must provide item to save.')
if jax.process_count() > 1:
directory /= f'process_{jax.process_index()}-of-{jax.process_count()}'
directory.mkdir(parents=False, exist_ok=False)
if isinstance(item, tf.data.Iterator):
ckpt = tf.train.Checkpoint(ds=item)
ckpt.write(os.fspath(directory / self._checkpoint_filename))
elif isinstance(item, clu.data.dataset_iterator.DatasetIterator):
item.save(os.fspath(directory / self._checkpoint_filename))
multihost_utils.sync_global_devices('DatasetCheckpointHandler:save')
if self._should_write_dataset_ckpt:
item = args.item
if item is None:
raise ValueError('Must provide item to save.')
if jax.process_count() > 1:
directory /= f'process_{jax.process_index()}-of-{jax.process_count()}'
directory.mkdir(parents=False, exist_ok=False)
if isinstance(item, tf.data.Iterator):
ckpt = tf.train.Checkpoint(ds=item)
ckpt.write(os.fspath(directory / self._checkpoint_filename))
elif isinstance(item, clu.data.dataset_iterator.DatasetIterator):
item.save(os.fspath(directory / self._checkpoint_filename))

def restore(
self,
Expand All @@ -2060,19 +2063,20 @@ def restore(
Returns:
a tf.data.Iterator restored from `directory`.
"""
if args is None:
raise ValueError('Must provide args to restore.')
item = args.item
if jax.process_count() > 1:
directory /= f'process_{jax.process_index()}-of-{jax.process_count()}'
if isinstance(item, tf.data.Iterator):
ckpt = tf.train.Checkpoint(ds=item)
ckpt.read(
os.fspath(directory / self._checkpoint_filename)
).assert_consumed()
elif isinstance(item, clu.data.dataset_iterator.DatasetIterator):
item.load(os.fspath(directory / self._checkpoint_filename))
return item
if self._should_write_dataset_ckpt:
if args is None:
raise ValueError('Must provide args to restore.')
item = args.item
if jax.process_count() > 1:
directory /= f'process_{jax.process_index()}-of-{jax.process_count()}'
if isinstance(item, tf.data.Iterator):
ckpt = tf.train.Checkpoint(ds=item)
ckpt.read(
os.fspath(directory / self._checkpoint_filename)
).assert_consumed()
elif isinstance(item, clu.data.dataset_iterator.DatasetIterator):
item.load(os.fspath(directory / self._checkpoint_filename))
return item


@ocp.args.register_with_handler(
Expand Down Expand Up @@ -2323,11 +2327,13 @@ def __init__(
)
# TODO(b/273803615) Enable OCDBT.
self._state_handler = ocp.PyTreeCheckpointHandler(use_ocdbt=False)
item_handlers = {_STATE_KEY: self._state_handler}
if self._should_write_dataset_ckpt:
item_handlers[_DATASET_KEY] = DatasetCheckpointHandler(
checkpoint_filename=dataset_ckpt_name
)
item_handlers = {
_STATE_KEY: self._state_handler,
_DATASET_KEY: DatasetCheckpointHandler(
checkpoint_filename=dataset_ckpt_name,
should_write_dataset_ckpt=self._should_write_dataset_ckpt,
),
}

def best_fn(metrics):
return metrics[metric_name_to_monitor]
Expand Down Expand Up @@ -2425,9 +2431,8 @@ def save(
state_dict,
save_args=save_args,
),
_DATASET_KEY: DatasetArgs(self._dataset_iterator),
}
if self._should_write_dataset_ckpt:
args[_DATASET_KEY] = DatasetArgs(self._dataset_iterator)
args = ocp.args.Composite(**args)
saved = self._manager.save(step, args=args, force=force)

Expand Down
49 changes: 49 additions & 0 deletions t5x/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import seqio
import t5.data
from t5x import adafactor
from t5x import checkpoints
from t5x import models
from t5x import optimizers
from t5x import partitioning
Expand Down Expand Up @@ -409,3 +410,51 @@ def partition(

def compile(self, partitioned_fn, *args):
return None

# -------------------- Checkpoint helpers --------------------


def _train_state_shapes(train_state):
def _maybe_get(x):
if isinstance(x, LazyArray):
return x.get()
return x

train_state = jax.tree_util.tree_map(_maybe_get, train_state)
return jax.eval_shape(lambda: train_state)


def save(checkpointer_or_manager, train_state, force=False):
saved = checkpointer_or_manager.save(train_state, force=force)
checkpointer_or_manager.wait_until_finished()
return saved


def create_checkpointer_or_manager(
train_state_shapes,
partitioner,
directory,
dataset_iterator=None,
save_dtype=None,
restore_dtype=None,
best=False,
keep=None,
period=1,
checkpoint_steps=None,
keep_checkpoints_without_metrics=True,
):
"""Creates an Orbax CheckpointManagerInterface."""
metric_name_to_monitor = 'train/accuracy' if best else None
return checkpoints.OrbaxCheckpointManagerInterface(
directory,
train_state_shapes,
partitioner,
dataset_iterator=dataset_iterator,
save_dtype=save_dtype,
restore_dtype=restore_dtype,
keep=keep,
period=period,
checkpoint_steps=checkpoint_steps,
metric_name_to_monitor=metric_name_to_monitor,
keep_checkpoints_without_metrics=keep_checkpoints_without_metrics,
)

0 comments on commit dddbfca

Please sign in to comment.