-
Notifications
You must be signed in to change notification settings - Fork 53
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Release config and code necessary to pretrain on LAION-400M.
PiperOrigin-RevId: 622146719
- Loading branch information
1 parent
f0a8702
commit 28f0b0b
Showing
7 changed files
with
1,094 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
# Copyright 2024 Google LLC. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Evaluators used during contrastive training.""" | ||
import time | ||
from typing import Any, Callable, Iterable, Optional, Tuple | ||
|
||
from clu import metric_writers | ||
from clu import periodic_actions | ||
import jax | ||
|
||
# pylint: disable=g-import-not-at-top | ||
try: | ||
from big_vision.evaluators.proj.image_text import discriminative_classifier as bv_discriminative | ||
except ImportError: | ||
bv_discriminative = None | ||
|
||
|
||
try: | ||
from big_vision.evaluators.proj.image_text import retrieval as bv_retrieval | ||
except ImportError: | ||
bv_retrieval = None | ||
# pylint: enable=g-import-not-at-top | ||
|
||
Array = jax.Array | ||
PyTree = Any | ||
|
||
|
||
class ZeroShotPeriodicAction(periodic_actions.PeriodicCallback): | ||
"""Periodic action that runs Big Vision's Retrieval evaluator repeatedly.""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
metric_writer: metric_writers.MetricWriter, | ||
apply_fn: Callable[..., Tuple[Array, Array, Any]], | ||
data_sharding: jax.sharding.NamedSharding, | ||
every_steps: Optional[int] = None, | ||
every_secs: Optional[float] = None, | ||
on_steps: Optional[Iterable[int]] = None, | ||
report_progress: Optional[periodic_actions.ReportProgress] = None, | ||
report_progress_name: str = 'zeroshot', | ||
**bv_evaluator_kwargs, | ||
): | ||
"""Constructor.""" | ||
if bv_discriminative is None: | ||
raise NotImplementedError( | ||
'Big Vision must be installed to run the discriminative evaluation.') | ||
bv_evaluator = bv_discriminative.Evaluator( | ||
predict_fn=apply_fn, | ||
devices=list(data_sharding.mesh.devices.flatten()), | ||
**bv_evaluator_kwargs, | ||
) | ||
callback = self._make_callback_fn( | ||
evaluator=bv_evaluator, | ||
metric_writer=metric_writer, | ||
report_progress=report_progress, | ||
report_progress_name=report_progress_name, | ||
) | ||
super().__init__( | ||
every_steps=every_steps, | ||
every_secs=every_secs, | ||
on_steps=on_steps, | ||
callback_fn=callback, | ||
execute_async=False, | ||
pass_step_and_time=True) | ||
|
||
def _make_callback_fn( | ||
self, *, evaluator, metric_writer, report_progress, | ||
report_progress_name): | ||
|
||
def callback_fn(step: int, t: Optional[float], variables: PyTree, **kwargs): | ||
del t # Unused. | ||
metrics = {} | ||
t0 = time.time() | ||
for task in evaluator.datasets: | ||
acc = evaluator.evaluate(variables, task)['accuracy'] | ||
t1 = time.time() | ||
metrics[f'{report_progress_name}/{task}/accuracy'] = acc | ||
metrics[f'{report_progress_name}/{task}/duration_secs'] = t1 - t0 | ||
metrics = metrics | {k: v for k, v in kwargs.items() if v is not None} | ||
metric_writer.write_scalars(step, metrics) | ||
|
||
if report_progress is None: | ||
return callback_fn | ||
else: | ||
return report_progress.timed( | ||
report_progress_name, wait_jax_async_dispatch=False)(callback_fn) | ||
|
||
|
||
class RetrievalPeriodicAction(periodic_actions.PeriodicCallback): | ||
"""Periodic action that runs Big Vision's Retrieval evaluator repeatedly.""" | ||
|
||
def __init__( | ||
self, | ||
*, | ||
metric_writer: metric_writers.MetricWriter, | ||
apply_fn: Callable[..., Tuple[Array, Array, Any]], | ||
task: str, | ||
data_sharding: jax.sharding.NamedSharding, | ||
every_steps: Optional[int] = None, | ||
every_secs: Optional[float] = None, | ||
on_steps: Optional[Iterable[int]] = None, | ||
report_progress: Optional[periodic_actions.ReportProgress] = None, | ||
report_progress_name: str = 'retrieval', | ||
**bv_evaluator_kwargs, | ||
): | ||
"""Constructor.""" | ||
if bv_retrieval is None: | ||
raise NotImplementedError( | ||
'Big Vision must be installed to run the retrieval evaluation.') | ||
bv_evaluator = bv_retrieval.Evaluator( | ||
predict_fn=apply_fn, | ||
devices=list(data_sharding.mesh.devices.flatten()), | ||
**bv_evaluator_kwargs, | ||
) | ||
callback = self._make_callback_fn( | ||
evaluator=bv_evaluator, | ||
task=task, | ||
metric_writer=metric_writer, | ||
report_progress=report_progress, | ||
report_progress_name=report_progress_name, | ||
) | ||
super().__init__( | ||
every_steps=every_steps, | ||
every_secs=every_secs, | ||
on_steps=on_steps, | ||
callback_fn=callback, | ||
execute_async=False, | ||
pass_step_and_time=True) | ||
|
||
def _make_callback_fn( | ||
self, *, evaluator, task, metric_writer, report_progress, | ||
report_progress_name): | ||
|
||
def callback_fn(step: int, t: Optional[float], variables: PyTree, **kwargs): | ||
del t # Unused. | ||
metrics = {} | ||
t0 = time.time() | ||
bv_metrics = evaluator.evaluate(variables) | ||
metrics.update({ | ||
f'{report_progress_name}/{task}/txt2img/{k}': v | ||
for k, v in bv_metrics['txt2img'].items() | ||
}) | ||
metrics.update({ | ||
f'{report_progress_name}/{task}/img2txt/{k}': v | ||
for k, v in bv_metrics['img2txt'].items() | ||
}) | ||
t1 = time.time() | ||
metrics[f'{report_progress_name}/{task}/duration_secs'] = t1 - t0 | ||
metrics = metrics | {k: v for k, v in kwargs.items() if v is not None} | ||
metric_writer.write_scalars(step, metrics) | ||
|
||
if report_progress is None: | ||
return callback_fn | ||
else: | ||
return report_progress.timed( | ||
report_progress_name, wait_jax_async_dispatch=False)(callback_fn) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,186 @@ | ||
# Copyright 2024 Google LLC. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Two tower model used for contrastive learning.""" | ||
import functools | ||
import sys | ||
from typing import Any, Mapping, Literal, Optional, Tuple | ||
|
||
import flax.linen as nn | ||
import jax | ||
import jax.numpy as jnp | ||
from vmoe import utils | ||
from vmoe.nn import vit_moe | ||
|
||
Array = jax.Array | ||
|
||
_default_image_module = vit_moe | ||
_default_text_module = sys.modules[__name__] | ||
|
||
|
||
class TextTransformer(nn.Module): | ||
"""Text transformer similar to CLIP, allowing blocks with MoEs.""" | ||
|
||
# Differences to CLIP text encoder (gpt-2) that I am aware of: | ||
# 1. https://imgur.com/HNi3jix (gpt-1) | ||
# 2. https://imgur.com/qKGZgBR (gpt-2) | ||
# 3. https://imgur.com/a/xrpYHF0 (clip) | ||
# - LayerNorm is on res-path (like pre-activation resnet) | ||
# - dropout 0.1 everywhere | ||
# - init as var=0.02, scaled by depth | ||
# - BOS and EOS tokens, take repr from EOS. | ||
# - self-attention is autoregressively masked. | ||
# - scaled in width only, with the image model. | ||
vocab_size: int | ||
num_classes: Optional[int] | ||
hidden_size: int | ||
encoder: Mapping[str, Any] | ||
pool_type: Literal['last', 'first', 'gap', 'gmp', 'map'] = 'last' | ||
deterministic: bool = False | ||
head_bias_init: float = 0.0 | ||
head_kernel_zero_init: bool = False | ||
|
||
@property | ||
def kernel_init(self) -> nn.initializers.Initializer: | ||
if self.head_kernel_zero_init: | ||
return nn.initializers.zeros | ||
else: | ||
return nn.linear.default_kernel_init | ||
|
||
@nn.compact | ||
def __call__(self, text): | ||
# We can't use where/argwhere since the output shape is not fixed. | ||
# Here we use the fact that sequences are padded with EOS tokens, that the | ||
# EOS token has value 1, and that argmin returns the first index. | ||
# eos_indices = jnp.argmin(text, axis=1) | ||
|
||
embedding = nn.Embed( | ||
num_embeddings=self.vocab_size, features=self.hidden_size) | ||
x = embedding(text) | ||
|
||
# TODO(jpuigcerver): Move position embedding outside of the Encoder class. | ||
encoder_kwargs = dict(self.encoder) | ||
if encoder_kwargs.get('position_emb', {}).get('name') == 'sincos2d': | ||
raise ValueError( | ||
'sincos2d position embeddings are not supproted for text.') | ||
|
||
x, metrics = vit_moe.EncoderMoe( | ||
name='Encoder', deterministic=self.deterministic, **encoder_kwargs)(x) | ||
|
||
x = self.apply_pooling(x) | ||
|
||
if self.num_classes: | ||
# Linear head outputing the logits for classification. | ||
logits = nn.Dense( | ||
features=self.num_classes, | ||
name='head', | ||
kernel_init=self.kernel_init, | ||
bias_init=nn.initializers.constant(self.head_bias_init))(x) | ||
return logits, metrics | ||
else: | ||
return x, metrics | ||
|
||
@nn.nowrap | ||
def apply_pooling(self, x): | ||
match self.pool_type: | ||
case 'last': return x[:, -1, :] | ||
case 'first': return x[:, 0, :] | ||
case 'gap': return x.mean(axis=1) | ||
case 'gmp': return x.max(axis=1) | ||
case 'map': | ||
return vit_moe.MapHead( | ||
num_heads=self.encoder['num_heads'], | ||
mlp_dim=self.encoder['mlp_dim'], | ||
qk_norm=self.encoder.get('attention_qk_norm', False), | ||
name='MapHead')(x) | ||
case _: | ||
raise NotImplementedError(f'Cannot do pooling {self.pool_type!r}') | ||
|
||
|
||
class TwoTower(nn.Module): | ||
"""A two-tower encoder model.""" | ||
image: Mapping[str, Any] | ||
text: Mapping[str, Any] | ||
scale_init: float = 1.0 | ||
bias_init: float | None = None | ||
deterministic: bool = False | ||
|
||
@functools.cached_property | ||
def image_model_class(self): | ||
# Default model for the image encoder is a Vision Transformer with MoEs. | ||
model_cls = self.image.get('name', 'VisionTransformerMoe') | ||
model_cls, args, kwargs = utils.parse_call(model_cls, _default_image_module) | ||
kwargs.update({k: v for k, v in self.image.items() if k != 'name'}) | ||
return functools.partial( | ||
model_cls, *args, **kwargs, deterministic=self.deterministic) | ||
|
||
@functools.cached_property | ||
def text_model_class(self): | ||
# Default model for the text encoder is a Text Transformer. | ||
model_cls = self.text.get('name', 'TextTransformer') | ||
model_cls, args, kwargs = utils.parse_call(model_cls, _default_text_module) | ||
kwargs.update({k: v for k, v in self.text.items() if k != 'name'}) | ||
return functools.partial( | ||
model_cls, *args, **kwargs, deterministic=self.deterministic) | ||
|
||
@nn.compact | ||
def __call__( | ||
self, | ||
images: Array | None, | ||
texts: Array | None, | ||
) -> Tuple[Array, Mapping[str, Any]]: | ||
if images is None and texts is None: | ||
raise ValueError('You must give at least one of images or texts arrays.') | ||
zimg, ztxt, metrics = None, None, {} | ||
|
||
if images is not None: | ||
zimg, metrics_img = self.image_model_class(name='img')(images) | ||
zimg_norm = jnp.linalg.norm(zimg, axis=-1, keepdims=True) | ||
zimg /= zimg_norm + 1e-8 | ||
self.sow('intermediates', 'zimg', zimg) | ||
metrics['img'] = metrics_img | ||
|
||
if texts is not None: | ||
ztxt, metrics_txt = self.text_model_class(name='txt')(texts) | ||
ztxt_norm = jnp.linalg.norm(ztxt, axis=-1, keepdims=True) | ||
ztxt /= ztxt_norm + 1e-8 | ||
self.sow('intermediates', 'ztxt', ztxt) | ||
metrics['txt'] = metrics_txt | ||
|
||
if images is None: | ||
# Return text embeddings and metrics. | ||
return ztxt, metrics | ||
elif texts is None: | ||
# Return image embeddings and metrics. | ||
return zimg, metrics | ||
else: | ||
# Compute logits as the dot product of the image and text embeddings. | ||
logits = jnp.einsum('...md,...nd->...mn', zimg, ztxt) | ||
|
||
# Note: Big Vision calls this "temperature", but it's actually | ||
# 1/temperature, if one uses the standard definition of temperature. | ||
scale_init = jnp.log(self.scale_init) | ||
s = self.param('s', nn.initializers.constant(scale_init), | ||
(), jnp.float32).astype(logits.dtype) | ||
s = jnp.exp(s) | ||
logits *= s | ||
metrics['scale'] = s | ||
|
||
if self.bias_init is not None: | ||
b = self.param('b', nn.initializers.constant(self.bias_init), | ||
(), jnp.float32).astype(logits.dtype) | ||
logits += b | ||
|
||
# Return the logits and the metrics. | ||
return logits, metrics |
Oops, something went wrong.