diff --git a/notebooks/fashion-mnist.ipynb b/notebooks/fashion-mnist.ipynb new file mode 100644 index 0000000..2405780 --- /dev/null +++ b/notebooks/fashion-mnist.ipynb @@ -0,0 +1,425 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "d244da11-5d40-492d-adf7-a4882c8ed113", + "metadata": {}, + "source": [ + "# Fashion MNIST\n", + "\n", + "In this notebook we're comparing t-SNE and UMAP on the Fashion MNIST dataset and try to understand how the two popular embedding methods differ in terms of the visual intermixing of the image classes and relative neighborhood changes of the image classes." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "050472bd-4df5-4cad-88d6-e55bc21d0189", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pyarrow as pa\n", + "import pandas as pd\n", + "from cev.widgets import Embedding, EmbeddingComparisonWidget" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "659e7463-60a3-416c-bf0a-c9e0cd52ef5a", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import requests\n", + "from io import BytesIO\n", + "\n", + "r = requests.get(\n", + " \"https://storage.googleapis.com/flekschas/regl-scatterplot/fashion-mnist-embeddings.arrow\"\n", + ")\n", + "df = pa.ipc.open_file(BytesIO(r.content)).read_all().to_pandas()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "bbdac8a7-86c1-478b-a812-a3f92eb19082", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "cmap = {\n", + " \"T-shirt/top\": \"#FFFF00\",\n", + " \"Trouser\": \"#1CE6FF\",\n", + " \"Pullover\": \"#FF34FF\",\n", + " \"Dress\": \"#FF4A46\",\n", + " \"Coat\": \"#008941\",\n", + " \"Sandal\": \"#006FA6\",\n", + " \"Shirt\": \"#A30059\",\n", + " \"Sneaker\": \"#FFDBE5\",\n", + " \"Bag\": \"#7A4900\",\n", + " \"Ankle boot\": \"#0000A6\",\n", + "}\n", + "\n", + "labels = (\n", + " df[\"class\"]\n", + " .map({i: label for i, label in enumerate(cmap.keys())})\n", + " .astype(\"category\")\n", + ")\n", + "\n", + "tsne = df[[\"tsneX\", \"tsneY\"]].values\n", + "umap = df[[\"umapX\", \"umapY\"]].values" + ] + }, + { + "cell_type": "markdown", + "id": "11fc3963-2c9d-4c4c-afda-3f44a87b1ad8", + "metadata": {}, + "source": [ + "## Using the Image Classes as the Labels\n", + "\n", + "In this first experiment, we're using the image classes that come with the Fashion MNIST dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "c4ab34d1-7f19-421e-8eb9-28641852d6d3", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "844f58fa5a74407786ec13f02e3a8bc4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "EmbeddingComparisonWidget(children=(VBox(children=(HBox(children=(WidthOptimizer(), Dropdown(description='Metr…" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tsne_embedding = Embedding(coords=tsne, labels=labels)\n", + "umap_embedding = Embedding(coords=umap, labels=labels)\n", + "\n", + "tsne_vs_umap = EmbeddingComparisonWidget(\n", + " tsne_embedding,\n", + " umap_embedding,\n", + " titles=[\"t-SNE\", \"UMAP\"],\n", + " metric=\"confusion\",\n", + " selection=\"synced\",\n", + " auto_zoom=True,\n", + " row_height=320,\n", + ")\n", + "\n", + "tsne_vs_umap.left.categorical_scatter.color(map=cmap)\n", + "tsne_vs_umap.left.categorical_scatter.legend(True)\n", + "tsne_vs_umap.right.categorical_scatter.color(map=cmap)\n", + "tsne_vs_umap.right.categorical_scatter.legend(True)\n", + "\n", + "tsne_vs_umap" + ] + }, + { + "cell_type": "markdown", + "id": "4f306e39-e586-4e86-a438-e246a79ca107", + "metadata": {}, + "source": [ + "## Using HDBScan Clusters as the Labels\n", + "\n", + "In the next experiment, we show how we can handle the case where we want to compare two embedding methods without any label information.\n", + "\n", + "The idea is to cluster one embedding and use the cluster IDs as labels for comparing the two embedding. We will do this in both direction:\n", + "1. Using cluster derived with HDBScan from the t-SNE embedding\n", + "2. Using cluster derived with HDBScan from the UMAP embedding" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "6614285d-8d74-44c3-913b-00e2a4af0ca6", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import hdbscan" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "f9f62cdd-9a87-4cfd-acc0-a96de884f2b4", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2c23828a56dc421eb2149f1f7723de53", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "EmbeddingComparisonWidget(children=(VBox(children=(HBox(children=(WidthOptimizer(), Dropdown(description='Metr…" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tsne_clusters = hdbscan.HDBSCAN(\n", + " min_cluster_size=15, cluster_selection_epsilon=0.015\n", + ").fit_predict(tsne)\n", + "tsne_cluster_labels = pd.Series([str(i) for i in tsne_clusters]).astype(\"category\")\n", + "\n", + "tsne_based_tsne_embedding = Embedding(\n", + " coords=tsne, labels=tsne_cluster_labels, robust=tsne_clusters >= 0\n", + ")\n", + "tsne_based_umap_embedding = Embedding(\n", + " coords=umap, labels=tsne_cluster_labels, robust=tsne_clusters >= 0\n", + ")\n", + "\n", + "tsne_based_tsne_vs_umap = EmbeddingComparisonWidget(\n", + " tsne_based_tsne_embedding,\n", + " tsne_based_umap_embedding,\n", + " titles=[\"t-SNE with t-SNE Clusters\", \"UMAP with t-SNE Clusters\"],\n", + " metric=\"confusion\",\n", + " selection=\"synced\",\n", + " auto_zoom=True,\n", + " row_height=320,\n", + ")\n", + "\n", + "tsne_based_tsne_vs_umap" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "ba5b10da-870c-43fb-89bd-99b932aa43a7", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "284364489c2445cb918b94654e7b779c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "EmbeddingComparisonWidget(children=(VBox(children=(HBox(children=(WidthOptimizer(), Dropdown(description='Metr…" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "umap_clusterer = hdbscan.HDBSCAN(min_cluster_size=10, cluster_selection_epsilon=0)\n", + "umap_clusters = umap_clusterer.fit_predict(umap)\n", + "umap_cluster_labels = pd.Series([str(i) for i in umap_clusters]).astype(\"category\")\n", + "\n", + "umap_based_tsne_embedding = Embedding(\n", + " coords=tsne, labels=umap_cluster_labels, robust=umap_clusters >= 0\n", + ")\n", + "umap_based_umap_embedding = Embedding(\n", + " coords=umap, labels=umap_cluster_labels, robust=umap_clusters >= 0\n", + ")\n", + "\n", + "umap_based_tsne_vs_umap = EmbeddingComparisonWidget(\n", + " umap_based_tsne_embedding,\n", + " umap_based_umap_embedding,\n", + " titles=[\"t-SNE with UMAP Clusters\", \"UMAP with UMAP Clusters\"],\n", + " metric=\"confusion\",\n", + " selection=\"synced\",\n", + " auto_zoom=True,\n", + " row_height=320,\n", + ")\n", + "\n", + "umap_based_tsne_vs_umap" + ] + }, + { + "cell_type": "markdown", + "id": "2c4f564e-0ce1-42b9-bb16-0b4bb4668741", + "metadata": {}, + "source": [ + "# Class-Based Sub-Clusters\n", + "\n", + "Finally, we combine both approaches by relying on the classes that come with Fashion MNIST but sub-clustering each class." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "4aa404d6-3bca-4f2e-97dc-1e826ba8028d", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f4669b4dcb1843a5a4363331e7e001d9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "EmbeddingComparisonWidget(children=(VBox(children=(HBox(children=(WidthOptimizer(), Dropdown(description='Metr…" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tsne_subclass_ids = np.zeros_like(labels.values).astype(int)\n", + "\n", + "k = 0\n", + "for class_id in df[\"class\"].unique():\n", + " indices = np.where(df[\"class\"].values == class_id)[0]\n", + "\n", + " cluster_labels = hdbscan.HDBSCAN(\n", + " min_cluster_size=20, cluster_selection_epsilon=0.05\n", + " ).fit_predict(tsne[indices])\n", + "\n", + " tsne_subclass_ids[indices] = class_id + cluster_labels + k\n", + " tsne_subclass_ids[indices[np.where(cluster_labels == -1)]] = -1\n", + " k += np.max(cluster_labels) + 1\n", + "\n", + "tsne_subcluster_labels = pd.Series([str(x) for x in tsne_subclass_ids]).astype(\n", + " \"category\"\n", + ")\n", + "\n", + "tsne_subcluster_based_tsne_embedding = Embedding(\n", + " coords=tsne, labels=tsne_subcluster_labels, robust=tsne_subclass_ids >= 0\n", + ")\n", + "tsne_subcluster_based_umap_embedding = Embedding(\n", + " coords=umap, labels=tsne_subcluster_labels, robust=tsne_subclass_ids >= 0\n", + ")\n", + "\n", + "tsne_subcluster_based_tsne_vs_umap = EmbeddingComparisonWidget(\n", + " tsne_subcluster_based_tsne_embedding,\n", + " tsne_subcluster_based_umap_embedding,\n", + " titles=[\"t-SNE with t-SNE Sub-Clusters\", \"UMAP with t-SNE Sub-Clusters\"],\n", + " metric=\"confusion\",\n", + " selection=\"synced\",\n", + " auto_zoom=True,\n", + " row_height=320,\n", + ")\n", + "\n", + "tsne_subcluster_based_tsne_vs_umap" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "872549dd-a176-45ec-9abb-d416ada06863", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b5446df83e6f4c44a6b9f275362ba538", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "EmbeddingComparisonWidget(children=(VBox(children=(HBox(children=(WidthOptimizer(), Dropdown(description='Metr…" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "umap_subclass_ids = np.zeros_like(labels.values).astype(int)\n", + "\n", + "k = 0\n", + "for class_id in df[\"class\"].unique():\n", + " indices = np.where(df[\"class\"].values == class_id)[0]\n", + "\n", + " cluster_labels = hdbscan.HDBSCAN(\n", + " min_cluster_size=20, cluster_selection_epsilon=0.05\n", + " ).fit_predict(umap[indices])\n", + "\n", + " umap_subclass_ids[indices] = class_id + cluster_labels + k\n", + " umap_subclass_ids[indices[np.where(cluster_labels == -1)]] = -1\n", + " k += np.max(cluster_labels) + 1\n", + "\n", + "umap_subcluster_labels = pd.Series([str(x) for x in umap_subclass_ids]).astype(\n", + " \"category\"\n", + ")\n", + "\n", + "umap_subcluster_based_tsne_embedding = Embedding(\n", + " coords=tsne, labels=umap_subcluster_labels, robust=umap_subclass_ids >= 0\n", + ")\n", + "umap_subcluster_based_umap_embedding = Embedding(\n", + " coords=umap, labels=umap_subcluster_labels, robust=umap_subclass_ids >= 0\n", + ")\n", + "\n", + "umap_subcluster_based_tsne_vs_umap = EmbeddingComparisonWidget(\n", + " umap_subcluster_based_tsne_embedding,\n", + " umap_subcluster_based_umap_embedding,\n", + " titles=[\"t-SNE with UMAP Sub-Clusters\", \"UMAP with UMAP Sub-Clusters\"],\n", + " metric=\"confusion\",\n", + " selection=\"synced\",\n", + " auto_zoom=True,\n", + " row_height=320,\n", + ")\n", + "\n", + "umap_subcluster_based_tsne_vs_umap" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "02b7a7bc-d086-44b6-b4f7-7551cde9bf8e", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/cev/_embedding_widget.py b/src/cev/_embedding_widget.py index 7dea648..44c1cc2 100644 --- a/src/cev/_embedding_widget.py +++ b/src/cev/_embedding_widget.py @@ -11,7 +11,12 @@ import traitlets from cev._embedding import Embedding -from cev._widget_utils import create_colormaps, link_widgets, robust_labels +from cev._widget_utils import ( + NON_ROBUST_LABEL, + create_colormaps, + link_widgets, + robust_labels, +) from cev.components import MarkerCompositionLogo _LABEL_COLUMN = "label" @@ -83,7 +88,13 @@ def _on_labels_change(self, change): np.asarray(self._labeler(labels)), dtype="category" ) self.logo.counts = self.label_counts(self.categorical_scatter.widget.selection) - self.has_markers = "+" in self._data[_LABEL_COLUMN][0] + self.has_markers = ( + isinstance(self._data[_LABEL_COLUMN][0], str) + and "+" in self._data[_LABEL_COLUMN][0] + ) + self.metric_scatter.filter( + np.argwhere(self.robust_labels.values != NON_ROBUST_LABEL) + ) @traitlets.validate("distances") def _validate_distances(self, proposal: object):