diff --git a/vmoe/checkpoints/periodic_actions.py b/vmoe/checkpoints/periodic_actions.py deleted file mode 100644 index 7a801a8..0000000 --- a/vmoe/checkpoints/periodic_actions.py +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright 2023 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""PeriodicAction that saves checkpoints periodically.""" -import multiprocessing -import os -from typing import Iterable, Optional - -from clu import periodic_actions -from clu.data import dataset_iterator as clu_dataset_iterator -import jax - -from vmoe import multihost_utils -from vmoe.checkpoints import base as checkpoints_base -from vmoe.checkpoints import partitioned as checkpoints_partitioned - - -DatasetIterator = clu_dataset_iterator.DatasetIterator -MapResult = checkpoints_partitioned.MapResult -PyTree = checkpoints_partitioned.PyTree -ThreadPool = checkpoints_partitioned.ThreadPool - - -class PeriodicSaveCheckpoint(periodic_actions.PeriodicCallback): - """Saves checkpoints of a partitioned training state periodically. - - Example: - saver = PeriodicSaveCheckpoint(prefix='/tmp/ckpt', every_steps=10) - for step in range(100): - state = update_state(...) - saver(step=step, state=state) # Saves at steps 0, 10, 20, 30, ... - """ - - def __init__( - self, - *, - prefix: str, - num_shards: int = 0, - num_threads: Optional[int] = None, - wait_seconds: Optional[int] = None, - every_steps: Optional[int] = None, - every_secs: Optional[float] = None, - on_steps: Optional[Iterable[int]] = None, - keep_last: Optional[int] = None, - keep_steps_multiple_of: Optional[int] = None, - execute_async: bool = True, - report_progress: Optional[periodic_actions.ReportProgress] = None, - report_progress_name: str = 'ckpt'): - """Initializer. - - Args: - prefix: Prefix for the checkpoint files. The step number is appended to - this when a checkpoint is written (e.g. prefix='ckpt_' gives checkpoints - 'ckpt_1', 'ckpt_2', ...). - num_shards: Number of checkpoint shards. If `num_shards <= 0`, the minimum - number of shards will be used. If `num_shards > 0`, this number is only - tentative. - num_threads: Number of threads to use for writing checkpoint shards. If - None, `multiprocessing.pool.cpu_count()` is used. - wait_seconds: If given, we wait at most this number of seconds for the - checkpoint writing to complete. Otherwise, TimeoutError is raised. - every_steps: If given, writes a checkpoint every `every_steps` steps. - every_secs: If given, writes a checkpoint every `every_secs` seconds. - on_steps: If given, writes a checkpoint on these particular steps. - keep_last: If given, we only keep the last `keep_last` checkpoints. - If None, only the last checkpoint is kept. - keep_steps_multiple_of: If given, all steps multiple of this number are - kept (in addition to the `keep_last` steps). - execute_async: If True, writes checkpoints shards asynchronously. - If False, waits `wait_seconds` for the writing to complete. Note that, - even if this is True, we always wait up to `wait_seconds` between two - consecutive checkpointing steps. - report_progress: When given, the `timed()` method of this `ReportProgress` - is used to time the saving of checkpoints. - report_progress_name: Name used by `ReportProgress.timed()`. - """ - self._thread_pool = ThreadPool(processes=num_threads) - self._async_result = None # type: Optional[MapResult] - self._wait_seconds = wait_seconds - self._makedirs(os.path.dirname(prefix)) - keep_last = max(keep_last or 1, 1) - keep_multiple = max(keep_steps_multiple_of or 0, 0) - super().__init__( - every_steps=every_steps, - every_secs=every_secs, - on_steps=on_steps, - callback_fn=self._make_callback_fn( - prefix, num_shards, wait_seconds, keep_last, keep_multiple, - execute_async, self._thread_pool, - report_progress, report_progress_name), - # Note: save_checkpoint() is still asynchronous. This just means that - # we wait until the callback_fn returns. - execute_async=False, - pass_step_and_time=True) - - def __del__(self): - if self._async_result: - self._block_async_result(self._wait_seconds) - self._thread_pool.close() - - @classmethod - def _makedirs(cls, workdir: str): - # Process 0 creates the workdir if it doesn't exist. All processes wait - # until this is done. - if jax.process_index() == 0 and not os.path.exists(workdir): - checkpoints_base.gfile.makedirs(workdir) - multihost_utils.sync_devices(f'checkpoints:mkdir:{workdir}') - - @classmethod - def _remove_old_checkpoints(cls, prefix: str, keep_last: int, - keep_multiple: int, thread_pool: ThreadPool): - - def _parse_step_from_filepath(filepath): - m = checkpoints_base.CHECKPOINT_REGEX.fullmatch(filepath) - step_str = m.group(2) if m else None - return int(step_str[1:]) if step_str else None - - def _find_step_numbers(filepaths): - for step in map(_parse_step_from_filepath, filepaths): - if step is not None: - yield step - - def _remove(): - # Find step number of pending shards. - workdir = os.path.dirname(prefix) - basename = os.path.basename(prefix) - prefix_tmp = os.path.join(workdir, f'.tmp.{basename}') + '*' - checkpoints_tmp = checkpoints_base.gfile.glob(prefix_tmp) - pending_steps = set(_find_step_numbers(checkpoints_tmp)) - # Find all completed shards. - checkpoints = checkpoints_base.gfile.glob(prefix + '*') - completed_steps = set(_find_step_numbers(checkpoints)) - # Keep `keep_last` completed steps. - keep_steps = set(sorted(completed_steps - pending_steps)[-keep_last:]) - # Keep steps multiple of `keep_multiple`. - if keep_multiple > 0: - keep_steps.update([ - step for step in completed_steps if step % keep_multiple == 0]) - # Always keep pending steps. - keep_steps.update(pending_steps) - # Remove checkpoints. - def match_remove_fn(filepath): - # Returns True (to remove) if the step is not in `keep_steps`. - step = _parse_step_from_filepath(filepath) - return (step not in keep_steps) if step is not None else False - checkpoints_base.remove_checkpoints( - checkpoints, match_remove_fn, thread_pool=thread_pool) - - # Only process 0 removes files. All processes wait untils this is done. - if jax.process_index() == 0: - _remove() - multihost_utils.sync_devices(f'checkpoints:remove:{prefix}') - - def _block_async_result(self, wait_seconds: Optional[int]): - try: - self._async_result.get(wait_seconds) - self._async_result = None - except multiprocessing.context.TimeoutError as exc: - raise TimeoutError('Timeout while writing checkpoint files after ' - f'{wait_seconds} seconds.') from exc - - def _make_callback_fn(self, prefix, num_shards, wait_seconds, keep_last, - keep_multiple, execute_async, thread_pool, - report_progress, report_progress_name): - - def callback_fn(step: int, t: float, state: PyTree, - iterator: Optional[DatasetIterator] = None): - del t # Unused. - # Wait up to `wait_seconds` seconds, until the previous checkpoint is - # completed before starting to write a new checkpoint. If the timeout - # expires, an exception is raised. This is to avoid having multiple copies - # of the model in the CPU memory. - if self._async_result: - self._block_async_result(wait_seconds) - multihost_utils.sync_devices(f'checkpoints:sync_pending:{prefix}') - # Remove outdated checkpoints before starting writing new ones. - self._remove_old_checkpoints( - prefix, keep_last, keep_multiple, thread_pool) - # Save new checkpoint. - self._async_result = checkpoints_partitioned.save_checkpoint( - prefix=f'{prefix}_{step}', - tree=state, - num_shards=num_shards, - thread_pool=thread_pool, - makedirs=False, - overwrite=True) - # Optionally, wait `wait_seconds` until the checkpointing is done, or - # raise an exception if writing doesn't finish in `wait_seconds`. - if not execute_async: - self._block_async_result(wait_seconds) - multihost_utils.sync_devices(f'checkpoints:no_async:{prefix}') - - if report_progress is None: - return callback_fn - else: - return report_progress.timed( - report_progress_name, wait_jax_async_dispatch=False)(callback_fn) diff --git a/vmoe/checkpoints/periodic_actions_test.py b/vmoe/checkpoints/periodic_actions_test.py deleted file mode 100644 index fb4ebc8..0000000 --- a/vmoe/checkpoints/periodic_actions_test.py +++ /dev/null @@ -1,118 +0,0 @@ -# Copyright 2023 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for periodic_actions.""" -import os -import random -import time -from unittest import mock - -from absl.testing import absltest -from vmoe.checkpoints import periodic_actions - - -class PeriodicSaveCheckpointRemoveOldCheckpointsTest(absltest.TestCase): - - def setUp(self): - super().setUp() - self.workdir = self.create_tempdir() - for step in range(10): - self.workdir.create_file('ckpt_with_no_step') - self.workdir.create_file(f'not_a_ckpt_{step}') - self.workdir.create_file(f'ckpt_main_{step}') - self.workdir.create_file(f'ckpt_model_{step}-00000-of-00001') - self.workdir.create_file(f'ckpt_host_{step}-00000-of-00002') - self.workdir.create_file(f'ckpt_host_{step}-00001-of-00002') - - def test_remove_old_checkpoints_with_keep_multiple_of(self): - periodic_actions.PeriodicSaveCheckpoint._remove_old_checkpoints( - prefix=os.path.join(self.workdir.full_path, 'ckpt_'), - keep_last=2, - keep_multiple=5, - thread_pool=None) - expected = (['ckpt_with_no_step'] + - [f'not_a_ckpt_{step}' for step in range(10)] + - [f'ckpt_main_{step}' for step in [0, 5, 8, 9]] + - [f'ckpt_model_{step}-00000-of-00001' for step in [0, 5, 8, 9]] + - [f'ckpt_host_{step}-00000-of-00002' for step in [0, 5, 8, 9]] + - [f'ckpt_host_{step}-00001-of-00002' for step in [0, 5, 8, 9]]) - self.assertCountEqual(expected, os.listdir(self.workdir.full_path)) - - def test_remove_old_checkpoints_with_no_keep_multiples(self): - periodic_actions.PeriodicSaveCheckpoint._remove_old_checkpoints( - prefix=os.path.join(self.workdir.full_path, 'ckpt_'), - keep_last=2, - keep_multiple=0, - thread_pool=None) - expected = (['ckpt_with_no_step'] + - [f'not_a_ckpt_{step}' for step in range(10)] + - [f'ckpt_main_{step}' for step in [8, 9]] + - [f'ckpt_model_{step}-00000-of-00001' for step in [8, 9]] + - [f'ckpt_host_{step}-00000-of-00002' for step in [8, 9]] + - [f'ckpt_host_{step}-00001-of-00002' for step in [8, 9]]) - self.assertCountEqual(expected, os.listdir(self.workdir.full_path)) - - -class PeriodicSaveCheckpointTest(absltest.TestCase): - - @mock.patch.object( - periodic_actions.checkpoints_partitioned, 'save_checkpoint') - def test(self, mock_save_checkpoint): - # When calling save_checkpoint, we do nothing but we'll wait a few seconds. - def _save_checkpoint_side_effect(*args, thread_pool, **kwargs): - del args - del kwargs - wait = lambda: time.sleep(random.randint(1, 3)) - return thread_pool.apply_async(wait) - mock_save_checkpoint.side_effect = _save_checkpoint_side_effect - # Run a few steps, calling saver on each step. - prefix = os.path.join(self.create_tempdir().full_path, 'ckpt') - saver = periodic_actions.PeriodicSaveCheckpoint( - prefix=prefix, - every_steps=4) - for step in range(1, 10): - saver(step=step, state={}) - # Check that the saver was called twice, on steps 4 and 8. - call_args_list = mock_save_checkpoint.call_args_list - self.assertLen(call_args_list, 2) - self.assertEqual(call_args_list[0], - mock.call(prefix=prefix + '_4', num_shards=0, - makedirs=False, overwrite=True, tree={}, - thread_pool=mock.ANY)) - self.assertEqual(call_args_list[1], - mock.call(prefix=prefix + '_8', num_shards=0, - makedirs=False, overwrite=True, tree={}, - thread_pool=mock.ANY)) - saver.__del__() - - def test_report_progress(self): - mock_report_progress = mock.MagicMock( - periodic_actions.periodic_actions.ReportProgress) - # Run a few steps, calling saver on each step. - prefix = os.path.join(self.create_tempdir().full_path, 'ckpt') - saver = periodic_actions.PeriodicSaveCheckpoint( - prefix=prefix, - every_steps=4, - report_progress=mock_report_progress, - report_progress_name='foo') - for step in range(1, 10): - saver(step=step, state={}) - call_args_list = mock_report_progress.timed.call_args_list - self.assertLen(call_args_list, 1) - self.assertEqual(call_args_list[0], - mock.call('foo', wait_jax_async_dispatch=False)) - - -if __name__ == '__main__': - absltest.main() diff --git a/vmoe/configs/eee_paper/eee_s32_last2_ilsvrc2012_ft_cifar100.py b/vmoe/configs/eee_paper/eee_s32_last2_ilsvrc2012_ft_cifar100.py index d9f49fd..9a884b1 100644 --- a/vmoe/configs/eee_paper/eee_s32_last2_ilsvrc2012_ft_cifar100.py +++ b/vmoe/configs/eee_paper/eee_s32_last2_ilsvrc2012_ft_cifar100.py @@ -165,7 +165,7 @@ def get_config(): config.save_checkpoint.every_steps = 1_000 config.save_checkpoint.keep_last = 1 config.save_checkpoint.num_shards = 32 # Target number of checkpoint shards. - config.save_checkpoint.wait_seconds = 1.0 + config.save_checkpoint.wait_seconds = 300 # Report training progress every 100 steps. config.report_progress = ml_collections.ConfigDict() config.report_progress.every_secs = None diff --git a/vmoe/configs/vmoe_paper/common.py b/vmoe/configs/vmoe_paper/common.py index 494b61b..f4fd172 100644 --- a/vmoe/configs/vmoe_paper/common.py +++ b/vmoe/configs/vmoe_paper/common.py @@ -60,7 +60,7 @@ def get_base_config() -> ml_collections.ConfigDict: config.save_checkpoint = ml_collections.ConfigDict() config.save_checkpoint.every_steps = 1_000 config.save_checkpoint.keep_last = 1 - config.save_checkpoint.wait_seconds = 1.0 + config.save_checkpoint.wait_seconds = 300 # Report training progress every minute. config.report_progress = ml_collections.ConfigDict() config.report_progress.every_secs = None diff --git a/vmoe/data/input_pipeline.py b/vmoe/data/input_pipeline.py index 40b7183..60c0cb6 100644 --- a/vmoe/data/input_pipeline.py +++ b/vmoe/data/input_pipeline.py @@ -19,6 +19,7 @@ """ from typing import Any, Callable, Dict, Optional, Union +from absl import logging from clu.data import dataset_iterator import jax import ml_collections @@ -85,6 +86,11 @@ def get_dataset( Returns: A DatasetIterator. """ + if variant == 'train' and shuffle_seed is not None: + logging.error('Deterministic training is not supported but you specified ' + 'shuffle_seed=%d for training. This can potentially lead to ' + 'data being repeated if restarts happen during training.', + shuffle_seed) builder = vmoe.data.builder.get_dataset_builder( name=name, split=split, diff --git a/vmoe/projects/soft_moe/configs/common.py b/vmoe/projects/soft_moe/configs/common.py index 87f9021..649c031 100644 --- a/vmoe/projects/soft_moe/configs/common.py +++ b/vmoe/projects/soft_moe/configs/common.py @@ -42,7 +42,7 @@ def get_base_config() -> ml_collections.ConfigDict: config.save_checkpoint = ml_collections.ConfigDict() config.save_checkpoint.every_steps = 1_000 config.save_checkpoint.keep_last = 1 - config.save_checkpoint.wait_seconds = 10 + config.save_checkpoint.wait_seconds = 300 # Report training progress every minute to avoid hitting maximum RPC/s quota. config.report_progress = ml_collections.ConfigDict() config.report_progress.every_secs = 60.0 diff --git a/vmoe/train/trainer.py b/vmoe/train/trainer.py index 081ccca..eb2375a 100644 --- a/vmoe/train/trainer.py +++ b/vmoe/train/trainer.py @@ -36,12 +36,12 @@ import ml_collections import numpy as np import optax -from vmoe import checkpoints +import orbax.checkpoint +import tensorflow as tf from vmoe import initialization from vmoe import multihost_utils from vmoe import partitioning from vmoe import utils -from vmoe.checkpoints import periodic_actions as checkpoints_periodic_actions from vmoe.data import input_pipeline from vmoe.data import pjit_utils from vmoe.evaluate import ensemble @@ -62,7 +62,6 @@ Mesh = partitioning.Mesh NamedSharding = jax.sharding.NamedSharding PartitionSpec = partitioning.PartitionSpec -PeriodicCheckpointSaver = checkpoints_periodic_actions.PeriodicSaveCheckpoint PRNGKey = Union[jax.numpy.ndarray, jax.random.KeyArray] PyTree = Any ReportProgress = train_periodic_actions.ReportProgress @@ -132,18 +131,38 @@ def accum_fn(i, state): return new_grad_and_metrics_fn -def create_checkpoint_hook(*, workdir: str, progress_hook: ReportProgress, - train_steps: int, - **kwargs) -> PeriodicCheckpointSaver: - on_steps = set(kwargs.pop('on_steps', [])) - # Always save checkpoint on the last step. - on_steps.update((0, train_steps)) - return PeriodicCheckpointSaver( - prefix=os.path.join(workdir, 'ckpt'), - report_progress=progress_hook, - report_progress_name='ckpt', - on_steps=on_steps, - **kwargs) +def create_checkpoint_manager( + *, + workdir: str, + every_steps: int, + keep_last: Optional[int] = None, + keep_steps_multiple_of: Optional[int] = None, + wait_seconds: int = 300, +) -> orbax.checkpoint.CheckpointManager: + """Creates an Orbax checkpoint manager.""" + directory = os.path.join(workdir, 'ckpt') + if jax.process_index() == 0 and not tf.io.gfile.exists(directory): + tf.io.gfile.makedirs(directory) + multihost_utils.sync_devices('create-ckpt-dir') + ckpt_options = orbax.checkpoint.CheckpointManagerOptions( + save_interval_steps=every_steps, + max_to_keep=keep_last, + keep_period=keep_steps_multiple_of, + ) + ckpt_manager = orbax.checkpoint.CheckpointManager( + directory, + { + 'state': orbax.checkpoint.AsyncCheckpointer( + orbax.checkpoint.PyTreeCheckpointHandler(), + timeout_secs=wait_seconds, + ), + 'dataset_iterator': orbax.checkpoint.Checkpointer( + orbax.checkpoint.JsonCheckpointHandler() + ), + }, + options=ckpt_options, + ) + return ckpt_manager def create_evaluation_hook( @@ -301,10 +320,11 @@ def _initialize_fn(): def get_dataset_iterator( - dataset: DatasetIterator, prefetch_size: int, init_step: int, mesh: Mesh, - workdir: str): + dataset: DatasetIterator, prefetch_size: int, mesh: Mesh, + last_seen_index: Optional[int] = None): """Creates a dataset iterator with device prefetching.""" - del init_step, workdir + logging.warning("Your dataset iterator doesn't allow checkpointing!") + del last_seen_index return pjit_utils.prefetch_to_device(dataset, size=prefetch_size, mesh=mesh) @@ -350,19 +370,17 @@ def initialize(): def restore_or_create_train_state( *, - prefix: str, + ckpt_manager: orbax.checkpoint.CheckpointManager, initialize_fn: Callable[[], TrainState], axis_resources_regexes: partitioning.AxisResourcesRegexes, mesh: Optional[Mesh] = None, thread_pool: Optional[ThreadPool] = None, initialization_kwargs: Optional[Mapping[str, Any]] = None, -) -> TrainState: +) -> Tuple[TrainState, Optional[int]]: """Restores a TrainState from the latest complete checkpoint or creates one. Args: - prefix: Prefix used to find the checkpoint (e.g. '/tmp/ckpt'). This assumes - that checkpoints are partitioned. Thus, a complete checkpoint has files - such as '/tmp/ckpt_1.index' and '/tmp/ckpt_1.data-?????-of-?????'. + ckpt_manager: Checkpoint manager. initialize_fn: Function used to create and initialize a train state from scratch. axis_resources_regexes: Regular expressions specifying how the TrainState @@ -373,7 +391,7 @@ def restore_or_create_train_state( initialize the TrainState from an existing checkpoint. Returns: - A TrainState. + A TrainState and (optionally) the last_seen_index. """ mesh = mesh or maps.thread_resources.env.physical_mesh train_state_shape_dtype = jax.eval_shape(initialize_fn) @@ -388,15 +406,25 @@ def restore_or_create_train_state( lambda x, y: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=y), train_state_shape_dtype, train_state_axis_resources) - prefix = checkpoints.find_latest_complete_checkpoint_for_prefix( - prefix=prefix, suffixes=('.index', '.data')) - if prefix: - logging.info('Continue training from checkpoint prefix = %r', prefix) - # Restore train_state from checkpoints to CPU memory. - return checkpoints.restore_checkpoint_partitioned( - prefix=prefix, - tree=train_state, - thread_pool=thread_pool) + if (step := ckpt_manager.latest_step()) is not None: + logging.info('Continue training from checkpoint at step %d', step) + def _array_restore_args_fn(x: jax.ShapeDtypeStruct): + return orbax.checkpoint.ArrayRestoreArgs( + dtype=x.dtype, sharding=x.sharding, global_shape=x.shape) + restore_kwargs = { + 'state': { + 'restore_args': jax.tree_map(_array_restore_args_fn, train_state), + }, + } + items = ckpt_manager.restore( + step, + items={ + 'state': train_state, + 'dataset_iterator': {'last_seen_index': 0}, + }, + restore_kwargs=restore_kwargs) + return items['state'], items['dataset_iterator']['last_seen_index'] + if initialization_kwargs: logging.info('Partially initializing the TrainState: %r', initialization_kwargs) @@ -409,7 +437,7 @@ def restore_or_create_train_state( train_state_shape_dtype.params, include_stats=False, msg='Parameter overview:') return create_or_reuse_train_state( - train_state=train_state, initialize_fn=initialize_fn, mesh=mesh) + train_state=train_state, initialize_fn=initialize_fn, mesh=mesh), None def get_loss_fn(name: str, **kwargs): @@ -503,6 +531,9 @@ def initialize_train_state_from_checkpoint( elif name == 'initialize_from_vit': return initialization.initialize_from_vit(target=train_state, mesh=mesh, **kwargs) + elif name == 'initialize_from_orbax': + return initialization.initialize_from_orbax(target=train_state, mesh=mesh, + **kwargs) else: raise ValueError(f'Unknown initialization method: {name!r}') @@ -661,6 +692,8 @@ def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str, datataset_element_shape_dtype = pjit_utils.get_dataset_shape_dtype_struct( datasets['train']) + ckpt_manager = create_checkpoint_manager( + workdir=workdir, **config.get('save_checkpoint', {})) train_state_initialize_fn = make_create_train_state_fn( model=create_flax_model(config=config.model, deterministic=False), optimizer_config=config.optimizer, @@ -669,8 +702,8 @@ def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str, train_steps=train_steps, extra_rng_keys=tuple(config.get('extra_rng_keys', [])), seed=config.get('seed', 0)) - train_state = restore_or_create_train_state( - prefix=os.path.join(workdir, 'ckpt'), + train_state, last_seen_index = restore_or_create_train_state( + ckpt_manager=ckpt_manager, initialize_fn=train_state_initialize_fn, axis_resources_regexes=config.params_axis_resources, thread_pool=ThreadPool(), @@ -680,7 +713,8 @@ def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str, tr_iter = get_dataset_iterator( dataset=datasets['train'], prefetch_size=config.dataset.train.get('prefetch_device', 1), - init_step=init_step, mesh=mesh, workdir=workdir) + mesh=mesh, + last_seen_index=last_seen_index) train_loss_fn, eval_loss_fn, label_pred_fn = get_loss_fn(**config.loss) summarizer = create_tree_summarizer(config.get('summarize_arrays')) train_step_fn = functools.partial( @@ -715,9 +749,6 @@ def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str, progress_hook = create_progress_hook( writer=writer, first_step=init_step + 1, train_steps=train_steps, **config.get('report_progress', {})) - checkpoint_hook = create_checkpoint_hook( - workdir=workdir, progress_hook=progress_hook, - train_steps=train_steps, **config.get('save_checkpoint', {})) evaluation_hook, config_model_eval = create_evaluation_hook( base_model_config=config.model.copy_and_resolve_references(), writer=writer, @@ -739,8 +770,19 @@ def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str, **config.get('fewshot', {})) # Run checkpoint hook just before starting the loop. This will save the train # state at initialization. - if init_step == 0: - checkpoint_hook(init_step, state=train_state, iterator=tr_iter) + def _save_checkpoint(step, ts, it, force=False): + last_seen_index = step * train_batch_size + with progress_hook.timed('ckpt', wait_jax_async_dispatch=False): + ckpt_manager.save( + step, + items={ + 'state': ts, + 'dataset_iterator': {'last_seen_index': last_seen_index}, + }, + force=force) + if init_step == 0 and not tf.io.gfile.exists(os.path.join(workdir, 'ckpt/0')): + multihost_utils.sync_devices('training:ckpt-first') + _save_checkpoint(init_step, train_state, tr_iter, force=True) # Explicitly compile train_step here and report the compilation time. t0 = time.time() train_step_pjit = train_step_pjit.lower( @@ -756,9 +798,14 @@ def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str, batch['labels']) progress_hook( step, scalar_metrics={f'train/{k}': v for k, v in metrics.items()}) - checkpoint_hook(step, state=train_state, iterator=tr_iter) + _save_checkpoint(step, train_state, tr_iter) evaluation_hook(step, params=train_state.params) fewshot_hook(step, variables={'params': train_state.params}) + ckpt_manager.wait_until_finished() + if not tf.io.gfile.exists(os.path.join(workdir, f'ckpt/{train_steps}')): + multihost_utils.sync_devices('training:ckpt-last') + _save_checkpoint(train_steps, train_state, tr_iter, force=True) + ckpt_manager.wait_until_finished() multihost_utils.sync_devices('training:completed') logging.info('Training completed.') diff --git a/vmoe/train/trainer_test.py b/vmoe/train/trainer_test.py index 8843347..b304c54 100644 --- a/vmoe/train/trainer_test.py +++ b/vmoe/train/trainer_test.py @@ -14,7 +14,6 @@ """Tests for trainer.""" import functools -import os from unittest import mock from absl.testing import absltest @@ -29,6 +28,7 @@ import ml_collections import numpy as np import optax +import orbax.checkpoint import tensorflow as tf from vmoe.train import trainer @@ -192,8 +192,7 @@ class InitializeTrainStateFromCheckpointTest(absltest.TestCase): """Tests that the appropriate initialization functions are called.""" @mock.patch.object( - trainer.initialization, 'initialize_from_vmoe', - autospec=True) + trainer.initialization, 'initialize_from_vmoe', autospec=True) def test_initialize_from_vmoe( self, mock_initialize_from_vmoe): train_state = mock.create_autospec(trainer.TrainState, instance=True) @@ -206,8 +205,7 @@ def test_initialize_from_vmoe( rules=[]) @mock.patch.object( - trainer.initialization, 'initialize_from_vit', - autospec=True) + trainer.initialization, 'initialize_from_vit', autospec=True) def test_initialize_from_vit( self, mock_initialize_from_vit): train_state = mock.create_autospec(trainer.TrainState, instance=True) @@ -218,6 +216,17 @@ def test_initialize_from_vit( mock_initialize_from_vit.assert_called_once_with( target=train_state, mesh=mesh, filepath='/foo', rules=[]) + @mock.patch.object( + trainer.initialization, 'initialize_from_orbax', autospec=True) + def test_initialize_from_orbax(self, mock_initialize_from_orbax): + train_state = mock.create_autospec(trainer.TrainState, instance=True) + mesh = mock.create_autospec(jax.sharding.Mesh, instance=True) + _ = trainer.initialize_train_state_from_checkpoint( + train_state=train_state, name='initialize_from_orbax', mesh=mesh, + directory='/foo', rules=[]) + mock_initialize_from_orbax.assert_called_once_with( + target=train_state, mesh=mesh, directory='/foo', rules=[]) + def test_unknown_method_raises(self): train_state = mock.create_autospec(trainer.TrainState, instance=True) mesh = mock.create_autospec(jax.sharding.Mesh, instance=True) @@ -326,23 +335,23 @@ def initialize_fn(): def test_create_from_scratch(self): """Tests when training starts from scratch.""" - prefix = os.path.join(self.create_tempdir().full_path, 'ckpt_1') - train_state = trainer.restore_or_create_train_state( - prefix=prefix, initialize_fn=self.initialize_fn, - axis_resources_regexes=[], mesh=self.mesh, + ckpt_manager = mock.create_autospec(orbax.checkpoint.CheckpointManager, + instance=True) + ckpt_manager.latest_step.return_value = None + train_state, last_seen = trainer.restore_or_create_train_state( + ckpt_manager=ckpt_manager, + initialize_fn=self.initialize_fn, + axis_resources_regexes=[], + mesh=self.mesh, initialization_kwargs={}) chex.assert_trees_all_close(flax.core.unfreeze(train_state.params), { 'a': 1 * np.ones((5,), dtype=np.float32), 'b': 2 * np.ones((10,), dtype=np.float32), }) chex.assert_trees_all_equal(train_state.step, 0) + self.assertIsNone(last_seen) - @mock.patch.object(trainer.checkpoints, - 'find_latest_complete_checkpoint_for_prefix', - return_value='/foo/ckpt_1') - @mock.patch.object(trainer.checkpoints, 'restore_checkpoint_partitioned', - autospec=True) - def test_continue_training(self, mock_restore_checkpoint, _): + def test_continue_training(self): """Tests when training continues from an existing checkpoint.""" # Mock the call to restore_checkpoint_partitioned. def restore_checkpoint_side_effect(*args, **kwargs): @@ -356,19 +365,29 @@ def f(): }) return train_state with self.mesh: - return pjit.pjit(f, out_shardings=None)() - mock_restore_checkpoint.side_effect = restore_checkpoint_side_effect + state = pjit.pjit(f, out_shardings=None)() + return { + 'state': state, + 'dataset_iterator': {'last_seen_index': 16}, + } + ckpt_manager = mock.create_autospec(orbax.checkpoint.CheckpointManager, + instance=True) + ckpt_manager.latest_step.return_value = 3 + ckpt_manager.restore.side_effect = restore_checkpoint_side_effect # Call restore_or_create_train_state and check that the outputs are the # expected ones. - train_state = trainer.restore_or_create_train_state( - prefix='/foo/ckpt_1', initialize_fn=self.initialize_fn, - axis_resources_regexes=[], mesh=self.mesh, + train_state, last_seen = trainer.restore_or_create_train_state( + ckpt_manager=ckpt_manager, + initialize_fn=self.initialize_fn, + axis_resources_regexes=[], + mesh=self.mesh, initialization_kwargs={}) chex.assert_trees_all_close(train_state.params, { 'a': 3 * np.ones((5,), dtype=np.float32), 'b': 4 * np.ones((10,), dtype=np.float32), }) chex.assert_trees_all_equal(train_state.step, 3) + self.assertEqual(last_seen, 16) @mock.patch.object(trainer, 'initialize_train_state_from_checkpoint') def test_initialize_from_checkpoint(self, @@ -399,15 +418,21 @@ def initialize_train_state_from_ckpt_side_effect(*args, **kwargs): initialize_train_state_from_ckpt_side_effect) # Call restore_or_create_train_state and check that the outputs are the # expected ones. - train_state = trainer.restore_or_create_train_state( - prefix='/foo/ckpt_1', initialize_fn=self.initialize_fn, - axis_resources_regexes=[], mesh=self.mesh, + ckpt_manager = mock.create_autospec(orbax.checkpoint.CheckpointManager, + instance=True) + ckpt_manager.latest_step.return_value = None + train_state, last_seen = trainer.restore_or_create_train_state( + ckpt_manager=ckpt_manager, + initialize_fn=self.initialize_fn, + axis_resources_regexes=[], + mesh=self.mesh, initialization_kwargs={'foo': 'bar'}) chex.assert_trees_all_close(flax.core.unfreeze(train_state.params), { 'a': 1 * np.ones((5,), dtype=np.float32), 'b': 5 * np.ones((10,), dtype=np.float32), }) chex.assert_trees_all_equal(train_state.step, 0) + self.assertIsNone(last_seen) class TrainAndEvaluateTest(parameterized.TestCase):