Skip to content

Commit

Permalink
Update trainers to use Orbax checkpointing.
Browse files Browse the repository at this point in the history
The config files have also been updated since orbax computes the "wait time" differently.

PiperOrigin-RevId: 557456501
  • Loading branch information
jpuigcerver authored and copybara-github committed Aug 16, 2023
1 parent c2b4ce0 commit 760d672
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 395 deletions.
208 changes: 0 additions & 208 deletions vmoe/checkpoints/periodic_actions.py

This file was deleted.

118 changes: 0 additions & 118 deletions vmoe/checkpoints/periodic_actions_test.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def get_config():
config.save_checkpoint.every_steps = 1_000
config.save_checkpoint.keep_last = 1
config.save_checkpoint.num_shards = 32 # Target number of checkpoint shards.
config.save_checkpoint.wait_seconds = 1.0
config.save_checkpoint.wait_seconds = 300
# Report training progress every 100 steps.
config.report_progress = ml_collections.ConfigDict()
config.report_progress.every_secs = None
Expand Down
2 changes: 1 addition & 1 deletion vmoe/configs/vmoe_paper/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_base_config() -> ml_collections.ConfigDict:
config.save_checkpoint = ml_collections.ConfigDict()
config.save_checkpoint.every_steps = 1_000
config.save_checkpoint.keep_last = 1
config.save_checkpoint.wait_seconds = 1.0
config.save_checkpoint.wait_seconds = 300
# Report training progress every minute.
config.report_progress = ml_collections.ConfigDict()
config.report_progress.every_secs = None
Expand Down
6 changes: 6 additions & 0 deletions vmoe/data/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""
from typing import Any, Callable, Dict, Optional, Union

from absl import logging
from clu.data import dataset_iterator
import jax
import ml_collections
Expand Down Expand Up @@ -85,6 +86,11 @@ def get_dataset(
Returns:
A DatasetIterator.
"""
if variant == 'train' and shuffle_seed is not None:
logging.error('Deterministic training is not supported but you specified '
'shuffle_seed=%d for training. This can potentially lead to '
'data being repeated if restarts happen during training.',
shuffle_seed)
builder = vmoe.data.builder.get_dataset_builder(
name=name,
split=split,
Expand Down
2 changes: 1 addition & 1 deletion vmoe/projects/soft_moe/configs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_base_config() -> ml_collections.ConfigDict:
config.save_checkpoint = ml_collections.ConfigDict()
config.save_checkpoint.every_steps = 1_000
config.save_checkpoint.keep_last = 1
config.save_checkpoint.wait_seconds = 10
config.save_checkpoint.wait_seconds = 300
# Report training progress every minute to avoid hitting maximum RPC/s quota.
config.report_progress = ml_collections.ConfigDict()
config.report_progress.every_secs = 60.0
Expand Down
Loading

0 comments on commit 760d672

Please sign in to comment.