Skip to content

Commit

Permalink
Release config and code necessary to pretrain on LAION-400M.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 622146719
  • Loading branch information
jpuigcerver authored and copybara-github committed Apr 5, 2024
1 parent f0a8702 commit 28f0b0b
Show file tree
Hide file tree
Showing 7 changed files with 1,094 additions and 1 deletion.
169 changes: 169 additions & 0 deletions vmoe/projects/contrastive/evaluators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Evaluators used during contrastive training."""
import time
from typing import Any, Callable, Iterable, Optional, Tuple

from clu import metric_writers
from clu import periodic_actions
import jax

# pylint: disable=g-import-not-at-top
try:
from big_vision.evaluators.proj.image_text import discriminative_classifier as bv_discriminative
except ImportError:
bv_discriminative = None


try:
from big_vision.evaluators.proj.image_text import retrieval as bv_retrieval
except ImportError:
bv_retrieval = None
# pylint: enable=g-import-not-at-top

Array = jax.Array
PyTree = Any


class ZeroShotPeriodicAction(periodic_actions.PeriodicCallback):
"""Periodic action that runs Big Vision's Retrieval evaluator repeatedly."""

def __init__(
self,
*,
metric_writer: metric_writers.MetricWriter,
apply_fn: Callable[..., Tuple[Array, Array, Any]],
data_sharding: jax.sharding.NamedSharding,
every_steps: Optional[int] = None,
every_secs: Optional[float] = None,
on_steps: Optional[Iterable[int]] = None,
report_progress: Optional[periodic_actions.ReportProgress] = None,
report_progress_name: str = 'zeroshot',
**bv_evaluator_kwargs,
):
"""Constructor."""
if bv_discriminative is None:
raise NotImplementedError(
'Big Vision must be installed to run the discriminative evaluation.')
bv_evaluator = bv_discriminative.Evaluator(
predict_fn=apply_fn,
devices=list(data_sharding.mesh.devices.flatten()),
**bv_evaluator_kwargs,
)
callback = self._make_callback_fn(
evaluator=bv_evaluator,
metric_writer=metric_writer,
report_progress=report_progress,
report_progress_name=report_progress_name,
)
super().__init__(
every_steps=every_steps,
every_secs=every_secs,
on_steps=on_steps,
callback_fn=callback,
execute_async=False,
pass_step_and_time=True)

def _make_callback_fn(
self, *, evaluator, metric_writer, report_progress,
report_progress_name):

def callback_fn(step: int, t: Optional[float], variables: PyTree, **kwargs):
del t # Unused.
metrics = {}
t0 = time.time()
for task in evaluator.datasets:
acc = evaluator.evaluate(variables, task)['accuracy']
t1 = time.time()
metrics[f'{report_progress_name}/{task}/accuracy'] = acc
metrics[f'{report_progress_name}/{task}/duration_secs'] = t1 - t0
metrics = metrics | {k: v for k, v in kwargs.items() if v is not None}
metric_writer.write_scalars(step, metrics)

if report_progress is None:
return callback_fn
else:
return report_progress.timed(
report_progress_name, wait_jax_async_dispatch=False)(callback_fn)


class RetrievalPeriodicAction(periodic_actions.PeriodicCallback):
"""Periodic action that runs Big Vision's Retrieval evaluator repeatedly."""

def __init__(
self,
*,
metric_writer: metric_writers.MetricWriter,
apply_fn: Callable[..., Tuple[Array, Array, Any]],
task: str,
data_sharding: jax.sharding.NamedSharding,
every_steps: Optional[int] = None,
every_secs: Optional[float] = None,
on_steps: Optional[Iterable[int]] = None,
report_progress: Optional[periodic_actions.ReportProgress] = None,
report_progress_name: str = 'retrieval',
**bv_evaluator_kwargs,
):
"""Constructor."""
if bv_retrieval is None:
raise NotImplementedError(
'Big Vision must be installed to run the retrieval evaluation.')
bv_evaluator = bv_retrieval.Evaluator(
predict_fn=apply_fn,
devices=list(data_sharding.mesh.devices.flatten()),
**bv_evaluator_kwargs,
)
callback = self._make_callback_fn(
evaluator=bv_evaluator,
task=task,
metric_writer=metric_writer,
report_progress=report_progress,
report_progress_name=report_progress_name,
)
super().__init__(
every_steps=every_steps,
every_secs=every_secs,
on_steps=on_steps,
callback_fn=callback,
execute_async=False,
pass_step_and_time=True)

def _make_callback_fn(
self, *, evaluator, task, metric_writer, report_progress,
report_progress_name):

def callback_fn(step: int, t: Optional[float], variables: PyTree, **kwargs):
del t # Unused.
metrics = {}
t0 = time.time()
bv_metrics = evaluator.evaluate(variables)
metrics.update({
f'{report_progress_name}/{task}/txt2img/{k}': v
for k, v in bv_metrics['txt2img'].items()
})
metrics.update({
f'{report_progress_name}/{task}/img2txt/{k}': v
for k, v in bv_metrics['img2txt'].items()
})
t1 = time.time()
metrics[f'{report_progress_name}/{task}/duration_secs'] = t1 - t0
metrics = metrics | {k: v for k, v in kwargs.items() if v is not None}
metric_writer.write_scalars(step, metrics)

if report_progress is None:
return callback_fn
else:
return report_progress.timed(
report_progress_name, wait_jax_async_dispatch=False)(callback_fn)
186 changes: 186 additions & 0 deletions vmoe/projects/contrastive/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Copyright 2024 Google LLC.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Two tower model used for contrastive learning."""
import functools
import sys
from typing import Any, Mapping, Literal, Optional, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
from vmoe import utils
from vmoe.nn import vit_moe

Array = jax.Array

_default_image_module = vit_moe
_default_text_module = sys.modules[__name__]


class TextTransformer(nn.Module):
"""Text transformer similar to CLIP, allowing blocks with MoEs."""

# Differences to CLIP text encoder (gpt-2) that I am aware of:
# 1. https://imgur.com/HNi3jix (gpt-1)
# 2. https://imgur.com/qKGZgBR (gpt-2)
# 3. https://imgur.com/a/xrpYHF0 (clip)
# - LayerNorm is on res-path (like pre-activation resnet)
# - dropout 0.1 everywhere
# - init as var=0.02, scaled by depth
# - BOS and EOS tokens, take repr from EOS.
# - self-attention is autoregressively masked.
# - scaled in width only, with the image model.
vocab_size: int
num_classes: Optional[int]
hidden_size: int
encoder: Mapping[str, Any]
pool_type: Literal['last', 'first', 'gap', 'gmp', 'map'] = 'last'
deterministic: bool = False
head_bias_init: float = 0.0
head_kernel_zero_init: bool = False

@property
def kernel_init(self) -> nn.initializers.Initializer:
if self.head_kernel_zero_init:
return nn.initializers.zeros
else:
return nn.linear.default_kernel_init

@nn.compact
def __call__(self, text):
# We can't use where/argwhere since the output shape is not fixed.
# Here we use the fact that sequences are padded with EOS tokens, that the
# EOS token has value 1, and that argmin returns the first index.
# eos_indices = jnp.argmin(text, axis=1)

embedding = nn.Embed(
num_embeddings=self.vocab_size, features=self.hidden_size)
x = embedding(text)

# TODO(jpuigcerver): Move position embedding outside of the Encoder class.
encoder_kwargs = dict(self.encoder)
if encoder_kwargs.get('position_emb', {}).get('name') == 'sincos2d':
raise ValueError(
'sincos2d position embeddings are not supproted for text.')

x, metrics = vit_moe.EncoderMoe(
name='Encoder', deterministic=self.deterministic, **encoder_kwargs)(x)

x = self.apply_pooling(x)

if self.num_classes:
# Linear head outputing the logits for classification.
logits = nn.Dense(
features=self.num_classes,
name='head',
kernel_init=self.kernel_init,
bias_init=nn.initializers.constant(self.head_bias_init))(x)
return logits, metrics
else:
return x, metrics

@nn.nowrap
def apply_pooling(self, x):
match self.pool_type:
case 'last': return x[:, -1, :]
case 'first': return x[:, 0, :]
case 'gap': return x.mean(axis=1)
case 'gmp': return x.max(axis=1)
case 'map':
return vit_moe.MapHead(
num_heads=self.encoder['num_heads'],
mlp_dim=self.encoder['mlp_dim'],
qk_norm=self.encoder.get('attention_qk_norm', False),
name='MapHead')(x)
case _:
raise NotImplementedError(f'Cannot do pooling {self.pool_type!r}')


class TwoTower(nn.Module):
"""A two-tower encoder model."""
image: Mapping[str, Any]
text: Mapping[str, Any]
scale_init: float = 1.0
bias_init: float | None = None
deterministic: bool = False

@functools.cached_property
def image_model_class(self):
# Default model for the image encoder is a Vision Transformer with MoEs.
model_cls = self.image.get('name', 'VisionTransformerMoe')
model_cls, args, kwargs = utils.parse_call(model_cls, _default_image_module)
kwargs.update({k: v for k, v in self.image.items() if k != 'name'})
return functools.partial(
model_cls, *args, **kwargs, deterministic=self.deterministic)

@functools.cached_property
def text_model_class(self):
# Default model for the text encoder is a Text Transformer.
model_cls = self.text.get('name', 'TextTransformer')
model_cls, args, kwargs = utils.parse_call(model_cls, _default_text_module)
kwargs.update({k: v for k, v in self.text.items() if k != 'name'})
return functools.partial(
model_cls, *args, **kwargs, deterministic=self.deterministic)

@nn.compact
def __call__(
self,
images: Array | None,
texts: Array | None,
) -> Tuple[Array, Mapping[str, Any]]:
if images is None and texts is None:
raise ValueError('You must give at least one of images or texts arrays.')
zimg, ztxt, metrics = None, None, {}

if images is not None:
zimg, metrics_img = self.image_model_class(name='img')(images)
zimg_norm = jnp.linalg.norm(zimg, axis=-1, keepdims=True)
zimg /= zimg_norm + 1e-8
self.sow('intermediates', 'zimg', zimg)
metrics['img'] = metrics_img

if texts is not None:
ztxt, metrics_txt = self.text_model_class(name='txt')(texts)
ztxt_norm = jnp.linalg.norm(ztxt, axis=-1, keepdims=True)
ztxt /= ztxt_norm + 1e-8
self.sow('intermediates', 'ztxt', ztxt)
metrics['txt'] = metrics_txt

if images is None:
# Return text embeddings and metrics.
return ztxt, metrics
elif texts is None:
# Return image embeddings and metrics.
return zimg, metrics
else:
# Compute logits as the dot product of the image and text embeddings.
logits = jnp.einsum('...md,...nd->...mn', zimg, ztxt)

# Note: Big Vision calls this "temperature", but it's actually
# 1/temperature, if one uses the standard definition of temperature.
scale_init = jnp.log(self.scale_init)
s = self.param('s', nn.initializers.constant(scale_init),
(), jnp.float32).astype(logits.dtype)
s = jnp.exp(s)
logits *= s
metrics['scale'] = s

if self.bias_init is not None:
b = self.param('b', nn.initializers.constant(self.bias_init),
(), jnp.float32).astype(logits.dtype)
logits += b

# Return the logits and the metrics.
return logits, metrics
Loading

0 comments on commit 28f0b0b

Please sign in to comment.