Skip to content

Commit

Permalink
fix: set fixed evaluators order and title as classmethod (#76)
Browse files Browse the repository at this point in the history
  • Loading branch information
eloy-encord authored May 9, 2024
1 parent cab13d8 commit f8d9233
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 27 deletions.
18 changes: 6 additions & 12 deletions tti_eval/evaluation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,7 @@
from .utils import normalize


class EvaluationModelTitleInterface:
# Enforce the evaluation models' title while removing its explicit mention in the `EvaluationModel` init.
# This way, when `EvaluationModel` is used as a type hint, there won't be a warning about unfilled title.
def __init__(self, title: str, **kwargs) -> None:
self._title = title

@property
def title(self) -> str:
return self._title


class EvaluationModel(EvaluationModelTitleInterface, ABC):
class EvaluationModel(ABC):
def __init__(
self,
train_embeddings: Embeddings,
Expand Down Expand Up @@ -95,6 +84,11 @@ def num_classes(self) -> int:
def evaluate(self) -> float:
...

@classmethod
@abstractmethod
def title(cls) -> str:
...


class ClassificationModel(EvaluationModel):
@abstractmethod
Expand Down
15 changes: 9 additions & 6 deletions tti_eval/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from natsort import natsorted, ns
from tabulate import tabulate
from tqdm.auto import tqdm

from tti_eval.common import EmbeddingDefinition, Split
from tti_eval.constants import OUTPUT_PATH
Expand Down Expand Up @@ -47,9 +48,9 @@ def run_evaluation(
embedding_definitions: list[EmbeddingDefinition],
) -> dict[EmbeddingDefinition, dict[str, float]]:
embeddings_performance: dict[EmbeddingDefinition, dict[str, float]] = {}
model_keys: set[str] = set()
used_evaluators: set[str] = set()

for def_ in embedding_definitions:
for def_ in tqdm(embedding_definitions, desc="Evaluating embedding definitions", leave=False):
train_embeddings = def_.load_embeddings(Split.TRAIN)
validation_embeddings = def_.load_embeddings(Split.VALIDATION)

Expand All @@ -68,11 +69,13 @@ def run_evaluation(
train_embeddings=train_embeddings,
validation_embeddings=validation_embeddings,
)
evaluator_performance[evaluator.title] = evaluator.evaluate()
model_keys.add(evaluator.title)
evaluator_performance[evaluator.title()] = evaluator.evaluate()
used_evaluators.add(evaluator.title())

for n in model_keys:
print_evaluation_results(embeddings_performance, n)
for evaluator_type in evaluators:
evaluator_title = evaluator_type.title()
if evaluator_title in used_evaluators:
print_evaluation_results(embeddings_performance, evaluator_title)
return embeddings_performance


Expand Down
9 changes: 8 additions & 1 deletion tti_eval/evaluation/image_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,18 @@
from autofaiss import build_index

from tti_eval.common import Embeddings
from tti_eval.utils import disable_tqdm, enable_tqdm

from .base import EvaluationModel

logger = logging.getLogger("multiclips")


class I2IRetrievalEvaluator(EvaluationModel):
@classmethod
def title(cls) -> str:
return "I2IR"

def __init__(
self,
train_embeddings: Embeddings,
Expand All @@ -33,14 +38,16 @@ def __init__(
:raises ValueError: If the build of the faiss index for similarity search fails.
"""
super().__init__(train_embeddings, validation_embeddings, num_classes, title="I2IR")
super().__init__(train_embeddings, validation_embeddings, num_classes)
self.k = min(k, len(validation_embeddings.images))

class_ids, counts = np.unique(self._val_embeddings.labels, return_counts=True)
self._class_counts = np.zeros(self.num_classes, dtype=np.int32)
self._class_counts[class_ids] = counts

disable_tqdm() # Disable tqdm progress bar when building the index
index, self.index_infos = build_index(self._val_embeddings.images, save_on_disk=False, verbose=logging.ERROR)
enable_tqdm()
if index is None:
raise ValueError("Failed to build an index for knn search")
self._index = index
Expand Down
15 changes: 9 additions & 6 deletions tti_eval/evaluation/knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from autofaiss import build_index

from tti_eval.common import ClassArray, Embeddings, ProbabilityArray
from tti_eval.utils import disable_tqdm, enable_tqdm

from .base import ClassificationModel
from .utils import softmax
Expand All @@ -13,6 +14,10 @@


class WeightedKNNClassifier(ClassificationModel):
@classmethod
def title(cls) -> str:
return "wKNN"

def __init__(
self,
train_embeddings: Embeddings,
Expand All @@ -36,15 +41,13 @@ def __init__(
:raises ValueError: If the build of the faiss index for KNN fails.
"""
super().__init__(train_embeddings, validation_embeddings, num_classes, title="wKNN")
super().__init__(train_embeddings, validation_embeddings, num_classes)
self.k = k

disable_tqdm() # Disable tqdm progress bar when building the index
index, self.index_infos = build_index(
train_embeddings.images,
metric_type="l2",
save_on_disk=False,
verbose=logging.ERROR,
train_embeddings.images, metric_type="l2", save_on_disk=False, verbose=logging.ERROR
)
enable_tqdm()
if index is None:
raise ValueError("Failed to build an index for knn search")
self._index = index
Expand Down
6 changes: 5 additions & 1 deletion tti_eval/evaluation/linear_probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@


class LinearProbeClassifier(ClassificationModel):
@classmethod
def title(cls) -> str:
return "linear_probe"

def __init__(
self,
train_embeddings: Embeddings,
Expand All @@ -28,7 +32,7 @@ def __init__(
:param log_reg_params: Parameters for the Logistic Regression model.
:param use_cross_validation: Flag that indicated whether to use cross-validation when training the model.
"""
super().__init__(train_embeddings, validation_embeddings, num_classes, title="linear_probe")
super().__init__(train_embeddings, validation_embeddings, num_classes)

params = log_reg_params or {}
self.classifier: LogisticRegressionCV | LogisticRegression
Expand Down
6 changes: 5 additions & 1 deletion tti_eval/evaluation/zero_shot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@


class ZeroShotClassifier(ClassificationModel):
@classmethod
def title(cls) -> str:
return "zero_shot"

def __init__(
self,
train_embeddings: Embeddings,
Expand All @@ -19,7 +23,7 @@ def __init__(
:param validation_embeddings: Embeddings and their labels used for evaluating the search space.
:param num_classes: Number of classes. If not specified, it will be inferred from the train labels.
"""
super().__init__(train_embeddings, validation_embeddings, num_classes, title="zero_shot")
super().__init__(train_embeddings, validation_embeddings, num_classes)
if self._train_embeddings.classes is None:
raise ValueError("Expected class embeddings in `train_embeddings`, got `None`")

Expand Down
11 changes: 11 additions & 0 deletions tti_eval/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
from functools import partialmethod
from itertools import chain
from typing import Literal, overload

from tqdm import tqdm

from tti_eval.common import EmbeddingDefinition
from tti_eval.constants import PROJECT_PATHS


def disable_tqdm():
tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)


def enable_tqdm():
tqdm.__init__ = partialmethod(tqdm.__init__, disable=False)


@overload
def read_all_cached_embeddings(as_list: Literal[True]) -> list[EmbeddingDefinition]:
...
Expand Down

0 comments on commit f8d9233

Please sign in to comment.