Skip to content

Commit

Permalink
Save models and write inference results.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 701159583
  • Loading branch information
sdenton4 committed Nov 29, 2024
1 parent b9eaec7 commit 7bc1eb5
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci_pip.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
pip install absl-py
pip install requests
pip install tensorflow-cpu
pip install git+https://github.com/googlestaging/hoplite.git
pip install git+https://github.com/google-research/hoplite.git
- name: Test db with unittest
run: python -m unittest discover -s hoplite/db/tests -p "*test.py"
- name: Test taxonomy with unittest
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Perch Hoplite
# Hoplite Vector Database

Hoplite is a system for storing large volumes of embeddings from machine
perception models. We focus on combining vector search with active learning
Expand Down
43 changes: 32 additions & 11 deletions hoplite/agile/2_agile_modeling_v2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
"source": [
"#@title Imports. { vertical-output: true }\n",
"\n",
"import numpy as np\n",
"import os\n",
"\n",
"from matplotlib import pyplot as plt\n",
"import numpy as np\n",
"\n",
"from hoplite.agile import audio_loader\n",
"from hoplite.agile import classifier\n",
Expand Down Expand Up @@ -107,9 +109,7 @@
"if not target_sampling:\n",
" target_score = None\n",
"score_fn = score_functions.get_score_fn('dot', target_score=target_score)\n",
"results, all_scores = brutalism.threaded_brute_search(\n",
" db, query_embedding, num_results, score_fn=score_fn,\n",
" sample_size=sample_size)\n",
"results, all_scores = db.search(query_embedding, num_results, exact=True)\n",
"\n",
"# TODO(tomdenton): Better histogram when target sampling.\n",
"_ = plt.hist(all_scores, bins=100)\n",
Expand Down Expand Up @@ -175,7 +175,7 @@
"#@markdown Set of labels to classify. If None, auto-populated from the DB.\n",
"target_labels = None #@param\n",
"\n",
"#@markdown Classifier traning params. These should not require tuning.\n",
"#@markdown Classifier traning hyperparams. These should not require tuning.\n",
"learning_rate = 1e-3 #@param\n",
"weak_neg_weight = 0.05 #@param\n",
"l2_mu = 0.000 #@param\n",
Expand All @@ -196,21 +196,22 @@
" rng=np.random.default_rng(seed=5))\n",
"print('Training for target labels : ')\n",
"print(data_manager.get_target_labels())\n",
"params, eval_scores = classifier.train_linear_classifier(\n",
"linear_classifier, eval_scores = classifier.train_linear_classifier(\n",
" data_manager=data_manager,\n",
" learning_rate=learning_rate,\n",
" weak_neg_weight=weak_neg_weight,\n",
" l2_mu=l2_mu,\n",
" num_train_steps=num_steps,\n",
" loss_name=loss_fn_name,\n",
")\n",
"print('\\n' + '-' * 80)\n",
"top1 = eval_scores['top1_acc']\n",
"print(f'top-1 {top1:.3f}')\n",
"rocauc = eval_scores['roc_auc']\n",
"print(f'roc_auc {rocauc:.3f}')\n",
"cmap = eval_scores['cmap']\n",
"print(f'cmap {cmap:.3f}')"
"print(f'cmap {cmap:.3f}')\n",
"\n",
"# Save linear classifier.\n",
"linear_classifier.save(os.path.join(db_path, 'agile_classifier_v2.pt'))"
]
},
{
Expand All @@ -228,8 +229,8 @@
"num_results = 50 #@param\n",
"\n",
"target_label_idx = data_manager.get_target_labels().index(target_label)\n",
"class_query = params['beta'][:, target_label_idx]\n",
"bias = params['beta_bias'][target_label_idx]\n",
"class_query = linear_classifier.beta[:, target_label_idx]\n",
"bias = linear_classifier.beta_bias[target_label_idx]\n",
"\n",
"#@markdown Number of (randomly selected) database entries to search over.\n",
"sample_size = 1_000_000 #@param\n",
Expand Down Expand Up @@ -289,6 +290,26 @@
"print('\\nnew_lbls: ', new_lbls)\n",
"print('\\nprev_lbls: ', prev_lbls)"
]
},
{
"metadata": {
"id": "kBH-2kz4SaS2"
},
"cell_type": "code",
"source": [
"#@title Run inference with trained classifier. { vertical-output: true }\n",
"\n",
"output_csv_filepath = '' #@param {type:'string'}\n",
"logit_threshold = 1.0 #@param\n",
"# Set labels to a tuple of desired labels if you want to run inference on a\n",
"# subset of the labels.\n",
"labels = None #@param\n",
"\n",
"classifier.write_inference_csv(\n",
" linear_classifier, db, output_csv_filepath, logit_threshold, labels=labels)\n"
],
"outputs": [],
"execution_count": null
}
],
"metadata": {
Expand Down
109 changes: 106 additions & 3 deletions hoplite/agile/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,64 @@

"""Functions for training and applying a linear classifier."""

from typing import Any
import base64
import dataclasses
import json
from typing import Any, Sequence

from hoplite.agile import classifier_data
from hoplite.agile import metrics
from hoplite.db import interface as db_interface
from hoplite.taxonomy import namespace
from ml_collections import config_dict
import numpy as np
import tensorflow as tf
import tqdm


@dataclasses.dataclass
class LinearClassifier:
"""Wrapper for linear classifier params and metadata."""

beta: np.ndarray
beta_bias: np.ndarray
classes: tuple[str, ...]
embedding_model_config: Any

def __call__(self, embeddings: np.ndarray):
return np.dot(embeddings, self.beta) + self.beta_bias

def save(self, path: str):
"""Save the classifier to a path."""
cfg = config_dict.ConfigDict()
cfg.model_config = self.embedding_model_config
cfg.classes = self.classes
# Convert numpy arrays to base64 encoded blobs.
beta_bytes = base64.b64encode(np.float32(self.beta).tobytes()).decode(
'ascii'
)
beta_bias_bytes = base64.b64encode(
np.float32(self.beta_bias).tobytes()
).decode('ascii')
cfg.beta = beta_bytes
cfg.beta_bias = beta_bias_bytes
with open(path, 'w') as f:
f.write(cfg.to_json())

@classmethod
def load(cls, path: str):
"""Load a classifier from a path."""
with open(path, 'r') as f:
cfg_json = json.loads(f.read())
cfg = config_dict.ConfigDict(cfg_json)
classes = cfg.classes
beta = np.frombuffer(base64.b64decode(cfg.beta), dtype=np.float32)
beta = np.reshape(beta, (-1, len(classes)))
beta_bias = np.frombuffer(base64.b64decode(cfg.beta_bias), dtype=np.float32)
embedding_model_config = cfg.model_config
return cls(beta, beta_bias, classes, embedding_model_config)


def get_linear_model(embedding_dim: int, num_classes: int) -> tf.keras.Model:
"""Create a simple linear Keras model."""
model = tf.keras.Sequential([
Expand Down Expand Up @@ -105,7 +154,7 @@ def train_linear_classifier(
learning_rate: float,
weak_neg_weight: float,
num_train_steps: int,
):
) -> tuple[LinearClassifier, dict[str, float]]:
"""Train a linear classifier."""
embedding_dim = data_manager.db.embedding_dimension()
num_classes = len(data_manager.get_target_labels())
Expand Down Expand Up @@ -147,4 +196,58 @@ def train_step(y_true, embeddings, is_labeled_mask):
'beta_bias': lin_model.get_weights()[1],
}
eval_scores = eval_classifier(params, data_manager, eval_idxes)
return params, eval_scores

model_config = data_manager.db.get_metadata('model_config')
linear_classifier = LinearClassifier(
beta=params['beta'],
beta_bias=params['beta_bias'],
classes=data_manager.get_target_labels(),
embedding_model_config=model_config,
)
return linear_classifier, eval_scores


def write_inference_csv(
linear_classifier: LinearClassifier,
db: db_interface.HopliteDBInterface,
output_filepath: str,
threshold: float,
labels: Sequence[str] | None = None,
):
"""Write a CSV for all audio windows with logits above a threshold.
Args:
params: The parameters of the linear classifier.
class_list: The class list of labels associated with the classifier.
db: HopliteDBInterface to read embeddings from.
output_filepath: Path to write the CSV to.
threshold: Logits must be above this value to be written.
labels: If provided, only write logits for these labels. If None, write
logits for all labels.
Returns:
None
"""
idxes = db.get_embedding_ids()
if labels is None:
labels = linear_classifier.classes
label_ids = {cl: i for i, cl in enumerate(linear_classifier.classes)}
target_label_ids = np.array([label_ids[l] for l in labels])
logits_fn = lambda emb: linear_classifier(emb)[target_label_ids]
with open(output_filepath, 'w') as f:
f.write('idx,dataset_name,source_id,offset,label,logits\n')
for idx in tqdm.tqdm(idxes):
source = db.get_embedding_source(idx)
emb = db.get_embedding(idx)
logits = logits_fn(emb)
for a in np.argwhere(logits > threshold):
lbl = labels[a]
row = [
idx,
source.dataset_name,
source.source_id,
source.offsets[0],
lbl,
logits[a],
]
f.write(','.join(map(str, row)) + '\n')
74 changes: 74 additions & 0 deletions hoplite/agile/tests/classifier_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# coding=utf-8
# Copyright 2024 The Perch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for linear classifier implementation."""

import os
import tempfile

from hoplite.agile import classifier
from ml_collections import config_dict
import numpy as np

from absl.testing import absltest


class ClassifierTest(absltest.TestCase):

def setUp(self):
super().setUp()
# `self.create_tempdir()` raises an UnparsedFlagAccessError, which is why
# we use `tempdir` directly.
self.tempdir = tempfile.mkdtemp()

def _make_linear_classifier(self, embedding_dim, classes):
np.random.seed(1234)
beta = np.float32(np.random.normal(size=(embedding_dim, len(classes))))
beta_bias = np.float32(np.random.normal(size=(len(classes),)))
embedding_model_config = config_dict.ConfigDict({
'model_name': 'nelson',
})
return classifier.LinearClassifier(
beta, beta_bias, classes, embedding_model_config
)

def test_call_linear_classifier(self):
embedding_dim = 8
classes = ('a', 'b', 'c')
classy = self._make_linear_classifier(embedding_dim, classes)

batch_embeddings = np.random.normal(size=(10, embedding_dim))
predictions = classy(batch_embeddings)
self.assertEqual(predictions.shape, (10, len(classes)))

single_embedding = np.random.normal(size=(embedding_dim,))
predictions = classy(single_embedding)
self.assertEqual(predictions.shape, (len(classes),))

def test_save_load_linear_classifier(self):
embedding_dim = 8
classes = ('a', 'b', 'c')
classy = self._make_linear_classifier(embedding_dim, classes)
classy_path = os.path.join(self.tempdir, 'classifier.json')
classy.save(classy_path)
classy_loaded = classifier.LinearClassifier.load(classy_path)
np.testing.assert_allclose(classy_loaded.beta, classy.beta)
np.testing.assert_allclose(classy_loaded.beta_bias, classy.beta_bias)
self.assertSequenceEqual(classy_loaded.classes, classy.classes)
self.assertEqual(classy_loaded.embedding_model_config.model_name, 'nelson')


if __name__ == '__main__':
absltest.main()

0 comments on commit 7bc1eb5

Please sign in to comment.