Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow choosing reduction type on animation #44

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions clip_eval/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,18 @@ def evaluate_embeddings(
The interface will prompt you to choose which embeddings you want to use.
""",
)
def animate_embeddings(interactive: Annotated[bool, Option(help="Interactive plot instead of animation")] = False):
def animate_embeddings(
interactive: Annotated[bool, Option(help="Interactive plot instead of animation")] = False,
reduction: Annotated[str, Option(help="Reduction type [pca, tsne, umap (default)")] = "umap",
):
from clip_eval.plotting.animation import build_animation, save_animation_to_file

# Error could be localised better
defns = select_existing_embedding_definitions(by_dataset=True)
assert len(defns) == 2, "Please select exactly two models to make animation"
def1 = max(defns, key=lambda d: int(d.model == "clip"))
def2 = defns[0] if defns[0] != def1 else defns[1]
res = build_animation(def1, def2, interactive=interactive)
res = build_animation(def1, def2, interactive=interactive, reduction=reduction)

if res is None:
plt.show()
Expand Down
10 changes: 6 additions & 4 deletions clip_eval/plotting/animation.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ def rotate_to_target(source: N2Array, destination: N2Array) -> N2Array:
source = np.pad(source, [(0, 0), (0, 1)], mode="constant", constant_values=0.0)
destination = np.pad(destination, [(0, 0), (0, 1)], mode="constant", constant_values=0.0)

rot, *_ = R.align_vectors(source, destination, return_sensitivity=True)
out = source @ rot.as_matrix()
rot, *_ = R.align_vectors(destination, source, return_sensitivity=True)
out = rot.apply(source)
return out[:, :2]


Expand Down Expand Up @@ -220,16 +220,18 @@ def build_animation(
reducer = reduction_from_string(reduction)
reduced_1 = standardize(reducer.get_reduction(defn_1))
reduced_2 = rotate_to_target(standardize(reducer.get_reduction(defn_2)), reduced_1)
labels = embeds.labels

if reduced_1.shape[0] > 2_000:
selection = np.random.permutation(reduced_1.shape[0])[2_000]
selection = np.random.permutation(reduced_1.shape[0])[:2_000]
reduced_1 = reduced_1[selection]
reduced_2 = reduced_2[selection]
labels = labels[selection]

return create_embedding_chart(
reduced_1,
reduced_2,
embeds.labels,
labels,
defn_1.model,
defn_2.model,
suptitle=dataset.title,
Expand Down
Loading