Skip to content

BobMcDear/open-clip-jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

99 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

CLIP in JAX/Flax

Introduction
Installation
Usage
Training
    • Dataset Preparation
        • CSV
        • TFRecord
    • Single-Worker Training
    • Multi-Worker Training
Available Models
Acknowledgements
Citations

Introduction

open_clip_jax is an open source JAX/Flax implementation of OpenAI's CLIP, including image and text towers, pre-trained parameters, training utilities, and more. It is inspired by but not affiliated with OpenCLIP and aims to deliver similar functionalities with a JAX backend.

Installation

The JAX installation process may differ depending on one's machine, so JAX needs to be installed manually by the user. Afterwards, open_clip_jax can be installed through pip install git+https://github.com/BobMcDear/open-clip-jax.git.

Usage

CLIPInference is a convenience class for conducting inference, which can be called on raw images and texts to compute their similarity scores, as demonstrated below.

import jax
from PIL import Image
from open_clip_jax import CLIPInference


clip = CLIPInference(
    'vit-base-patch32',
    softmax_temp=100.,
    pretrained='laion2b-s34b-b79k',
    )
image = Image.open('CLIP.png').convert('RGB')
text = ['A diagram', 'A dog', 'A cat']

# image and text can be single data points or lists.
probs, _ = clip(image, text)
print(probs)

Under the hood, CLIPInference utilizes create_model_with_params to create the CLIP model, create_image_transforms to pre-process the image(s), and tokenize to tokenize the text(s). A sample usage of these functions, equivalent to the code above, is exhibited in the following snippet. Breaking CLIPInference into these smaller components can offer greater flexibility.

from typing import Dict

import jax
from PIL import Image
from jax import Array
from open_clip_jax import create_image_transforms, create_model_with_params, tokenize


model, vars = create_model_with_params(
    'vit-base-patch32',
    pretrained='laion2b-s34b-b79k',
    )
image_transforms = create_image_transforms(
    train=False,
    input_format='image',
    do_batch_transforms=False,
    )

image = image_transforms(Image.open('CLIP.png').convert('RGB'))._numpy()
image = np.expand_dims(image, axis=0)
text = tokenize(['A diagram', 'A dog', 'A cat'])._numpy()

def calculate_similarity(vars: Dict, image: Array, text: Array) -> Array:
    # CLIP returns L2-normalized image and text features.
    image_proj, text_proj = model.apply(vars, image, text)
    return nn.softmax(100 * image_proj @ text_proj.T)

probs = jax.jit(calculate_similarity)(vars, image, text)
print(probs)

Training

This repository also supports training CLIP models from scratch, using either the utilities supplied by open_clip_jax.training for more fine-grained control or main.py for a fully-featured training script. The ensuing sections elaborate on training with main.py.

Dataset Preparation

main.py accepts two data formats, CSV files or TFRecords. The latter should generally be preferred as tf.data pipelines constructed around TFRecords are quite efficient, especially if the data is stored remotely, but training using CSV files can be more convenient and should not be an issue when dealing with smaller datasets. Dataset preparation instructions for each case are outlined below.

CSV

To prepare a dataset for CSV training, a CSV file needs to be created with one column containing image paths and another holding text captions corresponding to each image. Other columns may be included as well, but they are not read, and the order of the columns is also ignored. For instance, the table below displays how such a file may be structured.

caption image_path
Diagram of OpenAI's CLIP model clip.jpg
A Siamese cat cat.jpg
Dog running on grass dog.jpg
... ...

TensorFlow integrates seamlessly with Google Cloud Storage (GCS), so the CSV file or images may be stored in a GCS bucket, as can be seen below. However, doing so would slow down data loading since GCS has a high time to first byte (TTFB), and therefore TFRecords would be the appropriate option if storing data in the cloud.

caption image_path
Diagram of OpenAI's CLIP model gs://open_clip_jax/clip.jpg
A Siamese cat gs://open_clip_jax/cat.jpg
Dog running on grass gs://open_clip_jax/dog.jpg
... ...

TFRecord

To prepare a dataset for TFRecord training, every image-text pair must be written as a tf.train.Example/Protobuf message (images as JPEG-encoded bytes, text captions as strings) to TFRecord files stored in a single directory (local or in a GCS bucket), with ideally 100+ MB of data, or 10,000 samples, per file. img2dataset can automatically convert image URLs to such datasets by setting the output format to TFRecord via --output_format tfrecord and supports many popular datasets, e.g., COCO, LAION-400M, etc.

Single-Worker Training

In single-worker settings, assuming JAX and open_clip_jax have been installed, main.py simply needs to be downloaded and executed to begin training. If using cloud computing, the remote server should be logged into first, the data optionally transferred to it, and finally the following commands can be run, ideally in a tmux or screen session.

wget https://raw.githubusercontent.com/BobMcDear/open-clip-jax/main/main.py -q
python3 open-clip-jax/main.py \
    --train-path train.csv \
    --valid-path valid.csv \
    --image-key image_path \
    --text-key caption \
    --global-batch-size 128 \
    --model-name vit-base-patch32 \
    --learning-rate 1e-3 \
    --n-epochs 30

Multi-Worker Training

Multi-worker training has been tested only for TPUs, although the process should remain largely identical for GPU clusters. To train on a pod slice, the same commands are sent to every worker in parallel to install the necessary packages and start training. The data must also be in a GCS bucket in the same zone as the VM. An end-to-end minimal example, with TPU creation and deletion commands included, can be seen below.

NAME=open_clip_jax
ZONE=us-central1-a
TYPE=v3-32
VERSION=v2-alpha

# Create pod slice
gcloud compute tpus tpu-vm create $NAME \
    --zone=$ZONE \
    --accelerator-type=$TYPE \
    --version=$VERSION

# Connect to TPUs and train
gcloud compute tpus tpu-vm ssh $NAME \
    --zone $ZONE \
    --worker=all \
    --command "
        pip install -U pip &&
        pip install -U jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html &&
        pip install git+https://github.com/BobMcDear/open-clip-jax.git &&
        wget https://raw.githubusercontent.com/BobMcDear/open-clip-jax/main/main.py -q &&
        python3 main.py \
            --train-path gs://open_clip_jax/train_tfrecords/ \
            --valid-path gs://open_clip_jax/valid_tfrecords/ \
            --image-key jpg \
            --text-key caption \
            --global-batch-size 1024 \
            --model-name vit-base-patch32 \
            --learning-rate 3e-3 \
            --n-epochs 30 \
            --checkpoint-dir gs://open_clip_jax/checkpoints/
        "

# Delete VM
gcloud compute tpus tpu-vm delete $NAME \
    --zone $ZONE

An important caveat that should be borne in mind is that epoch boundaries become blurry if the number of batches assigned to each worker varies. For example, suppose a dataset consists of two TFRecord files, one containing 32 samples and the other 64, and training is being performed on two workers (each receiving one TFRecord file) with a per-worker batch size of 16: In this scenario, the number of steps per epoch is calculated to be 3 = (64 + 32) / 32 (the 32 in the denominator is the global batch size), so 3 batches, or 48 = 16 * 3 samples, are taken from each file in every epoch. Consequently, during the initial epoch, half the samples from the first file are iterated over twice, whereas there are samples in the second file that are not seen at all. In the subsequent epoch, however, the remaining samples from the second file will be the first to be grabbed, and the process restarts once the entire file has been consumed. Such extreme discrepancies should be rare as the number of files increases and are unlikely to pose a problem in practice, but one should be cognizant of them nevertheless.

Available Models

There are three functions related to listing available models and pre-trained parameters:

  • list_models: Returns the name of every model, but some, such as ViT-Small, do not have associated pre-trained parameters.
  • list_pretrained: Returns tuples of (name of model, name of pre-trained parameters). A model may have several groups of pre-trained parameters, so there may be multiple entries with identical model names but different pre-trained parameters.
  • list_pretrained_by_model: Returns a particular model's pre-trained parameters.
>>> import open_clip_jax
>>> open_clip_jax.list_models()
('convnext-base-w',
 'convnext-base',
 'convnext-large-d',
 'vit-base-patch16',
 'vit-base-patch32',
 'vit-huge-patch14',
 'vit-huge-patch16',
 'vit-large-patch14',
 'vit-large-patch16',
 'vit-nano-patch32',
 'vit-small-patch16',
 'vit-small-patch32')
>>> open_clip_jax.list_pretrained()
(('convnext-base', 'laion400m-s13b-b51k'),
 ('convnext-base-w', 'laion-aesthetic-s13b-b82k'),
 ('convnext-base-w', 'laion-aesthetic-s13b-b82k-320'),
 ('convnext-base-w', 'laion-aesthetic-s13b-b82k-augreg-320'),
 ('convnext-base-w', 'laion2b-s13b-b82k'),
 ('convnext-base-w', 'laion2b-s13b-b82k-augreg'),
 ('convnext-large-d', 'laion2b-s26b-b102k-augreg'),
 ('convnext-large-d', 'laion2b-s29b-b131k-ft-320'),
 ('convnext-large-d', 'laion2b-s29b-b131k-ft-soup-320'),
 ('vit-base-patch32', 'laion400m-e31'),
 ('vit-base-patch32', 'laion400m-e32'),
 ('vit-base-patch32', 'laion2b-e16'),
 ('vit-base-patch32', 'laion2b-s34b-b79k'),
 ('vit-base-patch16', 'laion400m-e31'),
 ('vit-base-patch16', 'laion400m-e32'),
 ('vit-base-patch16', 'laion2b-s34b-b88k'),
 ('vit-large-patch14', 'laion400m-e31'),
 ('vit-large-patch14', 'laion400m-e32'),
 ('vit-large-patch14', 'laion2b-s32b-b82k'),
 ('vit-huge-patch14', 'laion2b-s32b-b79k'))
>>> open_clip_jax.list_pretrained_by_model('vit-base-patch32')
('laion400m-e31', 'laion400m-e32', 'laion2b-e16', 'laion2b-s34b-b79k')

The pre-trained parameters have been ported from OpenCLIP, and more information regarding them, such as their training recipes or zero-shot performance, can be found in the OpenCLIP repository or as model cards on Hugging Face Hub.

Acknowledgements

Thanks to Google's TPU Research Cloud (TRC) program for providing hardware used to accelerate the development of this project.

Citations

@software{ilharco_gabriel_2021_5143773,
  author       = {Ilharco, Gabriel and
                  Wortsman, Mitchell and
                  Wightman, Ross and
                  Gordon, Cade and
                  Carlini, Nicholas and
                  Taori, Rohan and
                  Dave, Achal and
                  Shankar, Vaishaal and
                  Namkoong, Hongseok and
                  Miller, John and
                  Hajishirzi, Hannaneh and
                  Farhadi, Ali and
                  Schmidt, Ludwig},
  title        = {OpenCLIP},
  month        = jul,
  year         = 2021,
  note         = {If you use this software, please cite it as below.},
  publisher    = {Zenodo},
  version      = {0.1},
  doi          = {10.5281/zenodo.5143773},
  url          = {https://doi.org/10.5281/zenodo.5143773}
}
@inproceedings{Radford2021LearningTV,
  title={Learning Transferable Visual Models From Natural Language Supervision},
  author={Alec Radford and Jong Wook Kim and Chris Hallacy and A. Ramesh and Gabriel Goh and Sandhini Agarwal and Girish Sastry and Amanda Askell and Pamela Mishkin and Jack Clark and Gretchen Krueger and Ilya Sutskever},
  booktitle={ICML},
  year={2021}
}
@inproceedings{schuhmann2022laionb,
  title={{LAION}-5B: An open large-scale dataset for training next generation image-text models},
  author={Christoph Schuhmann and
          Romain Beaumont and
          Richard Vencu and
          Cade W Gordon and
          Ross Wightman and
          Mehdi Cherti and
          Theo Coombes and
          Aarush Katta and
          Clayton Mullis and
          Mitchell Wortsman and
          Patrick Schramowski and
          Srivatsa R Kundurthy and
          Katherine Crowson and
          Ludwig Schmidt and
          Robert Kaczmarczyk and
          Jenia Jitsev},
  booktitle={Thirty-sixth Conference on Neural Information Processing Systems Datasets and Benchmarks Track},
  year={2022},
  url={https://openreview.net/forum?id=M3Y74vmsMcY}
}