From 748c1538a778f661c26813002e849051af9a0114 Mon Sep 17 00:00:00 2001 From: Tom Denton Date: Thu, 28 Nov 2024 20:20:00 -0800 Subject: [PATCH] Save models and write inference results. Also fix a flipped sign in BCE loss, and fix some small problems in agile modeling notebook. PiperOrigin-RevId: 701159583 --- hoplite/agile/2_agile_modeling_v2.ipynb | 85 ++++++++--- hoplite/agile/audio_loader.py | 10 +- hoplite/agile/classifier.py | 139 +++++++++++++++++- hoplite/agile/tests/classifier_test.py | 74 ++++++++++ hoplite/agile/tests/embedding_display_test.py | 12 +- hoplite/audio_io.py | 2 +- 6 files changed, 283 insertions(+), 39 deletions(-) create mode 100644 hoplite/agile/tests/classifier_test.py diff --git a/hoplite/agile/2_agile_modeling_v2.ipynb b/hoplite/agile/2_agile_modeling_v2.ipynb index 3ba1131..ad6960d 100644 --- a/hoplite/agile/2_agile_modeling_v2.ipynb +++ b/hoplite/agile/2_agile_modeling_v2.ipynb @@ -10,15 +10,19 @@ "source": [ "#@title Imports. { vertical-output: true }\n", "\n", - "import numpy as np\n", + "import os\n", + "\n", "from matplotlib import pyplot as plt\n", + "import numpy as np\n", "\n", "from hoplite.agile import audio_loader\n", "from hoplite.agile import classifier\n", "from hoplite.agile import classifier_data\n", "from hoplite.agile import embedding_display\n", + "from hoplite.agile import source_info\n", "from hoplite.db import brutalism\n", "from hoplite.db import score_functions\n", + "from hoplite.db import search_results\n", "from hoplite.db import sqlite_usearch_impl\n", "from hoplite.zoo import model_configs\n" ] @@ -40,13 +44,18 @@ "\n", "db = sqlite_usearch_impl.SQLiteUsearchDB.create(db_path)\n", "db_model_config = db.get_metadata('model_config')\n", - "embed_config = db.get_metadata('embed_config')\n", + "embed_config = db.get_metadata('audio_sources')\n", "model_class = model_configs.MODEL_CLASS_MAP[db_model_config.model_key]\n", "embedding_model = model_class.from_config(db_model_config.model_config)\n", - "\n", + "audio_sources = source_info.AudioSources.from_config_dict(\n", + " embed_config.audio_sources)\n", + "if hasattr(embedding_model, 'window_size_s'):\n", + " window_size_s = embedding_model.window_size_s\n", + "else:\n", + " window_size_s = 5.0\n", "audio_filepath_loader = audio_loader.make_filepath_loader(\n", - " audio_globs=embed_config.audio_globs,\n", - " window_size_s=embedding_model.window_size_s,\n", + " audio_sources=audio_sources,\n", + " window_size_s=window_size_s,\n", " sample_rate_hz=embedding_model.sample_rate,\n", ")" ] @@ -94,8 +103,6 @@ "num_results = 50 #@param\n", "query_embedding = embedding_model.embed(\n", " query.get_audio_window()).embeddings[0, 0]\n", - "#@markdown Number of (randomly selected) database entries to search over.\n", - "sample_size = 1_000_000 #@param\n", "\n", "#@markdown If checked, search for examples\n", "#@markdown near a particular target score.\n", @@ -103,19 +110,28 @@ "\n", "#@markdown When target sampling, target this score.\n", "target_score = -1.0 #@param\n", - "\n", "if not target_sampling:\n", " target_score = None\n", - "score_fn = score_functions.get_score_fn('dot', target_score=target_score)\n", - "results, all_scores = brutalism.threaded_brute_search(\n", - " db, query_embedding, num_results, score_fn=score_fn,\n", - " sample_size=sample_size)\n", "\n", - "# TODO(tomdenton): Better histogram when target sampling.\n", - "_ = plt.hist(all_scores, bins=100)\n", - "hit_scores = [r.sort_score for r in results.search_results]\n", - "plt.scatter(hit_scores, np.zeros_like(hit_scores), marker='|',\n", - " color='r', alpha=0.5)\n" + "#@markdown If True, search the full DB. Otherwise, use approximate\n", + "#@markdown nearest-neighbor search.\n", + "exact_search = False #@param {type: 'boolean'}\n", + "\n", + "if exact_search:\n", + " target_score = None\n", + " score_fn = score_functions.get_score_fn('dot', target_score=target_score)\n", + " results, all_scores = brutalism.threaded_brute_search(\n", + " db, query_embedding, num_results, score_fn=score_fn)\n", + " # TODO(tomdenton): Better histogram when target sampling.\n", + " _ = plt.hist(all_scores, bins=100)\n", + " hit_scores = [r.sort_score for r in results.search_results]\n", + " plt.scatter(hit_scores, np.zeros_like(hit_scores), marker='|',\n", + " color='r', alpha=0.5)\n", + "else:\n", + " ann_matches = db.ui.search(query_embedding, count=num_results)\n", + " results = search_results.TopKSearchResults(top_k=num_results)\n", + " for k, d in zip(ann_matches.keys, ann_matches.distances):\n", + " results.update(search_results.SearchResult(k, d))\n" ] }, { @@ -175,7 +191,7 @@ "#@markdown Set of labels to classify. If None, auto-populated from the DB.\n", "target_labels = None #@param\n", "\n", - "#@markdown Classifier traning params. These should not require tuning.\n", + "#@markdown Classifier traning hyperparams. These should not require tuning.\n", "learning_rate = 1e-3 #@param\n", "weak_neg_weight = 0.05 #@param\n", "l2_mu = 0.000 #@param\n", @@ -196,13 +212,11 @@ " rng=np.random.default_rng(seed=5))\n", "print('Training for target labels : ')\n", "print(data_manager.get_target_labels())\n", - "params, eval_scores = classifier.train_linear_classifier(\n", + "linear_classifier, eval_scores = classifier.train_linear_classifier(\n", " data_manager=data_manager,\n", " learning_rate=learning_rate,\n", " weak_neg_weight=weak_neg_weight,\n", - " l2_mu=l2_mu,\n", " num_train_steps=num_steps,\n", - " loss_name=loss_fn_name,\n", ")\n", "print('\\n' + '-' * 80)\n", "top1 = eval_scores['top1_acc']\n", @@ -210,7 +224,10 @@ "rocauc = eval_scores['roc_auc']\n", "print(f'roc_auc {rocauc:.3f}')\n", "cmap = eval_scores['cmap']\n", - "print(f'cmap {cmap:.3f}')" + "print(f'cmap {cmap:.3f}')\n", + "\n", + "# Save linear classifier.\n", + "linear_classifier.save(os.path.join(db_path, 'agile_classifier_v2.pt'))" ] }, { @@ -228,8 +245,8 @@ "num_results = 50 #@param\n", "\n", "target_label_idx = data_manager.get_target_labels().index(target_label)\n", - "class_query = params['beta'][:, target_label_idx]\n", - "bias = params['beta_bias'][target_label_idx]\n", + "class_query = linear_classifier.beta[:, target_label_idx]\n", + "bias = linear_classifier.beta_bias[target_label_idx]\n", "\n", "#@markdown Number of (randomly selected) database entries to search over.\n", "sample_size = 1_000_000 #@param\n", @@ -289,6 +306,26 @@ "print('\\nnew_lbls: ', new_lbls)\n", "print('\\nprev_lbls: ', prev_lbls)" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kBH-2kz4SaS2" + }, + "outputs": [], + "source": [ + "#@title Run inference with trained classifier. { vertical-output: true }\n", + "\n", + "output_csv_filepath = '' #@param {type:'string'}\n", + "logit_threshold = 1.0 #@param\n", + "# Set labels to a tuple of desired labels if you want to run inference on a\n", + "# subset of the labels.\n", + "labels = None #@param\n", + "\n", + "classifier.write_inference_csv(\n", + " linear_classifier, db, output_csv_filepath, logit_threshold, labels=labels)\n" + ] } ], "metadata": { diff --git a/hoplite/agile/audio_loader.py b/hoplite/agile/audio_loader.py index 10224f0..365fd92 100644 --- a/hoplite/agile/audio_loader.py +++ b/hoplite/agile/audio_loader.py @@ -20,11 +20,12 @@ from etils import epath from hoplite import audio_io +from hoplite.agile import source_info import numpy as np def make_filepath_loader( - audio_globs: dict[str, tuple[str, str]], + audio_sources: source_info.AudioSources, sample_rate_hz: int = 32000, window_size_s: float = 5.0, dtype: str = 'float32', @@ -34,8 +35,7 @@ def make_filepath_loader( Note that if multiple globs match a given source ID, the first match is used. Args: - audio_globs: Mapping from dataset name to pairs of `(root directory, file - glob)`. (See `embed.EmbedConfig` for details.) + audio_sources: Embedding audio sources. sample_rate_hz: Sample rate of the audio. window_size_s: Window size of the audio. dtype: Data type of the audio. @@ -49,8 +49,8 @@ def make_filepath_loader( def loader(source_id: str, offset_s: float) -> np.ndarray: found_path = None - for base_path, _ in audio_globs.values(): - path = epath.Path(base_path) / source_id + for audio_source in audio_sources.audio_globs: + path = epath.Path(audio_source.base_path) / source_id if path.exists(): found_path = path break diff --git a/hoplite/agile/classifier.py b/hoplite/agile/classifier.py index f1f4edc..71dc763 100644 --- a/hoplite/agile/classifier.py +++ b/hoplite/agile/classifier.py @@ -15,15 +15,64 @@ """Functions for training and applying a linear classifier.""" -from typing import Any +import base64 +import dataclasses +import json +from typing import Any, Sequence from hoplite.agile import classifier_data from hoplite.agile import metrics +from hoplite.db import interface as db_interface +from hoplite.taxonomy import namespace +from ml_collections import config_dict import numpy as np import tensorflow as tf import tqdm +@dataclasses.dataclass +class LinearClassifier: + """Wrapper for linear classifier params and metadata.""" + + beta: np.ndarray + beta_bias: np.ndarray + classes: tuple[str, ...] + embedding_model_config: Any + + def __call__(self, embeddings: np.ndarray): + return np.dot(embeddings, self.beta) + self.beta_bias + + def save(self, path: str): + """Save the classifier to a path.""" + cfg = config_dict.ConfigDict() + cfg.model_config = self.embedding_model_config + cfg.classes = self.classes + # Convert numpy arrays to base64 encoded blobs. + beta_bytes = base64.b64encode(np.float32(self.beta).tobytes()).decode( + 'ascii' + ) + beta_bias_bytes = base64.b64encode( + np.float32(self.beta_bias).tobytes() + ).decode('ascii') + cfg.beta = beta_bytes + cfg.beta_bias = beta_bias_bytes + with open(path, 'w') as f: + f.write(cfg.to_json()) + + @classmethod + def load(cls, path: str): + """Load a classifier from a path.""" + with open(path, 'r') as f: + cfg_json = json.loads(f.read()) + cfg = config_dict.ConfigDict(cfg_json) + classes = cfg.classes + beta = np.frombuffer(base64.b64decode(cfg.beta), dtype=np.float32) + beta = np.reshape(beta, (-1, len(classes))) + beta_bias = np.frombuffer(base64.b64decode(cfg.beta_bias), dtype=np.float32) + embedding_model_config = cfg.model_config + return cls(beta, beta_bias, classes, embedding_model_config) + + def get_linear_model(embedding_dim: int, num_classes: int) -> tf.keras.Model: """Create a simple linear Keras model.""" model = tf.keras.Sequential([ @@ -43,12 +92,28 @@ def bce_loss( y_true = tf.cast(y_true, dtype=logits.dtype) log_p = tf.math.log_sigmoid(logits) log_not_p = tf.math.log_sigmoid(-logits) - raw_bce = -y_true * log_p + (1.0 - y_true) * log_not_p + # optax sigmoid_binary_cross_entropy: + # -labels * log_p - (1.0 - labels) * log_not_p + raw_bce = -y_true * log_p - (1.0 - y_true) * log_not_p is_labeled_mask = tf.cast(is_labeled_mask, dtype=logits.dtype) weights = (1.0 - is_labeled_mask) * weak_neg_weight + is_labeled_mask return tf.reduce_mean(raw_bce * weights) +def hinge_loss( + y_true: tf.Tensor, + logits: tf.Tensor, + is_labeled_mask: tf.Tensor, + weak_neg_weight: float, +) -> tf.Tensor: + """Weighted SVM hinge loss.""" + # Convert multihot to +/- 1 labels. + y_true = 2 * y_true - 1 + weights = (1.0 - is_labeled_mask) * weak_neg_weight + is_labeled_mask + raw_hinge_loss = tf.maximum(0, 1 - y_true * logits) + return tf.reduce_mean(raw_hinge_loss * weights) + + def infer(params, embeddings: np.ndarray): """Apply the model to embeddings.""" return np.dot(embeddings, params['beta']) + params['beta_bias'] @@ -105,19 +170,26 @@ def train_linear_classifier( learning_rate: float, weak_neg_weight: float, num_train_steps: int, -): + loss: str = 'bce', +) -> tuple[LinearClassifier, dict[str, float]]: """Train a linear classifier.""" embedding_dim = data_manager.db.embedding_dimension() num_classes = len(data_manager.get_target_labels()) lin_model = get_linear_model(embedding_dim, num_classes) optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) lin_model.compile(optimizer=optimizer, loss='binary_crossentropy') + if loss == 'hinge': + loss_fn = hinge_loss + elif loss == 'bce': + loss_fn = bce_loss + else: + raise ValueError(f'Unknown loss: {loss}') @tf.function def train_step(y_true, embeddings, is_labeled_mask): with tf.GradientTape() as tape: logits = lin_model(embeddings, training=True) - loss = bce_loss(y_true, logits, is_labeled_mask, weak_neg_weight) + loss = loss_fn(y_true, logits, is_labeled_mask, weak_neg_weight) loss = tf.reduce_mean(loss) grads = tape.gradient(loss, lin_model.trainable_variables) optimizer.apply_gradients(zip(grads, lin_model.trainable_variables)) @@ -147,4 +219,61 @@ def train_step(y_true, embeddings, is_labeled_mask): 'beta_bias': lin_model.get_weights()[1], } eval_scores = eval_classifier(params, data_manager, eval_idxes) - return params, eval_scores + + model_config = data_manager.db.get_metadata('model_config') + linear_classifier = LinearClassifier( + beta=params['beta'], + beta_bias=params['beta_bias'], + classes=data_manager.get_target_labels(), + embedding_model_config=model_config, + ) + return linear_classifier, eval_scores + + +def write_inference_csv( + linear_classifier: LinearClassifier, + db: db_interface.HopliteDBInterface, + output_filepath: str, + threshold: float, + labels: Sequence[str] | None = None, +): + """Write a CSV for all audio windows with logits above a threshold. + + Args: + params: The parameters of the linear classifier. + class_list: The class list of labels associated with the classifier. + db: HopliteDBInterface to read embeddings from. + output_filepath: Path to write the CSV to. + threshold: Logits must be above this value to be written. + labels: If provided, only write logits for these labels. If None, write + logits for all labels. + + Returns: + None + """ + idxes = db.get_embedding_ids() + if labels is None: + labels = linear_classifier.classes + label_ids = {cl: i for i, cl in enumerate(linear_classifier.classes)} + target_label_ids = np.array([label_ids[l] for l in labels]) + logits_fn = lambda emb: linear_classifier(emb)[target_label_ids] + detection_count = 0 + with open(output_filepath, 'w') as f: + f.write('idx,dataset_name,source_id,offset,label,logits\n') + for idx in tqdm.tqdm(idxes): + source = db.get_embedding_source(idx) + emb = db.get_embedding(idx) + logits = logits_fn(emb) + for a in np.argwhere(logits > threshold): + lbl = labels[a[0]] + row = [ + idx, + source.dataset_name, + source.source_id, + source.offsets[0], + lbl, + logits[a], + ] + f.write(','.join(map(str, row)) + '\n') + detection_count += 1 + print(f'Wrote {detection_count} detections to {output_filepath}') diff --git a/hoplite/agile/tests/classifier_test.py b/hoplite/agile/tests/classifier_test.py new file mode 100644 index 0000000..c3b4313 --- /dev/null +++ b/hoplite/agile/tests/classifier_test.py @@ -0,0 +1,74 @@ +# coding=utf-8 +# Copyright 2024 The Perch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for linear classifier implementation.""" + +import os +import tempfile + +from hoplite.agile import classifier +from ml_collections import config_dict +import numpy as np + +from absl.testing import absltest + + +class ClassifierTest(absltest.TestCase): + + def setUp(self): + super().setUp() + # `self.create_tempdir()` raises an UnparsedFlagAccessError, which is why + # we use `tempdir` directly. + self.tempdir = tempfile.mkdtemp() + + def _make_linear_classifier(self, embedding_dim, classes): + np.random.seed(1234) + beta = np.float32(np.random.normal(size=(embedding_dim, len(classes)))) + beta_bias = np.float32(np.random.normal(size=(len(classes),))) + embedding_model_config = config_dict.ConfigDict({ + 'model_name': 'nelson', + }) + return classifier.LinearClassifier( + beta, beta_bias, classes, embedding_model_config + ) + + def test_call_linear_classifier(self): + embedding_dim = 8 + classes = ('a', 'b', 'c') + classy = self._make_linear_classifier(embedding_dim, classes) + + batch_embeddings = np.random.normal(size=(10, embedding_dim)) + predictions = classy(batch_embeddings) + self.assertEqual(predictions.shape, (10, len(classes))) + + single_embedding = np.random.normal(size=(embedding_dim,)) + predictions = classy(single_embedding) + self.assertEqual(predictions.shape, (len(classes),)) + + def test_save_load_linear_classifier(self): + embedding_dim = 8 + classes = ('a', 'b', 'c') + classy = self._make_linear_classifier(embedding_dim, classes) + classy_path = os.path.join(self.tempdir, 'classifier.json') + classy.save(classy_path) + classy_loaded = classifier.LinearClassifier.load(classy_path) + np.testing.assert_allclose(classy_loaded.beta, classy.beta) + np.testing.assert_allclose(classy_loaded.beta_bias, classy.beta_bias) + self.assertSequenceEqual(classy_loaded.classes, classy.classes) + self.assertEqual(classy_loaded.embedding_model_config.model_name, 'nelson') + + +if __name__ == '__main__': + absltest.main() diff --git a/hoplite/agile/tests/embedding_display_test.py b/hoplite/agile/tests/embedding_display_test.py index 98660a5..67692db 100644 --- a/hoplite/agile/tests/embedding_display_test.py +++ b/hoplite/agile/tests/embedding_display_test.py @@ -22,6 +22,7 @@ from hoplite.agile import audio_loader from hoplite.agile import embedding_display +from hoplite.agile import source_info from hoplite.agile.tests import test_utils from hoplite.db import interface import IPython @@ -170,11 +171,14 @@ def test_embedding_display_group(self): sample_rate_hz=sample_rate_hz, ) - audio_globs = { - 'test_dataset': (self.tempdir, '*/*.wav'), - } + audio_source = source_info.AudioSourceConfig( + dataset_name='test_dataset', + base_path=self.tempdir, + file_glob='*/*.wav', + ) + audio_sources = source_info.AudioSources((audio_source,)) filepath_audio_loader = audio_loader.make_filepath_loader( - audio_globs, sample_rate_hz=sample_rate_hz, window_size_s=1.0 + audio_sources, sample_rate_hz=sample_rate_hz, window_size_s=1.0 ) group = embedding_display.EmbeddingDisplayGroup.create( members=[member0, member1], diff --git a/hoplite/audio_io.py b/hoplite/audio_io.py index e8dd122..88591d9 100644 --- a/hoplite/audio_io.py +++ b/hoplite/audio_io.py @@ -115,7 +115,7 @@ def load_audio_window_soundfile( with epath.Path(filepath).open('rb') as f: sf = soundfile.SoundFile(f) if offset_s > 0: - offset = int(offset_s * sf.samplerate) + offset = int(np.float32(offset_s) * sf.samplerate) sf.seek(offset) if window_size_s < 0: a = sf.read()