diff --git a/clip_eval/cli/main.py b/clip_eval/cli/main.py index 9ab33f8..019d937 100644 --- a/clip_eval/cli/main.py +++ b/clip_eval/cli/main.py @@ -106,7 +106,10 @@ def evaluate_embeddings( The interface will prompt you to choose which embeddings you want to use. """, ) -def animate_embeddings(interactive: Annotated[bool, Option(help="Interactive plot instead of animation")] = False): +def animate_embeddings( + interactive: Annotated[bool, Option(help="Interactive plot instead of animation")] = False, + reduction: Annotated[str, Option(help="Reduction type [pca, tsne, umap (default)")] = "umap", + ): from clip_eval.plotting.animation import build_animation, save_animation_to_file # Error could be localised better @@ -114,7 +117,7 @@ def animate_embeddings(interactive: Annotated[bool, Option(help="Interactive plo assert len(defns) == 2, "Please select exactly two models to make animation" def1 = max(defns, key=lambda d: int(d.model == "clip")) def2 = defns[0] if defns[0] != def1 else defns[1] - res = build_animation(def1, def2, interactive=interactive) + res = build_animation(def1, def2, interactive=interactive, reduction=reduction) if res is None: plt.show() diff --git a/clip_eval/plotting/animation.py b/clip_eval/plotting/animation.py index 993d07c..920d362 100644 --- a/clip_eval/plotting/animation.py +++ b/clip_eval/plotting/animation.py @@ -166,8 +166,8 @@ def rotate_to_target(source: N2Array, destination: N2Array) -> N2Array: source = np.pad(source, [(0, 0), (0, 1)], mode="constant", constant_values=0.0) destination = np.pad(destination, [(0, 0), (0, 1)], mode="constant", constant_values=0.0) - rot, *_ = R.align_vectors(source, destination, return_sensitivity=True) - out = source @ rot.as_matrix() + rot, *_ = R.align_vectors(destination, source, return_sensitivity=True) + out = rot.apply(source) return out[:, :2] @@ -220,16 +220,18 @@ def build_animation( reducer = reduction_from_string(reduction) reduced_1 = standardize(reducer.get_reduction(defn_1)) reduced_2 = rotate_to_target(standardize(reducer.get_reduction(defn_2)), reduced_1) + labels = embeds.labels if reduced_1.shape[0] > 2_000: - selection = np.random.permutation(reduced_1.shape[0])[2_000] + selection = np.random.permutation(reduced_1.shape[0])[:2_000] reduced_1 = reduced_1[selection] reduced_2 = reduced_2[selection] + labels = labels[selection] return create_embedding_chart( reduced_1, reduced_2, - embeds.labels, + labels, defn_1.model, defn_2.model, suptitle=dataset.title,