Skip to content

Commit

Permalink
feat: detach resources from providers (#53)
Browse files Browse the repository at this point in the history
Types and definitions of models and datasets are now dynamically fetched
from a `sources` folder. With this change, previously built-in resources
can be modified, silenced or eliminated if the user wants to. Also,
enables seamless addition of new models and datasets by just adding the
corresponding implementation (new type) or definition (instance of an
existing type).

On the side, refactored the naming convention of several components that
were moved with the previous changes, so they share the same patterns.
E.g. the model base class went from `CLIPModel` to `Model` as the class
itself is related to a higher abstraction than the one of clip.

Also, added the singleton pattern and some minor changes to the dataset
and model providers in order to avoid partially initialised module
errors that appeared when dynamic resources were (tried) loaded before
all local resources were successfully imported. The singleton pattern is
specifically used to avoid the creation of several provider instances as
the original global provider couldn't be loaded on initialisation.
  • Loading branch information
eloy-encord authored Apr 15, 2024
1 parent 190a90f commit 354a5a0
Show file tree
Hide file tree
Showing 45 changed files with 836 additions and 493 deletions.
8 changes: 4 additions & 4 deletions clip_eval/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,12 +130,12 @@ def list_models_datasets(
Option(help="List all models and dataset that are available via the tool."),
] = False,
):
from clip_eval.dataset.provider import dataset_provider
from clip_eval.models import model_provider
from clip_eval.dataset import DatasetProvider
from clip_eval.models import ModelProvider

if all:
datasets = dataset_provider.list_dataset_titles()
models = model_provider.list_model_titles()
datasets = DatasetProvider.list_dataset_titles()
models = ModelProvider.list_model_titles()
print(f"Available datasets are: {', '.join(datasets)}")
print(f"Available models are: {', '.join(models)}")
return
Expand Down
8 changes: 4 additions & 4 deletions clip_eval/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from natsort import natsorted, ns

from clip_eval.common.data_models import EmbeddingDefinition
from clip_eval.dataset.provider import dataset_provider
from clip_eval.models.provider import model_provider
from clip_eval.dataset import DatasetProvider
from clip_eval.models import ModelProvider
from clip_eval.utils import read_all_cached_embeddings


Expand Down Expand Up @@ -80,8 +80,8 @@ def select_from_all_embedding_definitions(
) -> list[EmbeddingDefinition]:
existing = set(read_all_cached_embeddings(as_list=True))

models = model_provider.list_model_titles()
datasets = dataset_provider.list_dataset_titles()
models = ModelProvider.list_model_titles()
datasets = DatasetProvider.list_dataset_titles()

defs = [EmbeddingDefinition(dataset=d, model=m) for d, m in product(datasets, models)]
if not include_existing:
Expand Down
4 changes: 2 additions & 2 deletions clip_eval/common/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from clip_eval.constants import NPZ_KEYS
from clip_eval.dataset import Dataset
from clip_eval.models import CLIPModel
from clip_eval.models import Model

from .numpy_types import ClassArray, EmbeddingArray

Expand Down Expand Up @@ -69,7 +69,7 @@ def to_file(self, path: Path) -> Path:
return path

@staticmethod
def build_embedding(model: CLIPModel, dataset: Dataset, batch_size: int = 50) -> "Embeddings":
def build_embedding(model: Model, dataset: Dataset, batch_size: int = 50) -> "Embeddings":
dataset.set_transform(model.get_transform())
dataloader = DataLoader(dataset, collate_fn=model.get_collate_fn(), batch_size=batch_size)

Expand Down
8 changes: 4 additions & 4 deletions clip_eval/common/data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from pydantic.functional_validators import AfterValidator

from clip_eval.constants import PROJECT_PATHS
from clip_eval.dataset import Split, dataset_provider
from clip_eval.models import model_provider
from clip_eval.dataset import DatasetProvider, Split
from clip_eval.models import ModelProvider

from .base import Embeddings
from .string_utils import safe_str
Expand Down Expand Up @@ -61,8 +61,8 @@ def save_embeddings(self, embeddings: Embeddings, split: Split, overwrite: bool
return True

def build_embeddings(self, split: Split) -> Embeddings:
model = model_provider.get_model(self.model)
dataset = dataset_provider.get_dataset(self.dataset, split)
model = ModelProvider.get_model(self.model)
dataset = DatasetProvider.get_dataset(self.dataset, split)
return Embeddings.build_embedding(model, dataset)

def __str__(self):
Expand Down
13 changes: 11 additions & 2 deletions clip_eval/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
load_dotenv()

# If the cache directory is not explicitly specified, use the `.cache` directory located in the project's root.
CACHE_PATH = Path(os.environ.get("CLIP_CACHE_PATH", Path(__file__).parent.parent / ".cache"))
_OUTPUT_PATH = Path(os.environ.get("OUTPUT_PATH", "output"))
_CLIP_EVAL_ROOT_DIR = Path(__file__).parent.parent
CACHE_PATH = Path(os.environ.get("CLIP_EVAL_CACHE_PATH", _CLIP_EVAL_ROOT_DIR / ".cache"))
_OUTPUT_PATH = Path(os.environ.get("CLIP_EVAL_OUTPUT_PATH", _CLIP_EVAL_ROOT_DIR / "output"))
_SOURCES_PATH = _CLIP_EVAL_ROOT_DIR / "sources"


class PROJECT_PATHS:
Expand All @@ -25,3 +27,10 @@ class NPZ_KEYS:
class OUTPUT_PATH:
ANIMATIONS = _OUTPUT_PATH / "animations"
EVALUATIONS = _OUTPUT_PATH / "evaluations"


class SOURCES_PATH:
DATASET_TYPES = _SOURCES_PATH / "datasets" / "types"
DATASET_INSTANCE_DEFINITIONS = _SOURCES_PATH / "datasets" / "definitions"
MODEL_TYPES = _SOURCES_PATH / "models" / "types"
MODEL_INSTANCE_DEFINITIONS = _SOURCES_PATH / "models" / "definitions"
2 changes: 1 addition & 1 deletion clip_eval/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .base import Dataset, Split
from .provider import dataset_provider
from .provider import DatasetProvider
13 changes: 13 additions & 0 deletions clip_eval/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from enum import StrEnum, auto
from pathlib import Path

from pydantic import BaseModel, ConfigDict
from torch.utils.data import Dataset as TorchDataset

from clip_eval.constants import CACHE_PATH
Expand All @@ -14,6 +15,18 @@ class Split(StrEnum):
ALL = auto()


class DatasetDefinitionSpec(BaseModel):
dataset_type: str
module_path: Path
title: str
split: Split = Split.ALL
title_in_source: str | None = None
cache_dir: Path | None = None

# Allow additional dataset configuration fields
model_config = ConfigDict(extra="allow")


class Dataset(TorchDataset, ABC):
def __init__(
self,
Expand Down
159 changes: 0 additions & 159 deletions clip_eval/dataset/encord.py

This file was deleted.

Loading

0 comments on commit 354a5a0

Please sign in to comment.