Skip to content

Commit

Permalink
Allow clu.data.dataset_iterator.DatasetIterator in addition to tf.dat…
Browse files Browse the repository at this point in the history
…a.Iterator

PiperOrigin-RevId: 615907993
  • Loading branch information
liangyaning33 authored and t5-copybara committed Mar 14, 2024
1 parent 962b9d5 commit 707995a
Showing 1 changed file with 30 additions and 11 deletions.
41 changes: 30 additions & 11 deletions t5x/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2015,7 +2015,11 @@ class DatasetCheckpointHandler(ocp.CheckpointHandler):
def __init__(self, checkpoint_filename: str):
self._checkpoint_filename = checkpoint_filename

def save(self, directory: epath.Path, item: tf.data.Iterator):
def save(
self,
directory: epath.Path,
item: Union[tf.data.Iterator, clu.data.dataset_iterator.DatasetIterator],
):
"""Saves the given item.
Args:
Expand All @@ -2025,13 +2029,20 @@ def save(self, directory: epath.Path, item: tf.data.Iterator):
if jax.process_count() > 1:
directory /= f'process_{jax.process_index()}-of-{jax.process_count()}'
directory.mkdir(parents=False, exist_ok=False)
ckpt = tf.train.Checkpoint(ds=item)
ckpt.write(os.fspath(directory / self._checkpoint_filename))
if isinstance(item, tf.data.Iterator):
ckpt = tf.train.Checkpoint(ds=item)
ckpt.write(os.fspath(directory / self._checkpoint_filename))
elif isinstance(item, clu.data.dataset_iterator.DatasetIterator):
item.save(os.fspath(directory / self._checkpoint_filename))
multihost_utils.sync_global_devices('DatasetCheckpointHandler:save')

def restore(
self, directory: epath.Path, item: Optional[tf.data.Iterator] = None
) -> tf.data.Iterator:
self,
directory: epath.Path,
item: Optional[
Union[tf.data.Iterator, clu.data.dataset_iterator.DatasetIterator]
] = None,
) -> Union[tf.data.Iterator, clu.data.dataset_iterator.DatasetIterator]:
"""Restores the given item.
Args:
Expand All @@ -2045,10 +2056,13 @@ def restore(
raise ValueError('Must provide item to restore')
if jax.process_count() > 1:
directory /= f'process_{jax.process_index()}-of-{jax.process_count()}'
ckpt = tf.train.Checkpoint(ds=item)
ckpt.read(
os.fspath(directory / self._checkpoint_filename)
).assert_consumed()
if isinstance(item, tf.data.Iterator):
ckpt = tf.train.Checkpoint(ds=item)
ckpt.read(
os.fspath(directory / self._checkpoint_filename)
).assert_consumed()
elif isinstance(item, clu.data.dataset_iterator.DatasetIterator):
item.load(os.fspath(directory / self._checkpoint_filename))
return item

def structure(self, directory: epath.Path) -> Any:
Expand Down Expand Up @@ -2259,7 +2273,9 @@ def __init__(
directory: str,
train_state: train_state_lib.TrainState,
partitioner: partitioning.BasePartitioner,
dataset_iterator: Optional[tf.data.Iterator] = None,
dataset_iterator: Optional[
Union[tf.data.Iterator, clu.data.dataset_iterator.DatasetIterator]
] = None,
save_dtype: Optional[jnp.dtype] = None,
restore_dtype: Optional[jnp.dtype] = None,
keep: Optional[int] = None,
Expand All @@ -2276,11 +2292,14 @@ def __init__(
del keep_dataset_checkpoints
self._train_state = train_state
self._partitioner = partitioner
if isinstance(
dataset_iterator, clu.data.dataset_iterator.TfDatasetIterator
):
assert dataset_iterator._checkpoint
self._dataset_iterator = dataset_iterator
self._save_dtype = save_dtype
self._restore_dtype = restore_dtype
self._tmp_directory: Optional[epath.PathLike] = None

data_layout = partitioner.get_data_layout()
dataset_ckpt_name = (
f'{_TRAIN_DS_PREFIX}-'
Expand Down

0 comments on commit 707995a

Please sign in to comment.