From f8d923318664c9dc648dda5c3aa8a9e1edc89336 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eloy=20P=C3=A9rez=20Torres?= <99720527+eloy-encord@users.noreply.github.com> Date: Thu, 9 May 2024 16:35:52 +0100 Subject: [PATCH] fix: set fixed evaluators order and title as classmethod (#76) --- tti_eval/evaluation/base.py | 18 ++++++------------ tti_eval/evaluation/evaluator.py | 15 +++++++++------ tti_eval/evaluation/image_retrieval.py | 9 ++++++++- tti_eval/evaluation/knn.py | 15 +++++++++------ tti_eval/evaluation/linear_probe.py | 6 +++++- tti_eval/evaluation/zero_shot.py | 6 +++++- tti_eval/utils.py | 11 +++++++++++ 7 files changed, 53 insertions(+), 27 deletions(-) diff --git a/tti_eval/evaluation/base.py b/tti_eval/evaluation/base.py index e529b50..6c74d43 100644 --- a/tti_eval/evaluation/base.py +++ b/tti_eval/evaluation/base.py @@ -6,18 +6,7 @@ from .utils import normalize -class EvaluationModelTitleInterface: - # Enforce the evaluation models' title while removing its explicit mention in the `EvaluationModel` init. - # This way, when `EvaluationModel` is used as a type hint, there won't be a warning about unfilled title. - def __init__(self, title: str, **kwargs) -> None: - self._title = title - - @property - def title(self) -> str: - return self._title - - -class EvaluationModel(EvaluationModelTitleInterface, ABC): +class EvaluationModel(ABC): def __init__( self, train_embeddings: Embeddings, @@ -95,6 +84,11 @@ def num_classes(self) -> int: def evaluate(self) -> float: ... + @classmethod + @abstractmethod + def title(cls) -> str: + ... + class ClassificationModel(EvaluationModel): @abstractmethod diff --git a/tti_eval/evaluation/evaluator.py b/tti_eval/evaluation/evaluator.py index d62fce8..a0ca294 100644 --- a/tti_eval/evaluation/evaluator.py +++ b/tti_eval/evaluation/evaluator.py @@ -3,6 +3,7 @@ from natsort import natsorted, ns from tabulate import tabulate +from tqdm.auto import tqdm from tti_eval.common import EmbeddingDefinition, Split from tti_eval.constants import OUTPUT_PATH @@ -47,9 +48,9 @@ def run_evaluation( embedding_definitions: list[EmbeddingDefinition], ) -> dict[EmbeddingDefinition, dict[str, float]]: embeddings_performance: dict[EmbeddingDefinition, dict[str, float]] = {} - model_keys: set[str] = set() + used_evaluators: set[str] = set() - for def_ in embedding_definitions: + for def_ in tqdm(embedding_definitions, desc="Evaluating embedding definitions", leave=False): train_embeddings = def_.load_embeddings(Split.TRAIN) validation_embeddings = def_.load_embeddings(Split.VALIDATION) @@ -68,11 +69,13 @@ def run_evaluation( train_embeddings=train_embeddings, validation_embeddings=validation_embeddings, ) - evaluator_performance[evaluator.title] = evaluator.evaluate() - model_keys.add(evaluator.title) + evaluator_performance[evaluator.title()] = evaluator.evaluate() + used_evaluators.add(evaluator.title()) - for n in model_keys: - print_evaluation_results(embeddings_performance, n) + for evaluator_type in evaluators: + evaluator_title = evaluator_type.title() + if evaluator_title in used_evaluators: + print_evaluation_results(embeddings_performance, evaluator_title) return embeddings_performance diff --git a/tti_eval/evaluation/image_retrieval.py b/tti_eval/evaluation/image_retrieval.py index b9b617c..fba005b 100644 --- a/tti_eval/evaluation/image_retrieval.py +++ b/tti_eval/evaluation/image_retrieval.py @@ -5,6 +5,7 @@ from autofaiss import build_index from tti_eval.common import Embeddings +from tti_eval.utils import disable_tqdm, enable_tqdm from .base import EvaluationModel @@ -12,6 +13,10 @@ class I2IRetrievalEvaluator(EvaluationModel): + @classmethod + def title(cls) -> str: + return "I2IR" + def __init__( self, train_embeddings: Embeddings, @@ -33,14 +38,16 @@ def __init__( :raises ValueError: If the build of the faiss index for similarity search fails. """ - super().__init__(train_embeddings, validation_embeddings, num_classes, title="I2IR") + super().__init__(train_embeddings, validation_embeddings, num_classes) self.k = min(k, len(validation_embeddings.images)) class_ids, counts = np.unique(self._val_embeddings.labels, return_counts=True) self._class_counts = np.zeros(self.num_classes, dtype=np.int32) self._class_counts[class_ids] = counts + disable_tqdm() # Disable tqdm progress bar when building the index index, self.index_infos = build_index(self._val_embeddings.images, save_on_disk=False, verbose=logging.ERROR) + enable_tqdm() if index is None: raise ValueError("Failed to build an index for knn search") self._index = index diff --git a/tti_eval/evaluation/knn.py b/tti_eval/evaluation/knn.py index 47d754b..dadaaf6 100644 --- a/tti_eval/evaluation/knn.py +++ b/tti_eval/evaluation/knn.py @@ -5,6 +5,7 @@ from autofaiss import build_index from tti_eval.common import ClassArray, Embeddings, ProbabilityArray +from tti_eval.utils import disable_tqdm, enable_tqdm from .base import ClassificationModel from .utils import softmax @@ -13,6 +14,10 @@ class WeightedKNNClassifier(ClassificationModel): + @classmethod + def title(cls) -> str: + return "wKNN" + def __init__( self, train_embeddings: Embeddings, @@ -36,15 +41,13 @@ def __init__( :raises ValueError: If the build of the faiss index for KNN fails. """ - super().__init__(train_embeddings, validation_embeddings, num_classes, title="wKNN") + super().__init__(train_embeddings, validation_embeddings, num_classes) self.k = k - + disable_tqdm() # Disable tqdm progress bar when building the index index, self.index_infos = build_index( - train_embeddings.images, - metric_type="l2", - save_on_disk=False, - verbose=logging.ERROR, + train_embeddings.images, metric_type="l2", save_on_disk=False, verbose=logging.ERROR ) + enable_tqdm() if index is None: raise ValueError("Failed to build an index for knn search") self._index = index diff --git a/tti_eval/evaluation/linear_probe.py b/tti_eval/evaluation/linear_probe.py index 3076a37..0286779 100644 --- a/tti_eval/evaluation/linear_probe.py +++ b/tti_eval/evaluation/linear_probe.py @@ -11,6 +11,10 @@ class LinearProbeClassifier(ClassificationModel): + @classmethod + def title(cls) -> str: + return "linear_probe" + def __init__( self, train_embeddings: Embeddings, @@ -28,7 +32,7 @@ def __init__( :param log_reg_params: Parameters for the Logistic Regression model. :param use_cross_validation: Flag that indicated whether to use cross-validation when training the model. """ - super().__init__(train_embeddings, validation_embeddings, num_classes, title="linear_probe") + super().__init__(train_embeddings, validation_embeddings, num_classes) params = log_reg_params or {} self.classifier: LogisticRegressionCV | LogisticRegression diff --git a/tti_eval/evaluation/zero_shot.py b/tti_eval/evaluation/zero_shot.py index dace450..d70a5c7 100644 --- a/tti_eval/evaluation/zero_shot.py +++ b/tti_eval/evaluation/zero_shot.py @@ -6,6 +6,10 @@ class ZeroShotClassifier(ClassificationModel): + @classmethod + def title(cls) -> str: + return "zero_shot" + def __init__( self, train_embeddings: Embeddings, @@ -19,7 +23,7 @@ def __init__( :param validation_embeddings: Embeddings and their labels used for evaluating the search space. :param num_classes: Number of classes. If not specified, it will be inferred from the train labels. """ - super().__init__(train_embeddings, validation_embeddings, num_classes, title="zero_shot") + super().__init__(train_embeddings, validation_embeddings, num_classes) if self._train_embeddings.classes is None: raise ValueError("Expected class embeddings in `train_embeddings`, got `None`") diff --git a/tti_eval/utils.py b/tti_eval/utils.py index b094f18..ab050ca 100644 --- a/tti_eval/utils.py +++ b/tti_eval/utils.py @@ -1,10 +1,21 @@ +from functools import partialmethod from itertools import chain from typing import Literal, overload +from tqdm import tqdm + from tti_eval.common import EmbeddingDefinition from tti_eval.constants import PROJECT_PATHS +def disable_tqdm(): + tqdm.__init__ = partialmethod(tqdm.__init__, disable=True) + + +def enable_tqdm(): + tqdm.__init__ = partialmethod(tqdm.__init__, disable=False) + + @overload def read_all_cached_embeddings(as_list: Literal[True]) -> list[EmbeddingDefinition]: ...