Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: set fixed evaluators order and title as classmethod #76

Merged
merged 3 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions tti_eval/evaluation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,7 @@
from .utils import normalize


class EvaluationModelTitleInterface:
# Enforce the evaluation models' title while removing its explicit mention in the `EvaluationModel` init.
# This way, when `EvaluationModel` is used as a type hint, there won't be a warning about unfilled title.
def __init__(self, title: str, **kwargs) -> None:
self._title = title

@property
def title(self) -> str:
return self._title


class EvaluationModel(EvaluationModelTitleInterface, ABC):
class EvaluationModel(ABC):
def __init__(
self,
train_embeddings: Embeddings,
Expand Down Expand Up @@ -95,6 +84,11 @@ def num_classes(self) -> int:
def evaluate(self) -> float:
...

@classmethod
@abstractmethod
def title(cls) -> str:
...


class ClassificationModel(EvaluationModel):
@abstractmethod
Expand Down
15 changes: 9 additions & 6 deletions tti_eval/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from natsort import natsorted, ns
from tabulate import tabulate
from tqdm.auto import tqdm

from tti_eval.common import EmbeddingDefinition, Split
from tti_eval.constants import OUTPUT_PATH
Expand Down Expand Up @@ -47,9 +48,9 @@ def run_evaluation(
embedding_definitions: list[EmbeddingDefinition],
) -> dict[EmbeddingDefinition, dict[str, float]]:
embeddings_performance: dict[EmbeddingDefinition, dict[str, float]] = {}
model_keys: set[str] = set()
used_evaluators: set[str] = set()

for def_ in embedding_definitions:
for def_ in tqdm(embedding_definitions, desc="Evaluating embedding definitions", leave=False):
train_embeddings = def_.load_embeddings(Split.TRAIN)
validation_embeddings = def_.load_embeddings(Split.VALIDATION)

Expand All @@ -68,11 +69,13 @@ def run_evaluation(
train_embeddings=train_embeddings,
validation_embeddings=validation_embeddings,
)
evaluator_performance[evaluator.title] = evaluator.evaluate()
model_keys.add(evaluator.title)
evaluator_performance[evaluator.title()] = evaluator.evaluate()
used_evaluators.add(evaluator.title())

for n in model_keys:
print_evaluation_results(embeddings_performance, n)
for evaluator_type in evaluators:
evaluator_title = evaluator_type.title()
if evaluator_title in used_evaluators:
print_evaluation_results(embeddings_performance, evaluator_title)
return embeddings_performance


Expand Down
9 changes: 8 additions & 1 deletion tti_eval/evaluation/image_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@
from autofaiss import build_index

from tti_eval.common import Embeddings
from tti_eval.utils import disable_tqdm, enable_tqdm

from .base import EvaluationModel

logger = logging.getLogger("multiclips")


class I2IRetrievalEvaluator(EvaluationModel):
@classmethod
def title(cls) -> str:
return "I2IR"

def __init__(
self,
train_embeddings: Embeddings,
Expand All @@ -33,14 +38,16 @@ def __init__(

:raises ValueError: If the build of the faiss index for similarity search fails.
"""
super().__init__(train_embeddings, validation_embeddings, num_classes, title="I2IR")
super().__init__(train_embeddings, validation_embeddings, num_classes)
self.k = min(k, len(validation_embeddings.images))

class_ids, counts = np.unique(self._val_embeddings.labels, return_counts=True)
self._class_counts = np.zeros(self.num_classes, dtype=np.int32)
self._class_counts[class_ids] = counts

disable_tqdm() # Disable tqdm progress bar when building the index
index, self.index_infos = build_index(self._val_embeddings.images, save_on_disk=False, verbose=logging.ERROR)
enable_tqdm()
if index is None:
raise ValueError("Failed to build an index for knn search")
self._index = index
Expand Down
15 changes: 9 additions & 6 deletions tti_eval/evaluation/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from autofaiss import build_index

from tti_eval.common import ClassArray, Embeddings, ProbabilityArray
from tti_eval.utils import disable_tqdm, enable_tqdm

from .base import ClassificationModel
from .utils import softmax
Expand All @@ -13,6 +14,10 @@


class WeightedKNNClassifier(ClassificationModel):
@classmethod
def title(cls) -> str:
return "wKNN"

def __init__(
self,
train_embeddings: Embeddings,
Expand All @@ -36,15 +41,13 @@ def __init__(

:raises ValueError: If the build of the faiss index for KNN fails.
"""
super().__init__(train_embeddings, validation_embeddings, num_classes, title="wKNN")
super().__init__(train_embeddings, validation_embeddings, num_classes)
self.k = k

disable_tqdm() # Disable tqdm progress bar when building the index
index, self.index_infos = build_index(
train_embeddings.images,
metric_type="l2",
save_on_disk=False,
verbose=logging.ERROR,
train_embeddings.images, metric_type="l2", save_on_disk=False, verbose=logging.ERROR
)
enable_tqdm()
if index is None:
raise ValueError("Failed to build an index for knn search")
self._index = index
Expand Down
6 changes: 5 additions & 1 deletion tti_eval/evaluation/linear_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@


class LinearProbeClassifier(ClassificationModel):
@classmethod
def title(cls) -> str:
return "linear_probe"

def __init__(
self,
train_embeddings: Embeddings,
Expand All @@ -28,7 +32,7 @@ def __init__(
:param log_reg_params: Parameters for the Logistic Regression model.
:param use_cross_validation: Flag that indicated whether to use cross-validation when training the model.
"""
super().__init__(train_embeddings, validation_embeddings, num_classes, title="linear_probe")
super().__init__(train_embeddings, validation_embeddings, num_classes)

params = log_reg_params or {}
self.classifier: LogisticRegressionCV | LogisticRegression
Expand Down
6 changes: 5 additions & 1 deletion tti_eval/evaluation/zero_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@


class ZeroShotClassifier(ClassificationModel):
@classmethod
def title(cls) -> str:
return "zero_shot"

def __init__(
self,
train_embeddings: Embeddings,
Expand All @@ -19,7 +23,7 @@ def __init__(
:param validation_embeddings: Embeddings and their labels used for evaluating the search space.
:param num_classes: Number of classes. If not specified, it will be inferred from the train labels.
"""
super().__init__(train_embeddings, validation_embeddings, num_classes, title="zero_shot")
super().__init__(train_embeddings, validation_embeddings, num_classes)
if self._train_embeddings.classes is None:
raise ValueError("Expected class embeddings in `train_embeddings`, got `None`")

Expand Down
11 changes: 11 additions & 0 deletions tti_eval/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
from functools import partialmethod
from itertools import chain
from typing import Literal, overload

from tqdm import tqdm

from tti_eval.common import EmbeddingDefinition
from tti_eval.constants import PROJECT_PATHS


def disable_tqdm():
tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)


def enable_tqdm():
tqdm.__init__ = partialmethod(tqdm.__init__, disable=False)


@overload
def read_all_cached_embeddings(as_list: Literal[True]) -> list[EmbeddingDefinition]:
...
Expand Down
Loading