diff --git a/README.md b/README.md index 54c7b36..24e43b0 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Perch Hoplite +# Hoplite Vector Database Hoplite is a system for storing large volumes of embeddings from machine perception models. We focus on combining vector search with active learning diff --git a/hoplite/agile/classifier.py b/hoplite/agile/classifier.py index f1f4edc..17000a7 100644 --- a/hoplite/agile/classifier.py +++ b/hoplite/agile/classifier.py @@ -15,15 +15,57 @@ """Functions for training and applying a linear classifier.""" -from typing import Any +import base64 +import dataclasses +from typing import Any, Sequence from hoplite.agile import classifier_data from hoplite.agile import metrics +from hoplite.db import interface as db_interface +from hoplite.taxonomy import namespace +from ml_collections import config_dict import numpy as np import tensorflow as tf import tqdm +@dataclasses.dataclass +class LinearClassifier: + """Wrapper for linear classifier params and metadata.""" + + beta: np.ndarray + beta_bias: np.ndarray + classes: tuple[str, ...] + embedding_model_config: Any + + def __call__(self, embeddings: np.ndarray): + return np.dot(self.beta, embeddings) + self.beta_bias + + def save(self, path: str): + """Save the classifier to a path.""" + cfg = config_dict.ConfigDict() + cfg.model_config = self.embedding_model_config + cfg.classes = self.classes + # Convert numpy arrays to base64 encoded blobs. + beta_bytes = base64.b64encode(self.beta.tobytes()) + beta_bias_bytes = base64.b64encode(self.beta_bias.tobytes()) + cfg.beta = beta_bytes + cfg.beta_bias = beta_bias_bytes + with open(path, 'wb') as f: + f.write(cfg.to_json()) + + def load(self, path: str): + """Load a classifier from a path.""" + with open(path, 'rb') as f: + cfg = config_dict.ConfigDict.from_json(f.read()) + self.beta = np.frombuffer(base64.b64decode(cfg.beta), dtype=np.float32) + self.beta_bias = np.frombuffer( + base64.b64decode(cfg.beta_bias), dtype=np.float32 + ) + self.classes = cfg.classes + self.embedding_model_config = cfg.model_config + + def get_linear_model(embedding_dim: int, num_classes: int) -> tf.keras.Model: """Create a simple linear Keras model.""" model = tf.keras.Sequential([ @@ -105,7 +147,7 @@ def train_linear_classifier( learning_rate: float, weak_neg_weight: float, num_train_steps: int, -): +) -> tuple[LinearClassifier, dict[str, float]]: """Train a linear classifier.""" embedding_dim = data_manager.db.embedding_dimension() num_classes = len(data_manager.get_target_labels()) @@ -147,4 +189,58 @@ def train_step(y_true, embeddings, is_labeled_mask): 'beta_bias': lin_model.get_weights()[1], } eval_scores = eval_classifier(params, data_manager, eval_idxes) - return params, eval_scores + + model_config = data_manager.db.get_metadata('model_config') + linear_classifier = LinearClassifier( + beta=params['beta'], + beta_bias=params['beta_bias'], + classes=data_manager.get_target_labels(), + embedding_model_config=model_config, + ) + return linear_classifier, eval_scores + + +def write_inference_csv( + linear_classifier: LinearClassifier, + db: db_interface.HopliteDBInterface, + output_filepath: str, + threshold: float, + labels: Sequence[str] | None = None, +): + """Write a CSV for all audio windows with logits above a threshold. + + Args: + params: The parameters of the linear classifier. + class_list: The class list of labels associated with the classifier. + db: HopliteDBInterface to read embeddings from. + output_filepath: Path to write the CSV to. + threshold: Logits must be above this value to be written. + labels: If provided, only write logits for these labels. If None, write + logits for all labels. + + Returns: + None + """ + idxes = db.get_embedding_ids() + if labels is None: + labels = linear_classifier.classes + label_ids = {cl: i for i, cl in enumerate(linear_classifier.classes)} + target_label_ids = np.array([label_ids[l] for l in labels]) + logits_fn = lambda emb: linear_classifier(emb)[target_label_ids] + with open(output_filepath, 'w') as f: + f.write('idx,dataset_name,source_id,offset,label,logits\n') + for idx in tqdm.tqdm(idxes): + source = db.get_embedding_source(idx) + emb = db.get_embedding(idx) + logits = logits_fn(emb) + for a in np.argwhere(logits > threshold): + lbl = labels[a] + row = [ + idx, + source.dataset_name, + source.source_id, + source.offsets[0], + lbl, + logits[a], + ] + f.write(','.join(map(str, row)) + '\n')