From 354a5a05109f1aff95704a7441e60be669a3ec70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eloy=20P=C3=A9rez=20Torres?= <99720527+eloy-encord@users.noreply.github.com> Date: Mon, 15 Apr 2024 09:58:37 +0100 Subject: [PATCH] feat: detach resources from providers (#53) Types and definitions of models and datasets are now dynamically fetched from a `sources` folder. With this change, previously built-in resources can be modified, silenced or eliminated if the user wants to. Also, enables seamless addition of new models and datasets by just adding the corresponding implementation (new type) or definition (instance of an existing type). On the side, refactored the naming convention of several components that were moved with the previous changes, so they share the same patterns. E.g. the model base class went from `CLIPModel` to `Model` as the class itself is related to a higher abstraction than the one of clip. Also, added the singleton pattern and some minor changes to the dataset and model providers in order to avoid partially initialised module errors that appeared when dynamic resources were (tried) loaded before all local resources were successfully imported. The singleton pattern is specifically used to avoid the creation of several provider instances as the original global provider couldn't be loaded on initialisation. --- clip_eval/cli/main.py | 8 +- clip_eval/cli/utils.py | 8 +- clip_eval/common/base.py | 4 +- clip_eval/common/data_models.py | 8 +- clip_eval/constants.py | 13 +- clip_eval/dataset/__init__.py | 2 +- clip_eval/dataset/base.py | 13 ++ clip_eval/dataset/encord.py | 159 -------------- clip_eval/dataset/provider.py | 140 ++++++------- clip_eval/dataset/utils.py | 8 + clip_eval/models/CLIP_model.py | 198 ------------------ clip_eval/models/__init__.py | 4 +- clip_eval/models/base.py | 80 +++++++ clip_eval/models/provider.py | 90 ++++---- clip_eval/plotting/animation.py | 4 +- .../datasets/dataset-definition-schema.json | 70 +++++++ .../datasets/definitions/Alzheimer-MRI.json | 6 + .../definitions/LungCancer4Types.json | 7 + .../datasets/definitions/NIH-Chest-X-ray.json | 9 + .../chest-xray-classification.json | 8 + .../datasets/definitions/geo-landmarks.json | 6 + sources/datasets/definitions/plants.json | 6 + sources/datasets/definitions/rsicd.json | 7 + sources/datasets/definitions/skin-cancer.json | 7 + .../definitions/sports-classification.json | 6 + .../datasets/types/encord_ds.py | 156 +++++++++++++- .../datasets/types/hugging_face.py | 2 +- sources/models/definitions/apple.json | 6 + sources/models/definitions/bioclip.json | 6 + sources/models/definitions/clip.json | 6 + sources/models/definitions/eva-clip.json | 6 + sources/models/definitions/fashion.json | 6 + sources/models/definitions/plip.json | 6 + sources/models/definitions/pubmed.json | 6 + sources/models/definitions/rsicd-encord.json | 6 + sources/models/definitions/rsicd.json | 6 + sources/models/definitions/siglip_large.json | 6 + sources/models/definitions/siglip_small.json | 6 + sources/models/definitions/street.json | 6 + sources/models/definitions/tinyclip.json | 6 + .../models/definitions/vit-b-32-laion2b.json | 7 + sources/models/model-definition-schema.json | 62 ++++++ sources/models/types/hugging_face_clip.py | 76 +++++++ .../models/types/local_clip.py | 2 +- sources/models/types/open_clip.py | 75 +++++++ 45 files changed, 836 insertions(+), 493 deletions(-) delete mode 100644 clip_eval/dataset/encord.py delete mode 100644 clip_eval/models/CLIP_model.py create mode 100644 clip_eval/models/base.py create mode 100644 sources/datasets/dataset-definition-schema.json create mode 100644 sources/datasets/definitions/Alzheimer-MRI.json create mode 100644 sources/datasets/definitions/LungCancer4Types.json create mode 100644 sources/datasets/definitions/NIH-Chest-X-ray.json create mode 100644 sources/datasets/definitions/chest-xray-classification.json create mode 100644 sources/datasets/definitions/geo-landmarks.json create mode 100644 sources/datasets/definitions/plants.json create mode 100644 sources/datasets/definitions/rsicd.json create mode 100644 sources/datasets/definitions/skin-cancer.json create mode 100644 sources/datasets/definitions/sports-classification.json rename clip_eval/dataset/encord_utils.py => sources/datasets/types/encord_ds.py (54%) rename clip_eval/dataset/hf.py => sources/datasets/types/hugging_face.py (98%) create mode 100644 sources/models/definitions/apple.json create mode 100644 sources/models/definitions/bioclip.json create mode 100644 sources/models/definitions/clip.json create mode 100644 sources/models/definitions/eva-clip.json create mode 100644 sources/models/definitions/fashion.json create mode 100644 sources/models/definitions/plip.json create mode 100644 sources/models/definitions/pubmed.json create mode 100644 sources/models/definitions/rsicd-encord.json create mode 100644 sources/models/definitions/rsicd.json create mode 100644 sources/models/definitions/siglip_large.json create mode 100644 sources/models/definitions/siglip_small.json create mode 100644 sources/models/definitions/street.json create mode 100644 sources/models/definitions/tinyclip.json create mode 100644 sources/models/definitions/vit-b-32-laion2b.json create mode 100644 sources/models/model-definition-schema.json create mode 100644 sources/models/types/hugging_face_clip.py rename clip_eval/models/local.py => sources/models/types/local_clip.py (89%) create mode 100644 sources/models/types/open_clip.py diff --git a/clip_eval/cli/main.py b/clip_eval/cli/main.py index 173876f..82f3bbd 100644 --- a/clip_eval/cli/main.py +++ b/clip_eval/cli/main.py @@ -130,12 +130,12 @@ def list_models_datasets( Option(help="List all models and dataset that are available via the tool."), ] = False, ): - from clip_eval.dataset.provider import dataset_provider - from clip_eval.models import model_provider + from clip_eval.dataset import DatasetProvider + from clip_eval.models import ModelProvider if all: - datasets = dataset_provider.list_dataset_titles() - models = model_provider.list_model_titles() + datasets = DatasetProvider.list_dataset_titles() + models = ModelProvider.list_model_titles() print(f"Available datasets are: {', '.join(datasets)}") print(f"Available models are: {', '.join(models)}") return diff --git a/clip_eval/cli/utils.py b/clip_eval/cli/utils.py index 5c011a5..8d5adc4 100644 --- a/clip_eval/cli/utils.py +++ b/clip_eval/cli/utils.py @@ -6,8 +6,8 @@ from natsort import natsorted, ns from clip_eval.common.data_models import EmbeddingDefinition -from clip_eval.dataset.provider import dataset_provider -from clip_eval.models.provider import model_provider +from clip_eval.dataset import DatasetProvider +from clip_eval.models import ModelProvider from clip_eval.utils import read_all_cached_embeddings @@ -80,8 +80,8 @@ def select_from_all_embedding_definitions( ) -> list[EmbeddingDefinition]: existing = set(read_all_cached_embeddings(as_list=True)) - models = model_provider.list_model_titles() - datasets = dataset_provider.list_dataset_titles() + models = ModelProvider.list_model_titles() + datasets = DatasetProvider.list_dataset_titles() defs = [EmbeddingDefinition(dataset=d, model=m) for d, m in product(datasets, models)] if not include_existing: diff --git a/clip_eval/common/base.py b/clip_eval/common/base.py index 4d6b3da..ff1d528 100644 --- a/clip_eval/common/base.py +++ b/clip_eval/common/base.py @@ -6,7 +6,7 @@ from clip_eval.constants import NPZ_KEYS from clip_eval.dataset import Dataset -from clip_eval.models import CLIPModel +from clip_eval.models import Model from .numpy_types import ClassArray, EmbeddingArray @@ -69,7 +69,7 @@ def to_file(self, path: Path) -> Path: return path @staticmethod - def build_embedding(model: CLIPModel, dataset: Dataset, batch_size: int = 50) -> "Embeddings": + def build_embedding(model: Model, dataset: Dataset, batch_size: int = 50) -> "Embeddings": dataset.set_transform(model.get_transform()) dataloader = DataLoader(dataset, collate_fn=model.get_collate_fn(), batch_size=batch_size) diff --git a/clip_eval/common/data_models.py b/clip_eval/common/data_models.py index d4150a0..653f490 100644 --- a/clip_eval/common/data_models.py +++ b/clip_eval/common/data_models.py @@ -7,8 +7,8 @@ from pydantic.functional_validators import AfterValidator from clip_eval.constants import PROJECT_PATHS -from clip_eval.dataset import Split, dataset_provider -from clip_eval.models import model_provider +from clip_eval.dataset import DatasetProvider, Split +from clip_eval.models import ModelProvider from .base import Embeddings from .string_utils import safe_str @@ -61,8 +61,8 @@ def save_embeddings(self, embeddings: Embeddings, split: Split, overwrite: bool return True def build_embeddings(self, split: Split) -> Embeddings: - model = model_provider.get_model(self.model) - dataset = dataset_provider.get_dataset(self.dataset, split) + model = ModelProvider.get_model(self.model) + dataset = DatasetProvider.get_dataset(self.dataset, split) return Embeddings.build_embedding(model, dataset) def __str__(self): diff --git a/clip_eval/constants.py b/clip_eval/constants.py index 183c2f8..5f85fd2 100644 --- a/clip_eval/constants.py +++ b/clip_eval/constants.py @@ -6,8 +6,10 @@ load_dotenv() # If the cache directory is not explicitly specified, use the `.cache` directory located in the project's root. -CACHE_PATH = Path(os.environ.get("CLIP_CACHE_PATH", Path(__file__).parent.parent / ".cache")) -_OUTPUT_PATH = Path(os.environ.get("OUTPUT_PATH", "output")) +_CLIP_EVAL_ROOT_DIR = Path(__file__).parent.parent +CACHE_PATH = Path(os.environ.get("CLIP_EVAL_CACHE_PATH", _CLIP_EVAL_ROOT_DIR / ".cache")) +_OUTPUT_PATH = Path(os.environ.get("CLIP_EVAL_OUTPUT_PATH", _CLIP_EVAL_ROOT_DIR / "output")) +_SOURCES_PATH = _CLIP_EVAL_ROOT_DIR / "sources" class PROJECT_PATHS: @@ -25,3 +27,10 @@ class NPZ_KEYS: class OUTPUT_PATH: ANIMATIONS = _OUTPUT_PATH / "animations" EVALUATIONS = _OUTPUT_PATH / "evaluations" + + +class SOURCES_PATH: + DATASET_TYPES = _SOURCES_PATH / "datasets" / "types" + DATASET_INSTANCE_DEFINITIONS = _SOURCES_PATH / "datasets" / "definitions" + MODEL_TYPES = _SOURCES_PATH / "models" / "types" + MODEL_INSTANCE_DEFINITIONS = _SOURCES_PATH / "models" / "definitions" diff --git a/clip_eval/dataset/__init__.py b/clip_eval/dataset/__init__.py index 29cbb56..3ff74e6 100644 --- a/clip_eval/dataset/__init__.py +++ b/clip_eval/dataset/__init__.py @@ -1,2 +1,2 @@ from .base import Dataset, Split -from .provider import dataset_provider +from .provider import DatasetProvider diff --git a/clip_eval/dataset/base.py b/clip_eval/dataset/base.py index 62af44e..57822b1 100644 --- a/clip_eval/dataset/base.py +++ b/clip_eval/dataset/base.py @@ -2,6 +2,7 @@ from enum import StrEnum, auto from pathlib import Path +from pydantic import BaseModel, ConfigDict from torch.utils.data import Dataset as TorchDataset from clip_eval.constants import CACHE_PATH @@ -14,6 +15,18 @@ class Split(StrEnum): ALL = auto() +class DatasetDefinitionSpec(BaseModel): + dataset_type: str + module_path: Path + title: str + split: Split = Split.ALL + title_in_source: str | None = None + cache_dir: Path | None = None + + # Allow additional dataset configuration fields + model_config = ConfigDict(extra="allow") + + class Dataset(TorchDataset, ABC): def __init__( self, diff --git a/clip_eval/dataset/encord.py b/clip_eval/dataset/encord.py deleted file mode 100644 index b36aa97..0000000 --- a/clip_eval/dataset/encord.py +++ /dev/null @@ -1,159 +0,0 @@ -import json -import os -from dataclasses import dataclass -from pathlib import Path -from typing import Any - -from encord import EncordUserClient -from encord.objects import Classification, LabelRowV2 -from encord.objects.common import PropertyType -from PIL import Image - -from .base import Dataset, Split -from .encord_utils import ( - download_data_from_project, - get_frame_file, - get_label_row_annotations_file, - get_label_rows_info_file, - simple_project_split, -) - - -class EncordDataset(Dataset): - def __init__( - self, - title: str, - project_hash: str, - classification_hash: str, - *, - split: Split = Split.ALL, - title_in_source: str | None = None, - transform=None, - cache_dir: str | None = None, - ssh_key_path: str | None = None, - **kwargs, - ): - super().__init__(title, split=split, title_in_source=title_in_source, transform=transform, cache_dir=cache_dir) - self._setup(project_hash, classification_hash, ssh_key_path, **kwargs) - - def __getitem__(self, idx): - frame_path = self._dataset_indices_info[idx].image_file - img = Image.open(frame_path) - label = self._dataset_indices_info[idx].label - - if self.transform is not None: - _d = self.transform(dict(image=[img], label=[label])) - res_item = dict(image=_d["image"][0], label=_d["label"][0]) - else: - res_item = dict(image=img, label=label) - return res_item - - def __len__(self): - return len(self._dataset_indices_info) - - def _get_frame_file(self, label_row: LabelRowV2, frame: int) -> Path: - return get_frame_file( - data_dir=self._cache_dir, - project_hash=self._project.project_hash, - label_row=label_row, - frame=frame, - ) - - def _get_label_row_annotations_file(self, label_row: LabelRowV2) -> Path: - return get_label_row_annotations_file( - data_dir=self._cache_dir, - project_hash=self._project.project_hash, - label_row_hash=label_row.label_hash, - ) - - def _ensure_answers_availability(self) -> dict: - lrs_info_file = get_label_rows_info_file(self._cache_dir, self._project.project_hash) - label_rows_info: dict = json.loads(lrs_info_file.read_text(encoding="utf-8")) - should_update_info = False - class_name_to_idx = {name: idx for idx, name in enumerate(self.class_names)} # Fast lookup of class indices - for label_row in self._label_rows: - if "answers" not in label_rows_info[label_row.label_hash]: - if not label_row.is_labelling_initialised: - # Retrieve label row content from local storage - anns_path = self._get_label_row_annotations_file(label_row) - label_row.from_labels_dict(json.loads(anns_path.read_text(encoding="utf-8"))) - - answers = dict() - for frame_view in label_row.get_frame_views(): - clf_instances = frame_view.get_classification_instances(self._classification) - # Skip frames where the input classification is missing - if len(clf_instances) == 0: - continue - - clf_instance = clf_instances[0] - clf_answer = clf_instance.get_answer(self._attribute) - # Skip frames where the input classification has no answer (probable annotation error) - if clf_answer is None: - continue - - answers[frame_view.frame] = { - "image_file": self._get_frame_file(label_row, frame_view.frame).as_posix(), - "label": class_name_to_idx[clf_answer.title], - } - label_rows_info[label_row.label_hash]["answers"] = answers - should_update_info = True - if should_update_info: - lrs_info_file.write_text(json.dumps(label_rows_info), encoding="utf-8") - return label_rows_info - - def _setup( - self, - project_hash: str, - classification_hash: str, - ssh_key_path: str | None = None, - **kwargs, - ): - ssh_key_path = ssh_key_path or os.getenv("ENCORD_SSH_KEY_PATH") - if ssh_key_path is None: - raise ValueError( - "The `ssh_key_path` parameter and the `ENCORD_SSH_KEY_PATH` environment variable are both missing." - "Please set one of them to proceed" - ) - client = EncordUserClient.create_with_ssh_private_key(ssh_private_key_path=ssh_key_path) - self._project = client.get_project(project_hash) - - self._classification = self._project.ontology_structure.get_child_by_hash( - classification_hash, type_=Classification - ) - radio_attribute = self._classification.attributes[0] - if radio_attribute.get_property_type() != PropertyType.RADIO: - raise ValueError("Expected a classification hash with an attribute of type `Radio`") - self._attribute = radio_attribute - self.class_names = [o.title for o in self._attribute.options] - - # Fetch the label rows of the selected split - splits_file = self._cache_dir / "splits.json" - split_to_lr_hashes: dict[str, list[str]] - if splits_file.exists(): - split_to_lr_hashes = json.loads(splits_file.read_text(encoding="utf-8")) - else: - split_to_lr_hashes = simple_project_split(self._project) - splits_file.write_text(json.dumps(split_to_lr_hashes), encoding="utf-8") - self._label_rows = self._project.list_label_rows_v2(label_hashes=split_to_lr_hashes[self.split]) - - # Get data from source. Users may supply the `overwrite_annotations` keyword in the init to download everything - download_data_from_project( - self._project, - self._cache_dir, - self._label_rows, - tqdm_desc=f"Downloading {self.split} data from Encord project `{self._project.title}`", - **kwargs, - ) - - # Prepare data for the __getitem__ method - self._dataset_indices_info: list[EncordDataset.DatasetIndexInfo] = [] - label_rows_info = self._ensure_answers_availability() - for label_row in self._label_rows: - answers: dict[int, Any] = label_rows_info[label_row.label_hash]["answers"] - for frame_num in sorted(answers.keys()): - self._dataset_indices_info.append(EncordDataset.DatasetIndexInfo(**answers[frame_num])) - - @dataclass - class DatasetIndexInfo: - image_file: Path | str - label: int diff --git a/clip_eval/dataset/provider.py b/clip_eval/dataset/provider.py index 1979320..c744251 100644 --- a/clip_eval/dataset/provider.py +++ b/clip_eval/dataset/provider.py @@ -1,103 +1,99 @@ from copy import deepcopy +from pathlib import Path from typing import Any from natsort import natsorted, ns -from clip_eval.constants import CACHE_PATH +from clip_eval.constants import CACHE_PATH, SOURCES_PATH -from .base import Dataset, Split -from .encord import EncordDataset -from .hf import HFDataset +from .base import Dataset, DatasetDefinitionSpec, Split +from .utils import load_class_from_path class DatasetProvider: + __instance = None + __global_settings: dict[str, Any] = dict() + __known_dataset_types: dict[tuple[Path, str], Any] = dict() + def __init__(self): self._datasets = {} - self._global_settings: dict[str, Any] = dict() - - @property - def global_settings(self) -> dict: - return deepcopy(self._global_settings) - - def add_global_setting(self, name: str, value: Any) -> None: - self._global_settings[name] = value - - def remove_global_setting(self, name: str): - self._global_settings.pop(name, None) - def register_dataset(self, source: type[Dataset], title: str, split: Split | None = None, **kwargs): + @classmethod + def prepare(cls): + if cls.__instance is None: + cls.__instance = cls() + cls.register_datasets_from_sources_dir(SOURCES_PATH.DATASET_INSTANCE_DEFINITIONS) + # Global settings + cls.__instance.add_global_setting("cache_dir", CACHE_PATH) + return cls.__instance + + @classmethod + def global_settings(cls) -> dict: + return deepcopy(cls.__global_settings) + + @classmethod + def add_global_setting(cls, name: str, value: Any) -> None: + cls.__global_settings[name] = value + + @classmethod + def remove_global_setting(cls, name: str) -> None: + cls.__global_settings.pop(name, None) + + @classmethod + def register_dataset(cls, source: type[Dataset], title: str, split: Split | None = None, **kwargs) -> None: + instance = cls.prepare() if split is None: # One dataset with all the split definitions - self._datasets[title] = (source, kwargs) + instance._datasets[title] = (source, kwargs) else: # One dataset is defined per split kwargs.update(split=split) - self._datasets[(title, split)] = (source, kwargs) - - def get_dataset(self, title: str, split: Split) -> Dataset: - if (title, split) in self._datasets: + instance._datasets[(title, split)] = (source, kwargs) + + @classmethod + def register_dataset_from_json_definition(cls, json_definition: Path) -> None: + spec = DatasetDefinitionSpec.model_validate_json(json_definition.read_text(encoding="utf-8")) + if not spec.module_path.is_absolute(): # Handle relative module paths + spec.module_path = (json_definition.parent / spec.module_path).resolve() + + # Fetch the class of the dataset type stated in the definition + dataset_type = cls.__known_dataset_types.get((spec.module_path, spec.dataset_type)) + if dataset_type is None: + dataset_type = load_class_from_path(spec.module_path.as_posix(), spec.dataset_type) + if not issubclass(dataset_type, Dataset): + raise ValueError( + f"Dataset type specified in the JSON definition file `{json_definition.as_posix()}` " + f"does not inherit from the base class `Dataset`" + ) + cls.__known_dataset_types[(spec.module_path, spec.dataset_type)] = dataset_type + cls.register_dataset(dataset_type, **spec.model_dump(exclude={"module_path", "dataset_type"})) + + @classmethod + def register_datasets_from_sources_dir(cls, source_dir: Path) -> None: + for f in source_dir.glob("*.json"): + cls.register_dataset_from_json_definition(f) + + @classmethod + def get_dataset(cls, title: str, split: Split) -> Dataset: + instance = cls.prepare() + if (title, split) in instance._datasets: # The split corresponds to fetching a whole dataset (one-to-one relationship) dict_key = (title, split) split = Split.ALL # Ensure to read the whole dataset - elif title in self._datasets: + elif title in instance._datasets: # The dataset internally knows how to determine the split dict_key = title else: raise ValueError(f"Unrecognized dataset: {title}") - source, kwargs = self._datasets[dict_key] + source, kwargs = instance._datasets[dict_key] # Apply global settings. Values of local settings take priority when local and global settings share keys. - kwargs_with_global_settings = self.global_settings | kwargs + kwargs_with_global_settings = cls.global_settings() | kwargs return source(title, split=split, **kwargs_with_global_settings) - def list_dataset_titles(self) -> list[str]: + @classmethod + def list_dataset_titles(cls) -> list[str]: dataset_titles = [ - dict_key[0] if isinstance(dict_key, tuple) else dict_key for dict_key in self._datasets.keys() + dict_key[0] if isinstance(dict_key, tuple) else dict_key for dict_key in cls.prepare()._datasets.keys() ] return natsorted(set(dataset_titles), alg=ns.IGNORECASE) - - -dataset_provider = DatasetProvider() -# Global settings -dataset_provider.add_global_setting("cache_dir", CACHE_PATH) - -# Hugging Face datasets -dataset_provider.register_dataset(HFDataset, "plants", title_in_source="sampath017/plants") -dataset_provider.register_dataset(HFDataset, "Alzheimer-MRI", title_in_source="Falah/Alzheimer_MRI") -dataset_provider.register_dataset(HFDataset, "skin-cancer", title_in_source="marmal88/skin_cancer", target_feature="dx") -dataset_provider.register_dataset(HFDataset, "geo-landmarks", title_in_source="Qdrant/google-landmark-geo") -dataset_provider.register_dataset( - HFDataset, - "LungCancer4Types", - title_in_source="Kabil007/LungCancer4Types", - revision="a1aab924c6bed6b080fc85552fd7b39724931605", -) -dataset_provider.register_dataset( - HFDataset, - "NIH-Chest-X-ray", - title_in_source="alkzar90/NIH-Chest-X-ray-dataset", - name="image-classification", - target_feature="labels", - trust_remote_code=True, -) -dataset_provider.register_dataset( - HFDataset, - "chest-xray-classification", - title_in_source="trpakov/chest-xray-classification", - name="full", - target_feature="labels", -) - -dataset_provider.register_dataset( - HFDataset, - "sports-classification", - title_in_source="HES-XPLAIN/SportsImageClassification", -) - -# Encord datasets -dataset_provider.register_dataset( - EncordDataset, - "rsicd", - project_hash="46ba913e-1428-48ef-be7f-2553e69bc1e6", - classification_hash="4f6cf0c8", -) diff --git a/clip_eval/dataset/utils.py b/clip_eval/dataset/utils.py index 2109bf2..c24fdd4 100644 --- a/clip_eval/dataset/utils.py +++ b/clip_eval/dataset/utils.py @@ -1,3 +1,4 @@ +import importlib.util import os from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor as Executor @@ -68,6 +69,13 @@ def download_file( f.flush() +def load_class_from_path(module_path: str, class_name: str): + spec = importlib.util.spec_from_file_location(module_path, module_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return getattr(module, class_name) + + def simple_random_split( dataset_size: int, seed: int = 42, diff --git a/clip_eval/models/CLIP_model.py b/clip_eval/models/CLIP_model.py deleted file mode 100644 index 45622ed..0000000 --- a/clip_eval/models/CLIP_model.py +++ /dev/null @@ -1,198 +0,0 @@ -from abc import ABC, abstractmethod -from collections.abc import Callable -from pathlib import Path -from typing import Any - -import numpy as np -import open_clip -import torch -from torch.utils.data import DataLoader -from tqdm import tqdm -from transformers import AutoModel as HF_AutoModel -from transformers import AutoProcessor as HF_AutoProcessor -from transformers import AutoTokenizer as HF_AutoTokenizer - -from clip_eval.common.numpy_types import ClassArray, EmbeddingArray -from clip_eval.constants import CACHE_PATH -from clip_eval.dataset import Dataset - - -class CLIPModel(ABC): - def __init__( - self, - title: str, - device: str | None = None, - *, - title_in_source: str | None = None, - cache_dir: str | None = None, - **kwargs, - ) -> None: - self.__title = title - self.__title_in_source = title_in_source or title - device = device or ("cuda" if torch.cuda.is_available() else "cpu") - self._check_device(device) - self.__device = torch.device(device) - if cache_dir is None: - cache_dir = CACHE_PATH - self._cache_dir = Path(cache_dir).expanduser().resolve() / "models" / title - - @property - def title(self) -> str: - return self.__title - - @property - def title_in_source(self) -> str: - return self.__title_in_source - - @property - def device(self) -> torch.device: - return self.__device - - @abstractmethod - def _setup(self, **kwargs) -> None: - pass - - @abstractmethod - def get_transform(self) -> Callable[[dict[str, Any]], dict[str, list[Any]]]: - ... - - @abstractmethod - def get_collate_fn(self) -> Callable[[Any], Any]: - ... - - @abstractmethod - def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, EmbeddingArray, ClassArray]: - ... - - @staticmethod - def _check_device(device: str): - # Check if the input device exists and is available - if device not in {"cuda", "cpu"}: - raise ValueError(f"Unrecognized device: {device}") - if not getattr(torch, device).is_available(): - raise ValueError(f"Unavailable device: {device}") - - -class ClosedCLIPModel(CLIPModel): - def __init__( - self, - title: str, - device: str | None = None, - *, - title_in_source: str | None = None, - cache_dir: str | None = None, - **kwargs, - ) -> None: - super().__init__(title, device, title_in_source=title_in_source, cache_dir=cache_dir) - self._setup(**kwargs) - - def get_transform(self) -> Callable[[dict[str, Any]], dict[str, list[Any]]]: - def process_fn(batch) -> dict[str, list[Any]]: - images = [i.convert("RGB") for i in batch["image"]] - batch["image"] = [ - self.processor(images=[i], return_tensors="pt").to(self.device).pixel_values.squeeze() for i in images - ] - return batch - - return process_fn - - def get_collate_fn(self) -> Callable[[Any], Any]: - def collate_fn(examples) -> dict[str, torch.Tensor]: - images = [] - labels = [] - for example in examples: - images.append(example["image"]) - labels.append(example["label"]) - - pixel_values = torch.stack(images) - labels = torch.tensor(labels) - return {"pixel_values": pixel_values, "labels": labels} - - return collate_fn - - def _setup(self, **kwargs) -> None: - self.model = HF_AutoModel.from_pretrained(self.title_in_source, cache_dir=self._cache_dir).to(self.device) - load_result = HF_AutoProcessor.from_pretrained(self.title_in_source, cache_dir=self._cache_dir) - self.processor = load_result[0] if isinstance(load_result, tuple) else load_result - self.tokenizer = HF_AutoTokenizer.from_pretrained(self.title_in_source, cache_dir=self._cache_dir) - - def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, EmbeddingArray, ClassArray]: - all_image_embeddings = [] - all_labels = [] - with torch.inference_mode(): - _dataset: Dataset = dataloader.dataset - inputs = self.tokenizer(_dataset.text_queries, padding=True, return_tensors="pt").to(self.device) - class_features = self.model.get_text_features(**inputs) - normalized_class_features = class_features / class_features.norm(p=2, dim=-1, keepdim=True) - class_embeddings = normalized_class_features.numpy(force=True) - for batch in tqdm(dataloader, desc=f"Embedding dataset with {self.title}"): - image_features = self.model.get_image_features(pixel_values=batch["pixel_values"].to(self.device)) - normalized_image_features = (image_features / image_features.norm(p=2, dim=-1, keepdim=True)).squeeze() - all_image_embeddings.append(normalized_image_features) - all_labels.append(batch["labels"]) - image_embeddings = torch.concatenate(all_image_embeddings).numpy(force=True) - labels = torch.concatenate(all_labels).numpy(force=True).astype(np.int32) - return image_embeddings, class_embeddings, labels - - -class OpenCLIPModel(CLIPModel): - def __init__( - self, - title: str, - device: str | None = None, - *, - title_in_source: str, - pretrained: str | None = None, - cache_dir: str | None = None, - **kwargs, - ) -> None: - self.pretrained = pretrained - super().__init__(title, device, title_in_source=title_in_source, cache_dir=cache_dir, **kwargs) - self._setup(**kwargs) - - def get_transform(self) -> Callable[[dict[str, Any]], dict[str, list[Any]]]: - def process_fn(batch) -> dict[str, list[Any]]: - images = [i.convert("RGB") for i in batch["image"]] - batch["image"] = [self.processor(i) for i in images] - return batch - - return process_fn - - def get_collate_fn(self) -> Callable[[Any], Any]: - def collate_fn(examples) -> dict[str, torch.Tensor]: - images = [] - labels = [] - for example in examples: - images.append(example["image"]) - labels.append(example["label"]) - - torch_images = torch.stack(images) - labels = torch.tensor(labels) - return {"image": torch_images, "labels": labels} - - return collate_fn - - def _setup(self, **kwargs) -> None: - self.model, _, self.processor = open_clip.create_model_and_transforms( - model_name=self.title_in_source, - pretrained=self.pretrained, - cache_dir=self._cache_dir.as_posix(), - device=self.device, - **kwargs, - ) - self.tokenizer = open_clip.get_tokenizer(model_name=self.title_in_source) - - def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, EmbeddingArray, ClassArray]: - all_image_embeddings = [] - all_labels = [] - with torch.inference_mode(): - _dataset: Dataset = dataloader.dataset - text = self.tokenizer(_dataset.text_queries).to(self.device) - class_embeddings = self.model.encode_text(text, normalize=True).numpy(force=True) - for batch in tqdm(dataloader, desc=f"Embedding dataset with {self.title}"): - image_features = self.model.encode_image(batch["image"].to(self.device), normalize=True) - all_image_embeddings.append(image_features) - all_labels.append(batch["labels"]) - image_embeddings = torch.concatenate(all_image_embeddings).numpy(force=True) - labels = torch.concatenate(all_labels).numpy(force=True).astype(np.int32) - return image_embeddings, class_embeddings, labels diff --git a/clip_eval/models/__init__.py b/clip_eval/models/__init__.py index 6970a20..7a1abaa 100644 --- a/clip_eval/models/__init__.py +++ b/clip_eval/models/__init__.py @@ -1,2 +1,2 @@ -from .CLIP_model import CLIPModel -from .provider import model_provider +from .base import Model +from .provider import ModelProvider diff --git a/clip_eval/models/base.py b/clip_eval/models/base.py new file mode 100644 index 0000000..909f353 --- /dev/null +++ b/clip_eval/models/base.py @@ -0,0 +1,80 @@ +from abc import ABC, abstractmethod +from collections.abc import Callable +from pathlib import Path +from typing import Any + +import torch +from pydantic import BaseModel, ConfigDict +from torch.utils.data import DataLoader + +from clip_eval.common.numpy_types import ClassArray, EmbeddingArray +from clip_eval.constants import CACHE_PATH + + +class ModelDefinitionSpec(BaseModel): + model_type: str + module_path: Path + title: str + device: str | None = None + title_in_source: str | None = None + cache_dir: Path | None = None + + # Allow additional model configuration fields + # Also, silence spurious Pydantic UserWarning: Field "model_type" has conflict with protected namespace "model_" + model_config = ConfigDict(protected_namespaces=(), extra="allow") + + +class Model(ABC): + def __init__( + self, + title: str, + device: str | None = None, + *, + title_in_source: str | None = None, + cache_dir: str | None = None, + **kwargs, + ) -> None: + self.__title = title + self.__title_in_source = title_in_source or title + device = device or ("cuda" if torch.cuda.is_available() else "cpu") + self._check_device(device) + self.__device = torch.device(device) + if cache_dir is None: + cache_dir = CACHE_PATH + self._cache_dir = Path(cache_dir).expanduser().resolve() / "models" / title + + @property + def title(self) -> str: + return self.__title + + @property + def title_in_source(self) -> str: + return self.__title_in_source + + @property + def device(self) -> torch.device: + return self.__device + + @abstractmethod + def _setup(self, **kwargs) -> None: + pass + + @abstractmethod + def get_transform(self) -> Callable[[dict[str, Any]], dict[str, list[Any]]]: + ... + + @abstractmethod + def get_collate_fn(self) -> Callable[[Any], Any]: + ... + + @abstractmethod + def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, EmbeddingArray, ClassArray]: + ... + + @staticmethod + def _check_device(device: str): + # Check if the input device exists and is available + if device not in {"cuda", "cpu"}: + raise ValueError(f"Unrecognized device: {device}") + if not getattr(torch, device).is_available(): + raise ValueError(f"Unavailable device: {device}") diff --git a/clip_eval/models/provider.py b/clip_eval/models/provider.py index c88ad34..a95a7b2 100644 --- a/clip_eval/models/provider.py +++ b/clip_eval/models/provider.py @@ -1,49 +1,63 @@ +from pathlib import Path +from typing import Any + from natsort import natsorted, ns -from .CLIP_model import CLIPModel, ClosedCLIPModel, OpenCLIPModel -from .local import LocalCLIPModel +from clip_eval.constants import SOURCES_PATH +from clip_eval.dataset.utils import load_class_from_path + +from .base import Model, ModelDefinitionSpec class ModelProvider: + __instance = None + __known_model_types: dict[tuple[Path, str], Any] = dict() + def __init__(self) -> None: self._models = {} - def register_model(self, title: str, source: type[CLIPModel], **kwargs): - self._models[title] = (source, kwargs) - - def get_model(self, title: str) -> CLIPModel: - if title not in self._models: + @classmethod + def prepare(cls): + if cls.__instance is None: + cls.__instance = cls() + cls.register_models_from_sources_dir(SOURCES_PATH.MODEL_INSTANCE_DEFINITIONS) + return cls.__instance + + @classmethod + def register_model(cls, source: type[Model], title: str, **kwargs): + cls.prepare()._models[title] = (source, kwargs) + + @classmethod + def register_model_from_json_definition(cls, json_definition: Path) -> None: + spec = ModelDefinitionSpec.model_validate_json(json_definition.read_text(encoding="utf-8")) + if not spec.module_path.is_absolute(): # Handle relative module paths + spec.module_path = (json_definition.parent / spec.module_path).resolve() + + # Fetch the class of the model type stated in the definition + model_type = cls.__known_model_types.get((spec.module_path, spec.model_type)) + if model_type is None: + model_type = load_class_from_path(spec.module_path.as_posix(), spec.model_type) + if not issubclass(model_type, Model): + raise ValueError( + f"Model type specified in the JSON definition file `{json_definition.as_posix()}` " + f"does not inherit from the base class `Model`" + ) + cls.__known_model_types[(spec.module_path, spec.model_type)] = model_type + cls.register_model(model_type, **spec.model_dump(exclude={"module_path", "model_type"})) + + @classmethod + def register_models_from_sources_dir(cls, source_dir: Path) -> None: + for f in source_dir.glob("*.json"): + cls.register_model_from_json_definition(f) + + @classmethod + def get_model(cls, title: str) -> Model: + instance = cls.prepare() + if title not in instance._models: raise ValueError(f"Unrecognized model: {title}") - source, kwargs = self._models[title] + source, kwargs = instance._models[title] return source(title, **kwargs) - def list_model_titles(self) -> list[str]: - return natsorted(self._models.keys(), alg=ns.IGNORECASE) - - -model_provider = ModelProvider() -model_provider.register_model("clip", ClosedCLIPModel, title_in_source="openai/clip-vit-large-patch14-336") -model_provider.register_model("plip", ClosedCLIPModel, title_in_source="vinid/plip") -model_provider.register_model( - "pubmed", - ClosedCLIPModel, - title_in_source="flaviagiammarino/pubmed-clip-vit-base-patch32", -) -model_provider.register_model( - "tinyclip", - ClosedCLIPModel, - title_in_source="wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M", -) -model_provider.register_model("fashion", ClosedCLIPModel, title_in_source="patrickjohncyh/fashion-clip") -model_provider.register_model("rsicd", ClosedCLIPModel, title_in_source="flax-community/clip-rsicd") -model_provider.register_model("street", ClosedCLIPModel, title_in_source="geolocal/StreetCLIP") -model_provider.register_model("siglip_small", ClosedCLIPModel, title_in_source="google/siglip-base-patch16-224") -model_provider.register_model("siglip_large", ClosedCLIPModel, title_in_source="google/siglip-large-patch16-256") - -model_provider.register_model("apple", OpenCLIPModel, title_in_source="hf-hub:apple/DFN5B-CLIP-ViT-H-14") -model_provider.register_model("eva-clip", OpenCLIPModel, title_in_source="BAAI/EVA-CLIP-8B-448") -model_provider.register_model("bioclip", OpenCLIPModel, title_in_source="hf-hub:imageomics/bioclip") -model_provider.register_model("vit-b-32-laion2b", OpenCLIPModel, title_in_source="ViT-B-32", pretrained="laion2b_e16") - -# Local sources -model_provider.register_model("rsicd-encord", LocalCLIPModel, title_in_source="ViT-B-32") + @classmethod + def list_model_titles(cls) -> list[str]: + return natsorted(cls.prepare()._models.keys(), alg=ns.IGNORECASE) diff --git a/clip_eval/plotting/animation.py b/clip_eval/plotting/animation.py index 038b494..28e0838 100644 --- a/clip_eval/plotting/animation.py +++ b/clip_eval/plotting/animation.py @@ -12,7 +12,7 @@ from clip_eval.common.data_models import SafeName from clip_eval.common.numpy_types import ClassArray, N2Array from clip_eval.constants import OUTPUT_PATH -from clip_eval.dataset import Split, dataset_provider +from clip_eval.dataset import DatasetProvider, Split from .reduction import REDUCTIONS, reduction_from_string @@ -212,7 +212,7 @@ def build_animation( reduction: REDUCTIONS = "umap", interactive: bool = False, ) -> animation.FuncAnimation | None: - dataset = dataset_provider.get_dataset(defn_1.dataset, split) + dataset = DatasetProvider.get_dataset(defn_1.dataset, split) embeds = defn_1.load_embeddings(split) # FIXME: This is expensive to get just labels if embeds is None: diff --git a/sources/datasets/dataset-definition-schema.json b/sources/datasets/dataset-definition-schema.json new file mode 100644 index 0000000..faa98c2 --- /dev/null +++ b/sources/datasets/dataset-definition-schema.json @@ -0,0 +1,70 @@ +{ + "$defs": { + "Split": { + "enum": [ + "train", + "validation", + "test", + "all" + ], + "title": "Split", + "type": "string" + } + }, + "additionalProperties": true, + "properties": { + "dataset_type": { + "title": "Dataset Type", + "type": "string" + }, + "module_path": { + "format": "path", + "title": "Module Path", + "type": "string" + }, + "title": { + "title": "Title", + "type": "string" + }, + "split": { + "allOf": [ + { + "$ref": "#/$defs/Split" + } + ], + "default": "all" + }, + "title_in_source": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Title In Source" + }, + "cache_dir": { + "anyOf": [ + { + "format": "path", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Cache Dir" + } + }, + "required": [ + "dataset_type", + "module_path", + "title" + ], + "title": "DatasetDefinitionSpec", + "type": "object" +} diff --git a/sources/datasets/definitions/Alzheimer-MRI.json b/sources/datasets/definitions/Alzheimer-MRI.json new file mode 100644 index 0000000..63999a5 --- /dev/null +++ b/sources/datasets/definitions/Alzheimer-MRI.json @@ -0,0 +1,6 @@ +{ + "dataset_type": "HFDataset", + "module_path": "../types/hugging_face.py", + "title": "Alzheimer-MRI", + "title_in_source": "Falah/Alzheimer_MRI" +} diff --git a/sources/datasets/definitions/LungCancer4Types.json b/sources/datasets/definitions/LungCancer4Types.json new file mode 100644 index 0000000..e71eb5d --- /dev/null +++ b/sources/datasets/definitions/LungCancer4Types.json @@ -0,0 +1,7 @@ +{ + "dataset_type": "HFDataset", + "module_path": "../types/hugging_face.py", + "title": "LungCancer4Types", + "title_in_source": "Kabil007/LungCancer4Types", + "revision": "a1aab924c6bed6b080fc85552fd7b39724931605" +} diff --git a/sources/datasets/definitions/NIH-Chest-X-ray.json b/sources/datasets/definitions/NIH-Chest-X-ray.json new file mode 100644 index 0000000..c82cea2 --- /dev/null +++ b/sources/datasets/definitions/NIH-Chest-X-ray.json @@ -0,0 +1,9 @@ +{ + "dataset_type": "HFDataset", + "module_path": "../types/hugging_face.py", + "title": "NIH-Chest-X-ray", + "title_in_source": "alkzar90/NIH-Chest-X-ray-dataset", + "name": "image-classification", + "target_feature": "labels", + "trust_remote_code": true +} diff --git a/sources/datasets/definitions/chest-xray-classification.json b/sources/datasets/definitions/chest-xray-classification.json new file mode 100644 index 0000000..1ca4a10 --- /dev/null +++ b/sources/datasets/definitions/chest-xray-classification.json @@ -0,0 +1,8 @@ +{ + "dataset_type": "HFDataset", + "module_path": "../types/hugging_face.py", + "title": "chest-xray-classification", + "title_in_source": "trpakov/chest-xray-classification", + "name": "full", + "target_feature": "labels" +} diff --git a/sources/datasets/definitions/geo-landmarks.json b/sources/datasets/definitions/geo-landmarks.json new file mode 100644 index 0000000..21acac5 --- /dev/null +++ b/sources/datasets/definitions/geo-landmarks.json @@ -0,0 +1,6 @@ +{ + "dataset_type": "HFDataset", + "module_path": "../types/hugging_face.py", + "title": "geo-landmarks", + "title_in_source": "Qdrant/google-landmark-geo" +} diff --git a/sources/datasets/definitions/plants.json b/sources/datasets/definitions/plants.json new file mode 100644 index 0000000..739a4af --- /dev/null +++ b/sources/datasets/definitions/plants.json @@ -0,0 +1,6 @@ +{ + "dataset_type": "HFDataset", + "module_path": "../types/hugging_face.py", + "title": "plants", + "title_in_source": "sampath017/plants" +} diff --git a/sources/datasets/definitions/rsicd.json b/sources/datasets/definitions/rsicd.json new file mode 100644 index 0000000..a5bc102 --- /dev/null +++ b/sources/datasets/definitions/rsicd.json @@ -0,0 +1,7 @@ +{ + "dataset_type": "EncordDataset", + "module_path": "../types/encord_ds.py", + "title": "rsicd", + "project_hash": "46ba913e-1428-48ef-be7f-2553e69bc1e6", + "classification_hash": "4f6cf0c8" +} diff --git a/sources/datasets/definitions/skin-cancer.json b/sources/datasets/definitions/skin-cancer.json new file mode 100644 index 0000000..0394732 --- /dev/null +++ b/sources/datasets/definitions/skin-cancer.json @@ -0,0 +1,7 @@ +{ + "dataset_type": "HFDataset", + "module_path": "../types/hugging_face.py", + "title": "skin-cancer", + "title_in_source": "marmal88/skin_cancer", + "target_feature": "dx" +} diff --git a/sources/datasets/definitions/sports-classification.json b/sources/datasets/definitions/sports-classification.json new file mode 100644 index 0000000..2b188cd --- /dev/null +++ b/sources/datasets/definitions/sports-classification.json @@ -0,0 +1,6 @@ +{ + "dataset_type": "HFDataset", + "module_path": "../types/hugging_face.py", + "title": "sports-classification", + "title_in_source": "HES-XPLAIN/SportsImageClassification" +} diff --git a/clip_eval/dataset/encord_utils.py b/sources/datasets/types/encord_ds.py similarity index 54% rename from clip_eval/dataset/encord_utils.py rename to sources/datasets/types/encord_ds.py index 997e514..77b5773 100644 --- a/clip_eval/dataset/encord_utils.py +++ b/sources/datasets/types/encord_ds.py @@ -1,14 +1,164 @@ import json +import os +from dataclasses import dataclass from pathlib import Path from typing import Any -from encord import Project +from encord import EncordUserClient, Project from encord.common.constants import DATETIME_STRING_FORMAT +from encord.objects import Classification, LabelRowV2 +from encord.objects.common import PropertyType from encord.orm.dataset import DataType, Image, Video -from encord.project import LabelRowV2 +from PIL import Image as PILImage from tqdm.auto import tqdm -from .utils import Split, collect_async, download_file, simple_random_split +from clip_eval.dataset import Dataset, Split +from clip_eval.dataset.utils import collect_async, download_file, simple_random_split + + +class EncordDataset(Dataset): + def __init__( + self, + title: str, + project_hash: str, + classification_hash: str, + *, + split: Split = Split.ALL, + title_in_source: str | None = None, + transform=None, + cache_dir: str | None = None, + ssh_key_path: str | None = None, + **kwargs, + ): + super().__init__(title, split=split, title_in_source=title_in_source, transform=transform, cache_dir=cache_dir) + self._setup(project_hash, classification_hash, ssh_key_path, **kwargs) + + def __getitem__(self, idx): + frame_path = self._dataset_indices_info[idx].image_file + img = PILImage.open(frame_path) + label = self._dataset_indices_info[idx].label + + if self.transform is not None: + _d = self.transform(dict(image=[img], label=[label])) + res_item = dict(image=_d["image"][0], label=_d["label"][0]) + else: + res_item = dict(image=img, label=label) + return res_item + + def __len__(self): + return len(self._dataset_indices_info) + + def _get_frame_file(self, label_row: LabelRowV2, frame: int) -> Path: + return get_frame_file( + data_dir=self._cache_dir, + project_hash=self._project.project_hash, + label_row=label_row, + frame=frame, + ) + + def _get_label_row_annotations_file(self, label_row: LabelRowV2) -> Path: + return get_label_row_annotations_file( + data_dir=self._cache_dir, + project_hash=self._project.project_hash, + label_row_hash=label_row.label_hash, + ) + + def _ensure_answers_availability(self) -> dict: + lrs_info_file = get_label_rows_info_file(self._cache_dir, self._project.project_hash) + label_rows_info: dict = json.loads(lrs_info_file.read_text(encoding="utf-8")) + should_update_info = False + class_name_to_idx = {name: idx for idx, name in enumerate(self.class_names)} # Fast lookup of class indices + for label_row in self._label_rows: + if "answers" not in label_rows_info[label_row.label_hash]: + if not label_row.is_labelling_initialised: + # Retrieve label row content from local storage + anns_path = self._get_label_row_annotations_file(label_row) + label_row.from_labels_dict(json.loads(anns_path.read_text(encoding="utf-8"))) + + answers = dict() + for frame_view in label_row.get_frame_views(): + clf_instances = frame_view.get_classification_instances(self._classification) + # Skip frames where the input classification is missing + if len(clf_instances) == 0: + continue + + clf_instance = clf_instances[0] + clf_answer = clf_instance.get_answer(self._attribute) + # Skip frames where the input classification has no answer (probable annotation error) + if clf_answer is None: + continue + + answers[frame_view.frame] = { + "image_file": self._get_frame_file(label_row, frame_view.frame).as_posix(), + "label": class_name_to_idx[clf_answer.title], + } + label_rows_info[label_row.label_hash]["answers"] = answers + should_update_info = True + if should_update_info: + lrs_info_file.write_text(json.dumps(label_rows_info), encoding="utf-8") + return label_rows_info + + def _setup( + self, + project_hash: str, + classification_hash: str, + ssh_key_path: str | None = None, + **kwargs, + ): + ssh_key_path = ssh_key_path or os.getenv("ENCORD_SSH_KEY_PATH") + if ssh_key_path is None: + raise ValueError( + "The `ssh_key_path` parameter and the `ENCORD_SSH_KEY_PATH` environment variable are both missing." + "Please set one of them to proceed" + ) + client = EncordUserClient.create_with_ssh_private_key(ssh_private_key_path=ssh_key_path) + self._project = client.get_project(project_hash) + + self._classification = self._project.ontology_structure.get_child_by_hash( + classification_hash, type_=Classification + ) + radio_attribute = self._classification.attributes[0] + if radio_attribute.get_property_type() != PropertyType.RADIO: + raise ValueError("Expected a classification hash with an attribute of type `Radio`") + self._attribute = radio_attribute + self.class_names = [o.title for o in self._attribute.options] + + # Fetch the label rows of the selected split + splits_file = self._cache_dir / "splits.json" + split_to_lr_hashes: dict[str, list[str]] + if splits_file.exists(): + split_to_lr_hashes = json.loads(splits_file.read_text(encoding="utf-8")) + else: + split_to_lr_hashes = simple_project_split(self._project) + splits_file.write_text(json.dumps(split_to_lr_hashes), encoding="utf-8") + self._label_rows = self._project.list_label_rows_v2(label_hashes=split_to_lr_hashes[self.split]) + + # Get data from source. Users may supply the `overwrite_annotations` keyword in the init to download everything + download_data_from_project( + self._project, + self._cache_dir, + self._label_rows, + tqdm_desc=f"Downloading {self.split} data from Encord project `{self._project.title}`", + **kwargs, + ) + + # Prepare data for the __getitem__ method + self._dataset_indices_info: list[EncordDataset.DatasetIndexInfo] = [] + label_rows_info = self._ensure_answers_availability() + for label_row in self._label_rows: + answers: dict[int, Any] = label_rows_info[label_row.label_hash]["answers"] + for frame_num in sorted(answers.keys()): + self._dataset_indices_info.append(EncordDataset.DatasetIndexInfo(**answers[frame_num])) + + @dataclass + class DatasetIndexInfo: + image_file: Path | str + label: int + + +# ----------------------------------------------------------------------- +# UTILITY FUNCTIONS +# ----------------------------------------------------------------------- def _download_image(image_data: Image | Video, destination_dir: Path) -> Path: diff --git a/clip_eval/dataset/hf.py b/sources/datasets/types/hugging_face.py similarity index 98% rename from clip_eval/dataset/hf.py rename to sources/datasets/types/hugging_face.py index 82ff75e..145bb10 100644 --- a/clip_eval/dataset/hf.py +++ b/sources/datasets/types/hugging_face.py @@ -1,6 +1,6 @@ from datasets import ClassLabel, DatasetDict, Sequence, Value, load_dataset -from .base import Dataset, Split +from clip_eval.dataset import Dataset, Split class HFDataset(Dataset): diff --git a/sources/models/definitions/apple.json b/sources/models/definitions/apple.json new file mode 100644 index 0000000..28de863 --- /dev/null +++ b/sources/models/definitions/apple.json @@ -0,0 +1,6 @@ +{ + "model_type": "OpenCLIPModel", + "module_path": "../types/open_clip.py", + "title": "apple", + "title_in_source": "hf-hub:apple/DFN5B-CLIP-ViT-H-14" +} diff --git a/sources/models/definitions/bioclip.json b/sources/models/definitions/bioclip.json new file mode 100644 index 0000000..cd9d899 --- /dev/null +++ b/sources/models/definitions/bioclip.json @@ -0,0 +1,6 @@ +{ + "model_type": "OpenCLIPModel", + "module_path": "../types/open_clip.py", + "title": "bioclip", + "title_in_source": "hf-hub:imageomics/bioclip" +} diff --git a/sources/models/definitions/clip.json b/sources/models/definitions/clip.json new file mode 100644 index 0000000..19941f6 --- /dev/null +++ b/sources/models/definitions/clip.json @@ -0,0 +1,6 @@ +{ + "model_type": "ClosedCLIPModel", + "module_path": "../types/hugging_face_clip.py", + "title": "clip", + "title_in_source": "openai/clip-vit-large-patch14-336" +} diff --git a/sources/models/definitions/eva-clip.json b/sources/models/definitions/eva-clip.json new file mode 100644 index 0000000..7d1c999 --- /dev/null +++ b/sources/models/definitions/eva-clip.json @@ -0,0 +1,6 @@ +{ + "model_type": "OpenCLIPModel", + "module_path": "../types/open_clip.py", + "title": "eva-clip", + "title_in_source": "BAAI/EVA-CLIP-8B-448" +} diff --git a/sources/models/definitions/fashion.json b/sources/models/definitions/fashion.json new file mode 100644 index 0000000..e3c1762 --- /dev/null +++ b/sources/models/definitions/fashion.json @@ -0,0 +1,6 @@ +{ + "model_type": "ClosedCLIPModel", + "module_path": "../types/hugging_face_clip.py", + "title": "fashion", + "title_in_source": "patrickjohncyh/fashion-clip" +} diff --git a/sources/models/definitions/plip.json b/sources/models/definitions/plip.json new file mode 100644 index 0000000..058e482 --- /dev/null +++ b/sources/models/definitions/plip.json @@ -0,0 +1,6 @@ +{ + "model_type": "ClosedCLIPModel", + "module_path": "../types/hugging_face_clip.py", + "title": "plip", + "title_in_source": "vinid/plip" +} diff --git a/sources/models/definitions/pubmed.json b/sources/models/definitions/pubmed.json new file mode 100644 index 0000000..23995a5 --- /dev/null +++ b/sources/models/definitions/pubmed.json @@ -0,0 +1,6 @@ +{ + "model_type": "ClosedCLIPModel", + "module_path": "../types/hugging_face_clip.py", + "title": "pubmed", + "title_in_source": "flaviagiammarino/pubmed-clip-vit-base-patch32" +} diff --git a/sources/models/definitions/rsicd-encord.json b/sources/models/definitions/rsicd-encord.json new file mode 100644 index 0000000..88481bb --- /dev/null +++ b/sources/models/definitions/rsicd-encord.json @@ -0,0 +1,6 @@ +{ + "model_type": "LocalCLIPModel", + "module_path": "../types/local_clip.py", + "title": "rsicd-encord", + "title_in_source": "ViT-B-32" +} diff --git a/sources/models/definitions/rsicd.json b/sources/models/definitions/rsicd.json new file mode 100644 index 0000000..e02f5de --- /dev/null +++ b/sources/models/definitions/rsicd.json @@ -0,0 +1,6 @@ +{ + "model_type": "ClosedCLIPModel", + "module_path": "../types/hugging_face_clip.py", + "title": "rsicd", + "title_in_source": "flax-community/clip-rsicd" +} diff --git a/sources/models/definitions/siglip_large.json b/sources/models/definitions/siglip_large.json new file mode 100644 index 0000000..ea356f6 --- /dev/null +++ b/sources/models/definitions/siglip_large.json @@ -0,0 +1,6 @@ +{ + "model_type": "ClosedCLIPModel", + "module_path": "../types/hugging_face_clip.py", + "title": "siglip_large", + "title_in_source": "google/siglip-large-patch16-256" +} diff --git a/sources/models/definitions/siglip_small.json b/sources/models/definitions/siglip_small.json new file mode 100644 index 0000000..5c6db71 --- /dev/null +++ b/sources/models/definitions/siglip_small.json @@ -0,0 +1,6 @@ +{ + "model_type": "ClosedCLIPModel", + "module_path": "../types/hugging_face_clip.py", + "title": "siglip_small", + "title_in_source": "google/siglip-base-patch16-224" +} diff --git a/sources/models/definitions/street.json b/sources/models/definitions/street.json new file mode 100644 index 0000000..404456f --- /dev/null +++ b/sources/models/definitions/street.json @@ -0,0 +1,6 @@ +{ + "model_type": "ClosedCLIPModel", + "module_path": "../types/hugging_face_clip.py", + "title": "street", + "title_in_source": "geolocal/StreetCLIP" +} diff --git a/sources/models/definitions/tinyclip.json b/sources/models/definitions/tinyclip.json new file mode 100644 index 0000000..10d66ac --- /dev/null +++ b/sources/models/definitions/tinyclip.json @@ -0,0 +1,6 @@ +{ + "model_type": "ClosedCLIPModel", + "module_path": "../types/hugging_face_clip.py", + "title": "tinyclip", + "title_in_source": "wkcn/TinyCLIP-ViT-40M-32-Text-19M-LAION400M" +} diff --git a/sources/models/definitions/vit-b-32-laion2b.json b/sources/models/definitions/vit-b-32-laion2b.json new file mode 100644 index 0000000..a100000 --- /dev/null +++ b/sources/models/definitions/vit-b-32-laion2b.json @@ -0,0 +1,7 @@ +{ + "model_type": "OpenCLIPModel", + "module_path": "../types/open_clip.py", + "title": "vit-b-32-laion2b", + "title_in_source": "ViT-B-32", + "pretrained": "laion2b_e16" +} diff --git a/sources/models/model-definition-schema.json b/sources/models/model-definition-schema.json new file mode 100644 index 0000000..40aa14e --- /dev/null +++ b/sources/models/model-definition-schema.json @@ -0,0 +1,62 @@ +{ + "additionalProperties": true, + "properties": { + "model_type": { + "title": "Model Type", + "type": "string" + }, + "module_path": { + "format": "path", + "title": "Module Path", + "type": "string" + }, + "title": { + "title": "Title", + "type": "string" + }, + "device": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Device" + }, + "title_in_source": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Title In Source" + }, + "cache_dir": { + "anyOf": [ + { + "format": "path", + "type": "string" + }, + { + "type": "null" + } + ], + "default": null, + "title": "Cache Dir" + } + }, + "required": [ + "model_type", + "module_path", + "title" + ], + "title": "ModelDefinitionSpec", + "type": "object" +} diff --git a/sources/models/types/hugging_face_clip.py b/sources/models/types/hugging_face_clip.py new file mode 100644 index 0000000..f5b39a4 --- /dev/null +++ b/sources/models/types/hugging_face_clip.py @@ -0,0 +1,76 @@ +from collections.abc import Callable +from typing import Any + +import numpy as np +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import AutoModel as HF_AutoModel +from transformers import AutoProcessor as HF_AutoProcessor +from transformers import AutoTokenizer as HF_AutoTokenizer + +from clip_eval.common import ClassArray, EmbeddingArray +from clip_eval.dataset import Dataset +from clip_eval.models import Model + + +class ClosedCLIPModel(Model): + def __init__( + self, + title: str, + device: str | None = None, + *, + title_in_source: str | None = None, + cache_dir: str | None = None, + **kwargs, + ) -> None: + super().__init__(title, device, title_in_source=title_in_source, cache_dir=cache_dir) + self._setup(**kwargs) + + def get_transform(self) -> Callable[[dict[str, Any]], dict[str, list[Any]]]: + def process_fn(batch) -> dict[str, list[Any]]: + images = [i.convert("RGB") for i in batch["image"]] + batch["image"] = [ + self.processor(images=[i], return_tensors="pt").to(self.device).pixel_values.squeeze() for i in images + ] + return batch + + return process_fn + + def get_collate_fn(self) -> Callable[[Any], Any]: + def collate_fn(examples) -> dict[str, torch.Tensor]: + images = [] + labels = [] + for example in examples: + images.append(example["image"]) + labels.append(example["label"]) + + pixel_values = torch.stack(images) + labels = torch.tensor(labels) + return {"pixel_values": pixel_values, "labels": labels} + + return collate_fn + + def _setup(self, **kwargs) -> None: + self.model = HF_AutoModel.from_pretrained(self.title_in_source, cache_dir=self._cache_dir).to(self.device) + load_result = HF_AutoProcessor.from_pretrained(self.title_in_source, cache_dir=self._cache_dir) + self.processor = load_result[0] if isinstance(load_result, tuple) else load_result + self.tokenizer = HF_AutoTokenizer.from_pretrained(self.title_in_source, cache_dir=self._cache_dir) + + def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, EmbeddingArray, ClassArray]: + all_image_embeddings = [] + all_labels = [] + with torch.inference_mode(): + _dataset: Dataset = dataloader.dataset + inputs = self.tokenizer(_dataset.text_queries, padding=True, return_tensors="pt").to(self.device) + class_features = self.model.get_text_features(**inputs) + normalized_class_features = class_features / class_features.norm(p=2, dim=-1, keepdim=True) + class_embeddings = normalized_class_features.numpy(force=True) + for batch in tqdm(dataloader, desc=f"Embedding dataset with {self.title}"): + image_features = self.model.get_image_features(pixel_values=batch["pixel_values"].to(self.device)) + normalized_image_features = (image_features / image_features.norm(p=2, dim=-1, keepdim=True)).squeeze() + all_image_embeddings.append(normalized_image_features) + all_labels.append(batch["labels"]) + image_embeddings = torch.concatenate(all_image_embeddings).numpy(force=True) + labels = torch.concatenate(all_labels).numpy(force=True).astype(np.int32) + return image_embeddings, class_embeddings, labels diff --git a/clip_eval/models/local.py b/sources/models/types/local_clip.py similarity index 89% rename from clip_eval/models/local.py rename to sources/models/types/local_clip.py index 6103b06..fd55f3f 100644 --- a/clip_eval/models/local.py +++ b/sources/models/types/local_clip.py @@ -1,4 +1,4 @@ -from .CLIP_model import OpenCLIPModel +from sources.models.types.open_clip import OpenCLIPModel class LocalCLIPModel(OpenCLIPModel): diff --git a/sources/models/types/open_clip.py b/sources/models/types/open_clip.py new file mode 100644 index 0000000..65aab6c --- /dev/null +++ b/sources/models/types/open_clip.py @@ -0,0 +1,75 @@ +from collections.abc import Callable +from typing import Any + +import numpy as np +import open_clip +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +from clip_eval.common import ClassArray, EmbeddingArray +from clip_eval.dataset import Dataset +from clip_eval.models import Model + + +class OpenCLIPModel(Model): + def __init__( + self, + title: str, + device: str | None = None, + *, + title_in_source: str, + pretrained: str | None = None, + cache_dir: str | None = None, + **kwargs, + ) -> None: + self.pretrained = pretrained + super().__init__(title, device, title_in_source=title_in_source, cache_dir=cache_dir, **kwargs) + self._setup(**kwargs) + + def get_transform(self) -> Callable[[dict[str, Any]], dict[str, list[Any]]]: + def process_fn(batch) -> dict[str, list[Any]]: + images = [i.convert("RGB") for i in batch["image"]] + batch["image"] = [self.processor(i) for i in images] + return batch + + return process_fn + + def get_collate_fn(self) -> Callable[[Any], Any]: + def collate_fn(examples) -> dict[str, torch.Tensor]: + images = [] + labels = [] + for example in examples: + images.append(example["image"]) + labels.append(example["label"]) + + torch_images = torch.stack(images) + labels = torch.tensor(labels) + return {"image": torch_images, "labels": labels} + + return collate_fn + + def _setup(self, **kwargs) -> None: + self.model, _, self.processor = open_clip.create_model_and_transforms( + model_name=self.title_in_source, + pretrained=self.pretrained, + cache_dir=self._cache_dir.as_posix(), + device=self.device, + **kwargs, + ) + self.tokenizer = open_clip.get_tokenizer(model_name=self.title_in_source) + + def build_embedding(self, dataloader: DataLoader) -> tuple[EmbeddingArray, EmbeddingArray, ClassArray]: + all_image_embeddings = [] + all_labels = [] + with torch.inference_mode(): + _dataset: Dataset = dataloader.dataset + text = self.tokenizer(_dataset.text_queries).to(self.device) + class_embeddings = self.model.encode_text(text, normalize=True).numpy(force=True) + for batch in tqdm(dataloader, desc=f"Embedding dataset with {self.title}"): + image_features = self.model.encode_image(batch["image"].to(self.device), normalize=True) + all_image_embeddings.append(image_features) + all_labels.append(batch["labels"]) + image_embeddings = torch.concatenate(all_image_embeddings).numpy(force=True) + labels = torch.concatenate(all_labels).numpy(force=True).astype(np.int32) + return image_embeddings, class_embeddings, labels