Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: better model selection in animation code #47

Merged
merged 5 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions clip_eval/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
@cli.command(
"build",
help="""Build embeddings.
If no argumens are given, you will be prompted to select a combination of dataset and model(s).
If no arguments are given, you will be prompted to select a combination of dataset and model(s).
You can use [TAB] to select multiple combinations and execute them sequentially.
""",
)
Expand Down Expand Up @@ -107,22 +107,18 @@ def evaluate_embeddings(
""",
)
def animate_embeddings(
interactive: Annotated[bool, Option(help="Interactive plot instead of animation")] = False,
reduction: Annotated[str, Option(help="Reduction type [pca, tsne, umap (default)")] = "umap",
):
interactive: Annotated[bool, Option(help="Interactive plot instead of animation")] = False,
reduction: Annotated[str, Option(help="Reduction type [pca, tsne, umap (default)")] = "umap",
):
from clip_eval.plotting.animation import build_animation, save_animation_to_file

# Error could be localised better
defns = select_existing_embedding_definitions(by_dataset=True)
assert len(defns) == 2, "Please select exactly two models to make animation"
def1 = max(defns, key=lambda d: int(d.model == "clip"))
def2 = defns[0] if defns[0] != def1 else defns[1]
res = build_animation(def1, def2, interactive=interactive, reduction=reduction)
defs = select_existing_embedding_definitions(by_dataset=True, count=2)
res = build_animation(defs[0], defs[1], interactive=interactive, reduction=reduction)

if res is None:
plt.show()
else:
save_animation_to_file(res, *defns)
save_animation_to_file(res, *defs)


@cli.command("list", help="List models and datasets. By default, only cached pairs are listed.")
Expand Down
35 changes: 29 additions & 6 deletions clip_eval/cli/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from itertools import product
from typing import Literal, overload

from InquirerPy import inquirer as inq
from InquirerPy.base.control import Choice
Expand All @@ -9,13 +10,31 @@
from clip_eval.utils import read_all_cached_embeddings


@overload
def _do_embedding_definition_selection(
defs: list[EmbeddingDefinition], single: bool = False
defs: list[EmbeddingDefinition], allow_multiple: Literal[True] = True
) -> list[EmbeddingDefinition]:
...


@overload
def _do_embedding_definition_selection(
defs: list[EmbeddingDefinition], allow_multiple: Literal[False]
) -> EmbeddingDefinition:
...


def _do_embedding_definition_selection(
defs: list[EmbeddingDefinition],
allow_multiple: bool = True,
) -> list[EmbeddingDefinition] | EmbeddingDefinition:
choices = [Choice(d, f"D: {d.dataset[:15]:18s} M: {d.model}") for d in defs]
message = f"Please select the desired pair{'' if single else 's'}"
definitions: list[EmbeddingDefinition] = inq.fuzzy(
message, choices=choices, multiselect=True, vi_mode=True
message = "Please select the desired pairs" if allow_multiple else "Please select a pair"
definitions = inq.fuzzy(
message,
choices=choices,
multiselect=allow_multiple,
vi_mode=True,
).execute() # type: ignore
return definitions

Expand All @@ -31,7 +50,7 @@ def _by_dataset(defs: list[EmbeddingDefinition] | dict[str, list[EmbeddingDefini
[Choice(v, f"D: {k[:15]:18s} M: {', '.join([d.model for d in v])}") for k, v in defs.items() if len(v)],
key=lambda c: len(c.value),
)
message = "Please select dataset"
message = "Please select a dataset"
definitions: list[EmbeddingDefinition] = inq.fuzzy(
message, choices=choices, multiselect=False, vi_mode=True
).execute() # type: ignore
Expand All @@ -40,14 +59,18 @@ def _by_dataset(defs: list[EmbeddingDefinition] | dict[str, list[EmbeddingDefini

def select_existing_embedding_definitions(
by_dataset: bool = False,
count: int | None = None,
) -> list[EmbeddingDefinition]:
defs = read_all_cached_embeddings(as_list=True)

if by_dataset:
# Subset definitions to specific dataset
defs = _by_dataset(defs)

return _do_embedding_definition_selection(defs)
if count is None:
return _do_embedding_definition_selection(defs)
else:
return [_do_embedding_definition_selection(defs, allow_multiple=False) for _ in range(count)]


def select_from_all_embedding_definitions(
Expand Down
20 changes: 10 additions & 10 deletions clip_eval/plotting/animation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datetime import datetime
from datetime import UTC, datetime
from pathlib import Path
from typing import Literal, overload

Expand Down Expand Up @@ -240,20 +240,20 @@ def build_animation(
)


def save_animation_to_file(anim: animation.FuncAnimation, defn_1, defn_2: EmbeddingDefinition):
ts = datetime.now()
animation_file = OUTPUT_PATH.ANIMATIONS / f"transition_{defn_1}-{defn_2}_{ts.isoformat()}.gif"
def save_animation_to_file(anim: animation.FuncAnimation, def1: EmbeddingDefinition, def2: EmbeddingDefinition):
date_code = datetime.now(UTC).strftime("%Y%m%d-%H%M%S")
animation_file = OUTPUT_PATH.ANIMATIONS / f"transition_{def1.dataset}_{def1.model}_{def2.model}_{date_code}.gif"
animation_file.parent.mkdir(parents=True, exist_ok=True) # Ensure that parent folder exists
anim.save(animation_file)
print(f"Stored animation in `{animation_file}`")
print(f"Animation stored at `{animation_file.resolve().as_posix()}`")


if __name__ == "__main__":
defn_1 = EmbeddingDefinition(model="clip", dataset="LungCancer4Types")
defn_2 = EmbeddingDefinition(model="pubmed", dataset="LungCancer4Types")
anim = build_animation(defn_1, defn_2, interactive=False)
ts = datetime.now()
animation_file = OUTPUT_PATH.ANIMATIONS / f"transition_{defn_1}-{defn_2}_{ts.isoformat()}.gif"
def1 = EmbeddingDefinition(model="clip", dataset="LungCancer4Types")
def2 = EmbeddingDefinition(model="pubmed", dataset="LungCancer4Types")
anim = build_animation(def1, def2, interactive=False)
date_code = datetime.now(UTC).strftime("%Y%m%d-%H%M%S")
animation_file = OUTPUT_PATH.ANIMATIONS / f"transition_{def1.dataset}_{def1.model}_{def2.model}_{date_code}.gif"
animation_file.parent.mkdir(parents=True, exist_ok=True) # Ensure that parent folder exists
anim.save(animation_file)
plt.show()
4 changes: 1 addition & 3 deletions clip_eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@ def read_all_cached_embeddings(as_list: Literal[True]) -> list[EmbeddingDefiniti


@overload
def read_all_cached_embeddings(
as_list: Literal[False] = False,
) -> dict[str, list[EmbeddingDefinition]]:
def read_all_cached_embeddings(as_list: Literal[False] = False) -> dict[str, list[EmbeddingDefinition]]:
...


Expand Down
Loading