Skip to content

chore: remove imports from common submodule to other submodules #61

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions clip_eval/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import matplotlib.pyplot as plt
from typer import Option, Typer

from clip_eval.dataset import Split
from clip_eval.common import Split
from clip_eval.compute import compute_embeddings_from_definition
from clip_eval.utils import read_all_cached_embeddings

from .utils import (
Expand Down Expand Up @@ -52,7 +53,7 @@ def build_command(
for embd_defn in definitions:
for split in splits:
try:
embeddings = embd_defn.build_embeddings(split)
embeddings = compute_embeddings_from_definition(embd_defn, split)
embd_defn.save_embeddings(embeddings=embeddings, split=split, overwrite=True)
print(f"Embeddings saved successfully to file at `{embd_defn.embedding_path(split)}`")
except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion clip_eval/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from InquirerPy.base.control import Choice
from natsort import natsorted, ns

from clip_eval.common.data_models import EmbeddingDefinition
from clip_eval.common import EmbeddingDefinition
from clip_eval.dataset import DatasetProvider
from clip_eval.model import ModelProvider
from clip_eval.utils import read_all_cached_embeddings
Expand Down
2 changes: 1 addition & 1 deletion clip_eval/common/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .data_models import EmbeddingDefinition, Embeddings
from .base import EmbeddingDefinition, Embeddings, Split
from .numpy_types import ClassArray, EmbeddingArray, ProbabilityArray, ReductionArray
82 changes: 69 additions & 13 deletions clip_eval/common/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,26 @@
import logging
from enum import StrEnum, auto
from pathlib import Path
from typing import Annotated, Any

import numpy as np
from pydantic import BaseModel, model_validator
from torch.utils.data import DataLoader
from pydantic.functional_validators import AfterValidator

from clip_eval.constants import NPZ_KEYS
from clip_eval.dataset import Dataset
from clip_eval.model import Model
from clip_eval.constants import NPZ_KEYS, PROJECT_PATHS

from .numpy_types import ClassArray, EmbeddingArray
from .string_utils import safe_str

SafeName = Annotated[str, AfterValidator(safe_str)]
logger = logging.getLogger("multiclips")


class Split(StrEnum):
TRAIN = auto()
VALIDATION = auto()
TEST = auto()
ALL = auto()


class Embeddings(BaseModel):
Expand Down Expand Up @@ -68,14 +80,58 @@ def to_file(self, path: Path) -> Path:
)
return path

@staticmethod
def build_embedding(model: Model, dataset: Dataset, batch_size: int = 50) -> "Embeddings":
dataset.set_transform(model.get_transform())
dataloader = DataLoader(dataset, collate_fn=model.get_collate_fn(), batch_size=batch_size)

image_embeddings, class_embeddings, labels = model.build_embedding(dataloader)
embeddings = Embeddings(images=image_embeddings, classes=class_embeddings, labels=labels)
return embeddings

class Config:
arbitrary_types_allowed = True


class EmbeddingDefinition(BaseModel):
model: SafeName
dataset: SafeName

def _get_embedding_path(self, split: Split, suffix: str) -> Path:
return Path(self.dataset) / f"{self.model}_{split}{suffix}"

def embedding_path(self, split: Split) -> Path:
return PROJECT_PATHS.EMBEDDINGS / self._get_embedding_path(split, ".npz")

def get_reduction_path(self, reduction_name: str, split: Split):
return PROJECT_PATHS.REDUCTIONS / self._get_embedding_path(split, f".{reduction_name}.2d.npy")

def load_embeddings(self, split: Split) -> Embeddings | None:
"""
Load embeddings for embedding configuration or return None
"""
try:
return Embeddings.from_file(self.embedding_path(split))
except ValueError:
return None

def save_embeddings(self, embeddings: Embeddings, split: Split, overwrite: bool = False) -> bool:
"""
Save embeddings associated to the embedding definition.
Args:
embeddings: The embeddings to store
split: The dataset split that corresponds to the embeddings
overwrite: If false, won't overwrite and will return False

Returns:
True iff file stored successfully

"""
if self.embedding_path(split).is_file() and not overwrite:
logger.warning(
f"Not saving embeddings to file `{self.embedding_path(split)}` as overwrite is False and file exists"
)
return False
self.embedding_path(split).parent.mkdir(exist_ok=True, parents=True)
embeddings.to_file(self.embedding_path(split))
return True

def __str__(self):
return self.model + "_" + self.dataset

def __eq__(self, other: Any) -> bool:
return isinstance(other, EmbeddingDefinition) and self.model == other.model and self.dataset == other.dataset

def __hash__(self):
return hash((self.model, self.dataset))
114 changes: 0 additions & 114 deletions clip_eval/common/data_models.py

This file was deleted.

20 changes: 20 additions & 0 deletions clip_eval/compute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from torch.utils.data import DataLoader

from clip_eval.common import EmbeddingDefinition, Embeddings, Split
from clip_eval.dataset import Dataset, DatasetProvider
from clip_eval.model import Model, ModelProvider


def compute_embeddings(model: Model, dataset: Dataset, batch_size: int = 50) -> Embeddings:
dataset.set_transform(model.get_transform())
dataloader = DataLoader(dataset, collate_fn=model.get_collate_fn(), batch_size=batch_size)

image_embeddings, class_embeddings, labels = model.build_embedding(dataloader)
embeddings = Embeddings(images=image_embeddings, classes=class_embeddings, labels=labels)
return embeddings


def compute_embeddings_from_definition(definition: EmbeddingDefinition, split: Split) -> Embeddings:
model = ModelProvider.get_model(definition.model)
dataset = DatasetProvider.get_dataset(definition.dataset, split)
return compute_embeddings(model, dataset)
9 changes: 1 addition & 8 deletions clip_eval/dataset/base.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,18 @@
from abc import ABC, abstractmethod
from enum import StrEnum, auto
from os.path import relpath
from pathlib import Path

from pydantic import BaseModel, ConfigDict
from torch.utils.data import Dataset as TorchDataset

from clip_eval.common import Split
from clip_eval.constants import CACHE_PATH, SOURCES_PATH

DEFAULT_DATASET_TYPES_LOCATION = (
Path(relpath(str(__file__), SOURCES_PATH.DATASET_INSTANCE_DEFINITIONS)).parent / "types" / "__init__.py"
)


class Split(StrEnum):
TRAIN = auto()
VALIDATION = auto()
TEST = auto()
ALL = auto()


class DatasetDefinitionSpec(BaseModel):
dataset_type: str
title: str
Expand Down
2 changes: 1 addition & 1 deletion clip_eval/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import requests
from tqdm.auto import tqdm

from .base import Split
from clip_eval.common import Split

T = TypeVar("T")
G = TypeVar("G")
Expand Down
2 changes: 1 addition & 1 deletion clip_eval/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from natsort import natsorted, ns
from tabulate import tabulate

from clip_eval.common.data_models import EmbeddingDefinition, Split
from clip_eval.common import EmbeddingDefinition, Split
from clip_eval.constants import OUTPUT_PATH
from clip_eval.evaluation import (
EvaluationModel,
Expand Down
5 changes: 3 additions & 2 deletions clip_eval/evaluation/image_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
import numpy as np
from autofaiss import build_index

from clip_eval.common.data_models import Embeddings
from clip_eval.evaluation.base import EvaluationModel
from clip_eval.common import Embeddings

from .base import EvaluationModel

logger = logging.getLogger("multiclips")

Expand Down
8 changes: 4 additions & 4 deletions clip_eval/evaluation/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import numpy as np
from autofaiss import build_index

from clip_eval.common.data_models import Embeddings
from clip_eval.common.numpy_types import ClassArray, ProbabilityArray
from clip_eval.evaluation.base import ClassificationModel
from clip_eval.evaluation.utils import softmax
from clip_eval.common import ClassArray, Embeddings, ProbabilityArray

from .base import ClassificationModel
from .utils import softmax

logger = logging.getLogger("multiclips")

Expand Down
2 changes: 1 addition & 1 deletion clip_eval/plotting/animation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from PIL import Image

from clip_eval.common import EmbeddingDefinition
from clip_eval.common.data_models import SafeName
from clip_eval.common.base import SafeName
from clip_eval.common.numpy_types import ClassArray, N2Array
from clip_eval.constants import OUTPUT_PATH
from clip_eval.dataset import DatasetProvider, Split
Expand Down
3 changes: 1 addition & 2 deletions clip_eval/plotting/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

from clip_eval.common import EmbeddingArray, EmbeddingDefinition, ReductionArray
from clip_eval.dataset import Split
from clip_eval.common import EmbeddingArray, EmbeddingDefinition, ReductionArray, Split


class Reducer:
Expand Down
2 changes: 1 addition & 1 deletion clip_eval/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from itertools import chain
from typing import Literal, overload

from clip_eval.common.data_models import EmbeddingDefinition
from clip_eval.common import EmbeddingDefinition
from clip_eval.constants import PROJECT_PATHS


Expand Down
Loading
Loading