From 7bc1eb5ad2f6f126d7d68bdc772da26586b7c12f Mon Sep 17 00:00:00 2001 From: Tom Denton Date: Thu, 28 Nov 2024 20:20:00 -0800 Subject: [PATCH] Save models and write inference results. PiperOrigin-RevId: 701159583 --- .github/workflows/ci_pip.yml | 2 +- README.md | 2 +- hoplite/agile/2_agile_modeling_v2.ipynb | 43 +++++++--- hoplite/agile/classifier.py | 109 +++++++++++++++++++++++- hoplite/agile/tests/classifier_test.py | 74 ++++++++++++++++ 5 files changed, 214 insertions(+), 16 deletions(-) create mode 100644 hoplite/agile/tests/classifier_test.py diff --git a/.github/workflows/ci_pip.yml b/.github/workflows/ci_pip.yml index 9f950ac..5dfa77c 100644 --- a/.github/workflows/ci_pip.yml +++ b/.github/workflows/ci_pip.yml @@ -29,7 +29,7 @@ jobs: pip install absl-py pip install requests pip install tensorflow-cpu - pip install git+https://github.com/googlestaging/hoplite.git + pip install git+https://github.com/google-research/hoplite.git - name: Test db with unittest run: python -m unittest discover -s hoplite/db/tests -p "*test.py" - name: Test taxonomy with unittest diff --git a/README.md b/README.md index 54c7b36..24e43b0 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Perch Hoplite +# Hoplite Vector Database Hoplite is a system for storing large volumes of embeddings from machine perception models. We focus on combining vector search with active learning diff --git a/hoplite/agile/2_agile_modeling_v2.ipynb b/hoplite/agile/2_agile_modeling_v2.ipynb index 3ba1131..cee0e0a 100644 --- a/hoplite/agile/2_agile_modeling_v2.ipynb +++ b/hoplite/agile/2_agile_modeling_v2.ipynb @@ -10,8 +10,10 @@ "source": [ "#@title Imports. { vertical-output: true }\n", "\n", - "import numpy as np\n", + "import os\n", + "\n", "from matplotlib import pyplot as plt\n", + "import numpy as np\n", "\n", "from hoplite.agile import audio_loader\n", "from hoplite.agile import classifier\n", @@ -107,9 +109,7 @@ "if not target_sampling:\n", " target_score = None\n", "score_fn = score_functions.get_score_fn('dot', target_score=target_score)\n", - "results, all_scores = brutalism.threaded_brute_search(\n", - " db, query_embedding, num_results, score_fn=score_fn,\n", - " sample_size=sample_size)\n", + "results, all_scores = db.search(query_embedding, num_results, exact=True)\n", "\n", "# TODO(tomdenton): Better histogram when target sampling.\n", "_ = plt.hist(all_scores, bins=100)\n", @@ -175,7 +175,7 @@ "#@markdown Set of labels to classify. If None, auto-populated from the DB.\n", "target_labels = None #@param\n", "\n", - "#@markdown Classifier traning params. These should not require tuning.\n", + "#@markdown Classifier traning hyperparams. These should not require tuning.\n", "learning_rate = 1e-3 #@param\n", "weak_neg_weight = 0.05 #@param\n", "l2_mu = 0.000 #@param\n", @@ -196,13 +196,11 @@ " rng=np.random.default_rng(seed=5))\n", "print('Training for target labels : ')\n", "print(data_manager.get_target_labels())\n", - "params, eval_scores = classifier.train_linear_classifier(\n", + "linear_classifier, eval_scores = classifier.train_linear_classifier(\n", " data_manager=data_manager,\n", " learning_rate=learning_rate,\n", " weak_neg_weight=weak_neg_weight,\n", - " l2_mu=l2_mu,\n", " num_train_steps=num_steps,\n", - " loss_name=loss_fn_name,\n", ")\n", "print('\\n' + '-' * 80)\n", "top1 = eval_scores['top1_acc']\n", @@ -210,7 +208,10 @@ "rocauc = eval_scores['roc_auc']\n", "print(f'roc_auc {rocauc:.3f}')\n", "cmap = eval_scores['cmap']\n", - "print(f'cmap {cmap:.3f}')" + "print(f'cmap {cmap:.3f}')\n", + "\n", + "# Save linear classifier.\n", + "linear_classifier.save(os.path.join(db_path, 'agile_classifier_v2.pt'))" ] }, { @@ -228,8 +229,8 @@ "num_results = 50 #@param\n", "\n", "target_label_idx = data_manager.get_target_labels().index(target_label)\n", - "class_query = params['beta'][:, target_label_idx]\n", - "bias = params['beta_bias'][target_label_idx]\n", + "class_query = linear_classifier.beta[:, target_label_idx]\n", + "bias = linear_classifier.beta_bias[target_label_idx]\n", "\n", "#@markdown Number of (randomly selected) database entries to search over.\n", "sample_size = 1_000_000 #@param\n", @@ -289,6 +290,26 @@ "print('\\nnew_lbls: ', new_lbls)\n", "print('\\nprev_lbls: ', prev_lbls)" ] + }, + { + "metadata": { + "id": "kBH-2kz4SaS2" + }, + "cell_type": "code", + "source": [ + "#@title Run inference with trained classifier. { vertical-output: true }\n", + "\n", + "output_csv_filepath = '' #@param {type:'string'}\n", + "logit_threshold = 1.0 #@param\n", + "# Set labels to a tuple of desired labels if you want to run inference on a\n", + "# subset of the labels.\n", + "labels = None #@param\n", + "\n", + "classifier.write_inference_csv(\n", + " linear_classifier, db, output_csv_filepath, logit_threshold, labels=labels)\n" + ], + "outputs": [], + "execution_count": null } ], "metadata": { diff --git a/hoplite/agile/classifier.py b/hoplite/agile/classifier.py index f1f4edc..7abbfb5 100644 --- a/hoplite/agile/classifier.py +++ b/hoplite/agile/classifier.py @@ -15,15 +15,64 @@ """Functions for training and applying a linear classifier.""" -from typing import Any +import base64 +import dataclasses +import json +from typing import Any, Sequence from hoplite.agile import classifier_data from hoplite.agile import metrics +from hoplite.db import interface as db_interface +from hoplite.taxonomy import namespace +from ml_collections import config_dict import numpy as np import tensorflow as tf import tqdm +@dataclasses.dataclass +class LinearClassifier: + """Wrapper for linear classifier params and metadata.""" + + beta: np.ndarray + beta_bias: np.ndarray + classes: tuple[str, ...] + embedding_model_config: Any + + def __call__(self, embeddings: np.ndarray): + return np.dot(embeddings, self.beta) + self.beta_bias + + def save(self, path: str): + """Save the classifier to a path.""" + cfg = config_dict.ConfigDict() + cfg.model_config = self.embedding_model_config + cfg.classes = self.classes + # Convert numpy arrays to base64 encoded blobs. + beta_bytes = base64.b64encode(np.float32(self.beta).tobytes()).decode( + 'ascii' + ) + beta_bias_bytes = base64.b64encode( + np.float32(self.beta_bias).tobytes() + ).decode('ascii') + cfg.beta = beta_bytes + cfg.beta_bias = beta_bias_bytes + with open(path, 'w') as f: + f.write(cfg.to_json()) + + @classmethod + def load(cls, path: str): + """Load a classifier from a path.""" + with open(path, 'r') as f: + cfg_json = json.loads(f.read()) + cfg = config_dict.ConfigDict(cfg_json) + classes = cfg.classes + beta = np.frombuffer(base64.b64decode(cfg.beta), dtype=np.float32) + beta = np.reshape(beta, (-1, len(classes))) + beta_bias = np.frombuffer(base64.b64decode(cfg.beta_bias), dtype=np.float32) + embedding_model_config = cfg.model_config + return cls(beta, beta_bias, classes, embedding_model_config) + + def get_linear_model(embedding_dim: int, num_classes: int) -> tf.keras.Model: """Create a simple linear Keras model.""" model = tf.keras.Sequential([ @@ -105,7 +154,7 @@ def train_linear_classifier( learning_rate: float, weak_neg_weight: float, num_train_steps: int, -): +) -> tuple[LinearClassifier, dict[str, float]]: """Train a linear classifier.""" embedding_dim = data_manager.db.embedding_dimension() num_classes = len(data_manager.get_target_labels()) @@ -147,4 +196,58 @@ def train_step(y_true, embeddings, is_labeled_mask): 'beta_bias': lin_model.get_weights()[1], } eval_scores = eval_classifier(params, data_manager, eval_idxes) - return params, eval_scores + + model_config = data_manager.db.get_metadata('model_config') + linear_classifier = LinearClassifier( + beta=params['beta'], + beta_bias=params['beta_bias'], + classes=data_manager.get_target_labels(), + embedding_model_config=model_config, + ) + return linear_classifier, eval_scores + + +def write_inference_csv( + linear_classifier: LinearClassifier, + db: db_interface.HopliteDBInterface, + output_filepath: str, + threshold: float, + labels: Sequence[str] | None = None, +): + """Write a CSV for all audio windows with logits above a threshold. + + Args: + params: The parameters of the linear classifier. + class_list: The class list of labels associated with the classifier. + db: HopliteDBInterface to read embeddings from. + output_filepath: Path to write the CSV to. + threshold: Logits must be above this value to be written. + labels: If provided, only write logits for these labels. If None, write + logits for all labels. + + Returns: + None + """ + idxes = db.get_embedding_ids() + if labels is None: + labels = linear_classifier.classes + label_ids = {cl: i for i, cl in enumerate(linear_classifier.classes)} + target_label_ids = np.array([label_ids[l] for l in labels]) + logits_fn = lambda emb: linear_classifier(emb)[target_label_ids] + with open(output_filepath, 'w') as f: + f.write('idx,dataset_name,source_id,offset,label,logits\n') + for idx in tqdm.tqdm(idxes): + source = db.get_embedding_source(idx) + emb = db.get_embedding(idx) + logits = logits_fn(emb) + for a in np.argwhere(logits > threshold): + lbl = labels[a] + row = [ + idx, + source.dataset_name, + source.source_id, + source.offsets[0], + lbl, + logits[a], + ] + f.write(','.join(map(str, row)) + '\n') diff --git a/hoplite/agile/tests/classifier_test.py b/hoplite/agile/tests/classifier_test.py new file mode 100644 index 0000000..c3b4313 --- /dev/null +++ b/hoplite/agile/tests/classifier_test.py @@ -0,0 +1,74 @@ +# coding=utf-8 +# Copyright 2024 The Perch Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for linear classifier implementation.""" + +import os +import tempfile + +from hoplite.agile import classifier +from ml_collections import config_dict +import numpy as np + +from absl.testing import absltest + + +class ClassifierTest(absltest.TestCase): + + def setUp(self): + super().setUp() + # `self.create_tempdir()` raises an UnparsedFlagAccessError, which is why + # we use `tempdir` directly. + self.tempdir = tempfile.mkdtemp() + + def _make_linear_classifier(self, embedding_dim, classes): + np.random.seed(1234) + beta = np.float32(np.random.normal(size=(embedding_dim, len(classes)))) + beta_bias = np.float32(np.random.normal(size=(len(classes),))) + embedding_model_config = config_dict.ConfigDict({ + 'model_name': 'nelson', + }) + return classifier.LinearClassifier( + beta, beta_bias, classes, embedding_model_config + ) + + def test_call_linear_classifier(self): + embedding_dim = 8 + classes = ('a', 'b', 'c') + classy = self._make_linear_classifier(embedding_dim, classes) + + batch_embeddings = np.random.normal(size=(10, embedding_dim)) + predictions = classy(batch_embeddings) + self.assertEqual(predictions.shape, (10, len(classes))) + + single_embedding = np.random.normal(size=(embedding_dim,)) + predictions = classy(single_embedding) + self.assertEqual(predictions.shape, (len(classes),)) + + def test_save_load_linear_classifier(self): + embedding_dim = 8 + classes = ('a', 'b', 'c') + classy = self._make_linear_classifier(embedding_dim, classes) + classy_path = os.path.join(self.tempdir, 'classifier.json') + classy.save(classy_path) + classy_loaded = classifier.LinearClassifier.load(classy_path) + np.testing.assert_allclose(classy_loaded.beta, classy.beta) + np.testing.assert_allclose(classy_loaded.beta_bias, classy.beta_bias) + self.assertSequenceEqual(classy_loaded.classes, classy.classes) + self.assertEqual(classy_loaded.embedding_model_config.model_name, 'nelson') + + +if __name__ == '__main__': + absltest.main()