Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: more friendly cli #28

Merged
merged 7 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 0 additions & 114 deletions clip_eval/cli.py

This file was deleted.

Empty file added clip_eval/cli/__init__.py
Empty file.
145 changes: 145 additions & 0 deletions clip_eval/cli/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from typing import Annotated, Optional

import matplotlib.pyplot as plt
from typer import Option, Typer

from clip_eval.common.data_models import EmbeddingDefinition
from clip_eval.utils import read_all_cached_embeddings

from .utils import (
select_existing_embedding_definitions,
select_from_all_embedding_definitions,
)

cli = Typer(name="clip-eval", no_args_is_help=True, rich_markup_mode="markdown")


@cli.command(
"build",
help="""Build embeddings.
If no argumens are given, you will be prompted to select a combination of dataset and model(s).
You can use [TAB] to select multiple combinations and execute them sequentially.
""",
)
def build_command(
model_dataset: Annotated[str, Option(help="model, dataset pair delimited by model/dataset")] = "",
include_existing: Annotated[
bool,
Option(help="Show also options for which the embeddings have been computed already"),
] = False,
by_dataset: Annotated[
bool,
Option(help="Select dataset first, then model. Will only work if `model_dataset` not specified."),
] = False,
):
if len(model_dataset) > 0:
if model_dataset.count("/") != 1:
raise ValueError("model dataset must contain only 1 /")
model, dataset = model_dataset.split("/")
definitions = [EmbeddingDefinition(model=model, dataset=dataset)]
else:
definitions = select_from_all_embedding_definitions(
include_existing=include_existing,
by_dataset=by_dataset,
)

for embd_defn in definitions:
try:
embeddings = embd_defn.build_embeddings()
print("Made embedding successfully")
embd_defn.save_embeddings(embeddings=embeddings)
print("Saved embedding to file successfully at", embd_defn.embedding_path)
except Exception as e:
print(f"Failed to build embeddings for this bastard: {embd_defn}")
print(e)
import traceback

traceback.print_exc()


@cli.command(
"evaluate",
help="""Evaluate embedding performance.
For this two work, you should have already run the `build` command for the model/dataset of interest.
""",
)
def evaluate_embeddings(
model_datasets: Annotated[
Optional[list[str]],
Option(help="Specify specific combinations of models and datasets"),
] = None,
is_all: Annotated[bool, Option(help="Evaluate all models.")] = False,
save: Annotated[bool, Option(help="Save evaluation results to csv")] = False,
):
from clip_eval.evaluation import (
LinearProbeClassifier,
WeightedKNNClassifier,
ZeroShotClassifier,
)
from clip_eval.evaluation.evaluator import export_evaluation_to_csv, run_evaluation

model_datasets = model_datasets or []

if is_all:
defns = read_all_cached_embeddings(as_list=True)
elif len(model_datasets) > 0:
# Error could be localised better
if not all([model_dataset.count("/") == 1 for model_dataset in model_datasets]):
raise ValueError("All model,dataset pairs must be presented as MODEL/DATASET")
model_dataset_pairs = [model_dataset.split("/") for model_dataset in model_datasets]
defns = [
EmbeddingDefinition(model=model_dataset[0], dataset=model_dataset[1])
for model_dataset in model_dataset_pairs
]
else:
defns = select_existing_embedding_definitions()

models = [ZeroShotClassifier, LinearProbeClassifier, WeightedKNNClassifier]
performances = run_evaluation(models, defns)
if save:
export_evaluation_to_csv(defns, performances)
print(performances)


@cli.command(
"animate",
help="""Animate 2D embeddings from two different models on the same dataset.
The interface will prompt you to choose which embeddings you want to use.
""",
)
def animate_embeddings():
from clip_eval.plotting.animation import build_animation, save_animation_to_file

# Error could be localised better
defns = select_existing_embedding_definitions(by_dataset=True)
assert len(defns) == 2, "Please select exactly two models to make animation"
def1 = max(defns, key=lambda d: int(d.model == "clip"))
def2 = defns[0] if defns[0] != def1 else defns[1]
anim = build_animation(def1, def2)
save_animation_to_file(anim, *defns)
plt.show()


@cli.command("list", help="List models and datasets. By default, only cached pairs are listed.")
def list_models_datasets(
all: Annotated[
bool,
Option(help="List all models and dataset that are available via the tool."),
] = False,
):
from clip_eval.dataset.provider import dataset_provider
from clip_eval.models import model_provider

if all:
datasets = dataset_provider.list_dataset_names()
models = model_provider.list_model_names()
print(f"Available datasets are: {', '.join(datasets)}")
print(f"Available models are: {', '.join(models)}")
return

defns = read_all_cached_embeddings(as_list=True)
print(f"Available model_dataset pairs: {', '.join([str(defn) for defn in defns])}")


if __name__ == "__main__":
cli()
68 changes: 68 additions & 0 deletions clip_eval/cli/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from itertools import product

from InquirerPy import inquirer as inq
from InquirerPy.base.control import Choice

from clip_eval.common.data_models import EmbeddingDefinition
from clip_eval.dataset.provider import dataset_provider
from clip_eval.models.provider import model_provider
from clip_eval.utils import read_all_cached_embeddings


def _do_embedding_definition_selection(
defs: list[EmbeddingDefinition], single: bool = False
) -> list[EmbeddingDefinition]:
choices = [Choice(d, f"D: {d.dataset[:15]:18s} M: {d.model}") for d in defs]
message = f"Please select the desired pair{'' if single else 's'}"
definitions: list[EmbeddingDefinition] = inq.fuzzy(
message, choices=choices, multiselect=True, vi_mode=True
).execute() # type: ignore
return definitions


def _by_dataset(defs: list[EmbeddingDefinition] | dict[str, list[EmbeddingDefinition]]) -> list[EmbeddingDefinition]:
if isinstance(defs, list):
defs_list = defs
defs = {}
for d in defs_list:
defs.setdefault(d.dataset, []).append(d)

choices = sorted(
[Choice(v, f"D: {k[:15]:18s} M: {', '.join([d.model for d in v])}") for k, v in defs.items() if len(v)],
key=lambda c: len(c.value),
)
message = "Please select dataset"
definitions: list[EmbeddingDefinition] = inq.fuzzy(
message, choices=choices, multiselect=False, vi_mode=True
).execute() # type: ignore
return definitions


def select_existing_embedding_definitions(
by_dataset: bool = False,
) -> list[EmbeddingDefinition]:
defs = read_all_cached_embeddings(as_list=True)

if by_dataset:
# Subset definitions to specific dataset
defs = _by_dataset(defs)

return _do_embedding_definition_selection(defs)


def select_from_all_embedding_definitions(
include_existing: bool = False, by_dataset: bool = False
) -> list[EmbeddingDefinition]:
existing = set(read_all_cached_embeddings(as_list=True))

models = model_provider.list_model_names()
datasets = dataset_provider.list_dataset_names()

defs = [EmbeddingDefinition(dataset=d, model=m) for d, m in product(datasets, models)]
if not include_existing:
defs = [d for d in defs if d not in existing]

if by_dataset:
defs = _by_dataset(defs)

return _do_embedding_definition_selection(defs)
8 changes: 7 additions & 1 deletion clip_eval/common/data_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from pathlib import Path
from typing import Annotated
from typing import Annotated, Any

import numpy as np
from pydantic import BaseModel
Expand Down Expand Up @@ -64,6 +64,12 @@ def build_embeddings(self) -> Embeddings:
def __str__(self):
return self.model + "_" + self.dataset

def __eq__(self, other: Any) -> bool:
return isinstance(other, EmbeddingDefinition) and self.model == other.model and self.dataset == other.dataset

def __hash__(self):
return hash((self.model, self.dataset))


if __name__ == "__main__":
def_ = EmbeddingDefinition(
Expand Down
2 changes: 1 addition & 1 deletion clip_eval/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def export_evaluation_to_csv(

if __name__ == "__main__":
models = [ZeroShotClassifier, LinearProbeClassifier, WeightedKNNClassifier]
defs = [d for k, v in read_all_cached_embeddings().items() for d in v]
defs = read_all_cached_embeddings(as_list=True)
print(defs)
performances = run_evaluation(models, defs)
export_evaluation_to_csv(defs, performances)
Expand Down
Loading
Loading