diff --git a/clip_eval/common/base.py b/clip_eval/common/base.py index 3bb66e9..4aafa95 100644 --- a/clip_eval/common/base.py +++ b/clip_eval/common/base.py @@ -1,7 +1,6 @@ from pathlib import Path import numpy as np -import torch from pydantic import BaseModel, model_validator from torch.utils.data import DataLoader @@ -78,19 +77,8 @@ def from_embedding_definition(model_name: str, dataset_name: str) -> "Embeddings @staticmethod def build_embedding(model: CLIPModel, dataset: Dataset, batch_size: int = 50) -> "Embeddings": - def _collate_fn(examples) -> dict[str, torch.Tensor]: - images = [] - labels = [] - for example in examples: - images.append(example["image"]) - labels.append(example["label"]) - - pixel_values = torch.stack(images) - labels = torch.tensor(labels) - return {"pixel_values": pixel_values, "labels": labels} - dataset.set_transform(model.get_transform()) - dataloader = DataLoader(dataset, collate_fn=_collate_fn, batch_size=batch_size) + dataloader = DataLoader(dataset, collate_fn=model.get_collate_fn(), batch_size=batch_size) image_embeddings, labels = model.build_embedding(dataloader) embeddings = Embeddings(images=image_embeddings, labels=labels) diff --git a/clip_eval/models/CLIP_model.py b/clip_eval/models/CLIP_model.py index b87ca85..6d14ea8 100644 --- a/clip_eval/models/CLIP_model.py +++ b/clip_eval/models/CLIP_model.py @@ -3,6 +3,7 @@ from typing import Any import numpy as np +import open_clip import torch from torch.utils.data import DataLoader from tqdm import tqdm @@ -27,6 +28,7 @@ def __init__( device = device or ("cuda" if torch.cuda.is_available() else "cpu") self._check_device(device) self.__device = torch.device(device) + self._setup(**kwargs) @property def title(self) -> str: @@ -41,15 +43,19 @@ def device(self) -> torch.device: return self.__device @abstractmethod - def _setup(self, **kwargs): + def _setup(self, **kwargs) -> None: pass @abstractmethod - def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, ClassArray]: + def get_transform(self) -> Callable[[dict[str, Any]], dict[str, list[Any]]]: + ... + + @abstractmethod + def get_collate_fn(self) -> Callable[[Any], Any]: ... @abstractmethod - def get_transform(self) -> Callable[[Any], Any]: + def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, ClassArray]: ... @staticmethod @@ -64,9 +70,8 @@ def _check_device(device: str): class closed_CLIPModel(CLIPModel): def __init__(self, title: str, title_in_source: str, device: str | None = None) -> None: super().__init__(title, title_in_source, device) - self._setup() - def define_process_fn(self) -> Callable[[dict[str, Any]], dict[str, list[Any]]]: + def get_transform(self) -> Callable[[dict[str, Any]], dict[str, list[Any]]]: def process_fn(batch) -> dict[str, list[Any]]: images = [i.convert("RGB") for i in batch["image"]] batch["image"] = [ @@ -76,11 +81,24 @@ def process_fn(batch) -> dict[str, list[Any]]: return process_fn - def _setup(self): + def get_collate_fn(self) -> Callable[[Any], Any]: + def collate_fn(examples) -> dict[str, torch.Tensor]: + images = [] + labels = [] + for example in examples: + images.append(example["image"]) + labels.append(example["label"]) + + pixel_values = torch.stack(images) + labels = torch.tensor(labels) + return {"pixel_values": pixel_values, "labels": labels} + + return collate_fn + + def _setup(self, **kwargs) -> None: self.model = HF_ClipModel.from_pretrained(self.title_in_source).to(self.device) # type: ignore load_result = HF_ClipProcessor.from_pretrained(self.title_in_source) self.processor = load_result[0] if isinstance(load_result, tuple) else load_result - self.process_fn = self.define_process_fn() def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, ClassArray]: tmp_embeddings = [] @@ -96,42 +114,74 @@ def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, Class labels = class_array.numpy() return image_embeddings, labels - def get_transform(self) -> Callable[[Any], Any]: - return self.process_fn - class open_CLIPModel(CLIPModel): def __init__( self, title: str, - title_in_source: str | None = None, + model_name: str, + pretrained: str, device: str | None = None, **kwargs, ) -> None: + self.pretrained = pretrained + self.model_name = model_name + title_in_source = model_name + "_" + pretrained super().__init__(title, title_in_source, device, **kwargs) - self._setup() - def _setup(self, **kwargs): - raise NotImplementedError("open Clip not implemented") + def get_transform(self) -> Callable[[dict[str, Any]], dict[str, list[Any]]]: + def process_fn(batch) -> dict[str, list[Any]]: + images = [i.convert("RGB") for i in batch["image"]] + batch["image"] = [self.processor(i).to(self.device).unsqueeze(0) for i in images] + return batch - def build_embedding(self, dataloader: DataLoader): - raise NotImplementedError("open Clip not implemented") + return process_fn + + def get_collate_fn(self) -> Callable[[Any], Any]: + def collate_fn(examples) -> dict[str, torch.Tensor]: + images = [] + labels = [] + for example in examples: + images.append(example["image"]) + labels.append(example["label"]) + + torch_images = torch.stack(images) + labels = torch.tensor(labels) + return {"image": torch_images, "labels": labels} - def get_transform(self) -> Callable[[Any], Any]: - raise NotImplementedError("open Clip not implemented") + return collate_fn + + def _setup(self, **kwargs) -> None: + model, _, preprocess = open_clip.create_model_and_transforms( + model_name=self.model_name, pretrained=self.pretrained, **kwargs + ) + self.model = model + self.processor = preprocess + + def build_embedding(self, dataloader: DataLoader): + tmp_embeddings = [] + tmp_labels = [] + with torch.inference_mode(): + for batch in tqdm(dataloader, desc=f"Embedding dataset with {self.title}"): + tmp_labels.append(batch["labels"]) + features = torch.stack([self.model.encode_image(image) for image in batch["image"]]) + emb = (features / features.norm(p=2, dim=-1, keepdim=True)).squeeze() + tmp_embeddings.append(emb.to("cpu")) + image_embeddings: EmbeddingArray = np.concatenate(tmp_embeddings, 0) + class_array = torch.concatenate(tmp_labels) + labels = class_array.numpy() + return image_embeddings, labels class SiglipModel(CLIPModel): def __init__(self, title: str, title_in_source: str | None = None, device: str | None = None, **kwargs) -> None: super().__init__(title, title_in_source, device, **kwargs) - self._setup() def _setup(self, **kwargs): self.model = HF_SiglipModel.from_pretrained(self.title_in_source).to(self.device) self.processor = HF_SiglipProcessor.from_pretrained(self.title_in_source) - self.process_fn = self.define_process_fn() - def define_process_fn(self) -> Callable[[dict[str, Any]], dict[str, list[Any]]]: + def get_transform(self) -> Callable[[dict[str, Any]], dict[str, list[Any]]]: def process_fn(batch) -> dict[str, list[Any]]: images = [i.convert("RGB") for i in batch["image"]] batch["image"] = [ @@ -141,6 +191,20 @@ def process_fn(batch) -> dict[str, list[Any]]: return process_fn + def get_collate_fn(self) -> Callable[[Any], Any]: + def collate_fn(examples) -> dict[str, torch.Tensor]: + images = [] + labels = [] + for example in examples: + images.append(example["image"]) + labels.append(example["label"]) + + pixel_values = torch.stack(images) + labels = torch.tensor(labels) + return {"pixel_values": pixel_values, "labels": labels} + + return collate_fn + def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, ClassArray]: tmp_embeddings = [] tmp_labels = [] @@ -154,6 +218,3 @@ def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, Class class_array = torch.concatenate(tmp_labels) labels = class_array.numpy() return image_embeddings, labels - - def get_transform(self) -> Callable[[Any], Any]: - return self.process_fn diff --git a/clip_eval/models/provider.py b/clip_eval/models/provider.py index 1b53e3c..33d1055 100644 --- a/clip_eval/models/provider.py +++ b/clip_eval/models/provider.py @@ -37,5 +37,6 @@ def list_model_names(self) -> list[str]: model_provider.register_model("street", closed_CLIPModel, title_in_source="geolocal/StreetCLIP") model_provider.register_model("apple", open_CLIPModel, title_in_source="apple/DFN5B-CLIP-ViT-H-14") model_provider.register_model("eva-clip", open_CLIPModel, title_in_source="BAAI/EVA-CLIP-8B-448") +model_provider.register_model("vit-b-32-laion2b", open_CLIPModel, model_name="ViT-B-32", pretrained="laion2b_e16") model_provider.register_model("siglip_small", SiglipModel, title_in_source="google/siglip-base-patch16-224") model_provider.register_model("siglip_large", SiglipModel, title_in_source="google/siglip-large-patch16-256") diff --git a/poetry.lock b/poetry.lock index 284f7d3..060c82d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -959,6 +959,20 @@ smb = ["smbprotocol"] ssh = ["paramiko"] tqdm = ["tqdm"] +[[package]] +name = "ftfy" +version = "6.1.3" +description = "Fixes mojibake and other problems with Unicode, after the fact" +optional = false +python-versions = ">=3.8,<4" +files = [ + {file = "ftfy-6.1.3-py3-none-any.whl", hash = "sha256:e49c306c06a97f4986faa7a8740cfe3c13f3106e85bcec73eb629817e671557c"}, + {file = "ftfy-6.1.3.tar.gz", hash = "sha256:693274aead811cff24c1e8784165aa755cd2f6e442a5ec535c7d697f6422a422"}, +] + +[package.dependencies] +wcwidth = ">=0.2.12,<0.3.0" + [[package]] name = "huggingface-hub" version = "0.20.3" @@ -1854,6 +1868,31 @@ files = [ {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"}, ] +[[package]] +name = "open-clip-torch" +version = "2.24.0" +description = "OpenCLIP" +optional = false +python-versions = ">=3.7" +files = [ + {file = "open_clip_torch-2.24.0-py3-none-any.whl", hash = "sha256:2537dbe76c8008caa46652bc97cb32bceeae56baff6289e7b4eb22539a80c801"}, + {file = "open_clip_torch-2.24.0.tar.gz", hash = "sha256:1ae2482aee313827c399eb8a4e735f0b0cd31e4c62085ce2dbfa3a13190219ff"}, +] + +[package.dependencies] +ftfy = "*" +huggingface-hub = "*" +protobuf = "*" +regex = "*" +sentencepiece = "*" +timm = "*" +torch = ">=1.9.0" +torchvision = "*" +tqdm = "*" + +[package.extras] +training = ["braceexpand", "fsspec", "ftfy", "huggingface-hub", "pandas", "regex", "timm (>=0.9.8)", "torch (>=1.9.0)", "torchvision", "tqdm", "transformers", "webdataset (>=0.2.5)"] + [[package]] name = "packaging" version = "23.2" @@ -3039,6 +3078,24 @@ files = [ {file = "threadpoolctl-3.2.0.tar.gz", hash = "sha256:c96a0ba3bdddeaca37dc4cc7344aafad41cdb8c313f74fdfe387a867bba93355"}, ] +[[package]] +name = "timm" +version = "0.9.16" +description = "PyTorch Image Models" +optional = false +python-versions = ">=3.8" +files = [ + {file = "timm-0.9.16-py3-none-any.whl", hash = "sha256:bf5704014476ab011589d3c14172ee4c901fd18f9110a928019cac5be2945914"}, + {file = "timm-0.9.16.tar.gz", hash = "sha256:891e54f375d55adf31a71ab0c117761f0e472f9f3971858ecdd1e7376b7071e6"}, +] + +[package.dependencies] +huggingface_hub = "*" +pyyaml = "*" +safetensors = "*" +torch = "*" +torchvision = "*" + [[package]] name = "tokenizers" version = "0.15.1" @@ -3755,4 +3812,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "d5853877788ce5c368275dda25dad97257495b9827cfc78af5cd4851ec2f3e3f" +content-hash = "67f87a0b699167c3cfa1a6ff4b1cf08c1006f50090a5901cb506357003bf82e0" diff --git a/pyproject.toml b/pyproject.toml index f501cf9..996a476 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ inquirerpy = "^0.3.4" tabulate = "^0.9.0" sentencepiece = "^0.2.0" protobuf = "^4.25.3" +open-clip-torch = "^2.24.0" [tool.poetry.group.dev.dependencies] mypy = "^1.8.0"