-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: friendlier cli with Typer (#28)
- Loading branch information
1 parent
23b5533
commit 15ebd0d
Showing
11 changed files
with
606 additions
and
130 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
from typing import Annotated, Optional | ||
|
||
import matplotlib.pyplot as plt | ||
from typer import Option, Typer | ||
|
||
from clip_eval.common.data_models import EmbeddingDefinition | ||
from clip_eval.utils import read_all_cached_embeddings | ||
|
||
from .utils import ( | ||
select_existing_embedding_definitions, | ||
select_from_all_embedding_definitions, | ||
) | ||
|
||
cli = Typer(name="clip-eval", no_args_is_help=True, rich_markup_mode="markdown") | ||
|
||
|
||
@cli.command( | ||
"build", | ||
help="""Build embeddings. | ||
If no argumens are given, you will be prompted to select a combination of dataset and model(s). | ||
You can use [TAB] to select multiple combinations and execute them sequentially. | ||
""", | ||
) | ||
def build_command( | ||
model_dataset: Annotated[str, Option(help="model, dataset pair delimited by model/dataset")] = "", | ||
include_existing: Annotated[ | ||
bool, | ||
Option(help="Show also options for which the embeddings have been computed already"), | ||
] = False, | ||
by_dataset: Annotated[ | ||
bool, | ||
Option(help="Select dataset first, then model. Will only work if `model_dataset` not specified."), | ||
] = False, | ||
): | ||
if len(model_dataset) > 0: | ||
if model_dataset.count("/") != 1: | ||
raise ValueError("model dataset must contain only 1 /") | ||
model, dataset = model_dataset.split("/") | ||
definitions = [EmbeddingDefinition(model=model, dataset=dataset)] | ||
else: | ||
definitions = select_from_all_embedding_definitions( | ||
include_existing=include_existing, | ||
by_dataset=by_dataset, | ||
) | ||
|
||
for embd_defn in definitions: | ||
try: | ||
embeddings = embd_defn.build_embeddings() | ||
print("Made embedding successfully") | ||
embd_defn.save_embeddings(embeddings=embeddings) | ||
print("Saved embedding to file successfully at", embd_defn.embedding_path) | ||
except Exception as e: | ||
print(f"Failed to build embeddings for this bastard: {embd_defn}") | ||
print(e) | ||
import traceback | ||
|
||
traceback.print_exc() | ||
|
||
|
||
@cli.command( | ||
"evaluate", | ||
help="""Evaluate embedding performance. | ||
For this two work, you should have already run the `build` command for the model/dataset of interest. | ||
""", | ||
) | ||
def evaluate_embeddings( | ||
model_datasets: Annotated[ | ||
Optional[list[str]], | ||
Option(help="Specify specific combinations of models and datasets"), | ||
] = None, | ||
is_all: Annotated[bool, Option(help="Evaluate all models.")] = False, | ||
save: Annotated[bool, Option(help="Save evaluation results to csv")] = False, | ||
): | ||
from clip_eval.evaluation import ( | ||
LinearProbeClassifier, | ||
WeightedKNNClassifier, | ||
ZeroShotClassifier, | ||
) | ||
from clip_eval.evaluation.evaluator import export_evaluation_to_csv, run_evaluation | ||
|
||
model_datasets = model_datasets or [] | ||
|
||
if is_all: | ||
defns = read_all_cached_embeddings(as_list=True) | ||
elif len(model_datasets) > 0: | ||
# Error could be localised better | ||
if not all([model_dataset.count("/") == 1 for model_dataset in model_datasets]): | ||
raise ValueError("All model,dataset pairs must be presented as MODEL/DATASET") | ||
model_dataset_pairs = [model_dataset.split("/") for model_dataset in model_datasets] | ||
defns = [ | ||
EmbeddingDefinition(model=model_dataset[0], dataset=model_dataset[1]) | ||
for model_dataset in model_dataset_pairs | ||
] | ||
else: | ||
defns = select_existing_embedding_definitions() | ||
|
||
models = [ZeroShotClassifier, LinearProbeClassifier, WeightedKNNClassifier] | ||
performances = run_evaluation(models, defns) | ||
if save: | ||
export_evaluation_to_csv(defns, performances) | ||
print(performances) | ||
|
||
|
||
@cli.command( | ||
"animate", | ||
help="""Animate 2D embeddings from two different models on the same dataset. | ||
The interface will prompt you to choose which embeddings you want to use. | ||
""", | ||
) | ||
def animate_embeddings(): | ||
from clip_eval.plotting.animation import build_animation, save_animation_to_file | ||
|
||
# Error could be localised better | ||
defns = select_existing_embedding_definitions(by_dataset=True) | ||
assert len(defns) == 2, "Please select exactly two models to make animation" | ||
def1 = max(defns, key=lambda d: int(d.model == "clip")) | ||
def2 = defns[0] if defns[0] != def1 else defns[1] | ||
anim = build_animation(def1, def2) | ||
save_animation_to_file(anim, *defns) | ||
plt.show() | ||
|
||
|
||
@cli.command("list", help="List models and datasets. By default, only cached pairs are listed.") | ||
def list_models_datasets( | ||
all: Annotated[ | ||
bool, | ||
Option(help="List all models and dataset that are available via the tool."), | ||
] = False, | ||
): | ||
from clip_eval.dataset.provider import dataset_provider | ||
from clip_eval.models import model_provider | ||
|
||
if all: | ||
datasets = dataset_provider.list_dataset_names() | ||
models = model_provider.list_model_names() | ||
print(f"Available datasets are: {', '.join(datasets)}") | ||
print(f"Available models are: {', '.join(models)}") | ||
return | ||
|
||
defns = read_all_cached_embeddings(as_list=True) | ||
print(f"Available model_dataset pairs: {', '.join([str(defn) for defn in defns])}") | ||
|
||
|
||
if __name__ == "__main__": | ||
cli() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from itertools import product | ||
|
||
from InquirerPy import inquirer as inq | ||
from InquirerPy.base.control import Choice | ||
|
||
from clip_eval.common.data_models import EmbeddingDefinition | ||
from clip_eval.dataset.provider import dataset_provider | ||
from clip_eval.models.provider import model_provider | ||
from clip_eval.utils import read_all_cached_embeddings | ||
|
||
|
||
def _do_embedding_definition_selection( | ||
defs: list[EmbeddingDefinition], single: bool = False | ||
) -> list[EmbeddingDefinition]: | ||
choices = [Choice(d, f"D: {d.dataset[:15]:18s} M: {d.model}") for d in defs] | ||
message = f"Please select the desired pair{'' if single else 's'}" | ||
definitions: list[EmbeddingDefinition] = inq.fuzzy( | ||
message, choices=choices, multiselect=True, vi_mode=True | ||
).execute() # type: ignore | ||
return definitions | ||
|
||
|
||
def _by_dataset(defs: list[EmbeddingDefinition] | dict[str, list[EmbeddingDefinition]]) -> list[EmbeddingDefinition]: | ||
if isinstance(defs, list): | ||
defs_list = defs | ||
defs = {} | ||
for d in defs_list: | ||
defs.setdefault(d.dataset, []).append(d) | ||
|
||
choices = sorted( | ||
[Choice(v, f"D: {k[:15]:18s} M: {', '.join([d.model for d in v])}") for k, v in defs.items() if len(v)], | ||
key=lambda c: len(c.value), | ||
) | ||
message = "Please select dataset" | ||
definitions: list[EmbeddingDefinition] = inq.fuzzy( | ||
message, choices=choices, multiselect=False, vi_mode=True | ||
).execute() # type: ignore | ||
return definitions | ||
|
||
|
||
def select_existing_embedding_definitions( | ||
by_dataset: bool = False, | ||
) -> list[EmbeddingDefinition]: | ||
defs = read_all_cached_embeddings(as_list=True) | ||
|
||
if by_dataset: | ||
# Subset definitions to specific dataset | ||
defs = _by_dataset(defs) | ||
|
||
return _do_embedding_definition_selection(defs) | ||
|
||
|
||
def select_from_all_embedding_definitions( | ||
include_existing: bool = False, by_dataset: bool = False | ||
) -> list[EmbeddingDefinition]: | ||
existing = set(read_all_cached_embeddings(as_list=True)) | ||
|
||
models = model_provider.list_model_names() | ||
datasets = dataset_provider.list_dataset_names() | ||
|
||
defs = [EmbeddingDefinition(dataset=d, model=m) for d, m in product(datasets, models)] | ||
if not include_existing: | ||
defs = [d for d in defs if d not in existing] | ||
|
||
if by_dataset: | ||
defs = _by_dataset(defs) | ||
|
||
return _do_embedding_definition_selection(defs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.