MATLAB implementation of the OpenAI CLIP deep learning model
For training code see trainCLIP.mlx
if ~isfile("net-gpu-poor-squeezenet-bert-tiny.mat")
!curl -O https://github.com/Lxrd-AJ/openai-clip-matlab/releases/download/v1.0.0/net-gpu-poor-squeezenet-bert-tiny.mat
end
load net-gpu-poor-squeezenet-bert-tiny.mat net logTemperature
imageInputSize = net.Layers(1).Layers(1).InputSize(1:2);
clip = CLIP(net, Temperature=exp(logTemperature), ImageInputSize=imageInputSize)
clip =
CLIP with no properties.
% Try encoding some images
% First gather all image paths from the dataset
imgBaseDir = "./flickr-dataset/Flicker8k_Dataset/";
% Test Images
devImages = readlines("flickr-dataset/Flickr_8k.devImages.txt");
devImages = fullfile(imgBaseDir, devImages);
% Get some random images from the dataset
someImagePaths = randsample(devImages, 20);
images = arrayfun(@(x) imread(x), someImagePaths, UniformOutput=false);
montage(images)
[probs, logits] = clip.predict(someImagePaths, ["two Dogs", "Birthday Party"]);
disp(probs)
20x2 single gpuArray dlarray
0.0000 0.0003
0.0000 0.0000
0.0000 0.9989
0.0000 0.0005
0.0000 0.0001
0.0000 0.0000
0.0000 0.0000
0.0000 0.0000
0.0000 0.0000
0.0000 0.0000
0.0000 0.0000
0.9998 0.0000
0.0000 0.0000
0.0000 0.0000
0.0000 0.0000
0.0000 0.0000
0.0002 0.0000
0.0000 0.0002
0.0000 0.0000
0.0000 0.0000
[maxProb, maxIdx] = max(extractdata(gather(probs)));
disp("Query Match Probability: " + maxProb)
"Query Match Probability: 0.99983" "Query Match Probability: 0.99886"
maxImages = {};
for idx=1:numel(maxIdx)
maxImages{end+1} = imread(someImagePaths(maxIdx(idx)));
end
montage(maxImages)
datastore = CLIPDatastore(ImageFolder="./flickr-dataset/Flicker8k_Dataset");
ds = shuffle(datastore, "PercentageToKeep", 1)
ds =
CLIPDatastore with no properties.
disp("Number of training images " + numel(ds))
Number of training images 30000
testDatastore = CLIPDatastore(ImageFolder="./flickr-dataset/Flicker8k_Dataset", TrainTestVal="./flickr-dataset/Flickr_8k.testImages.txt");
tds = shuffle(testDatastore, "PercentageToKeep", 1)
tds =
CLIPDatastore with no properties.
disp("Number of test images " + numel(tds))
Number of test images 5000
[net,tokenizer] = bert();
r = read(ds)
1 | 2 | 3 | |
---|---|---|---|
1 | 375x500x3 uint8 | 1x19 double | "A brown and white dog be stand on a beach with a tennis ball beside it ." |
[im, tokens, caption] = r{:};
imshow(im)
%caption = "A group of horse and their rider be race each other .";
[~, segments] = encode(tokenizer, caption);
dltoken = dlarray(tokens, 'CT');
dlsegment = dlarray(segments{1}, 'CT');
mask = dlarray(ones(1, numel(dlsegment)), 'CT');
pred = predict(net, dltoken, mask, dlsegment);
decodedTokens = decode(tokenizer, tokens)
decodedTokens = "[CLS] a brown and white dog be stand on a beach with a tennis ball beside it . [SEP]"
disp(size(pred))
768 19
last = pred(:,1) % Use the [CLS] token but the official CLIP implementation trained their own text encoder and used the last token's embedding
last =
768(C) x 1(T) single dlarray
-0.1619
0.1122
0.0165
0.2698
0.4631
-0.1553
-0.8530
...
-0.6838
0.0352
0.1187
-0.6164
0.5134
0.2817
-0.7342
0.0385
0.1096
-0.1773
-0.6413
768(C) x 1(T) single dlarray
% NB: Using Bert for batched prediction
[net, tokenizer] = bert() %bert("Model","tiny");
net =
dlnetwork with properties:
Layers: [129x1 nnet.cnn.layer.Layer]
Connections: [164x2 table]
Learnables: [197x3 table]
State: [0x3 table]
InputNames: {'input_ids' 'attention_mask' 'seg_ids'}
OutputNames: {'enc12_layernorm2'}
Initialized: 1
View summary with summary.
tokenizer =
bertTokenizer with properties:
IgnoreCase: 1
StripAccents: 1
PaddingToken: "[PAD]"
PaddingCode: 1
StartToken: "[CLS]"
StartCode: 102
UnknownToken: "[UNK]"
UnknownCode: 101
SeparatorToken: "[SEP]"
SeparatorCode: 103
ContextSize: 512
paddingValue = tokenizer.PaddingCode;
str = [
"Coolant is pooling underneath sorter."
"Sorter blows fuses at start up."
"There are some very loud rattling sounds coming from the assembler."];
[inputIdsStr, segmentIdsStr] = encode(tokenizer, str);
% The `maskStr` is used to specify the indices of the `paddingValue` so
% that the model ignores the padding
[inputIdsStr, maskStr] = padsequences(inputIdsStr, 2,"PaddingValue",paddingValue);
segmentIdsStr = padsequences(segmentIdsStr, 2,"PaddingValue",paddingValue);
inputIdsStr = dlarray(inputIdsStr, "CTB");
maskStr = dlarray(maskStr, "CTB");
segmentIdsStr = dlarray(segmentIdsStr, "CTB");
predictions = predict(net,inputIdsStr,maskStr,segmentIdsStr);
net = imageEncoder();
randX = dlarray(randn(net.Layers(1).InputSize), 'SSC');
net = dlnetwork(net, randX);
out = predict(net, randX);
disp(size(out))
100352 1
net = textEncoder();
randX = dlarray(randn(1,1,10), 'CBT');
net = dlnetwork(net, Initialize=false);
net = initialize(net, randX, randX, randX)
net =
dlnetwork with properties:
Layers: [1x1 nnet.cnn.layer.NetworkLayer]
Connections: [0x2 table]
Learnables: [37x3 table]
State: [0x3 table]
InputNames: {'bert_encoder/bert_model/input_ids' 'bert_encoder/bert_model/attention_mask' 'bert_encoder/bert_model/seg_ids'}
OutputNames: {'bert_encoder'}
Initialized: 1
View summary with summary.
randInputIDs = dlarray(randi(1000, [1 3 10]), 'CBT');
attentionMask = dlarray(ones(size(randInputIDs)), 'CBT');
segmentIDs = dlarray(ones(size(randInputIDs)), 'CBT');
out = predict(net, randInputIDs, attentionMask, segmentIDs);
clsEmbeddings = out(:,:,1);
projHead = projectionHead();
net = dlnetwork(projHead, dlarray(randn(1,2048), 'BC'))
net =
dlnetwork with properties:
Layers: [1x1 nnet.cnn.layer.NetworkLayer]
Connections: [0x2 table]
Learnables: [4x3 table]
State: [0x3 table]
InputNames: {'proj'}
OutputNames: {'proj'}
Initialized: 1
View summary with summary.
out = predict(net, dlarray(randn(1,2048), 'BC'));
size(out)
ans = 1x2
256 1
See https://hockenmaier.cs.illinois.edu/8k-pictures.html Data sources for download
- Flickr 8K https://github.com/goodwillyoga/Flickr8k_dataset or see https://github.com/jbrownlee/Datasets/blob/master/Flickr8k_Dataset.names
- Flickr 8K https://www.kaggle.com/datasets/adityajn105/flickr8k/data
- Use https://uk.mathworks.com/help/matlab/ref/memmapfile.html to store & query the image embeddings for fast search
- Design a smaller model (use Bert tiny and design a smaller image encoder from an existing pretrained image model - use squeezenet)
- Allow the encoder models to learn but with a smaller learning rate
- Use [SEP] token from bert rather than [CLS] token
- Allow the model to learn the logits scaling
- Support training on the train, validation and test sets
- Update datastore
- Calculate accuracy metric:
argmax(logits) == targets
- In training loop perform validation while training
- Compute accuracy on validation set
- Follow model design and training guides in Section 2.4 & 2.5
- Use cosine schedule
- Clip logits scaling temperature parameter to 100 max
- Move image resizing outside of the
processMiniBatch
function and into a transform function for the datastore - Upgraded datastore class: Use the provided train, validation and test sets.
- Save the model at different checkpoints during training
- Train on Flickr30k dataset
- Wrapper class around the CLIP model
- See API in https://github.com/openai/CLIP?tab=readme-ov-file#api
- Encode images
- Get softmax and logit scores for a batch of (image, text) pair
- Find the top-k images that match a given query
- See API in https://github.com/openai/CLIP?tab=readme-ov-file#api
- Comparison against CIFAR-10(100)
- Front end GUI for interfacing with the model (using uicomponentcontainer)
- Index an existing folder
- Run indexing in
backgroundPool