From 662341d007650d5bbb7c6a2bef7f3c759a20cc7e Mon Sep 17 00:00:00 2001 From: Joan Puigcerver Date: Tue, 15 Aug 2023 04:59:07 -0700 Subject: [PATCH] Release Soft MoE code. PiperOrigin-RevId: 557096170 --- vmoe/nn/vit_moe.py | 7 + vmoe/projects/soft_moe/README.md | 12 + vmoe/projects/soft_moe/configs/common.py | 264 ++++++++++++++++++ .../soft_moe/configs/pretrain_jft4b.py | 132 +++++++++ vmoe/projects/soft_moe/main.py | 20 ++ vmoe/projects/soft_moe/router.py | 6 +- vmoe/projects/soft_moe/router_test.py | 1 + 7 files changed, 441 insertions(+), 1 deletion(-) create mode 100644 vmoe/projects/soft_moe/README.md create mode 100644 vmoe/projects/soft_moe/configs/common.py create mode 100644 vmoe/projects/soft_moe/configs/pretrain_jft4b.py create mode 100644 vmoe/projects/soft_moe/main.py diff --git a/vmoe/nn/vit_moe.py b/vmoe/nn/vit_moe.py index d88c8ca..aeedae4 100644 --- a/vmoe/nn/vit_moe.py +++ b/vmoe/nn/vit_moe.py @@ -129,6 +129,7 @@ class MapHead(nn.Module): """Multihead Attention Pooling.""" mlp_dim: int num_heads: int + qk_norm: bool = False @nn.compact def __call__(self, x): @@ -140,6 +141,7 @@ def __call__(self, x): num_heads=self.num_heads, kernel_init=nn.initializers.xavier_uniform(), deterministic=True, + normalize_qk=self.qk_norm, name='MultiHeadDotProductAttention')(inputs_q=probe, inputs_kv=x) y = nn.LayerNorm(name='LayerNorm')(x) y = MlpBlock( @@ -154,6 +156,7 @@ class EncoderBlock(nn.Module): dtype: Optional[DType] = None dropout_rate: float = 0.0 attention_dropout_rate: float = 0.0 + attention_qk_norm: bool = False deterministic: bool = False @nn.compact @@ -166,6 +169,7 @@ def __call__(self, inputs): broadcast_dropout=False, deterministic=self.deterministic, dropout_rate=self.attention_dropout_rate, + normalize_qk=self.attention_qk_norm, num_heads=self.num_heads, name='SelfAttention')(inputs_q=x, inputs_kv=x) x = nn.Dropout(rate=self.dropout_rate, deterministic=self.deterministic)(x) @@ -214,6 +218,7 @@ class EncoderMoe(nn.Module): num_heads: int dropout_rate: float = 0.0 attention_dropout_rate: float = 0.0 + attention_qk_norm: bool = False moe: Optional[KwArgs] = None deterministic: bool = False dtype: Optional[DType] = None @@ -241,6 +246,7 @@ def __call__(self, inputs): num_heads=self.num_heads, dropout_rate=self.dropout_rate, attention_dropout_rate=self.attention_dropout_rate, + attention_qk_norm=self.attention_qk_norm, deterministic=self.deterministic, dtype=self.dtype) @@ -347,6 +353,7 @@ def __call__(self, inputs): elif self.classifier == 'map': x = MapHead( num_heads=self.encoder['num_heads'], mlp_dim=self.encoder['mlp_dim'], + qk_norm=self.encoder.get('attention_qk_norm', False), name='MapHead')(x) else: raise ValueError(f'Unknown classifier: {self.classifier!r}') diff --git a/vmoe/projects/soft_moe/README.md b/vmoe/projects/soft_moe/README.md new file mode 100644 index 0000000..2bc6aaf --- /dev/null +++ b/vmoe/projects/soft_moe/README.md @@ -0,0 +1,12 @@ +# From Sparse to Soft Mixture of Experts + +This folder contains the implementation of Soft MoE, presented in the paper: + +- [From Sparse to Soft Mixtures of Experts](https://arxiv.org/abs/2308.00951), + by Joan Puigcerver, Carlos Riquelme, Basil Mustafa, and Neil Houlsby. + +We provide the config files used to run some of the experiments reported in the +paper. + +Notice that all experiments either train on JFT-4B, a proprietary dataset, +or use models pre-trained on it, thus we cannot release any of the checkpoints. diff --git a/vmoe/projects/soft_moe/configs/common.py b/vmoe/projects/soft_moe/configs/common.py new file mode 100644 index 0000000..87f9021 --- /dev/null +++ b/vmoe/projects/soft_moe/configs/common.py @@ -0,0 +1,264 @@ +# Copyright 2023 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common configurations used in the Soft router experiments.""" +import math +from typing import Optional + +import ml_collections +from ml_collections import config_dict +from vmoe.configs import common_fewshot + +get_fewshot_config = common_fewshot.get_fewshot + + +def flatten_dict(config, prefix=''): + if isinstance(config, ml_collections.ConfigDict): + config = config.to_dict() + flat_dict = {} + for k, v in config.items(): + if isinstance(v, dict): + flat_dict.update(flatten_dict(v, prefix=f'{prefix}{k}.')) + else: + flat_dict[f'{prefix}{k}'] = v + return flat_dict + + +def get_base_config() -> ml_collections.ConfigDict: + """Returns the base config with options for saving checkpoints, profiling, etc.""" + config = ml_collections.ConfigDict() + # Write checkpoints every 1000 steps. + config.save_checkpoint = ml_collections.ConfigDict() + config.save_checkpoint.every_steps = 1_000 + config.save_checkpoint.keep_last = 1 + config.save_checkpoint.wait_seconds = 10 + # Report training progress every minute to avoid hitting maximum RPC/s quota. + config.report_progress = ml_collections.ConfigDict() + config.report_progress.every_secs = 60.0 + config.report_progress.every_steps = 250 + # Evaluate on the validation set every 1000 steps. + config.evaluate = ml_collections.ConfigDict() + config.evaluate.every_steps = 1_000 + # Run device profiling on process_index = 0, for 5 steps, starting at step 10. + # Then repeat profiling every hour. + config.profile = ml_collections.ConfigDict() + config.profile.all_processes = False + config.profile.num_profile_steps = 5 + config.profile.first_profile = 10 + config.profile.every_secs = 3600.0 + # Seed for generating random numbers. + config.seed = 0 + return config + + +def get_data_config( + name: str, + split: str, + process: str, + batch_size: int, + shuffle_buffer: Optional[int] = None, + cache: Optional[str] = None, + data_dir: Optional[str] = None, +) -> ml_collections.ConfigDict: + """Returns dataset parameters.""" + config = ml_collections.ConfigDict(type_safe=False) + config.name = name + config.split = split + config.process = process + config.batch_size = batch_size + config.prefetch = 'autotune' + config.prefetch_device = 2 + if shuffle_buffer: + config.shuffle_buffer = shuffle_buffer or config_dict.placeholder(int) + if cache: + config.cache = cache or config_dict.placeholder(str) + if data_dir: + config.data_dir = data_dir or config_dict.placeholder(str) + return config + + +def get_adam_config() -> ml_collections.ConfigDict: + config = ml_collections.ConfigDict(type_safe=False) + config.name = 'adam' + config.b1 = 0.9 + config.b2 = 0.999 + config.mu_dtype = 'float32' # Optionally, use bfloat16 to save memory. + config.weight_decay = ( + ('head/kernel', 3.0), + ('.*/kernel', 0.03), + ) + config.gradient_clip = ml_collections.ConfigDict({'global_norm': 1.0}) + return config + + +def get_optimizer_linear_config() -> ml_collections.ConfigDict: + """Returns optimizer parameters as in the "Scaling Vision Transformers" paper with linear LR decay.""" + config = get_adam_config() + # Parameters of the learning rate schedule. + config.learning_rate = ml_collections.ConfigDict() + config.learning_rate.schedule = 'warmup_linear_decay' + config.learning_rate.peak_value = 8e-4 + config.learning_rate.end_value = 0. + config.learning_rate.warmup_steps = 10_000 + return config + + +def get_optimizer_rsqrt_config() -> ml_collections.ConfigDict: + """Returns optimizer parameters as in the ViT 22b paper.""" + config = get_adam_config() + # Parameters of the learning rate schedule. + config.learning_rate = ml_collections.ConfigDict() + config.learning_rate.schedule = 'big_vision_rsqrt' + config.learning_rate.peak_value = 1e-3 + config.learning_rate.warmup_steps = 10_000 + config.learning_rate.cooldown_steps = 50_000 + config.learning_rate.timescale = 10_000 + return config + + +def get_imagenet_config( + batch_size: int, + resize_hi: int = 256, + resize_lo: int = 224, + randaug: str = '', + data_dir: Optional[str] = None, +) -> ml_collections.ConfigDict: + """Returns configuration for training/evaluating on ImageNet.""" + randaug = f'|{randaug}' if randaug and randaug[0] != '|' else randaug + # pylint: disable=line-too-long + pp_common_fn = lambda inkey: f'value_range(-1,1)|onehot(1000, inkey="{inkey}", outkey="labels")|keep("image", "labels")' + pp_train = f'decode_jpeg_and_inception_crop({resize_lo})|flip_lr{randaug}|{pp_common_fn("label")}' + pp_eval1 = f'decode|resize_small({resize_hi})|central_crop({resize_lo})|{pp_common_fn("label")}' + pp_eval2 = f'decode|resize_small({resize_hi})|central_crop({resize_lo})|ignore_no_labels(labels_key="real_label")|{pp_common_fn("real_label")}' + # pylint: enable=line-too-long + return ml_collections.ConfigDict({ + 'train': { + 'name': 'imagenet2012', + 'split': 'train[:99%]', + 'process': pp_train, + 'batch_size': batch_size, + 'data_dir': data_dir, + 'cache': 'loaded', + 'shuffle_buffer': 250_000, + }, + 'val': { + 'name': 'imagenet2012', + 'split': 'train[99%:]', + 'process': pp_eval1, + 'batch_size': batch_size, + 'data_dir': data_dir, + 'cache': 'batched', + }, + 'test': { + 'name': 'imagenet2012', + 'split': 'validation', + 'process': pp_eval1, + 'batch_size': batch_size, + 'data_dir': data_dir, + 'cache': 'batched', + }, + 'v2': { + 'name': 'imagenet_v2', + 'split': 'test', + 'process': pp_eval1, + 'batch_size': batch_size, + 'data_dir': data_dir, + 'cache': 'batched', + }, + 'real': { + 'name': 'imagenet2012_real', + 'split': 'validation', + 'process': pp_eval2, + 'batch_size': batch_size, + 'data_dir': data_dir, + 'cache': 'batched', + }, + }) + + +def get_vit_config( + variant: str, patch_size: int, num_classes: Optional[int], +) -> ml_collections.ConfigDict: + """Returns transformer parameters for different canonical architectures.""" + variant_idx = ['Ti', 'S', 'B', 'L', 'H'].index(variant) + return ml_collections.ConfigDict({ + 'name': 'VisionTransformerMoe', + 'num_classes': num_classes, + 'patch_size': (patch_size, patch_size), + 'hidden_size': [192, 384, 768, 1024, 1280][variant_idx], + 'classifier': 'gap', + 'head_bias_init': -math.log(num_classes) if num_classes else 0.0, + 'encoder': { + 'num_layers': [12, 12, 12, 24, 32][variant_idx], + 'mlp_dim': [768, 1536, 3072, 4096, 5120][variant_idx], + 'num_heads': [3, 6, 12, 16, 16][variant_idx], + 'dropout_rate': 0.0, + 'attention_dropout_rate': 0.0, + 'attention_qk_norm': True, + 'moe': {'layers': ()}, + }, + }, type_safe=False) + + +def get_vmoe_experts_choose_config( + variant: str, patch_size: int, num_classes: Optional[int], *, + image_size: int, num_experts: int, last_n: int, + capacity_factor: float = 1.0, +) -> ml_collections.ConfigDict: + """Returns a ViT model with MoE layers using the ExpertsChoose router.""" + config = get_vit_config(variant, patch_size, num_classes) + config.encoder.moe = ml_collections.ConfigDict({ + 'layers': tuple(range(config.encoder.num_layers))[-last_n:], + 'num_experts': num_experts, + 'group_size': (image_size // patch_size)**2, + 'split_rngs': False, + 'router': { + 'name': 'NoisyTopItemsPerExpertRouter', + 'noise_std': 1.0, + 'dispatcher': { + 'name': 'einsum', + 'bfloat16': True, + 'capacity_factor': capacity_factor, + # Note: this is what it's used in the soft router, so we change + # the defaults for a fair comparison. Otherwise, the actual + # capacity_factor can be significantly bigger. + 'capacity_ceil_or_round': 'round', + 'capacity_multiple_of': 1, + 'partition_spec': (('expert', 'replica'),), + }, + } + }) + return config + + +def get_vmoe_soft_router_config( + variant: str, patch_size: int, num_classes: Optional[int], *, + image_size: int, num_experts: int, last_n: int, + capacity_factor: Optional[float] = 1.0, num_slots: Optional[int] = None): + """Returns a ViT model with MoE layers using the Soft router.""" + config = get_vit_config(variant, patch_size, num_classes) + config.encoder.moe = ml_collections.ConfigDict({ + 'layers': tuple(range(config.encoder.num_layers))[-last_n:], + 'num_experts': num_experts, + 'group_size': (image_size // patch_size)**2, + 'split_rngs': False, + 'router': { + 'name': 'SoftRouter', + 'capacity_factor': capacity_factor, + 'num_slots': num_slots, + 'partition_spec': (('expert', 'replica'),), + 'compute_similarity_metrics': True, + } + }) + return config diff --git a/vmoe/projects/soft_moe/configs/pretrain_jft4b.py b/vmoe/projects/soft_moe/configs/pretrain_jft4b.py new file mode 100644 index 0000000..58fedeb --- /dev/null +++ b/vmoe/projects/soft_moe/configs/pretrain_jft4b.py @@ -0,0 +1,132 @@ +# Copyright 2023 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Train different models used in the Soft MoE paper. + +This includes the configs used for some of the "Long training runs" results. + +Notice that we pretrain on JFT-4B, a proprietary dataset by Google, which is not +available externally. Feel free to use this config file as a template and adapt +it to train on your favorite dataset. + +""" +# pylint: enable=line-too-long +import ml_collections +from vmoe.projects.soft_moe.configs import common + +BATCH_SIZE = 4096 +DATASET = 'jft4b' # This is a proprietary dataset by Google. +NUM_CLASSES = 29_593 +PP_COMMON = f'value_range(-1,1)|onehot({NUM_CLASSES})|keep("image", "labels")' + + +def get_default_moe_num_experts_and_last_n(variant, patch_size): + """Default number of experts and MoE layers for Sparse and Soft MoEs.""" + num_experts = { + ('S', 16): 128, + ('S', 14): 256, + ('B', 16): 128, + ('L', 16): 128, + ('H', 14): 256, + }[(variant, patch_size)] + last_n = { + 'S': 6, + 'B': 6, + 'L': 12, + 'H': 16, + }[variant] + return num_experts, last_n + + +def get_config(model='soft-s16') -> ml_collections.ConfigDict: + """Config to train different models used in the Soft MoE paper.""" + # Parse model argument. + model_type, model_backbone = model.split('-') + patch_size = int(model_backbone[1:]) + variant = model_backbone[0].upper() + + config = common.get_base_config() + config.dataset = ml_collections.ConfigDict() + config.dataset.train = common.get_data_config( + name=DATASET, + split='full[16384:]', + batch_size=BATCH_SIZE, + process=f'decode_jpeg_and_inception_crop(224)|flip_lr|{PP_COMMON}', + ) + config.dataset.val = common.get_data_config( + name=DATASET, + split='full[:16384]', + batch_size=BATCH_SIZE, + process=f'decode|resize_small(256)|central_crop(224)|{PP_COMMON}', + cache='batched', + ) + config.fewshot = common.get_fewshot_config( + batch_size=BATCH_SIZE, resize_resolution=256, target_resolution=224) + config.loss = ml_collections.ConfigDict({'name': 'sigmoid_xent'}) + config.optimizer = common.get_optimizer_rsqrt_config() + # Model hyperparameters depend on the model type. + if model_type == 'vit': + config.model = common.get_vit_config(variant, patch_size, NUM_CLASSES) + elif model_type == 'ec': + num_experts, last_n = get_default_moe_num_experts_and_last_n( + variant, patch_size) + config.model = common.get_vmoe_experts_choose_config( + variant, patch_size, NUM_CLASSES, image_size=224, + num_experts=num_experts, last_n=last_n, capacity_factor=1.0) + elif model_type == 'soft': + num_experts, last_n = get_default_moe_num_experts_and_last_n( + variant, patch_size) + config.model = common.get_vmoe_soft_router_config( + variant, patch_size, NUM_CLASSES, image_size=224, + num_experts=num_experts, last_n=last_n, capacity_factor=None, + num_slots=1) + config.model.encoder.moe.router.compute_similarity_metrics = False + config.optimizer.weight_decay = config.optimizer.weight_decay + ( + ('.*/Moe/Router/scale', 0.03), # SoftMoE doesn't have a kernel param. + ) + else: + raise ValueError(f'Unknown model type: {model_type!r}') + if variant == 'H': + config.train_steps = 2_000_000 + else: + config.train_steps = 4_000_000 + # These control how the train state is partitioned across the device mesh. + if model_type == 'vit': + config.num_expert_partitions = 1 + config.params_axis_resources = [] + else: + config.num_expert_partitions = config.model.encoder.moe.num_experts + config.params_axis_resources = [('Moe/Mlp/.*', ('expert',))] + config.extra_rng_keys = ('dropout', 'gating') + # Plot summary of different arrays. + config.summarize_arrays = ml_collections.ConfigDict({ + 'rules': [ + 'opt_state/.*/hyperparams/learning_rate', # Learning rate. + 'params/.*/Moe/Router/scale', # Soft MoE scale. + ], + # Maximum values reported per rule and array. + # If you are reporting individual values for every expert parameter, + # increase this accordingly. + 'max_summary_values': 1, + }) + # Keep checkpoints every 50k steps, useful to do intermediate cooldowns. + config.save_checkpoint.keep_last = 2 + config.save_checkpoint.keep_steps_multiple_of = 50_000 + return config + + +def get_hyper(hyper, model='soft-s16'): + del model + return hyper.product([]) diff --git a/vmoe/projects/soft_moe/main.py b/vmoe/projects/soft_moe/main.py new file mode 100644 index 0000000..aba33e6 --- /dev/null +++ b/vmoe/projects/soft_moe/main.py @@ -0,0 +1,20 @@ +# Copyright 2023 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Main training script for Soft MoE experiments.""" +from vmoe import app +from vmoe.train import trainer + +if __name__ == '__main__': + app.run(trainer.train_and_evaluate) diff --git a/vmoe/projects/soft_moe/router.py b/vmoe/projects/soft_moe/router.py index 79771c5..1fbae84 100644 --- a/vmoe/projects/soft_moe/router.py +++ b/vmoe/projects/soft_moe/router.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Soft router merging tokens as inputs/outputs of the experts.""" +"""Soft MoE implemented as a router merging tokens as inputs/outputs of the experts. + +Results using this algorithm presented in the paper: + - "From Sparse to Soft Mixture of Experts" (https://arxiv.org/abs/2308.00951). +""" from typing import Dict, Optional, Tuple from absl import logging diff --git a/vmoe/projects/soft_moe/router_test.py b/vmoe/projects/soft_moe/router_test.py index 4685b81..d6d8476 100644 --- a/vmoe/projects/soft_moe/router_test.py +++ b/vmoe/projects/soft_moe/router_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Tests for the Soft MoE implementation.""" from absl.testing import absltest import chex import flax.core