Skip to content

Commit

Permalink
feat: friendlier cli with Typer (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
frederik-encord authored Feb 23, 2024
1 parent 23b5533 commit 15ebd0d
Show file tree
Hide file tree
Showing 11 changed files with 606 additions and 130 deletions.
114 changes: 0 additions & 114 deletions clip_eval/cli.py

This file was deleted.

Empty file added clip_eval/cli/__init__.py
Empty file.
145 changes: 145 additions & 0 deletions clip_eval/cli/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from typing import Annotated, Optional

import matplotlib.pyplot as plt
from typer import Option, Typer

from clip_eval.common.data_models import EmbeddingDefinition
from clip_eval.utils import read_all_cached_embeddings

from .utils import (
select_existing_embedding_definitions,
select_from_all_embedding_definitions,
)

cli = Typer(name="clip-eval", no_args_is_help=True, rich_markup_mode="markdown")


@cli.command(
"build",
help="""Build embeddings.
If no argumens are given, you will be prompted to select a combination of dataset and model(s).
You can use [TAB] to select multiple combinations and execute them sequentially.
""",
)
def build_command(
model_dataset: Annotated[str, Option(help="model, dataset pair delimited by model/dataset")] = "",
include_existing: Annotated[
bool,
Option(help="Show also options for which the embeddings have been computed already"),
] = False,
by_dataset: Annotated[
bool,
Option(help="Select dataset first, then model. Will only work if `model_dataset` not specified."),
] = False,
):
if len(model_dataset) > 0:
if model_dataset.count("/") != 1:
raise ValueError("model dataset must contain only 1 /")
model, dataset = model_dataset.split("/")
definitions = [EmbeddingDefinition(model=model, dataset=dataset)]
else:
definitions = select_from_all_embedding_definitions(
include_existing=include_existing,
by_dataset=by_dataset,
)

for embd_defn in definitions:
try:
embeddings = embd_defn.build_embeddings()
print("Made embedding successfully")
embd_defn.save_embeddings(embeddings=embeddings)
print("Saved embedding to file successfully at", embd_defn.embedding_path)
except Exception as e:
print(f"Failed to build embeddings for this bastard: {embd_defn}")
print(e)
import traceback

traceback.print_exc()


@cli.command(
"evaluate",
help="""Evaluate embedding performance.
For this two work, you should have already run the `build` command for the model/dataset of interest.
""",
)
def evaluate_embeddings(
model_datasets: Annotated[
Optional[list[str]],
Option(help="Specify specific combinations of models and datasets"),
] = None,
is_all: Annotated[bool, Option(help="Evaluate all models.")] = False,
save: Annotated[bool, Option(help="Save evaluation results to csv")] = False,
):
from clip_eval.evaluation import (
LinearProbeClassifier,
WeightedKNNClassifier,
ZeroShotClassifier,
)
from clip_eval.evaluation.evaluator import export_evaluation_to_csv, run_evaluation

model_datasets = model_datasets or []

if is_all:
defns = read_all_cached_embeddings(as_list=True)
elif len(model_datasets) > 0:
# Error could be localised better
if not all([model_dataset.count("/") == 1 for model_dataset in model_datasets]):
raise ValueError("All model,dataset pairs must be presented as MODEL/DATASET")
model_dataset_pairs = [model_dataset.split("/") for model_dataset in model_datasets]
defns = [
EmbeddingDefinition(model=model_dataset[0], dataset=model_dataset[1])
for model_dataset in model_dataset_pairs
]
else:
defns = select_existing_embedding_definitions()

models = [ZeroShotClassifier, LinearProbeClassifier, WeightedKNNClassifier]
performances = run_evaluation(models, defns)
if save:
export_evaluation_to_csv(defns, performances)
print(performances)


@cli.command(
"animate",
help="""Animate 2D embeddings from two different models on the same dataset.
The interface will prompt you to choose which embeddings you want to use.
""",
)
def animate_embeddings():
from clip_eval.plotting.animation import build_animation, save_animation_to_file

# Error could be localised better
defns = select_existing_embedding_definitions(by_dataset=True)
assert len(defns) == 2, "Please select exactly two models to make animation"
def1 = max(defns, key=lambda d: int(d.model == "clip"))
def2 = defns[0] if defns[0] != def1 else defns[1]
anim = build_animation(def1, def2)
save_animation_to_file(anim, *defns)
plt.show()


@cli.command("list", help="List models and datasets. By default, only cached pairs are listed.")
def list_models_datasets(
all: Annotated[
bool,
Option(help="List all models and dataset that are available via the tool."),
] = False,
):
from clip_eval.dataset.provider import dataset_provider
from clip_eval.models import model_provider

if all:
datasets = dataset_provider.list_dataset_names()
models = model_provider.list_model_names()
print(f"Available datasets are: {', '.join(datasets)}")
print(f"Available models are: {', '.join(models)}")
return

defns = read_all_cached_embeddings(as_list=True)
print(f"Available model_dataset pairs: {', '.join([str(defn) for defn in defns])}")


if __name__ == "__main__":
cli()
68 changes: 68 additions & 0 deletions clip_eval/cli/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from itertools import product

from InquirerPy import inquirer as inq
from InquirerPy.base.control import Choice

from clip_eval.common.data_models import EmbeddingDefinition
from clip_eval.dataset.provider import dataset_provider
from clip_eval.models.provider import model_provider
from clip_eval.utils import read_all_cached_embeddings


def _do_embedding_definition_selection(
defs: list[EmbeddingDefinition], single: bool = False
) -> list[EmbeddingDefinition]:
choices = [Choice(d, f"D: {d.dataset[:15]:18s} M: {d.model}") for d in defs]
message = f"Please select the desired pair{'' if single else 's'}"
definitions: list[EmbeddingDefinition] = inq.fuzzy(
message, choices=choices, multiselect=True, vi_mode=True
).execute() # type: ignore
return definitions


def _by_dataset(defs: list[EmbeddingDefinition] | dict[str, list[EmbeddingDefinition]]) -> list[EmbeddingDefinition]:
if isinstance(defs, list):
defs_list = defs
defs = {}
for d in defs_list:
defs.setdefault(d.dataset, []).append(d)

choices = sorted(
[Choice(v, f"D: {k[:15]:18s} M: {', '.join([d.model for d in v])}") for k, v in defs.items() if len(v)],
key=lambda c: len(c.value),
)
message = "Please select dataset"
definitions: list[EmbeddingDefinition] = inq.fuzzy(
message, choices=choices, multiselect=False, vi_mode=True
).execute() # type: ignore
return definitions


def select_existing_embedding_definitions(
by_dataset: bool = False,
) -> list[EmbeddingDefinition]:
defs = read_all_cached_embeddings(as_list=True)

if by_dataset:
# Subset definitions to specific dataset
defs = _by_dataset(defs)

return _do_embedding_definition_selection(defs)


def select_from_all_embedding_definitions(
include_existing: bool = False, by_dataset: bool = False
) -> list[EmbeddingDefinition]:
existing = set(read_all_cached_embeddings(as_list=True))

models = model_provider.list_model_names()
datasets = dataset_provider.list_dataset_names()

defs = [EmbeddingDefinition(dataset=d, model=m) for d, m in product(datasets, models)]
if not include_existing:
defs = [d for d in defs if d not in existing]

if by_dataset:
defs = _by_dataset(defs)

return _do_embedding_definition_selection(defs)
8 changes: 7 additions & 1 deletion clip_eval/common/data_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from pathlib import Path
from typing import Annotated
from typing import Annotated, Any

import numpy as np
from pydantic import BaseModel
Expand Down Expand Up @@ -64,6 +64,12 @@ def build_embeddings(self) -> Embeddings:
def __str__(self):
return self.model + "_" + self.dataset

def __eq__(self, other: Any) -> bool:
return isinstance(other, EmbeddingDefinition) and self.model == other.model and self.dataset == other.dataset

def __hash__(self):
return hash((self.model, self.dataset))


if __name__ == "__main__":
def_ = EmbeddingDefinition(
Expand Down
2 changes: 1 addition & 1 deletion clip_eval/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def export_evaluation_to_csv(

if __name__ == "__main__":
models = [ZeroShotClassifier, LinearProbeClassifier, WeightedKNNClassifier]
defs = [d for k, v in read_all_cached_embeddings().items() for d in v]
defs = read_all_cached_embeddings(as_list=True)
print(defs)
performances = run_evaluation(models, defs)
export_evaluation_to_csv(defs, performances)
Expand Down
Loading

0 comments on commit 15ebd0d

Please sign in to comment.