diff --git a/clip_eval/cli/main.py b/clip_eval/cli/main.py index 173876f..82f3bbd 100644 --- a/clip_eval/cli/main.py +++ b/clip_eval/cli/main.py @@ -130,12 +130,12 @@ def list_models_datasets( Option(help="List all models and dataset that are available via the tool."), ] = False, ): - from clip_eval.dataset.provider import dataset_provider - from clip_eval.models import model_provider + from clip_eval.dataset import DatasetProvider + from clip_eval.models import ModelProvider if all: - datasets = dataset_provider.list_dataset_titles() - models = model_provider.list_model_titles() + datasets = DatasetProvider.list_dataset_titles() + models = ModelProvider.list_model_titles() print(f"Available datasets are: {', '.join(datasets)}") print(f"Available models are: {', '.join(models)}") return diff --git a/clip_eval/cli/utils.py b/clip_eval/cli/utils.py index 5c011a5..8d5adc4 100644 --- a/clip_eval/cli/utils.py +++ b/clip_eval/cli/utils.py @@ -6,8 +6,8 @@ from natsort import natsorted, ns from clip_eval.common.data_models import EmbeddingDefinition -from clip_eval.dataset.provider import dataset_provider -from clip_eval.models.provider import model_provider +from clip_eval.dataset import DatasetProvider +from clip_eval.models import ModelProvider from clip_eval.utils import read_all_cached_embeddings @@ -80,8 +80,8 @@ def select_from_all_embedding_definitions( ) -> list[EmbeddingDefinition]: existing = set(read_all_cached_embeddings(as_list=True)) - models = model_provider.list_model_titles() - datasets = dataset_provider.list_dataset_titles() + models = ModelProvider.list_model_titles() + datasets = DatasetProvider.list_dataset_titles() defs = [EmbeddingDefinition(dataset=d, model=m) for d, m in product(datasets, models)] if not include_existing: diff --git a/clip_eval/common/base.py b/clip_eval/common/base.py index 4d6b3da..ff1d528 100644 --- a/clip_eval/common/base.py +++ b/clip_eval/common/base.py @@ -6,7 +6,7 @@ from clip_eval.constants import NPZ_KEYS from clip_eval.dataset import Dataset -from clip_eval.models import CLIPModel +from clip_eval.models import Model from .numpy_types import ClassArray, EmbeddingArray @@ -69,7 +69,7 @@ def to_file(self, path: Path) -> Path: return path @staticmethod - def build_embedding(model: CLIPModel, dataset: Dataset, batch_size: int = 50) -> "Embeddings": + def build_embedding(model: Model, dataset: Dataset, batch_size: int = 50) -> "Embeddings": dataset.set_transform(model.get_transform()) dataloader = DataLoader(dataset, collate_fn=model.get_collate_fn(), batch_size=batch_size) diff --git a/clip_eval/common/data_models.py b/clip_eval/common/data_models.py index d4150a0..653f490 100644 --- a/clip_eval/common/data_models.py +++ b/clip_eval/common/data_models.py @@ -7,8 +7,8 @@ from pydantic.functional_validators import AfterValidator from clip_eval.constants import PROJECT_PATHS -from clip_eval.dataset import Split, dataset_provider -from clip_eval.models import model_provider +from clip_eval.dataset import DatasetProvider, Split +from clip_eval.models import ModelProvider from .base import Embeddings from .string_utils import safe_str @@ -61,8 +61,8 @@ def save_embeddings(self, embeddings: Embeddings, split: Split, overwrite: bool return True def build_embeddings(self, split: Split) -> Embeddings: - model = model_provider.get_model(self.model) - dataset = dataset_provider.get_dataset(self.dataset, split) + model = ModelProvider.get_model(self.model) + dataset = DatasetProvider.get_dataset(self.dataset, split) return Embeddings.build_embedding(model, dataset) def __str__(self): diff --git a/clip_eval/constants.py b/clip_eval/constants.py index 183c2f8..5f85fd2 100644 --- a/clip_eval/constants.py +++ b/clip_eval/constants.py @@ -6,8 +6,10 @@ load_dotenv() # If the cache directory is not explicitly specified, use the `.cache` directory located in the project's root. -CACHE_PATH = Path(os.environ.get("CLIP_CACHE_PATH", Path(__file__).parent.parent / ".cache")) -_OUTPUT_PATH = Path(os.environ.get("OUTPUT_PATH", "output")) +_CLIP_EVAL_ROOT_DIR = Path(__file__).parent.parent +CACHE_PATH = Path(os.environ.get("CLIP_EVAL_CACHE_PATH", _CLIP_EVAL_ROOT_DIR / ".cache")) +_OUTPUT_PATH = Path(os.environ.get("CLIP_EVAL_OUTPUT_PATH", _CLIP_EVAL_ROOT_DIR / "output")) +_SOURCES_PATH = _CLIP_EVAL_ROOT_DIR / "sources" class PROJECT_PATHS: @@ -25,3 +27,10 @@ class NPZ_KEYS: class OUTPUT_PATH: ANIMATIONS = _OUTPUT_PATH / "animations" EVALUATIONS = _OUTPUT_PATH / "evaluations" + + +class SOURCES_PATH: + DATASET_TYPES = _SOURCES_PATH / "datasets" / "types" + DATASET_INSTANCE_DEFINITIONS = _SOURCES_PATH / "datasets" / "definitions" + MODEL_TYPES = _SOURCES_PATH / "models" / "types" + MODEL_INSTANCE_DEFINITIONS = _SOURCES_PATH / "models" / "definitions" diff --git a/clip_eval/dataset/__init__.py b/clip_eval/dataset/__init__.py index 29cbb56..3ff74e6 100644 --- a/clip_eval/dataset/__init__.py +++ b/clip_eval/dataset/__init__.py @@ -1,2 +1,2 @@ from .base import Dataset, Split -from .provider import dataset_provider +from .provider import DatasetProvider diff --git a/clip_eval/dataset/base.py b/clip_eval/dataset/base.py index 62af44e..57822b1 100644 --- a/clip_eval/dataset/base.py +++ b/clip_eval/dataset/base.py @@ -2,6 +2,7 @@ from enum import StrEnum, auto from pathlib import Path +from pydantic import BaseModel, ConfigDict from torch.utils.data import Dataset as TorchDataset from clip_eval.constants import CACHE_PATH @@ -14,6 +15,18 @@ class Split(StrEnum): ALL = auto() +class DatasetDefinitionSpec(BaseModel): + dataset_type: str + module_path: Path + title: str + split: Split = Split.ALL + title_in_source: str | None = None + cache_dir: Path | None = None + + # Allow additional dataset configuration fields + model_config = ConfigDict(extra="allow") + + class Dataset(TorchDataset, ABC): def __init__( self, diff --git a/clip_eval/dataset/encord.py b/clip_eval/dataset/encord.py deleted file mode 100644 index b36aa97..0000000 --- a/clip_eval/dataset/encord.py +++ /dev/null @@ -1,159 +0,0 @@ -import json -import os -from dataclasses import dataclass -from pathlib import Path -from typing import Any - -from encord import EncordUserClient -from encord.objects import Classification, LabelRowV2 -from encord.objects.common import PropertyType -from PIL import Image - -from .base import Dataset, Split -from .encord_utils import ( - download_data_from_project, - get_frame_file, - get_label_row_annotations_file, - get_label_rows_info_file, - simple_project_split, -) - - -class EncordDataset(Dataset): - def __init__( - self, - title: str, - project_hash: str, - classification_hash: str, - *, - split: Split = Split.ALL, - title_in_source: str | None = None, - transform=None, - cache_dir: str | None = None, - ssh_key_path: str | None = None, - **kwargs, - ): - super().__init__(title, split=split, title_in_source=title_in_source, transform=transform, cache_dir=cache_dir) - self._setup(project_hash, classification_hash, ssh_key_path, **kwargs) - - def __getitem__(self, idx): - frame_path = self._dataset_indices_info[idx].image_file - img = Image.open(frame_path) - label = self._dataset_indices_info[idx].label - - if self.transform is not None: - _d = self.transform(dict(image=[img], label=[label])) - res_item = dict(image=_d["image"][0], label=_d["label"][0]) - else: - res_item = dict(image=img, label=label) - return res_item - - def __len__(self): - return len(self._dataset_indices_info) - - def _get_frame_file(self, label_row: LabelRowV2, frame: int) -> Path: - return get_frame_file( - data_dir=self._cache_dir, - project_hash=self._project.project_hash, - label_row=label_row, - frame=frame, - ) - - def _get_label_row_annotations_file(self, label_row: LabelRowV2) -> Path: - return get_label_row_annotations_file( - data_dir=self._cache_dir, - project_hash=self._project.project_hash, - label_row_hash=label_row.label_hash, - ) - - def _ensure_answers_availability(self) -> dict: - lrs_info_file = get_label_rows_info_file(self._cache_dir, self._project.project_hash) - label_rows_info: dict = json.loads(lrs_info_file.read_text(encoding="utf-8")) - should_update_info = False - class_name_to_idx = {name: idx for idx, name in enumerate(self.class_names)} # Fast lookup of class indices - for label_row in self._label_rows: - if "answers" not in label_rows_info[label_row.label_hash]: - if not label_row.is_labelling_initialised: - # Retrieve label row content from local storage - anns_path = self._get_label_row_annotations_file(label_row) - label_row.from_labels_dict(json.loads(anns_path.read_text(encoding="utf-8"))) - - answers = dict() - for frame_view in label_row.get_frame_views(): - clf_instances = frame_view.get_classification_instances(self._classification) - # Skip frames where the input classification is missing - if len(clf_instances) == 0: - continue - - clf_instance = clf_instances[0] - clf_answer = clf_instance.get_answer(self._attribute) - # Skip frames where the input classification has no answer (probable annotation error) - if clf_answer is None: - continue - - answers[frame_view.frame] = { - "image_file": self._get_frame_file(label_row, frame_view.frame).as_posix(), - "label": class_name_to_idx[clf_answer.title], - } - label_rows_info[label_row.label_hash]["answers"] = answers - should_update_info = True - if should_update_info: - lrs_info_file.write_text(json.dumps(label_rows_info), encoding="utf-8") - return label_rows_info - - def _setup( - self, - project_hash: str, - classification_hash: str, - ssh_key_path: str | None = None, - **kwargs, - ): - ssh_key_path = ssh_key_path or os.getenv("ENCORD_SSH_KEY_PATH") - if ssh_key_path is None: - raise ValueError( - "The `ssh_key_path` parameter and the `ENCORD_SSH_KEY_PATH` environment variable are both missing." - "Please set one of them to proceed" - ) - client = EncordUserClient.create_with_ssh_private_key(ssh_private_key_path=ssh_key_path) - self._project = client.get_project(project_hash) - - self._classification = self._project.ontology_structure.get_child_by_hash( - classification_hash, type_=Classification - ) - radio_attribute = self._classification.attributes[0] - if radio_attribute.get_property_type() != PropertyType.RADIO: - raise ValueError("Expected a classification hash with an attribute of type `Radio`") - self._attribute = radio_attribute - self.class_names = [o.title for o in self._attribute.options] - - # Fetch the label rows of the selected split - splits_file = self._cache_dir / "splits.json" - split_to_lr_hashes: dict[str, list[str]] - if splits_file.exists(): - split_to_lr_hashes = json.loads(splits_file.read_text(encoding="utf-8")) - else: - split_to_lr_hashes = simple_project_split(self._project) - splits_file.write_text(json.dumps(split_to_lr_hashes), encoding="utf-8") - self._label_rows = self._project.list_label_rows_v2(label_hashes=split_to_lr_hashes[self.split]) - - # Get data from source. Users may supply the `overwrite_annotations` keyword in the init to download everything - download_data_from_project( - self._project, - self._cache_dir, - self._label_rows, - tqdm_desc=f"Downloading {self.split} data from Encord project `{self._project.title}`", - **kwargs, - ) - - # Prepare data for the __getitem__ method - self._dataset_indices_info: list[EncordDataset.DatasetIndexInfo] = [] - label_rows_info = self._ensure_answers_availability() - for label_row in self._label_rows: - answers: dict[int, Any] = label_rows_info[label_row.label_hash]["answers"] - for frame_num in sorted(answers.keys()): - self._dataset_indices_info.append(EncordDataset.DatasetIndexInfo(**answers[frame_num])) - - @dataclass - class DatasetIndexInfo: - image_file: Path | str - label: int diff --git a/clip_eval/dataset/provider.py b/clip_eval/dataset/provider.py index 1979320..c744251 100644 --- a/clip_eval/dataset/provider.py +++ b/clip_eval/dataset/provider.py @@ -1,103 +1,99 @@ from copy import deepcopy +from pathlib import Path from typing import Any from natsort import natsorted, ns -from clip_eval.constants import CACHE_PATH +from clip_eval.constants import CACHE_PATH, SOURCES_PATH -from .base import Dataset, Split -from .encord import EncordDataset -from .hf import HFDataset +from .base import Dataset, DatasetDefinitionSpec, Split +from .utils import load_class_from_path class DatasetProvider: + __instance = None + __global_settings: dict[str, Any] = dict() + __known_dataset_types: dict[tuple[Path, str], Any] = dict() + def __init__(self): self._datasets = {} - self._global_settings: dict[str, Any] = dict() - - @property - def global_settings(self) -> dict: - return deepcopy(self._global_settings) - - def add_global_setting(self, name: str, value: Any) -> None: - self._global_settings[name] = value - - def remove_global_setting(self, name: str): - self._global_settings.pop(name, None) - def register_dataset(self, source: type[Dataset], title: str, split: Split | None = None, **kwargs): + @classmethod + def prepare(cls): + if cls.__instance is None: + cls.__instance = cls() + cls.register_datasets_from_sources_dir(SOURCES_PATH.DATASET_INSTANCE_DEFINITIONS) + # Global settings + cls.__instance.add_global_setting("cache_dir", CACHE_PATH) + return cls.__instance + + @classmethod + def global_settings(cls) -> dict: + return deepcopy(cls.__global_settings) + + @classmethod + def add_global_setting(cls, name: str, value: Any) -> None: + cls.__global_settings[name] = value + + @classmethod + def remove_global_setting(cls, name: str) -> None: + cls.__global_settings.pop(name, None) + + @classmethod + def register_dataset(cls, source: type[Dataset], title: str, split: Split | None = None, **kwargs) -> None: + instance = cls.prepare() if split is None: # One dataset with all the split definitions - self._datasets[title] = (source, kwargs) + instance._datasets[title] = (source, kwargs) else: # One dataset is defined per split kwargs.update(split=split) - self._datasets[(title, split)] = (source, kwargs) - - def get_dataset(self, title: str, split: Split) -> Dataset: - if (title, split) in self._datasets: + instance._datasets[(title, split)] = (source, kwargs) + + @classmethod + def register_dataset_from_json_definition(cls, json_definition: Path) -> None: + spec = DatasetDefinitionSpec.model_validate_json(json_definition.read_text(encoding="utf-8")) + if not spec.module_path.is_absolute(): # Handle relative module paths + spec.module_path = (json_definition.parent / spec.module_path).resolve() + + # Fetch the class of the dataset type stated in the definition + dataset_type = cls.__known_dataset_types.get((spec.module_path, spec.dataset_type)) + if dataset_type is None: + dataset_type = load_class_from_path(spec.module_path.as_posix(), spec.dataset_type) + if not issubclass(dataset_type, Dataset): + raise ValueError( + f"Dataset type specified in the JSON definition file `{json_definition.as_posix()}` " + f"does not inherit from the base class `Dataset`" + ) + cls.__known_dataset_types[(spec.module_path, spec.dataset_type)] = dataset_type + cls.register_dataset(dataset_type, **spec.model_dump(exclude={"module_path", "dataset_type"})) + + @classmethod + def register_datasets_from_sources_dir(cls, source_dir: Path) -> None: + for f in source_dir.glob("*.json"): + cls.register_dataset_from_json_definition(f) + + @classmethod + def get_dataset(cls, title: str, split: Split) -> Dataset: + instance = cls.prepare() + if (title, split) in instance._datasets: # The split corresponds to fetching a whole dataset (one-to-one relationship) dict_key = (title, split) split = Split.ALL # Ensure to read the whole dataset - elif title in self._datasets: + elif title in instance._datasets: # The dataset internally knows how to determine the split dict_key = title else: raise ValueError(f"Unrecognized dataset: {title}") - source, kwargs = self._datasets[dict_key] + source, kwargs = instance._datasets[dict_key] # Apply global settings. Values of local settings take priority when local and global settings share keys. - kwargs_with_global_settings = self.global_settings | kwargs + kwargs_with_global_settings = cls.global_settings() | kwargs return source(title, split=split, **kwargs_with_global_settings) - def list_dataset_titles(self) -> list[str]: + @classmethod + def list_dataset_titles(cls) -> list[str]: dataset_titles = [ - dict_key[0] if isinstance(dict_key, tuple) else dict_key for dict_key in self._datasets.keys() + dict_key[0] if isinstance(dict_key, tuple) else dict_key for dict_key in cls.prepare()._datasets.keys() ] return natsorted(set(dataset_titles), alg=ns.IGNORECASE) - - -dataset_provider = DatasetProvider() -# Global settings -dataset_provider.add_global_setting("cache_dir", CACHE_PATH) - -# Hugging Face datasets -dataset_provider.register_dataset(HFDataset, "plants", title_in_source="sampath017/plants") -dataset_provider.register_dataset(HFDataset, "Alzheimer-MRI", title_in_source="Falah/Alzheimer_MRI") -dataset_provider.register_dataset(HFDataset, "skin-cancer", title_in_source="marmal88/skin_cancer", target_feature="dx") -dataset_provider.register_dataset(HFDataset, "geo-landmarks", title_in_source="Qdrant/google-landmark-geo") -dataset_provider.register_dataset( - HFDataset, - "LungCancer4Types", - title_in_source="Kabil007/LungCancer4Types", - revision="a1aab924c6bed6b080fc85552fd7b39724931605", -) -dataset_provider.register_dataset( - HFDataset, - "NIH-Chest-X-ray", - title_in_source="alkzar90/NIH-Chest-X-ray-dataset", - name="image-classification", - target_feature="labels", - trust_remote_code=True, -) -dataset_provider.register_dataset( - HFDataset, - "chest-xray-classification", - title_in_source="trpakov/chest-xray-classification", - name="full", - target_feature="labels", -) - -dataset_provider.register_dataset( - HFDataset, - "sports-classification", - title_in_source="HES-XPLAIN/SportsImageClassification", -) - -# Encord datasets -dataset_provider.register_dataset( - EncordDataset, - "rsicd", - project_hash="46ba913e-1428-48ef-be7f-2553e69bc1e6", - classification_hash="4f6cf0c8", -) diff --git a/clip_eval/dataset/utils.py b/clip_eval/dataset/utils.py index 2109bf2..c24fdd4 100644 --- a/clip_eval/dataset/utils.py +++ b/clip_eval/dataset/utils.py @@ -1,3 +1,4 @@ +import importlib.util import os from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor as Executor @@ -68,6 +69,13 @@ def download_file( f.flush() +def load_class_from_path(module_path: str, class_name: str): + spec = importlib.util.spec_from_file_location(module_path, module_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return getattr(module, class_name) + + def simple_random_split( dataset_size: int, seed: int = 42, diff --git a/clip_eval/models/CLIP_model.py b/clip_eval/models/CLIP_model.py deleted file mode 100644 index 45622ed..0000000 --- a/clip_eval/models/CLIP_model.py +++ /dev/null @@ -1,198 +0,0 @@ -from abc import ABC, abstractmethod -from collections.abc import Callable -from pathlib import Path -from typing import Any - -import numpy as np -import open_clip -import torch -from torch.utils.data import DataLoader -from tqdm import tqdm -from transformers import AutoModel as HF_AutoModel -from transformers import AutoProcessor as HF_AutoProcessor -from transformers import AutoTokenizer as HF_AutoTokenizer - -from clip_eval.common.numpy_types import ClassArray, EmbeddingArray -from clip_eval.constants import CACHE_PATH -from clip_eval.dataset import Dataset - - -class CLIPModel(ABC): - def __init__( - self, - title: str, - device: str | None = None, - *, - title_in_source: str | None = None, - cache_dir: str | None = None, - **kwargs, - ) -> None: - self.__title = title - self.__title_in_source = title_in_source or title - device = device or ("cuda" if torch.cuda.is_available() else "cpu") - self._check_device(device) - self.__device = torch.device(device) - if cache_dir is None: - cache_dir = CACHE_PATH - self._cache_dir = Path(cache_dir).expanduser().resolve() / "models" / title - - @property - def title(self) -> str: - return self.__title - - @property - def title_in_source(self) -> str: - return self.__title_in_source - - @property - def device(self) -> torch.device: - return self.__device - - @abstractmethod - def _setup(self, **kwargs) -> None: - pass - - @abstractmethod - def get_transform(self) -> Callable[[dict[str, Any]], dict[str, list[Any]]]: - ... - - @abstractmethod - def get_collate_fn(self) -> Callable[[Any], Any]: - ... - - @abstractmethod - def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, EmbeddingArray, ClassArray]: - ... - - @staticmethod - def _check_device(device: str): - # Check if the input device exists and is available - if device not in {"cuda", "cpu"}: - raise ValueError(f"Unrecognized device: {device}") - if not getattr(torch, device).is_available(): - raise ValueError(f"Unavailable device: {device}") - - -class ClosedCLIPModel(CLIPModel): - def __init__( - self, - title: str, - device: str | None = None, - *, - title_in_source: str | None = None, - cache_dir: str | None = None, - **kwargs, - ) -> None: - super().__init__(title, device, title_in_source=title_in_source, cache_dir=cache_dir) - self._setup(**kwargs) - - def get_transform(self) -> Callable[[dict[str, Any]], dict[str, list[Any]]]: - def process_fn(batch) -> dict[str, list[Any]]: - images = [i.convert("RGB") for i in batch["image"]] - batch["image"] = [ - self.processor(images=[i], return_tensors="pt").to(self.device).pixel_values.squeeze() for i in images - ] - return batch - - return process_fn - - def get_collate_fn(self) -> Callable[[Any], Any]: - def collate_fn(examples) -> dict[str, torch.Tensor]: - images = [] - labels = [] - for example in examples: - images.append(example["image"]) - labels.append(example["label"]) - - pixel_values = torch.stack(images) - labels = torch.tensor(labels) - return {"pixel_values": pixel_values, "labels": labels} - - return collate_fn - - def _setup(self, **kwargs) -> None: - self.model = HF_AutoModel.from_pretrained(self.title_in_source, cache_dir=self._cache_dir).to(self.device) - load_result = HF_AutoProcessor.from_pretrained(self.title_in_source, cache_dir=self._cache_dir) - self.processor = load_result[0] if isinstance(load_result, tuple) else load_result - self.tokenizer = HF_AutoTokenizer.from_pretrained(self.title_in_source, cache_dir=self._cache_dir) - - def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, EmbeddingArray, ClassArray]: - all_image_embeddings = [] - all_labels = [] - with torch.inference_mode(): - _dataset: Dataset = dataloader.dataset - inputs = self.tokenizer(_dataset.text_queries, padding=True, return_tensors="pt").to(self.device) - class_features = self.model.get_text_features(**inputs) - normalized_class_features = class_features / class_features.norm(p=2, dim=-1, keepdim=True) - class_embeddings = normalized_class_features.numpy(force=True) - for batch in tqdm(dataloader, desc=f"Embedding dataset with {self.title}"): - image_features = self.model.get_image_features(pixel_values=batch["pixel_values"].to(self.device)) - normalized_image_features = (image_features / image_features.norm(p=2, dim=-1, keepdim=True)).squeeze() - all_image_embeddings.append(normalized_image_features) - all_labels.append(batch["labels"]) - image_embeddings = torch.concatenate(all_image_embeddings).numpy(force=True) - labels = torch.concatenate(all_labels).numpy(force=True).astype(np.int32) - return image_embeddings, class_embeddings, labels - - -class OpenCLIPModel(CLIPModel): - def __init__( - self, - title: str, - device: str | None = None, - *, - title_in_source: str, - pretrained: str | None = None, - cache_dir: str | None = None, - **kwargs, - ) -> None: - self.pretrained = pretrained - super().__init__(title, device, title_in_source=title_in_source, cache_dir=cache_dir, **kwargs) - self._setup(**kwargs) - - def get_transform(self) -> Callable[[dict[str, Any]], dict[str, list[Any]]]: - def process_fn(batch) -> dict[str, list[Any]]: - images = [i.convert("RGB") for i in batch["image"]] - batch["image"] = [self.processor(i) for i in images] - return batch - - return process_fn - - def get_collate_fn(self) -> Callable[[Any], Any]: - def collate_fn(examples) -> dict[str, torch.Tensor]: - images = [] - labels = [] - for example in examples: - images.append(example["image"]) - labels.append(example["label"]) - - torch_images = torch.stack(images) - labels = torch.tensor(labels) - return {"image": torch_images, "labels": labels} - - return collate_fn - - def _setup(self, **kwargs) -> None: - self.model, _, self.processor = open_clip.create_model_and_transforms( - model_name=self.title_in_source, - pretrained=self.pretrained, - cache_dir=self._cache_dir.as_posix(), - device=self.device, - **kwargs, - ) - self.tokenizer = open_clip.get_tokenizer(model_name=self.title_in_source) - - def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, EmbeddingArray, ClassArray]: - all_image_embeddings = [] - all_labels = [] - with torch.inference_mode(): - _dataset: Dataset = dataloader.dataset - text = self.tokenizer(_dataset.text_queries).to(self.device) - class_embeddings = self.model.encode_text(text, normalize=True).numpy(force=True) - for batch in tqdm(dataloader, desc=f"Embedding dataset with {self.title}"): - image_features = self.model.encode_image(batch["image"].to(self.device), normalize=True) - all_image_embeddings.append(image_features) - all_labels.append(batch["labels"]) - image_embeddings = torch.concatenate(all_image_embeddings).numpy(force=True) - labels = torch.concatenate(all_labels).numpy(force=True).astype(np.int32) - return image_embeddings, class_embeddings, labels diff --git a/clip_eval/models/__init__.py b/clip_eval/models/__init__.py index 6970a20..7a1abaa 100644 --- a/clip_eval/models/__init__.py +++ b/clip_eval/models/__init__.py @@ -1,2 +1,2 @@ -from .CLIP_model import CLIPModel -from .provider import model_provider +from .base import Model +from .provider import ModelProvider diff --git a/clip_eval/models/base.py b/clip_eval/models/base.py new file mode 100644 index 0000000..909f353 --- /dev/null +++ b/clip_eval/models/base.py @@ -0,0 +1,80 @@ +from abc import ABC, abstractmethod +from collections.abc import Callable +from pathlib import Path +from typing import Any + +import torch +from pydantic import BaseModel, ConfigDict +from torch.utils.data import DataLoader + +from clip_eval.common.numpy_types import ClassArray, EmbeddingArray +from clip_eval.constants import CACHE_PATH + + +class ModelDefinitionSpec(BaseModel): + model_type: str + module_path: Path + title: str + device: str | None = None + title_in_source: str | None = None + cache_dir: Path | None = None + + # Allow additional model configuration fields + # Also, silence spurious Pydantic UserWarning: Field "model_type" has conflict with protected namespace "model_" + model_config = ConfigDict(protected_namespaces=(), extra="allow") + + +class Model(ABC): + def __init__( + self, + title: str, + device: str | None = None, + *, + title_in_source: str | None = None, + cache_dir: str | None = None, + **kwargs, + ) -> None: + self.__title = title + self.__title_in_source = title_in_source or title + device = device or ("cuda" if torch.cuda.is_available() else "cpu") + self._check_device(device) + self.__device = torch.device(device) + if cache_dir is None: + cache_dir = CACHE_PATH + self._cache_dir = Path(cache_dir).expanduser().resolve() / "models" / title + + @property + def title(self) -> str: + return self.__title + + @property + def title_in_source(self) -> str: + return self.__title_in_source + + @property + def device(self) -> torch.device: + return self.__device + + @abstractmethod + def _setup(self, **kwargs) -> None: + pass + + @abstractmethod + def get_transform(self) -> Callable[[dict[str, Any]], dict[str, list[Any]]]: + ... + + @abstractmethod + def get_collate_fn(self) -> Callable[[Any], Any]: + ... + + @abstractmethod + def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, EmbeddingArray, ClassArray]: + ... + + @staticmethod + def _check_device(device: str): + # Check if the input device exists and is available + if device not in {"cuda", "cpu"}: + raise ValueError(f"Unrecognized device: {device}") + if not getattr(torch, device).is_available(): + raise ValueError(f"Unavailable device: {device}") diff --git a/clip_eval/models/provider.py b/clip_eval/models/provider.py index c88ad34..a95a7b2 100644 --- a/clip_eval/models/provider.py +++ b/clip_eval/models/provider.py @@ -1,49 +1,63 @@ +from pathlib import Path +from typing import Any + from natsort import natsorted, ns -from .CLIP_model import CLIPModel, ClosedCLIPModel, OpenCLIPModel -from .local import LocalCLIPModel +from clip_eval.constants import SOURCES_PATH +from clip_eval.dataset.utils import load_class_from_path + +from .base import Model, ModelDefinitionSpec class ModelProvider: + __instance = None + __known_model_types: dict[tuple[Path, str], Any] = dict() + def __init__(self) -> None: self._models = {} - def register_model(self, title: str, source: type[CLIPModel], **kwargs): - self._models[title] = (source, kwargs) - - def get_model(self, title: str) -> CLIPModel: - if title not in self._models: + @classmethod + def prepare(cls): + if cls.__instance is None: + cls.__instance = cls() + cls.register_models_from_sources_dir(SOURCES_PATH.MODEL_INSTANCE_DEFINITIONS) + return cls.__instance + + @classmethod + def register_model(cls, source: type[Model], title: str, **kwargs): + cls.prepare()._models[title] = (source, kwargs) + + @classmethod + def register_model_from_json_definition(cls, json_definition: Path) -> None: + spec = ModelDefinitionSpec.model_validate_json(json_definition.read_text(encoding="utf-8")) + if not spec.module_path.is_absolute(): # Handle relative module paths + spec.module_path = (json_definition.parent / spec.module_path).resolve() + + # Fetch the class of the model type stated in the definition + model_type = cls.__known_model_types.get((spec.module_path, spec.model_type)) + if model_type is None: + model_type = load_class_from_path(spec.module_path.as_posix(), spec.model_type) + if not issubclass(model_type, Model): + raise ValueError( + f"Model type specified in the JSON definition file `{json_definition.as_posix()}` " + f"does not inherit from the base class `Model`" + ) + cls.__known_model_types[(spec.module_path, spec.model_type)] = model_type + cls.register_model(model_type, **spec.model_dump(exclude={"module_path", "model_type"})) + + @classmethod + def register_models_from_sources_dir(cls, source_dir: Path) -> None: + for f in source_dir.glob("*.json"): + cls.register_model_from_json_definition(f) + + @classmethod + def get_model(cls, title: str) -> Model: + instance = cls.prepare() + if title not in instance._models: raise ValueError(f"Unrecognized model: {title}") - source, kwargs = self._models[title] + source, kwargs = instance._models[title] return source(title, **kwargs) - def list_model_titles(self) -> list[str]: - return natsorted(self._models.keys(), alg=ns.IGNORECASE) - - -model_provider = ModelProvider() -model_provider.register_model("clip", ClosedCLIPModel, title_in_source="openai/clip-vit-large-patch14-336") -model_provider.register_model("plip", ClosedCLIPModel, title_in_source="vinid/plip") -model_provider.register_model( - "pubmed", - ClosedCLIPModel, - title_in_source="flaviagiammarino/pubmed-clip-vit-base-patch32", -) -model_provider.register_model( - "tinyclip", - ClosedCLIPModel, - title_in_source="wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M", -) -model_provider.register_model("fashion", ClosedCLIPModel, title_in_source="patrickjohncyh/fashion-clip") -model_provider.register_model("rsicd", ClosedCLIPModel, title_in_source="flax-community/clip-rsicd") -model_provider.register_model("street", ClosedCLIPModel, title_in_source="geolocal/StreetCLIP") -model_provider.register_model("siglip_small", ClosedCLIPModel, title_in_source="google/siglip-base-patch16-224") -model_provider.register_model("siglip_large", ClosedCLIPModel, title_in_source="google/siglip-large-patch16-256") - -model_provider.register_model("apple", OpenCLIPModel, title_in_source="hf-hub:apple/DFN5B-CLIP-ViT-H-14") -model_provider.register_model("eva-clip", OpenCLIPModel, title_in_source="BAAI/EVA-CLIP-8B-448") -model_provider.register_model("bioclip", OpenCLIPModel, title_in_source="hf-hub:imageomics/bioclip") -model_provider.register_model("vit-b-32-laion2b", OpenCLIPModel, title_in_source="ViT-B-32", pretrained="laion2b_e16") - -# Local sources -model_provider.register_model("rsicd-encord", LocalCLIPModel, title_in_source="ViT-B-32") + @classmethod + def list_model_titles(cls) -> list[str]: + return natsorted(cls.prepare()._models.keys(), alg=ns.IGNORECASE) diff --git a/clip_eval/plotting/animation.py b/clip_eval/plotting/animation.py index 038b494..28e0838 100644 --- a/clip_eval/plotting/animation.py +++ b/clip_eval/plotting/animation.py @@ -12,7 +12,7 @@ from clip_eval.common.data_models import SafeName from clip_eval.common.numpy_types import ClassArray, N2Array from clip_eval.constants import OUTPUT_PATH -from clip_eval.dataset import Split, dataset_provider +from clip_eval.dataset import DatasetProvider, Split from .reduction import REDUCTIONS, reduction_from_string @@ -212,7 +212,7 @@ def build_animation( reduction: REDUCTIONS = "umap", interactive: bool = False, ) -> animation.FuncAnimation | None: - dataset = dataset_provider.get_dataset(defn_1.dataset, split) + dataset = DatasetProvider.get_dataset(defn_1.dataset, split) embeds = defn_1.load_embeddings(split) # FIXME: This is expensive to get just labels if embeds is None: diff --git a/sources/datasets/dataset-definition-schema.json b/sources/datasets/dataset-definition-schema.json new file mode 100644 index 0000000..faa98c2 --- /dev/null +++ b/sources/datasets/dataset-definition-schema.json @@ -0,0 +1,70 @@ +{ + "$defs": { + "Split": { + "enum": [ + "train", + "validation", + "test", + "all" + ], + "title": "Split", + "type": "string" + } + }, + "additionalProperties": true, + "properties": { + "dataset_type": { + "title": "Dataset Type", + "type": "string" + }, + "module_path": { + "format": "path", + "title": "Module Path", + "type": "string" + }, + "title": { + "title": "Title", + "type": "string" + }, + "split": { + "allOf": [ + { + "$ref": "#/$defs/Split" + } + ], + "default": "all" + }, + "title_in_source": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Title In Source" + }, + "cache_dir": { + "anyOf": [ + { + "format": "path", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Cache Dir" + } + }, + "required": [ + "dataset_type", + "module_path", + "title" + ], + "title": "DatasetDefinitionSpec", + "type": "object" +} diff --git a/sources/datasets/definitions/Alzheimer-MRI.json b/sources/datasets/definitions/Alzheimer-MRI.json new file mode 100644 index 0000000..63999a5 --- /dev/null +++ b/sources/datasets/definitions/Alzheimer-MRI.json @@ -0,0 +1,6 @@ +{ + "dataset_type": "HFDataset", + "module_path": "../types/hugging_face.py", + "title": "Alzheimer-MRI", + "title_in_source": "Falah/Alzheimer_MRI" +} diff --git a/sources/datasets/definitions/LungCancer4Types.json b/sources/datasets/definitions/LungCancer4Types.json new file mode 100644 index 0000000..e71eb5d --- /dev/null +++ b/sources/datasets/definitions/LungCancer4Types.json @@ -0,0 +1,7 @@ +{ + "dataset_type": "HFDataset", + "module_path": "../types/hugging_face.py", + "title": "LungCancer4Types", + "title_in_source": "Kabil007/LungCancer4Types", + "revision": "a1aab924c6bed6b080fc85552fd7b39724931605" +} diff --git a/sources/datasets/definitions/NIH-Chest-X-ray.json b/sources/datasets/definitions/NIH-Chest-X-ray.json new file mode 100644 index 0000000..c82cea2 --- /dev/null +++ b/sources/datasets/definitions/NIH-Chest-X-ray.json @@ -0,0 +1,9 @@ +{ + "dataset_type": "HFDataset", + "module_path": "../types/hugging_face.py", + "title": "NIH-Chest-X-ray", + "title_in_source": "alkzar90/NIH-Chest-X-ray-dataset", + "name": "image-classification", + "target_feature": "labels", + "trust_remote_code": true +} diff --git a/sources/datasets/definitions/chest-xray-classification.json b/sources/datasets/definitions/chest-xray-classification.json new file mode 100644 index 0000000..1ca4a10 --- /dev/null +++ b/sources/datasets/definitions/chest-xray-classification.json @@ -0,0 +1,8 @@ +{ + "dataset_type": "HFDataset", + "module_path": "../types/hugging_face.py", + "title": "chest-xray-classification", + "title_in_source": "trpakov/chest-xray-classification", + "name": "full", + "target_feature": "labels" +} diff --git a/sources/datasets/definitions/geo-landmarks.json b/sources/datasets/definitions/geo-landmarks.json new file mode 100644 index 0000000..21acac5 --- /dev/null +++ b/sources/datasets/definitions/geo-landmarks.json @@ -0,0 +1,6 @@ +{ + "dataset_type": "HFDataset", + "module_path": "../types/hugging_face.py", + "title": "geo-landmarks", + "title_in_source": "Qdrant/google-landmark-geo" +} diff --git a/sources/datasets/definitions/plants.json b/sources/datasets/definitions/plants.json new file mode 100644 index 0000000..739a4af --- /dev/null +++ b/sources/datasets/definitions/plants.json @@ -0,0 +1,6 @@ +{ + "dataset_type": "HFDataset", + "module_path": "../types/hugging_face.py", + "title": "plants", + "title_in_source": "sampath017/plants" +} diff --git a/sources/datasets/definitions/rsicd.json b/sources/datasets/definitions/rsicd.json new file mode 100644 index 0000000..a5bc102 --- /dev/null +++ b/sources/datasets/definitions/rsicd.json @@ -0,0 +1,7 @@ +{ + "dataset_type": "EncordDataset", + "module_path": "../types/encord_ds.py", + "title": "rsicd", + "project_hash": "46ba913e-1428-48ef-be7f-2553e69bc1e6", + "classification_hash": "4f6cf0c8" +} diff --git a/sources/datasets/definitions/skin-cancer.json b/sources/datasets/definitions/skin-cancer.json new file mode 100644 index 0000000..0394732 --- /dev/null +++ b/sources/datasets/definitions/skin-cancer.json @@ -0,0 +1,7 @@ +{ + "dataset_type": "HFDataset", + "module_path": "../types/hugging_face.py", + "title": "skin-cancer", + "title_in_source": "marmal88/skin_cancer", + "target_feature": "dx" +} diff --git a/sources/datasets/definitions/sports-classification.json b/sources/datasets/definitions/sports-classification.json new file mode 100644 index 0000000..2b188cd --- /dev/null +++ b/sources/datasets/definitions/sports-classification.json @@ -0,0 +1,6 @@ +{ + "dataset_type": "HFDataset", + "module_path": "../types/hugging_face.py", + "title": "sports-classification", + "title_in_source": "HES-XPLAIN/SportsImageClassification" +} diff --git a/clip_eval/dataset/encord_utils.py b/sources/datasets/types/encord_ds.py similarity index 54% rename from clip_eval/dataset/encord_utils.py rename to sources/datasets/types/encord_ds.py index 997e514..77b5773 100644 --- a/clip_eval/dataset/encord_utils.py +++ b/sources/datasets/types/encord_ds.py @@ -1,14 +1,164 @@ import json +import os +from dataclasses import dataclass from pathlib import Path from typing import Any -from encord import Project +from encord import EncordUserClient, Project from encord.common.constants import DATETIME_STRING_FORMAT +from encord.objects import Classification, LabelRowV2 +from encord.objects.common import PropertyType from encord.orm.dataset import DataType, Image, Video -from encord.project import LabelRowV2 +from PIL import Image as PILImage from tqdm.auto import tqdm -from .utils import Split, collect_async, download_file, simple_random_split +from clip_eval.dataset import Dataset, Split +from clip_eval.dataset.utils import collect_async, download_file, simple_random_split + + +class EncordDataset(Dataset): + def __init__( + self, + title: str, + project_hash: str, + classification_hash: str, + *, + split: Split = Split.ALL, + title_in_source: str | None = None, + transform=None, + cache_dir: str | None = None, + ssh_key_path: str | None = None, + **kwargs, + ): + super().__init__(title, split=split, title_in_source=title_in_source, transform=transform, cache_dir=cache_dir) + self._setup(project_hash, classification_hash, ssh_key_path, **kwargs) + + def __getitem__(self, idx): + frame_path = self._dataset_indices_info[idx].image_file + img = PILImage.open(frame_path) + label = self._dataset_indices_info[idx].label + + if self.transform is not None: + _d = self.transform(dict(image=[img], label=[label])) + res_item = dict(image=_d["image"][0], label=_d["label"][0]) + else: + res_item = dict(image=img, label=label) + return res_item + + def __len__(self): + return len(self._dataset_indices_info) + + def _get_frame_file(self, label_row: LabelRowV2, frame: int) -> Path: + return get_frame_file( + data_dir=self._cache_dir, + project_hash=self._project.project_hash, + label_row=label_row, + frame=frame, + ) + + def _get_label_row_annotations_file(self, label_row: LabelRowV2) -> Path: + return get_label_row_annotations_file( + data_dir=self._cache_dir, + project_hash=self._project.project_hash, + label_row_hash=label_row.label_hash, + ) + + def _ensure_answers_availability(self) -> dict: + lrs_info_file = get_label_rows_info_file(self._cache_dir, self._project.project_hash) + label_rows_info: dict = json.loads(lrs_info_file.read_text(encoding="utf-8")) + should_update_info = False + class_name_to_idx = {name: idx for idx, name in enumerate(self.class_names)} # Fast lookup of class indices + for label_row in self._label_rows: + if "answers" not in label_rows_info[label_row.label_hash]: + if not label_row.is_labelling_initialised: + # Retrieve label row content from local storage + anns_path = self._get_label_row_annotations_file(label_row) + label_row.from_labels_dict(json.loads(anns_path.read_text(encoding="utf-8"))) + + answers = dict() + for frame_view in label_row.get_frame_views(): + clf_instances = frame_view.get_classification_instances(self._classification) + # Skip frames where the input classification is missing + if len(clf_instances) == 0: + continue + + clf_instance = clf_instances[0] + clf_answer = clf_instance.get_answer(self._attribute) + # Skip frames where the input classification has no answer (probable annotation error) + if clf_answer is None: + continue + + answers[frame_view.frame] = { + "image_file": self._get_frame_file(label_row, frame_view.frame).as_posix(), + "label": class_name_to_idx[clf_answer.title], + } + label_rows_info[label_row.label_hash]["answers"] = answers + should_update_info = True + if should_update_info: + lrs_info_file.write_text(json.dumps(label_rows_info), encoding="utf-8") + return label_rows_info + + def _setup( + self, + project_hash: str, + classification_hash: str, + ssh_key_path: str | None = None, + **kwargs, + ): + ssh_key_path = ssh_key_path or os.getenv("ENCORD_SSH_KEY_PATH") + if ssh_key_path is None: + raise ValueError( + "The `ssh_key_path` parameter and the `ENCORD_SSH_KEY_PATH` environment variable are both missing." + "Please set one of them to proceed" + ) + client = EncordUserClient.create_with_ssh_private_key(ssh_private_key_path=ssh_key_path) + self._project = client.get_project(project_hash) + + self._classification = self._project.ontology_structure.get_child_by_hash( + classification_hash, type_=Classification + ) + radio_attribute = self._classification.attributes[0] + if radio_attribute.get_property_type() != PropertyType.RADIO: + raise ValueError("Expected a classification hash with an attribute of type `Radio`") + self._attribute = radio_attribute + self.class_names = [o.title for o in self._attribute.options] + + # Fetch the label rows of the selected split + splits_file = self._cache_dir / "splits.json" + split_to_lr_hashes: dict[str, list[str]] + if splits_file.exists(): + split_to_lr_hashes = json.loads(splits_file.read_text(encoding="utf-8")) + else: + split_to_lr_hashes = simple_project_split(self._project) + splits_file.write_text(json.dumps(split_to_lr_hashes), encoding="utf-8") + self._label_rows = self._project.list_label_rows_v2(label_hashes=split_to_lr_hashes[self.split]) + + # Get data from source. Users may supply the `overwrite_annotations` keyword in the init to download everything + download_data_from_project( + self._project, + self._cache_dir, + self._label_rows, + tqdm_desc=f"Downloading {self.split} data from Encord project `{self._project.title}`", + **kwargs, + ) + + # Prepare data for the __getitem__ method + self._dataset_indices_info: list[EncordDataset.DatasetIndexInfo] = [] + label_rows_info = self._ensure_answers_availability() + for label_row in self._label_rows: + answers: dict[int, Any] = label_rows_info[label_row.label_hash]["answers"] + for frame_num in sorted(answers.keys()): + self._dataset_indices_info.append(EncordDataset.DatasetIndexInfo(**answers[frame_num])) + + @dataclass + class DatasetIndexInfo: + image_file: Path | str + label: int + + +# ----------------------------------------------------------------------- +# UTILITY FUNCTIONS +# ----------------------------------------------------------------------- def _download_image(image_data: Image | Video, destination_dir: Path) -> Path: diff --git a/clip_eval/dataset/hf.py b/sources/datasets/types/hugging_face.py similarity index 98% rename from clip_eval/dataset/hf.py rename to sources/datasets/types/hugging_face.py index 82ff75e..145bb10 100644 --- a/clip_eval/dataset/hf.py +++ b/sources/datasets/types/hugging_face.py @@ -1,6 +1,6 @@ from datasets import ClassLabel, DatasetDict, Sequence, Value, load_dataset -from .base import Dataset, Split +from clip_eval.dataset import Dataset, Split class HFDataset(Dataset): diff --git a/sources/models/definitions/apple.json b/sources/models/definitions/apple.json new file mode 100644 index 0000000..28de863 --- /dev/null +++ b/sources/models/definitions/apple.json @@ -0,0 +1,6 @@ +{ + "model_type": "OpenCLIPModel", + "module_path": "../types/open_clip.py", + "title": "apple", + "title_in_source": "hf-hub:apple/DFN5B-CLIP-ViT-H-14" +} diff --git a/sources/models/definitions/bioclip.json b/sources/models/definitions/bioclip.json new file mode 100644 index 0000000..cd9d899 --- /dev/null +++ b/sources/models/definitions/bioclip.json @@ -0,0 +1,6 @@ +{ + "model_type": "OpenCLIPModel", + "module_path": "../types/open_clip.py", + "title": "bioclip", + "title_in_source": "hf-hub:imageomics/bioclip" +} diff --git a/sources/models/definitions/clip.json b/sources/models/definitions/clip.json new file mode 100644 index 0000000..19941f6 --- /dev/null +++ b/sources/models/definitions/clip.json @@ -0,0 +1,6 @@ +{ + "model_type": "ClosedCLIPModel", + "module_path": "../types/hugging_face_clip.py", + "title": "clip", + "title_in_source": "openai/clip-vit-large-patch14-336" +} diff --git a/sources/models/definitions/eva-clip.json b/sources/models/definitions/eva-clip.json new file mode 100644 index 0000000..7d1c999 --- /dev/null +++ b/sources/models/definitions/eva-clip.json @@ -0,0 +1,6 @@ +{ + "model_type": "OpenCLIPModel", + "module_path": "../types/open_clip.py", + "title": "eva-clip", + "title_in_source": "BAAI/EVA-CLIP-8B-448" +} diff --git a/sources/models/definitions/fashion.json b/sources/models/definitions/fashion.json new file mode 100644 index 0000000..e3c1762 --- /dev/null +++ b/sources/models/definitions/fashion.json @@ -0,0 +1,6 @@ +{ + "model_type": "ClosedCLIPModel", + "module_path": "../types/hugging_face_clip.py", + "title": "fashion", + "title_in_source": "patrickjohncyh/fashion-clip" +} diff --git a/sources/models/definitions/plip.json b/sources/models/definitions/plip.json new file mode 100644 index 0000000..058e482 --- /dev/null +++ b/sources/models/definitions/plip.json @@ -0,0 +1,6 @@ +{ + "model_type": "ClosedCLIPModel", + "module_path": "../types/hugging_face_clip.py", + "title": "plip", + "title_in_source": "vinid/plip" +} diff --git a/sources/models/definitions/pubmed.json b/sources/models/definitions/pubmed.json new file mode 100644 index 0000000..23995a5 --- /dev/null +++ b/sources/models/definitions/pubmed.json @@ -0,0 +1,6 @@ +{ + "model_type": "ClosedCLIPModel", + "module_path": "../types/hugging_face_clip.py", + "title": "pubmed", + "title_in_source": "flaviagiammarino/pubmed-clip-vit-base-patch32" +} diff --git a/sources/models/definitions/rsicd-encord.json b/sources/models/definitions/rsicd-encord.json new file mode 100644 index 0000000..88481bb --- /dev/null +++ b/sources/models/definitions/rsicd-encord.json @@ -0,0 +1,6 @@ +{ + "model_type": "LocalCLIPModel", + "module_path": "../types/local_clip.py", + "title": "rsicd-encord", + "title_in_source": "ViT-B-32" +} diff --git a/sources/models/definitions/rsicd.json b/sources/models/definitions/rsicd.json new file mode 100644 index 0000000..e02f5de --- /dev/null +++ b/sources/models/definitions/rsicd.json @@ -0,0 +1,6 @@ +{ + "model_type": "ClosedCLIPModel", + "module_path": "../types/hugging_face_clip.py", + "title": "rsicd", + "title_in_source": "flax-community/clip-rsicd" +} diff --git a/sources/models/definitions/siglip_large.json b/sources/models/definitions/siglip_large.json new file mode 100644 index 0000000..ea356f6 --- /dev/null +++ b/sources/models/definitions/siglip_large.json @@ -0,0 +1,6 @@ +{ + "model_type": "ClosedCLIPModel", + "module_path": "../types/hugging_face_clip.py", + "title": "siglip_large", + "title_in_source": "google/siglip-large-patch16-256" +} diff --git a/sources/models/definitions/siglip_small.json b/sources/models/definitions/siglip_small.json new file mode 100644 index 0000000..5c6db71 --- /dev/null +++ b/sources/models/definitions/siglip_small.json @@ -0,0 +1,6 @@ +{ + "model_type": "ClosedCLIPModel", + "module_path": "../types/hugging_face_clip.py", + "title": "siglip_small", + "title_in_source": "google/siglip-base-patch16-224" +} diff --git a/sources/models/definitions/street.json b/sources/models/definitions/street.json new file mode 100644 index 0000000..404456f --- /dev/null +++ b/sources/models/definitions/street.json @@ -0,0 +1,6 @@ +{ + "model_type": "ClosedCLIPModel", + "module_path": "../types/hugging_face_clip.py", + "title": "street", + "title_in_source": "geolocal/StreetCLIP" +} diff --git a/sources/models/definitions/tinyclip.json b/sources/models/definitions/tinyclip.json new file mode 100644 index 0000000..10d66ac --- /dev/null +++ b/sources/models/definitions/tinyclip.json @@ -0,0 +1,6 @@ +{ + "model_type": "ClosedCLIPModel", + "module_path": "../types/hugging_face_clip.py", + "title": "tinyclip", + "title_in_source": "wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M" +} diff --git a/sources/models/definitions/vit-b-32-laion2b.json b/sources/models/definitions/vit-b-32-laion2b.json new file mode 100644 index 0000000..a100000 --- /dev/null +++ b/sources/models/definitions/vit-b-32-laion2b.json @@ -0,0 +1,7 @@ +{ + "model_type": "OpenCLIPModel", + "module_path": "../types/open_clip.py", + "title": "vit-b-32-laion2b", + "title_in_source": "ViT-B-32", + "pretrained": "laion2b_e16" +} diff --git a/sources/models/model-definition-schema.json b/sources/models/model-definition-schema.json new file mode 100644 index 0000000..40aa14e --- /dev/null +++ b/sources/models/model-definition-schema.json @@ -0,0 +1,62 @@ +{ + "additionalProperties": true, + "properties": { + "model_type": { + "title": "Model Type", + "type": "string" + }, + "module_path": { + "format": "path", + "title": "Module Path", + "type": "string" + }, + "title": { + "title": "Title", + "type": "string" + }, + "device": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Device" + }, + "title_in_source": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Title In Source" + }, + "cache_dir": { + "anyOf": [ + { + "format": "path", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Cache Dir" + } + }, + "required": [ + "model_type", + "module_path", + "title" + ], + "title": "ModelDefinitionSpec", + "type": "object" +} diff --git a/sources/models/types/hugging_face_clip.py b/sources/models/types/hugging_face_clip.py new file mode 100644 index 0000000..f5b39a4 --- /dev/null +++ b/sources/models/types/hugging_face_clip.py @@ -0,0 +1,76 @@ +from collections.abc import Callable +from typing import Any + +import numpy as np +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import AutoModel as HF_AutoModel +from transformers import AutoProcessor as HF_AutoProcessor +from transformers import AutoTokenizer as HF_AutoTokenizer + +from clip_eval.common import ClassArray, EmbeddingArray +from clip_eval.dataset import Dataset +from clip_eval.models import Model + + +class ClosedCLIPModel(Model): + def __init__( + self, + title: str, + device: str | None = None, + *, + title_in_source: str | None = None, + cache_dir: str | None = None, + **kwargs, + ) -> None: + super().__init__(title, device, title_in_source=title_in_source, cache_dir=cache_dir) + self._setup(**kwargs) + + def get_transform(self) -> Callable[[dict[str, Any]], dict[str, list[Any]]]: + def process_fn(batch) -> dict[str, list[Any]]: + images = [i.convert("RGB") for i in batch["image"]] + batch["image"] = [ + self.processor(images=[i], return_tensors="pt").to(self.device).pixel_values.squeeze() for i in images + ] + return batch + + return process_fn + + def get_collate_fn(self) -> Callable[[Any], Any]: + def collate_fn(examples) -> dict[str, torch.Tensor]: + images = [] + labels = [] + for example in examples: + images.append(example["image"]) + labels.append(example["label"]) + + pixel_values = torch.stack(images) + labels = torch.tensor(labels) + return {"pixel_values": pixel_values, "labels": labels} + + return collate_fn + + def _setup(self, **kwargs) -> None: + self.model = HF_AutoModel.from_pretrained(self.title_in_source, cache_dir=self._cache_dir).to(self.device) + load_result = HF_AutoProcessor.from_pretrained(self.title_in_source, cache_dir=self._cache_dir) + self.processor = load_result[0] if isinstance(load_result, tuple) else load_result + self.tokenizer = HF_AutoTokenizer.from_pretrained(self.title_in_source, cache_dir=self._cache_dir) + + def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, EmbeddingArray, ClassArray]: + all_image_embeddings = [] + all_labels = [] + with torch.inference_mode(): + _dataset: Dataset = dataloader.dataset + inputs = self.tokenizer(_dataset.text_queries, padding=True, return_tensors="pt").to(self.device) + class_features = self.model.get_text_features(**inputs) + normalized_class_features = class_features / class_features.norm(p=2, dim=-1, keepdim=True) + class_embeddings = normalized_class_features.numpy(force=True) + for batch in tqdm(dataloader, desc=f"Embedding dataset with {self.title}"): + image_features = self.model.get_image_features(pixel_values=batch["pixel_values"].to(self.device)) + normalized_image_features = (image_features / image_features.norm(p=2, dim=-1, keepdim=True)).squeeze() + all_image_embeddings.append(normalized_image_features) + all_labels.append(batch["labels"]) + image_embeddings = torch.concatenate(all_image_embeddings).numpy(force=True) + labels = torch.concatenate(all_labels).numpy(force=True).astype(np.int32) + return image_embeddings, class_embeddings, labels diff --git a/clip_eval/models/local.py b/sources/models/types/local_clip.py similarity index 89% rename from clip_eval/models/local.py rename to sources/models/types/local_clip.py index 6103b06..fd55f3f 100644 --- a/clip_eval/models/local.py +++ b/sources/models/types/local_clip.py @@ -1,4 +1,4 @@ -from .CLIP_model import OpenCLIPModel +from sources.models.types.open_clip import OpenCLIPModel class LocalCLIPModel(OpenCLIPModel): diff --git a/sources/models/types/open_clip.py b/sources/models/types/open_clip.py new file mode 100644 index 0000000..65aab6c --- /dev/null +++ b/sources/models/types/open_clip.py @@ -0,0 +1,75 @@ +from collections.abc import Callable +from typing import Any + +import numpy as np +import open_clip +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +from clip_eval.common import ClassArray, EmbeddingArray +from clip_eval.dataset import Dataset +from clip_eval.models import Model + + +class OpenCLIPModel(Model): + def __init__( + self, + title: str, + device: str | None = None, + *, + title_in_source: str, + pretrained: str | None = None, + cache_dir: str | None = None, + **kwargs, + ) -> None: + self.pretrained = pretrained + super().__init__(title, device, title_in_source=title_in_source, cache_dir=cache_dir, **kwargs) + self._setup(**kwargs) + + def get_transform(self) -> Callable[[dict[str, Any]], dict[str, list[Any]]]: + def process_fn(batch) -> dict[str, list[Any]]: + images = [i.convert("RGB") for i in batch["image"]] + batch["image"] = [self.processor(i) for i in images] + return batch + + return process_fn + + def get_collate_fn(self) -> Callable[[Any], Any]: + def collate_fn(examples) -> dict[str, torch.Tensor]: + images = [] + labels = [] + for example in examples: + images.append(example["image"]) + labels.append(example["label"]) + + torch_images = torch.stack(images) + labels = torch.tensor(labels) + return {"image": torch_images, "labels": labels} + + return collate_fn + + def _setup(self, **kwargs) -> None: + self.model, _, self.processor = open_clip.create_model_and_transforms( + model_name=self.title_in_source, + pretrained=self.pretrained, + cache_dir=self._cache_dir.as_posix(), + device=self.device, + **kwargs, + ) + self.tokenizer = open_clip.get_tokenizer(model_name=self.title_in_source) + + def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, EmbeddingArray, ClassArray]: + all_image_embeddings = [] + all_labels = [] + with torch.inference_mode(): + _dataset: Dataset = dataloader.dataset + text = self.tokenizer(_dataset.text_queries).to(self.device) + class_embeddings = self.model.encode_text(text, normalize=True).numpy(force=True) + for batch in tqdm(dataloader, desc=f"Embedding dataset with {self.title}"): + image_features = self.model.encode_image(batch["image"].to(self.device), normalize=True) + all_image_embeddings.append(image_features) + all_labels.append(batch["labels"]) + image_embeddings = torch.concatenate(all_image_embeddings).numpy(force=True) + labels = torch.concatenate(all_labels).numpy(force=True).astype(np.int32) + return image_embeddings, class_embeddings, labels