Skip to content

Commit

Permalink
Save models and write inference results.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 701159583
  • Loading branch information
sdenton4 committed Nov 29, 2024
1 parent b9eaec7 commit 641ba21
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci_pip.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
pip install absl-py
pip install requests
pip install tensorflow-cpu
pip install git+https://github.com/googlestaging/hoplite.git
pip install git+https://github.com/google-research/hoplite.git
- name: Test db with unittest
run: python -m unittest discover -s hoplite/db/tests -p "*test.py"
- name: Test taxonomy with unittest
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Perch Hoplite
# Hoplite Vector Database

Hoplite is a system for storing large volumes of embeddings from machine
perception models. We focus on combining vector search with active learning
Expand Down
102 changes: 99 additions & 3 deletions hoplite/agile/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,57 @@

"""Functions for training and applying a linear classifier."""

from typing import Any
import base64
import dataclasses
from typing import Any, Sequence

from hoplite.agile import classifier_data
from hoplite.agile import metrics
from hoplite.db import interface as db_interface
from hoplite.taxonomy import namespace
from ml_collections import config_dict
import numpy as np
import tensorflow as tf
import tqdm


@dataclasses.dataclass
class LinearClassifier:
"""Wrapper for linear classifier params and metadata."""

beta: np.ndarray
beta_bias: np.ndarray
classes: tuple[str, ...]
embedding_model_config: Any

def __call__(self, embeddings: np.ndarray):
return np.dot(self.beta, embeddings) + self.beta_bias

def save(self, path: str):
"""Save the classifier to a path."""
cfg = config_dict.ConfigDict()
cfg.model_config = self.embedding_model_config
cfg.classes = self.classes
# Convert numpy arrays to base64 encoded blobs.
beta_bytes = base64.b64encode(self.beta.tobytes())
beta_bias_bytes = base64.b64encode(self.beta_bias.tobytes())
cfg.beta = beta_bytes
cfg.beta_bias = beta_bias_bytes
with open(path, 'wb') as f:
f.write(cfg.to_json())

def load(self, path: str):
"""Load a classifier from a path."""
with open(path, 'rb') as f:
cfg = config_dict.ConfigDict.from_json(f.read())
self.beta = np.frombuffer(base64.b64decode(cfg.beta), dtype=np.float32)
self.beta_bias = np.frombuffer(
base64.b64decode(cfg.beta_bias), dtype=np.float32
)
self.classes = cfg.classes
self.embedding_model_config = cfg.model_config


def get_linear_model(embedding_dim: int, num_classes: int) -> tf.keras.Model:
"""Create a simple linear Keras model."""
model = tf.keras.Sequential([
Expand Down Expand Up @@ -105,7 +147,7 @@ def train_linear_classifier(
learning_rate: float,
weak_neg_weight: float,
num_train_steps: int,
):
) -> tuple[LinearClassifier, dict[str, float]]:
"""Train a linear classifier."""
embedding_dim = data_manager.db.embedding_dimension()
num_classes = len(data_manager.get_target_labels())
Expand Down Expand Up @@ -147,4 +189,58 @@ def train_step(y_true, embeddings, is_labeled_mask):
'beta_bias': lin_model.get_weights()[1],
}
eval_scores = eval_classifier(params, data_manager, eval_idxes)
return params, eval_scores

model_config = data_manager.db.get_metadata('model_config')
linear_classifier = LinearClassifier(
beta=params['beta'],
beta_bias=params['beta_bias'],
classes=data_manager.get_target_labels(),
embedding_model_config=model_config,
)
return linear_classifier, eval_scores


def write_inference_csv(
linear_classifier: LinearClassifier,
db: db_interface.HopliteDBInterface,
output_filepath: str,
threshold: float,
labels: Sequence[str] | None = None,
):
"""Write a CSV for all audio windows with logits above a threshold.
Args:
params: The parameters of the linear classifier.
class_list: The class list of labels associated with the classifier.
db: HopliteDBInterface to read embeddings from.
output_filepath: Path to write the CSV to.
threshold: Logits must be above this value to be written.
labels: If provided, only write logits for these labels. If None, write
logits for all labels.
Returns:
None
"""
idxes = db.get_embedding_ids()
if labels is None:
labels = linear_classifier.classes
label_ids = {cl: i for i, cl in enumerate(linear_classifier.classes)}
target_label_ids = np.array([label_ids[l] for l in labels])
logits_fn = lambda emb: linear_classifier(emb)[target_label_ids]
with open(output_filepath, 'w') as f:
f.write('idx,dataset_name,source_id,offset,label,logits\n')
for idx in tqdm.tqdm(idxes):
source = db.get_embedding_source(idx)
emb = db.get_embedding(idx)
logits = logits_fn(emb)
for a in np.argwhere(logits > threshold):
lbl = labels[a]
row = [
idx,
source.dataset_name,
source.source_id,
source.offsets[0],
lbl,
logits[a],
]
f.write(','.join(map(str, row)) + '\n')

0 comments on commit 641ba21

Please sign in to comment.