diff --git a/vmoe/projects/contrastive/evaluators.py b/vmoe/projects/contrastive/evaluators.py new file mode 100644 index 0000000..11ee1d9 --- /dev/null +++ b/vmoe/projects/contrastive/evaluators.py @@ -0,0 +1,169 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Evaluators used during contrastive training.""" +import time +from typing import Any, Callable, Iterable, Optional, Tuple + +from clu import metric_writers +from clu import periodic_actions +import jax + +# pylint: disable=g-import-not-at-top +try: + from big_vision.evaluators.proj.image_text import discriminative_classifier as bv_discriminative +except ImportError: + bv_discriminative = None + + +try: + from big_vision.evaluators.proj.image_text import retrieval as bv_retrieval +except ImportError: + bv_retrieval = None +# pylint: enable=g-import-not-at-top + +Array = jax.Array +PyTree = Any + + +class ZeroShotPeriodicAction(periodic_actions.PeriodicCallback): + """Periodic action that runs Big Vision's Retrieval evaluator repeatedly.""" + + def __init__( + self, + *, + metric_writer: metric_writers.MetricWriter, + apply_fn: Callable[..., Tuple[Array, Array, Any]], + data_sharding: jax.sharding.NamedSharding, + every_steps: Optional[int] = None, + every_secs: Optional[float] = None, + on_steps: Optional[Iterable[int]] = None, + report_progress: Optional[periodic_actions.ReportProgress] = None, + report_progress_name: str = 'zeroshot', + **bv_evaluator_kwargs, + ): + """Constructor.""" + if bv_discriminative is None: + raise NotImplementedError( + 'Big Vision must be installed to run the discriminative evaluation.') + bv_evaluator = bv_discriminative.Evaluator( + predict_fn=apply_fn, + devices=list(data_sharding.mesh.devices.flatten()), + **bv_evaluator_kwargs, + ) + callback = self._make_callback_fn( + evaluator=bv_evaluator, + metric_writer=metric_writer, + report_progress=report_progress, + report_progress_name=report_progress_name, + ) + super().__init__( + every_steps=every_steps, + every_secs=every_secs, + on_steps=on_steps, + callback_fn=callback, + execute_async=False, + pass_step_and_time=True) + + def _make_callback_fn( + self, *, evaluator, metric_writer, report_progress, + report_progress_name): + + def callback_fn(step: int, t: Optional[float], variables: PyTree, **kwargs): + del t # Unused. + metrics = {} + t0 = time.time() + for task in evaluator.datasets: + acc = evaluator.evaluate(variables, task)['accuracy'] + t1 = time.time() + metrics[f'{report_progress_name}/{task}/accuracy'] = acc + metrics[f'{report_progress_name}/{task}/duration_secs'] = t1 - t0 + metrics = metrics | {k: v for k, v in kwargs.items() if v is not None} + metric_writer.write_scalars(step, metrics) + + if report_progress is None: + return callback_fn + else: + return report_progress.timed( + report_progress_name, wait_jax_async_dispatch=False)(callback_fn) + + +class RetrievalPeriodicAction(periodic_actions.PeriodicCallback): + """Periodic action that runs Big Vision's Retrieval evaluator repeatedly.""" + + def __init__( + self, + *, + metric_writer: metric_writers.MetricWriter, + apply_fn: Callable[..., Tuple[Array, Array, Any]], + task: str, + data_sharding: jax.sharding.NamedSharding, + every_steps: Optional[int] = None, + every_secs: Optional[float] = None, + on_steps: Optional[Iterable[int]] = None, + report_progress: Optional[periodic_actions.ReportProgress] = None, + report_progress_name: str = 'retrieval', + **bv_evaluator_kwargs, + ): + """Constructor.""" + if bv_retrieval is None: + raise NotImplementedError( + 'Big Vision must be installed to run the retrieval evaluation.') + bv_evaluator = bv_retrieval.Evaluator( + predict_fn=apply_fn, + devices=list(data_sharding.mesh.devices.flatten()), + **bv_evaluator_kwargs, + ) + callback = self._make_callback_fn( + evaluator=bv_evaluator, + task=task, + metric_writer=metric_writer, + report_progress=report_progress, + report_progress_name=report_progress_name, + ) + super().__init__( + every_steps=every_steps, + every_secs=every_secs, + on_steps=on_steps, + callback_fn=callback, + execute_async=False, + pass_step_and_time=True) + + def _make_callback_fn( + self, *, evaluator, task, metric_writer, report_progress, + report_progress_name): + + def callback_fn(step: int, t: Optional[float], variables: PyTree, **kwargs): + del t # Unused. + metrics = {} + t0 = time.time() + bv_metrics = evaluator.evaluate(variables) + metrics.update({ + f'{report_progress_name}/{task}/txt2img/{k}': v + for k, v in bv_metrics['txt2img'].items() + }) + metrics.update({ + f'{report_progress_name}/{task}/img2txt/{k}': v + for k, v in bv_metrics['img2txt'].items() + }) + t1 = time.time() + metrics[f'{report_progress_name}/{task}/duration_secs'] = t1 - t0 + metrics = metrics | {k: v for k, v in kwargs.items() if v is not None} + metric_writer.write_scalars(step, metrics) + + if report_progress is None: + return callback_fn + else: + return report_progress.timed( + report_progress_name, wait_jax_async_dispatch=False)(callback_fn) diff --git a/vmoe/projects/contrastive/models.py b/vmoe/projects/contrastive/models.py new file mode 100644 index 0000000..bcf5764 --- /dev/null +++ b/vmoe/projects/contrastive/models.py @@ -0,0 +1,186 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Two tower model used for contrastive learning.""" +import functools +import sys +from typing import Any, Mapping, Literal, Optional, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +from vmoe import utils +from vmoe.nn import vit_moe + +Array = jax.Array + +_default_image_module = vit_moe +_default_text_module = sys.modules[__name__] + + +class TextTransformer(nn.Module): + """Text transformer similar to CLIP, allowing blocks with MoEs.""" + + # Differences to CLIP text encoder (gpt-2) that I am aware of: + # 1. https://imgur.com/HNi3jix (gpt-1) + # 2. https://imgur.com/qKGZgBR (gpt-2) + # 3. https://imgur.com/a/xrpYHF0 (clip) + # - LayerNorm is on res-path (like pre-activation resnet) + # - dropout 0.1 everywhere + # - init as var=0.02, scaled by depth + # - BOS and EOS tokens, take repr from EOS. + # - self-attention is autoregressively masked. + # - scaled in width only, with the image model. + vocab_size: int + num_classes: Optional[int] + hidden_size: int + encoder: Mapping[str, Any] + pool_type: Literal['last', 'first', 'gap', 'gmp', 'map'] = 'last' + deterministic: bool = False + head_bias_init: float = 0.0 + head_kernel_zero_init: bool = False + + @property + def kernel_init(self) -> nn.initializers.Initializer: + if self.head_kernel_zero_init: + return nn.initializers.zeros + else: + return nn.linear.default_kernel_init + + @nn.compact + def __call__(self, text): + # We can't use where/argwhere since the output shape is not fixed. + # Here we use the fact that sequences are padded with EOS tokens, that the + # EOS token has value 1, and that argmin returns the first index. + # eos_indices = jnp.argmin(text, axis=1) + + embedding = nn.Embed( + num_embeddings=self.vocab_size, features=self.hidden_size) + x = embedding(text) + + # TODO(jpuigcerver): Move position embedding outside of the Encoder class. + encoder_kwargs = dict(self.encoder) + if encoder_kwargs.get('position_emb', {}).get('name') == 'sincos2d': + raise ValueError( + 'sincos2d position embeddings are not supproted for text.') + + x, metrics = vit_moe.EncoderMoe( + name='Encoder', deterministic=self.deterministic, **encoder_kwargs)(x) + + x = self.apply_pooling(x) + + if self.num_classes: + # Linear head outputing the logits for classification. + logits = nn.Dense( + features=self.num_classes, + name='head', + kernel_init=self.kernel_init, + bias_init=nn.initializers.constant(self.head_bias_init))(x) + return logits, metrics + else: + return x, metrics + + @nn.nowrap + def apply_pooling(self, x): + match self.pool_type: + case 'last': return x[:, -1, :] + case 'first': return x[:, 0, :] + case 'gap': return x.mean(axis=1) + case 'gmp': return x.max(axis=1) + case 'map': + return vit_moe.MapHead( + num_heads=self.encoder['num_heads'], + mlp_dim=self.encoder['mlp_dim'], + qk_norm=self.encoder.get('attention_qk_norm', False), + name='MapHead')(x) + case _: + raise NotImplementedError(f'Cannot do pooling {self.pool_type!r}') + + +class TwoTower(nn.Module): + """A two-tower encoder model.""" + image: Mapping[str, Any] + text: Mapping[str, Any] + scale_init: float = 1.0 + bias_init: float | None = None + deterministic: bool = False + + @functools.cached_property + def image_model_class(self): + # Default model for the image encoder is a Vision Transformer with MoEs. + model_cls = self.image.get('name', 'VisionTransformerMoe') + model_cls, args, kwargs = utils.parse_call(model_cls, _default_image_module) + kwargs.update({k: v for k, v in self.image.items() if k != 'name'}) + return functools.partial( + model_cls, *args, **kwargs, deterministic=self.deterministic) + + @functools.cached_property + def text_model_class(self): + # Default model for the text encoder is a Text Transformer. + model_cls = self.text.get('name', 'TextTransformer') + model_cls, args, kwargs = utils.parse_call(model_cls, _default_text_module) + kwargs.update({k: v for k, v in self.text.items() if k != 'name'}) + return functools.partial( + model_cls, *args, **kwargs, deterministic=self.deterministic) + + @nn.compact + def __call__( + self, + images: Array | None, + texts: Array | None, + ) -> Tuple[Array, Mapping[str, Any]]: + if images is None and texts is None: + raise ValueError('You must give at least one of images or texts arrays.') + zimg, ztxt, metrics = None, None, {} + + if images is not None: + zimg, metrics_img = self.image_model_class(name='img')(images) + zimg_norm = jnp.linalg.norm(zimg, axis=-1, keepdims=True) + zimg /= zimg_norm + 1e-8 + self.sow('intermediates', 'zimg', zimg) + metrics['img'] = metrics_img + + if texts is not None: + ztxt, metrics_txt = self.text_model_class(name='txt')(texts) + ztxt_norm = jnp.linalg.norm(ztxt, axis=-1, keepdims=True) + ztxt /= ztxt_norm + 1e-8 + self.sow('intermediates', 'ztxt', ztxt) + metrics['txt'] = metrics_txt + + if images is None: + # Return text embeddings and metrics. + return ztxt, metrics + elif texts is None: + # Return image embeddings and metrics. + return zimg, metrics + else: + # Compute logits as the dot product of the image and text embeddings. + logits = jnp.einsum('...md,...nd->...mn', zimg, ztxt) + + # Note: Big Vision calls this "temperature", but it's actually + # 1/temperature, if one uses the standard definition of temperature. + scale_init = jnp.log(self.scale_init) + s = self.param('s', nn.initializers.constant(scale_init), + (), jnp.float32).astype(logits.dtype) + s = jnp.exp(s) + logits *= s + metrics['scale'] = s + + if self.bias_init is not None: + b = self.param('b', nn.initializers.constant(self.bias_init), + (), jnp.float32).astype(logits.dtype) + logits += b + + # Return the logits and the metrics. + return logits, metrics diff --git a/vmoe/projects/contrastive/models_test.py b/vmoe/projects/contrastive/models_test.py new file mode 100644 index 0000000..3b59265 --- /dev/null +++ b/vmoe/projects/contrastive/models_test.py @@ -0,0 +1,109 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +import chex +import jax +import jax.numpy as jnp +from vmoe.projects.contrastive import models + + +class TwoTowerTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.image_config = { + 'num_classes': self.output_dim, + 'patch_size': (2, 2), + 'hidden_size': 64, + 'classifier': 'gap', + 'encoder': { + 'num_layers': 2, + 'mlp_dim': 256, + 'num_heads': 2, + }, + 'head_kernel_zero_init': True, + } + self.text_config = { + 'num_classes': self.output_dim, + 'hidden_size': 64, + 'encoder': { + 'num_layers': 2, + 'mlp_dim': 256, + 'num_heads': 2, + }, + 'vocab_size': 128, + } + + @property + def output_dim(self) -> int: + return 32 + + def test(self): + """Tests initialization and forward pass.""" + batch_size, height, width, text_len = 4, 8, 8, 16 + model = models.TwoTower( + image=self.image_config, + text=self.text_config, + scale_init=2.0, + bias_init=1.0, + ) + + @jax.jit + def init_fn(): + images = jnp.zeros((batch_size, height, width, 3), dtype=jnp.float32) + texts = jnp.zeros((batch_size, text_len), dtype=jnp.int32) + return model.init({'params': jax.random.PRNGKey(0)}, images, texts) + + variables = init_fn() + self.assertIn('txt', variables['params']) + self.assertIn('img', variables['params']) + self.assertIn('s', variables['params']) + self.assertIn('b', variables['params']) + # Check shape and initial values for scale and bias params. + chex.assert_trees_all_close( + variables['params']['s'], jnp.log(jnp.asarray(2., dtype=jnp.float32))) + chex.assert_trees_all_close( + variables['params']['b'], jnp.asarray(1., dtype=jnp.float32)) + + @jax.jit + def forward(variables, images, text): + return model.apply(variables, images, text) + + # Forward with both images and text embeddings, logits' shape must be + # (batch_size, batch_size). + images = jnp.zeros((batch_size, height, width, 3), dtype=jnp.float32) + texts = jnp.zeros((batch_size, text_len), dtype=jnp.int32) + logits, _ = forward(variables, images, texts) + chex.assert_trees_all_equal_shapes_and_dtypes( + logits, + jax.ShapeDtypeStruct((batch_size, batch_size), jnp.float32)) + + # Forward only images: the output should be all 0s, since the image head + # kernel is initialized with 0. + zimg, _ = forward(variables, images, None) + chex.assert_trees_all_close( + zimg, jnp.zeros((batch_size, self.output_dim), jnp.float32)) + + # Forward only texts: the output should be different than 0s, since the text + # head kernel is NOT initialized with 0s. + ztxt, _ = forward(variables, None, texts) + chex.assert_trees_all_equal_shapes_and_dtypes( + ztxt, + jax.ShapeDtypeStruct((batch_size, self.output_dim), jnp.float32)) + self.assertGreater(jnp.abs(ztxt).sum(), 0.) + + +if __name__ == '__main__': + absltest.main() diff --git a/vmoe/projects/contrastive/trainer.py b/vmoe/projects/contrastive/trainer.py new file mode 100644 index 0000000..cb071b7 --- /dev/null +++ b/vmoe/projects/contrastive/trainer.py @@ -0,0 +1,414 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Classes and functions used for training (from-scratch and fine-tuning).""" +import functools +import multiprocessing.pool +import os +import time +from typing import Any, Callable, Mapping, Optional, Sequence, Tuple + +from absl import logging +from clu import metric_writers +import flax +import flax.serialization +import flax.training.train_state +import flax.traverse_util +import jax +import jax.numpy as jnp +import ml_collections +import tensorflow as tf +from vmoe import multihost_utils +from vmoe import partitioning +from vmoe import utils +from vmoe.data import input_pipeline +from vmoe.data import pjit_utils +from vmoe.evaluate import fewshot +from vmoe.projects.contrastive import evaluators +from vmoe.train import periodic_actions as train_periodic_actions +from vmoe.train import train_state as train_state_module +from vmoe.train import trainer +from vmoe.train import tree_summarizer + + +Array = jax.numpy.ndarray +DatasetIterator = input_pipeline.DatasetIterator +Mesh = partitioning.Mesh +ReportProgress = train_periodic_actions.ReportProgress +ThreadPool = multiprocessing.pool.ThreadPool +TrainState = train_state_module.TrainState +TreeSummarizer = tree_summarizer.TreeSummarizer + +accumulate_gradients_and_metrics = trainer.accumulate_gradients_and_metrics +create_checkpoint_manager = trainer.create_checkpoint_manager +create_flax_model = trainer.create_flax_model +create_profile_hook = trainer.create_profile_hook +create_progress_hook = trainer.create_progress_hook +create_tree_summarizer = trainer.create_tree_summarizer +get_dataset_iterator = trainer.get_dataset_iterator +get_train_steps_and_epochs = trainer.get_train_steps_and_epochs +make_create_train_state_fn = trainer.make_create_train_state_fn +make_train_cost_fn = trainer.make_train_cost_fn +override_base_config = trainer.override_base_config +restore_or_create_train_state = trainer.restore_or_create_train_state + + +def create_fewshot_hook( + *, + base_model_config: ml_collections.ConfigDict, + writer: metric_writers.MetricWriter, + progress_hook: ReportProgress, + first_step: int, + train_steps: int, + extra_rng_keys: Sequence[str], + model_overrides: Optional[ml_collections.ConfigDict] = None, + **kwargs) -> Callable[..., Any]: + """Returns a hook to run fewshot evaluation of a model periodically.""" + model_config = override_base_config(base_model_config, model_overrides) + # Few-shot eval requires additional mandatory parameters. If none of those is + # given, we assume that no few-shot eval should be done. + if not kwargs: + return (lambda *args, **kw: None) + model = create_flax_model( + config=model_config.to_dict(), deterministic=True) + # Apply function only embeds images. + apply_fn = lambda p, x, **kw: model.apply(p, images=x, texts=None, **kw) + on_steps = set(kwargs.pop('on_steps', [])) + # Always evaluate on the first and last step. + on_steps.update([first_step, train_steps]) + periodic_action = fewshot.FewShotPeriodicAction( + metric_writer=writer, + apply_fn=apply_fn, + rng_keys=extra_rng_keys, + report_progress=progress_hook, + report_progress_name='fewshot', + on_steps=on_steps, + **kwargs) + return periodic_action + + +def create_retrieval_hook( + *, + base_model_config: ml_collections.ConfigDict, + writer: metric_writers.MetricWriter, + progress_hook: ReportProgress, + first_step: int, + train_steps: int, + every_steps: Optional[int] = None, + every_secs: Optional[int] = None, + datasets: Optional[Mapping[str, Mapping[str, Any]]] = None, + model_overrides: Optional[ml_collections.ConfigDict] = None, + data_sharding: jax.sharding.NamedSharding, + **kwargs) -> Callable[..., Any]: + """Returns a hook to run retrieval evaluation of a model periodically.""" + model_config = override_base_config(base_model_config, model_overrides) + model = create_flax_model( + config=model_config.to_dict(), deterministic=True) + # Always evaluate on the first and last step. + on_steps = set(kwargs.pop('on_steps', [])) + on_steps.update([first_step, train_steps]) + + # Make the apply_fn function conform with Big Vision's evaluator expected + # inputs and outputs. + def apply_fn(v, input_dict): + img = input_dict.get('image') + txt = input_dict.get('labels') + if (img is None) == (txt is None): + raise ValueError('One and only of images or text must be None.') + z, _ = model.apply(v, images=img, texts=txt) + return (None, z, None) if img is None else (z, None, None) + + datasets = datasets or {} + if isinstance(datasets, ml_collections.ConfigDict): + datasets = datasets.to_dict() + try: + # Instantiate hooks for each of the tasks to evaluate. + hooks = [ + evaluators.RetrievalPeriodicAction( + metric_writer=writer, + apply_fn=apply_fn, + task=task, + data_sharding=data_sharding, + every_steps=every_steps, + every_secs=every_secs, + on_steps=on_steps, + report_progress=progress_hook, + **kwargs, + **bv_kw) + for task, bv_kw in datasets.items() + ] + def periodic_action(*a, **kw): + for hook in hooks: + hook(*a, **kw) + return periodic_action + except NotImplementedError as e: + logging.warning('%s', str(e)) + return (lambda *a, **kw: None) + + +def create_zeroshot_hook( + *, + base_model_config: ml_collections.ConfigDict, + writer: metric_writers.MetricWriter, + progress_hook: ReportProgress, + first_step: int, + train_steps: int, + every_steps: Optional[int] = None, + every_secs: Optional[int] = None, + datasets: Optional[Mapping[str, Mapping[str, Any]]] = None, + model_overrides: Optional[ml_collections.ConfigDict] = None, + data_sharding: jax.sharding.NamedSharding, + **kwargs) -> Callable[..., Any]: + """Returns a hook to run zeroshot evaluation of a model periodically.""" + model_config = override_base_config(base_model_config, model_overrides) + model = create_flax_model( + config=model_config.to_dict(), deterministic=True) + # Always evaluate on the first and last step. + on_steps = set(kwargs.pop('on_steps', [])) + on_steps.update([first_step, train_steps]) + + # Make the apply_fn function conform with Big Vision's evaluator expected + # inputs and outputs. + def apply_fn(v, input_dict): + img = input_dict.get('image') + txt = input_dict.get('labels') + if (img is None) == (txt is None): + raise ValueError('One and only of images or text must be None.') + z, _ = model.apply(v, images=img, texts=txt) + return (None, z, None) if img is None else (z, None, None) + + datasets = datasets or {} + if isinstance(datasets, ml_collections.ConfigDict): + datasets = datasets.to_dict() + if not datasets: + return (lambda *a, **kw: None) + + try: + return evaluators.ZeroShotPeriodicAction( + metric_writer=writer, + apply_fn=apply_fn, + data_sharding=data_sharding, + every_steps=every_steps, + every_secs=every_secs, + on_steps=on_steps, + report_progress=progress_hook, + dataset_names=tuple(datasets.keys()), + dataset_overrides=datasets, + **kwargs) + except NotImplementedError as e: + logging.warning('%s', str(e)) + return (lambda *a, **kw: None) + + +def sigmoid_loss(logits: Array): + if logits.ndim < 2 or logits.shape[-1] != logits.shape[-2]: + raise ValueError( + f'Last two dims of logits must be equal, but got {logits.shape=}') + # SigLIP loss, as described in https://arxiv.org/pdf/2303.15343.pdf. + # Positives are in the diagonal, negatives are off-diagonal. + z = 2. * jnp.eye(logits.shape[-1], dtype=logits.dtype) - 1. + log_lkh = jax.nn.log_sigmoid(jnp.einsum('...mn,mn->...mn', logits, z)) + # Normalize by npos per column, but that's one, so just sum. + return -jnp.sum(log_lkh, axis=-1) + + +def train_step( + state: TrainState, + images: Array, + texts: Array, + loss_fn: Callable[[Array], Array], + microsteps: Optional[int] = None, + summarizer: Optional[TreeSummarizer] = None, +) -> Tuple[TrainState, Mapping[str, Any]]: + """Performs one update step of the given TrainState object .""" + + @functools.partial(jax.grad, has_aux=True) + def compute_grads_and_metrics(params, images, texts, rngs): + rngs, next_rngs = utils.tree_rngs_split(rngs) + logits, metrics = state.apply_fn( + {'params': params}, images, texts, rngs=rngs) + metrics = dict(**metrics) + metrics['main_loss'] = jnp.mean(loss_fn(logits)) + metrics = jax.tree_util.tree_map(jnp.mean, metrics) + total_loss = metrics['main_loss'] + metrics.get('auxiliary_loss', 0.0) + metrics['total_loss'] = total_loss + return total_loss, (next_rngs, metrics) + + compute_grads_and_metrics = accumulate_gradients_and_metrics( + compute_grads_and_metrics, microsteps) + grads, (next_rngs, metrics) = compute_grads_and_metrics( + state.params, images, texts, state.rngs) + state, global_norms = state.apply_gradients_and_compute_global_norms( + grads, rngs=next_rngs) + metrics.update({f'global_norm/{k}': v for k, v in global_norms.items()}) + + if summarizer: + # Summarize arrays in the gradients tree or the train state. + state_flat = flax.traverse_util.flatten_dict( + flax.serialization.to_state_dict(state), sep='/') + state_flat['params_grads'] = flax.traverse_util.flatten_dict(grads, sep='/') + metrics.update(summarizer(state_flat)) + + return state, metrics + + +def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str, + mesh: Mesh, writer: metric_writers.MetricWriter): + """Trains a model and evaluates it periodically.""" + datasets = input_pipeline.get_datasets(config.dataset) + if 'train' not in datasets: + raise KeyError(f'You must have a "train" variant of the dataset. ' + f'Available variants are {sorted(datasets.keys())!r}') + train_examples = input_pipeline.get_data_num_examples(config.dataset.train) + train_batch_size = config.dataset.train.batch_size + train_steps, train_epochs = get_train_steps_and_epochs( + train_steps=config.get('train_steps'), + train_epochs=config.get('train_epochs'), + train_batch_size=train_batch_size, + train_examples=train_examples) + logging.info( + 'Training for %d steps (%g epochs) over %d examples, with a ' + 'batch size of %d', train_steps, train_epochs, train_examples, + train_batch_size) + + # Get the global shape of the image array. + dataset_element_shape_dtype = pjit_utils.get_dataset_shape_dtype_struct( + datasets['train']) + + ckpt_manager = create_checkpoint_manager( + workdir=workdir, **config.get('save_checkpoint', {})) + train_state_initialize_fn = make_create_train_state_fn( + model=create_flax_model(config=config.model, deterministic=False), + optimizer_config=config.optimizer, + input_shape_dtypes=(dataset_element_shape_dtype['image'], + dataset_element_shape_dtype['text']), + train_steps=train_steps, + extra_rng_keys=tuple(config.get('extra_rng_keys', [])), + seed=config.get('seed', 0)) + train_state, last_seen_index = restore_or_create_train_state( + ckpt_manager=ckpt_manager, + initialize_fn=train_state_initialize_fn, + axis_resources_regexes=config.params_axis_resources, + thread_pool=ThreadPool(), + initialization_kwargs=config.get('initialization')) + init_step = int(train_state.step) + logging.info('Initial step = %d', init_step) + tr_iter = get_dataset_iterator( + dataset=datasets['train'], + prefetch_size=config.dataset.train.get('prefetch_device', 1), + mesh=mesh, + last_seen_index=last_seen_index) + summarizer = create_tree_summarizer(config.get('summarize_arrays')) + train_step_fn = functools.partial( + train_step, + loss_fn=sigmoid_loss, + microsteps=config.get('microsteps'), + summarizer=summarizer) + + train_step_pjit = jax.jit( + fun=train_step_fn, + out_shardings=( + jax.tree_util.tree_map(lambda x: x.sharding, train_state), + None, + ), + donate_argnums=(0, 1, 2), + ) + + # Setup hooks. + profile_hook = create_profile_hook( + workdir=workdir, **config.get('profile', {})) + progress_hook = create_progress_hook( + writer=writer, first_step=init_step + 1, train_steps=train_steps, + **config.get('report_progress', {})) + fewshot_hook = create_fewshot_hook( + base_model_config=config.model.copy_and_resolve_references(), + writer=writer, + progress_hook=progress_hook, + first_step=init_step + 1, + train_steps=train_steps, + extra_rng_keys=config.get('extra_rng_keys', []), + **config.get('fewshot', {})) + retrieval_hook = create_retrieval_hook( + data_sharding=dataset_element_shape_dtype['image'].sharding, + base_model_config=config.model.copy_and_resolve_references(), + writer=writer, + progress_hook=progress_hook, + first_step=init_step + 1, + train_steps=train_steps, + **config.get('retrieval', {})) + zeroshot_hook = create_zeroshot_hook( + data_sharding=dataset_element_shape_dtype['image'].sharding, + base_model_config=config.model.copy_and_resolve_references(), + writer=writer, + progress_hook=progress_hook, + first_step=init_step + 1, + train_steps=train_steps, + **config.get('zeroshot', {})) + # Run checkpoint hook just before starting the loop. This will save the train + # state at initialization. + def _save_checkpoint(step, ts, it, force=False): + last_seen_index = step * train_batch_size + with progress_hook.timed('ckpt', wait_jax_async_dispatch=False): + ckpt_manager.save( + step, + items={ + 'state': ts, + 'dataset_iterator': {'last_seen_index': last_seen_index}, + }, + force=force) + if init_step == 0 and not tf.io.gfile.exists(os.path.join(workdir, 'ckpt/0')): + multihost_utils.sync_devices('training:ckpt-first') + _save_checkpoint(init_step, train_state, tr_iter, force=True) + # Explicitly compile train_step here and report the compilation time. + t0 = time.time() + train_step_pjit = train_step_pjit.lower( + train_state, + dataset_element_shape_dtype['image'], + dataset_element_shape_dtype['text']).compile() + t1 = time.time() + # Report compilation time, and flops and optimal seconds per step and device. + writer.write_scalars(init_step + 1, {'train/compile_secs': t1 - t0}) + train_step_flops_per_device, train_step_seconds_per_device = ( + utils.get_flops_and_seconds_per_device(train_step_pjit)) + if train_step_flops_per_device: + writer.write_scalars( + init_step + 1, + {'train/step_flops_per_device': train_step_flops_per_device}) + if train_step_seconds_per_device: + writer.write_scalars( + init_step + 1, + {'train/step_seconds_per_device': train_step_seconds_per_device}) + train_cost_fn = make_train_cost_fn(train_step_pjit) + for step, batch in zip(range(init_step + 1, train_steps + 1), tr_iter): + profile_hook(step) + with jax.profiler.StepTraceAnnotation('train', step_num=step): + train_state, metrics = train_step_pjit(train_state, batch['image'], + batch['text']) + progress_hook(step, scalar_metrics=( + train_cost_fn(step) | {f'train/{k}': v for k, v in metrics.items()} + )) + _save_checkpoint(step, train_state, tr_iter) + fewshot_hook(step, variables={'params': train_state.params}, + **train_cost_fn(step)) + retrieval_hook(step, variables={'params': train_state.params}, + **train_cost_fn(step)) + zeroshot_hook(step, variables={'params': train_state.params}, + **train_cost_fn(step)) + ckpt_manager.wait_until_finished() + if not tf.io.gfile.exists(os.path.join(workdir, f'ckpt/{train_steps}')): + multihost_utils.sync_devices('training:ckpt-last') + _save_checkpoint(train_steps, train_state, tr_iter, force=True) + ckpt_manager.wait_until_finished() + multihost_utils.sync_devices('training:completed') + logging.info('Training completed.') diff --git a/vmoe/projects/soft_moe/README.md b/vmoe/projects/soft_moe/README.md index 2bc6aaf..a5d624a 100644 --- a/vmoe/projects/soft_moe/README.md +++ b/vmoe/projects/soft_moe/README.md @@ -8,5 +8,12 @@ This folder contains the implementation of Soft MoE, presented in the paper: We provide the config files used to run some of the experiments reported in the paper. -Notice that all experiments either train on JFT-4B, a proprietary dataset, +Notice that most experiments either train on JFT-4B, a proprietary dataset, or use models pre-trained on it, thus we cannot release any of the checkpoints. +We have released the config file used to train on JFT-4B from scratch, for +reference. + +We have also included a config file to pretrain on LAION-400M, which is a +publicly available dataset. This can be used replicate the experiments that we +conducted on this dataset and are reported in the paper. Note, however, that we +are not planning on releasing any checkpoint trained in this dataset. diff --git a/vmoe/projects/soft_moe/configs/pretrain_laion.py b/vmoe/projects/soft_moe/configs/pretrain_laion.py new file mode 100644 index 0000000..bc3c643 --- /dev/null +++ b/vmoe/projects/soft_moe/configs/pretrain_laion.py @@ -0,0 +1,188 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# pylint: disable=line-too-long +r"""Train different models used in the Soft MoE paper on the LAION dataset. + +""" +# pylint: enable=line-too-long +import ml_collections +from vmoe.projects.soft_moe.configs import common + +BATCH_SIZE = 16_384 +DATASET = 'laion400m' + + +def get_default_moe_num_experts_and_last_n(variant, patch_size): + """Default number of experts and MoE layers for Sparse and Soft MoEs.""" + num_experts = { + ('S', 16): 128, + ('S', 14): 256, + ('B', 16): 128, + ('L', 16): 128, + ('H', 14): 256, + }[(variant, patch_size)] + last_n = { + 'S': 6, + 'B': 6, + 'L': 12, + 'H': 16, + }[variant] + return num_experts, last_n + + +# pylint: disable=line-too-long +def tokenize(inkey: str, outkey: str = 'text') -> str: + return f'tokenize(max_len=16, model="c4_en", eos="sticky", inkey="{inkey}", outkey="{outkey}", pad_value=1)' +# pylint: enable=line-too-long + + +def get_config(model='soft-s16') -> ml_collections.ConfigDict: + """Config to train different models used in the Soft MoE paper.""" + # Parse model argument. + model_type, model_backbone = model.split('-') + patch_size = int(model_backbone[1:]) + variant = model_backbone[0].upper() + + # SoftMoEs highly benefit from data augmentation, while ViTs and MoEs with + # Experts Choice routing actually do worse. See Figure 15 in the paper. + # + if model_type in ('vit', 'ec'): + process_str = 'decode|resize(256)' + else: + process_str = 'decode_jpeg_and_inception_crop(256)' + + config = common.get_base_config() + config.dataset = ml_collections.ConfigDict() + config.dataset.train = common.get_data_config( + name=DATASET, + split='full[16384:]', + batch_size=BATCH_SIZE, + process=( + f'{process_str}|value_range(-1,1)|flatten|' + f'{tokenize("text")}|keep("image", "text")' + ), + shuffle_buffer=250_000, + ) + config.fewshot = common.get_fewshot_config( + batch_size=1_024, resize_resolution=292, target_resolution=256, + every_steps=10_000, seeds_per_step=3) + config.fewshot.model_overrides = ml_collections.ConfigDict() + config.retrieval = ml_collections.ConfigDict({ + 'batch_size': 1_024, + 'every_steps': 10_000, + 'datasets': { + 'coco': { + 'dataset': 'coco_captions', + 'txt_name': ('captions', 'text'), + 'pp_img': 'resize(256)|value_range(-1, 1)', + 'pp_txt': f'{tokenize(inkey="texts", outkey="labels")}', + }, + 'flickr': { + 'dataset': 'argus:flickr30k/captions', + 'txt_name': 'texts', + 'pp_img': 'resize(256)|value_range(-1, 1)', + 'pp_txt': f'{tokenize(inkey="texts", outkey="labels")}', + }, + } + }) + config.zeroshot = ml_collections.ConfigDict({ + 'batch_size': 1_024, + 'every_steps': 10_000, + 'pp_img': 'resize(256)|value_range(-1, 1)', + 'pp_txt': f'{tokenize(inkey="texts", outkey="labels")}', + 'datasets': { + 'cifar100': {}, + 'imagenet2012': {'class_names': 'clip', 'split': 'validation'}, + 'oxford_iiit_pet': {}, + }, + }) + + # Optimizer configuration. + config.optimizer = common.get_optimizer_rsqrt_config() + config.optimizer.weight_decay = (('.*/kernel', 0.1),) + config.optimizer.learning_rate.warmup_steps = 20_000 + config.optimizer.learning_rate.cooldown_steps = 20_000 + config.train_steps = 750_000 + + config.model = ml_collections.ConfigDict({ + 'name': 'vmoe.projects.contrastive.models.TwoTower', + 'bias_init': -10.0, + 'scale_init': 10.0, + }) + + # Image encoder hyperparameters depend on the model type. + if model_type == 'vit': + config.model.image = common.get_vit_config(variant, patch_size, None) + elif model_type == 'ec': + num_experts, last_n = get_default_moe_num_experts_and_last_n( + variant, patch_size) + config.model.image = common.get_vmoe_experts_choose_config( + variant, patch_size, None, image_size=256, + num_experts=num_experts, last_n=last_n, capacity_factor=1.0) + elif model_type == 'soft': + num_experts, last_n = get_default_moe_num_experts_and_last_n( + variant, patch_size) + config.model.image = common.get_vmoe_soft_router_config( + variant, patch_size, None, image_size=256, + num_experts=num_experts, last_n=last_n, capacity_factor=None, + num_slots=1) + config.model.image.encoder.moe.router.compute_similarity_metrics = False + else: + raise ValueError(f'Unknown model type: {model_type!r}') + + # Text encoder is a B size model. + config.model.text = ml_collections.ConfigDict({ + 'vocab_size': 32_000, + 'num_classes': config.model.image.hidden_size, + 'hidden_size': 768, + 'encoder': { + 'num_layers': 12, + 'mlp_dim': 3072, + 'num_heads': 12, + 'dropout_rate': 0.0, + 'attention_dropout_rate': 0.0, + 'attention_qk_norm': True, + 'moe': {'layers': ()}, + } + }) + + # These control how the train state is partitioned across the device mesh. + if model_type == 'vit': + config.num_expert_partitions = 1 + config.params_axis_resources = [] + else: + config.num_expert_partitions = config.model.image.encoder.moe.num_experts + config.params_axis_resources = [('Moe/Mlp/.*', ('expert',))] + config.extra_rng_keys = ('dropout', 'gating') + # Plot summary of different arrays. + config.summarize_arrays = ml_collections.ConfigDict({ + 'rules': [ + 'opt_state/.*/hyperparams/learning_rate', # Learning rate. + 'params/.*/Moe/Router/scale', # Soft MoE scale. + ], + # Maximum values reported per rule and array. + # If you are reporting individual values for every expert parameter, + # increase this accordingly. + 'max_summary_values': 1, + }) + # Keep checkpoints every 50k steps, useful to do intermediate cooldowns. + config.save_checkpoint.keep_last = 2 + config.save_checkpoint.keep_steps_multiple_of = 50_000 + return config + + +def get_hyper(hyper, model='soft-s16'): + del model + return hyper.product([]) diff --git a/vmoe/projects/soft_moe/main_contrastive.py b/vmoe/projects/soft_moe/main_contrastive.py new file mode 100644 index 0000000..beffc62 --- /dev/null +++ b/vmoe/projects/soft_moe/main_contrastive.py @@ -0,0 +1,20 @@ +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Main training script for Soft MoE experiments.""" +from vmoe import app +from vmoe.projects.contrastive import trainer + +if __name__ == '__main__': + app.run(trainer.train_and_evaluate)