Skip to content

Commit

Permalink
feat: add OpenCLIP models (#35)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jim-Encord authored Feb 28, 2024
1 parent 65d43fe commit 50caf37
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 38 deletions.
14 changes: 1 addition & 13 deletions clip_eval/common/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from pathlib import Path

import numpy as np
import torch
from pydantic import BaseModel, model_validator
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -78,19 +77,8 @@ def from_embedding_definition(model_name: str, dataset_name: str) -> "Embeddings

@staticmethod
def build_embedding(model: CLIPModel, dataset: Dataset, batch_size: int = 50) -> "Embeddings":
def _collate_fn(examples) -> dict[str, torch.Tensor]:
images = []
labels = []
for example in examples:
images.append(example["image"])
labels.append(example["label"])

pixel_values = torch.stack(images)
labels = torch.tensor(labels)
return {"pixel_values": pixel_values, "labels": labels}

dataset.set_transform(model.get_transform())
dataloader = DataLoader(dataset, collate_fn=_collate_fn, batch_size=batch_size)
dataloader = DataLoader(dataset, collate_fn=model.get_collate_fn(), batch_size=batch_size)

image_embeddings, labels = model.build_embedding(dataloader)
embeddings = Embeddings(images=image_embeddings, labels=labels)
Expand Down
109 changes: 85 additions & 24 deletions clip_eval/models/CLIP_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any

import numpy as np
import open_clip
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
Expand All @@ -27,6 +28,7 @@ def __init__(
device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self._check_device(device)
self.__device = torch.device(device)
self._setup(**kwargs)

@property
def title(self) -> str:
Expand All @@ -41,15 +43,19 @@ def device(self) -> torch.device:
return self.__device

@abstractmethod
def _setup(self, **kwargs):
def _setup(self, **kwargs) -> None:
pass

@abstractmethod
def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, ClassArray]:
def get_transform(self) -> Callable[[dict[str, Any]], dict[str, list[Any]]]:
...

@abstractmethod
def get_collate_fn(self) -> Callable[[Any], Any]:
...

@abstractmethod
def get_transform(self) -> Callable[[Any], Any]:
def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, ClassArray]:
...

@staticmethod
Expand All @@ -64,9 +70,8 @@ def _check_device(device: str):
class closed_CLIPModel(CLIPModel):
def __init__(self, title: str, title_in_source: str, device: str | None = None) -> None:
super().__init__(title, title_in_source, device)
self._setup()

def define_process_fn(self) -> Callable[[dict[str, Any]], dict[str, list[Any]]]:
def get_transform(self) -> Callable[[dict[str, Any]], dict[str, list[Any]]]:
def process_fn(batch) -> dict[str, list[Any]]:
images = [i.convert("RGB") for i in batch["image"]]
batch["image"] = [
Expand All @@ -76,11 +81,24 @@ def process_fn(batch) -> dict[str, list[Any]]:

return process_fn

def _setup(self):
def get_collate_fn(self) -> Callable[[Any], Any]:
def collate_fn(examples) -> dict[str, torch.Tensor]:
images = []
labels = []
for example in examples:
images.append(example["image"])
labels.append(example["label"])

pixel_values = torch.stack(images)
labels = torch.tensor(labels)
return {"pixel_values": pixel_values, "labels": labels}

return collate_fn

def _setup(self, **kwargs) -> None:
self.model = HF_ClipModel.from_pretrained(self.title_in_source).to(self.device) # type: ignore
load_result = HF_ClipProcessor.from_pretrained(self.title_in_source)
self.processor = load_result[0] if isinstance(load_result, tuple) else load_result
self.process_fn = self.define_process_fn()

def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, ClassArray]:
tmp_embeddings = []
Expand All @@ -96,42 +114,74 @@ def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, Class
labels = class_array.numpy()
return image_embeddings, labels

def get_transform(self) -> Callable[[Any], Any]:
return self.process_fn


class open_CLIPModel(CLIPModel):
def __init__(
self,
title: str,
title_in_source: str | None = None,
model_name: str,
pretrained: str,
device: str | None = None,
**kwargs,
) -> None:
self.pretrained = pretrained
self.model_name = model_name
title_in_source = model_name + "_" + pretrained
super().__init__(title, title_in_source, device, **kwargs)
self._setup()

def _setup(self, **kwargs):
raise NotImplementedError("open Clip not implemented")
def get_transform(self) -> Callable[[dict[str, Any]], dict[str, list[Any]]]:
def process_fn(batch) -> dict[str, list[Any]]:
images = [i.convert("RGB") for i in batch["image"]]
batch["image"] = [self.processor(i).to(self.device).unsqueeze(0) for i in images]
return batch

def build_embedding(self, dataloader: DataLoader):
raise NotImplementedError("open Clip not implemented")
return process_fn

def get_collate_fn(self) -> Callable[[Any], Any]:
def collate_fn(examples) -> dict[str, torch.Tensor]:
images = []
labels = []
for example in examples:
images.append(example["image"])
labels.append(example["label"])

torch_images = torch.stack(images)
labels = torch.tensor(labels)
return {"image": torch_images, "labels": labels}

def get_transform(self) -> Callable[[Any], Any]:
raise NotImplementedError("open Clip not implemented")
return collate_fn

def _setup(self, **kwargs) -> None:
model, _, preprocess = open_clip.create_model_and_transforms(
model_name=self.model_name, pretrained=self.pretrained, **kwargs
)
self.model = model
self.processor = preprocess

def build_embedding(self, dataloader: DataLoader):
tmp_embeddings = []
tmp_labels = []
with torch.inference_mode():
for batch in tqdm(dataloader, desc=f"Embedding dataset with {self.title}"):
tmp_labels.append(batch["labels"])
features = torch.stack([self.model.encode_image(image) for image in batch["image"]])
emb = (features / features.norm(p=2, dim=-1, keepdim=True)).squeeze()
tmp_embeddings.append(emb.to("cpu"))
image_embeddings: EmbeddingArray = np.concatenate(tmp_embeddings, 0)
class_array = torch.concatenate(tmp_labels)
labels = class_array.numpy()
return image_embeddings, labels


class SiglipModel(CLIPModel):
def __init__(self, title: str, title_in_source: str | None = None, device: str | None = None, **kwargs) -> None:
super().__init__(title, title_in_source, device, **kwargs)
self._setup()

def _setup(self, **kwargs):
self.model = HF_SiglipModel.from_pretrained(self.title_in_source).to(self.device)
self.processor = HF_SiglipProcessor.from_pretrained(self.title_in_source)
self.process_fn = self.define_process_fn()

def define_process_fn(self) -> Callable[[dict[str, Any]], dict[str, list[Any]]]:
def get_transform(self) -> Callable[[dict[str, Any]], dict[str, list[Any]]]:
def process_fn(batch) -> dict[str, list[Any]]:
images = [i.convert("RGB") for i in batch["image"]]
batch["image"] = [
Expand All @@ -141,6 +191,20 @@ def process_fn(batch) -> dict[str, list[Any]]:

return process_fn

def get_collate_fn(self) -> Callable[[Any], Any]:
def collate_fn(examples) -> dict[str, torch.Tensor]:
images = []
labels = []
for example in examples:
images.append(example["image"])
labels.append(example["label"])

pixel_values = torch.stack(images)
labels = torch.tensor(labels)
return {"pixel_values": pixel_values, "labels": labels}

return collate_fn

def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, ClassArray]:
tmp_embeddings = []
tmp_labels = []
Expand All @@ -154,6 +218,3 @@ def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, Class
class_array = torch.concatenate(tmp_labels)
labels = class_array.numpy()
return image_embeddings, labels

def get_transform(self) -> Callable[[Any], Any]:
return self.process_fn
1 change: 1 addition & 0 deletions clip_eval/models/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,6 @@ def list_model_names(self) -> list[str]:
model_provider.register_model("street", closed_CLIPModel, title_in_source="geolocal/StreetCLIP")
model_provider.register_model("apple", open_CLIPModel, title_in_source="apple/DFN5B-CLIP-ViT-H-14")
model_provider.register_model("eva-clip", open_CLIPModel, title_in_source="BAAI/EVA-CLIP-8B-448")
model_provider.register_model("vit-b-32-laion2b", open_CLIPModel, model_name="ViT-B-32", pretrained="laion2b_e16")
model_provider.register_model("siglip_small", SiglipModel, title_in_source="google/siglip-base-patch16-224")
model_provider.register_model("siglip_large", SiglipModel, title_in_source="google/siglip-large-patch16-256")
59 changes: 58 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ inquirerpy = "^0.3.4"
tabulate = "^0.9.0"
sentencepiece = "^0.2.0"
protobuf = "^4.25.3"
open-clip-torch = "^2.24.0"

[tool.poetry.group.dev.dependencies]
mypy = "^1.8.0"
Expand Down

0 comments on commit 50caf37

Please sign in to comment.