Skip to content

Commit

Permalink
Save models and write inference results.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 701159583
  • Loading branch information
sdenton4 authored and copybara-github committed Dec 7, 2024
1 parent 30d46d9 commit 25ed26a
Show file tree
Hide file tree
Showing 6 changed files with 285 additions and 41 deletions.
89 changes: 63 additions & 26 deletions hoplite/agile/2_agile_modeling_v2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,19 @@
"source": [
"#@title Imports. { vertical-output: true }\n",
"\n",
"import numpy as np\n",
"import os\n",
"\n",
"from matplotlib import pyplot as plt\n",
"import numpy as np\n",
"\n",
"from hoplite.agile import audio_loader\n",
"from hoplite.agile import classifier\n",
"from hoplite.agile import classifier_data\n",
"from hoplite.agile import embedding_display\n",
"from hoplite.agile import source_info\n",
"from hoplite.db import brutalism\n",
"from hoplite.db import score_functions\n",
"from hoplite.db import search_results\n",
"from hoplite.db import sqlite_usearch_impl\n",
"from hoplite.zoo import model_configs\n"
]
Expand All @@ -34,19 +38,24 @@
"#@title Load model and connect to database. { vertical-output: true }\n",
"\n",
"#@markdown Location of database containing audio embeddings.\n",
"db_path = '' #@param {type:'string'}\n",
"db_path = '/usr/local/google/home/tomdenton/terrorbyte/anuraset' #@param {type:'string'}\n",
"#@markdown Identifier (eg, name) to attach to labels produced during validation.\n",
"annotator_id = 'linnaeus' #@param {type:'string'}\n",
"annotator_id = 'tmd' #@param {type:'string'}\n",
"\n",
"db = sqlite_usearch_impl.SQLiteUsearchDB.create(db_path)\n",
"db_model_config = db.get_metadata('model_config')\n",
"embed_config = db.get_metadata('embed_config')\n",
"embed_config = db.get_metadata('audio_sources')\n",
"model_class = model_configs.MODEL_CLASS_MAP[db_model_config.model_key]\n",
"embedding_model = model_class.from_config(db_model_config.model_config)\n",
"\n",
"audio_sources = source_info.AudioSources.from_config_dict(\n",
" embed_config.audio_sources)\n",
"if hasattr(embedding_model, 'window_size_s'):\n",
" window_size_s = embedding_model.window_size_s\n",
"else:\n",
" window_size_s = 5.0\n",
"audio_filepath_loader = audio_loader.make_filepath_loader(\n",
" audio_globs=embed_config.audio_globs,\n",
" window_size_s=embedding_model.window_size_s,\n",
" audio_sources=audio_sources,\n",
" window_size_s=window_size_s,\n",
" sample_rate_hz=embedding_model.sample_rate,\n",
")"
]
Expand Down Expand Up @@ -94,28 +103,35 @@
"num_results = 50 #@param\n",
"query_embedding = embedding_model.embed(\n",
" query.get_audio_window()).embeddings[0, 0]\n",
"#@markdown Number of (randomly selected) database entries to search over.\n",
"sample_size = 1_000_000 #@param\n",
"\n",
"#@markdown If checked, search for examples\n",
"#@markdown near a particular target score.\n",
"target_sampling = False #@param {type: 'boolean'}\n",
"\n",
"#@markdown When target sampling, target this score.\n",
"target_score = -1.0 #@param\n",
"\n",
"if not target_sampling:\n",
" target_score = None\n",
"score_fn = score_functions.get_score_fn('dot', target_score=target_score)\n",
"results, all_scores = brutalism.threaded_brute_search(\n",
" db, query_embedding, num_results, score_fn=score_fn,\n",
" sample_size=sample_size)\n",
"\n",
"# TODO(tomdenton): Better histogram when target sampling.\n",
"_ = plt.hist(all_scores, bins=100)\n",
"hit_scores = [r.sort_score for r in results.search_results]\n",
"plt.scatter(hit_scores, np.zeros_like(hit_scores), marker='|',\n",
" color='r', alpha=0.5)\n"
"#@markdown If True, search the full DB. Otherwise, use approximate\n",
"#@markdown nearest-neighbor search.\n",
"exact_search = False #@param {type: 'boolean'}\n",
"\n",
"if exact_search:\n",
" target_score = None\n",
" score_fn = score_functions.get_score_fn('dot', target_score=target_score)\n",
" results, all_scores = brutalism.threaded_brute_search(\n",
" db, query_embedding, num_results, score_fn=score_fn)\n",
" # TODO(tomdenton): Better histogram when target sampling.\n",
" _ = plt.hist(all_scores, bins=100)\n",
" hit_scores = [r.sort_score for r in results.search_results]\n",
" plt.scatter(hit_scores, np.zeros_like(hit_scores), marker='|',\n",
" color='r', alpha=0.5)\n",
"else:\n",
" ann_matches = db.ui.search(query_embedding, count=num_results)\n",
" results = search_results.TopKSearchResults(top_k=num_results)\n",
" for k, d in zip(ann_matches.keys, ann_matches.distances):\n",
" results.update(search_results.SearchResult(k, d))\n"
]
},
{
Expand Down Expand Up @@ -175,7 +191,7 @@
"#@markdown Set of labels to classify. If None, auto-populated from the DB.\n",
"target_labels = None #@param\n",
"\n",
"#@markdown Classifier traning params. These should not require tuning.\n",
"#@markdown Classifier traning hyperparams. These should not require tuning.\n",
"learning_rate = 1e-3 #@param\n",
"weak_neg_weight = 0.05 #@param\n",
"l2_mu = 0.000 #@param\n",
Expand All @@ -196,21 +212,22 @@
" rng=np.random.default_rng(seed=5))\n",
"print('Training for target labels : ')\n",
"print(data_manager.get_target_labels())\n",
"params, eval_scores = classifier.train_linear_classifier(\n",
"linear_classifier, eval_scores = classifier.train_linear_classifier(\n",
" data_manager=data_manager,\n",
" learning_rate=learning_rate,\n",
" weak_neg_weight=weak_neg_weight,\n",
" l2_mu=l2_mu,\n",
" num_train_steps=num_steps,\n",
" loss_name=loss_fn_name,\n",
")\n",
"print('\\n' + '-' * 80)\n",
"top1 = eval_scores['top1_acc']\n",
"print(f'top-1 {top1:.3f}')\n",
"rocauc = eval_scores['roc_auc']\n",
"print(f'roc_auc {rocauc:.3f}')\n",
"cmap = eval_scores['cmap']\n",
"print(f'cmap {cmap:.3f}')"
"print(f'cmap {cmap:.3f}')\n",
"\n",
"# Save linear classifier.\n",
"linear_classifier.save(os.path.join(db_path, 'agile_classifier_v2.pt'))"
]
},
{
Expand All @@ -228,8 +245,8 @@
"num_results = 50 #@param\n",
"\n",
"target_label_idx = data_manager.get_target_labels().index(target_label)\n",
"class_query = params['beta'][:, target_label_idx]\n",
"bias = params['beta_bias'][target_label_idx]\n",
"class_query = linear_classifier.beta[:, target_label_idx]\n",
"bias = linear_classifier.beta_bias[target_label_idx]\n",
"\n",
"#@markdown Number of (randomly selected) database entries to search over.\n",
"sample_size = 1_000_000 #@param\n",
Expand Down Expand Up @@ -289,6 +306,26 @@
"print('\\nnew_lbls: ', new_lbls)\n",
"print('\\nprev_lbls: ', prev_lbls)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kBH-2kz4SaS2"
},
"outputs": [],
"source": [
"#@title Run inference with trained classifier. { vertical-output: true }\n",
"\n",
"output_csv_filepath = '' #@param {type:'string'}\n",
"logit_threshold = 1.0 #@param\n",
"# Set labels to a tuple of desired labels if you want to run inference on a\n",
"# subset of the labels.\n",
"labels = None #@param\n",
"\n",
"classifier.write_inference_csv(\n",
" linear_classifier, db, output_csv_filepath, logit_threshold, labels=labels)\n"
]
}
],
"metadata": {
Expand Down
10 changes: 5 additions & 5 deletions hoplite/agile/audio_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@

from etils import epath
from hoplite import audio_io
from hoplite.agile import source_info
import numpy as np


def make_filepath_loader(
audio_globs: dict[str, tuple[str, str]],
audio_sources: source_info.AudioSources,
sample_rate_hz: int = 32000,
window_size_s: float = 5.0,
dtype: str = 'float32',
Expand All @@ -34,8 +35,7 @@ def make_filepath_loader(
Note that if multiple globs match a given source ID, the first match is used.
Args:
audio_globs: Mapping from dataset name to pairs of `(root directory, file
glob)`. (See `embed.EmbedConfig` for details.)
audio_sources: Embedding audio sources.
sample_rate_hz: Sample rate of the audio.
window_size_s: Window size of the audio.
dtype: Data type of the audio.
Expand All @@ -49,8 +49,8 @@ def make_filepath_loader(

def loader(source_id: str, offset_s: float) -> np.ndarray:
found_path = None
for base_path, _ in audio_globs.values():
path = epath.Path(base_path) / source_id
for audio_source in audio_sources.audio_globs:
path = epath.Path(audio_source.base_path) / source_id
if path.exists():
found_path = path
break
Expand Down
139 changes: 134 additions & 5 deletions hoplite/agile/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,64 @@

"""Functions for training and applying a linear classifier."""

from typing import Any
import base64
import dataclasses
import json
from typing import Any, Sequence

from hoplite.agile import classifier_data
from hoplite.agile import metrics
from hoplite.db import interface as db_interface
from hoplite.taxonomy import namespace
from ml_collections import config_dict
import numpy as np
import tensorflow as tf
import tqdm


@dataclasses.dataclass
class LinearClassifier:
"""Wrapper for linear classifier params and metadata."""

beta: np.ndarray
beta_bias: np.ndarray
classes: tuple[str, ...]
embedding_model_config: Any

def __call__(self, embeddings: np.ndarray):
return np.dot(embeddings, self.beta) + self.beta_bias

def save(self, path: str):
"""Save the classifier to a path."""
cfg = config_dict.ConfigDict()
cfg.model_config = self.embedding_model_config
cfg.classes = self.classes
# Convert numpy arrays to base64 encoded blobs.
beta_bytes = base64.b64encode(np.float32(self.beta).tobytes()).decode(
'ascii'
)
beta_bias_bytes = base64.b64encode(
np.float32(self.beta_bias).tobytes()
).decode('ascii')
cfg.beta = beta_bytes
cfg.beta_bias = beta_bias_bytes
with open(path, 'w') as f:
f.write(cfg.to_json())

@classmethod
def load(cls, path: str):
"""Load a classifier from a path."""
with open(path, 'r') as f:
cfg_json = json.loads(f.read())
cfg = config_dict.ConfigDict(cfg_json)
classes = cfg.classes
beta = np.frombuffer(base64.b64decode(cfg.beta), dtype=np.float32)
beta = np.reshape(beta, (-1, len(classes)))
beta_bias = np.frombuffer(base64.b64decode(cfg.beta_bias), dtype=np.float32)
embedding_model_config = cfg.model_config
return cls(beta, beta_bias, classes, embedding_model_config)


def get_linear_model(embedding_dim: int, num_classes: int) -> tf.keras.Model:
"""Create a simple linear Keras model."""
model = tf.keras.Sequential([
Expand All @@ -43,12 +92,28 @@ def bce_loss(
y_true = tf.cast(y_true, dtype=logits.dtype)
log_p = tf.math.log_sigmoid(logits)
log_not_p = tf.math.log_sigmoid(-logits)
raw_bce = -y_true * log_p + (1.0 - y_true) * log_not_p
# optax sigmoid_binary_cross_entropy:
# -labels * log_p - (1.0 - labels) * log_not_p
raw_bce = -y_true * log_p - (1.0 - y_true) * log_not_p
is_labeled_mask = tf.cast(is_labeled_mask, dtype=logits.dtype)
weights = (1.0 - is_labeled_mask) * weak_neg_weight + is_labeled_mask
return tf.reduce_mean(raw_bce * weights)


def hinge_loss(
y_true: tf.Tensor,
logits: tf.Tensor,
is_labeled_mask: tf.Tensor,
weak_neg_weight: float,
) -> tf.Tensor:
"""Weighted SVM hinge loss."""
# Convert multihot to +/- 1 labels.
y_true = 2 * y_true - 1
weights = (1.0 - is_labeled_mask) * weak_neg_weight + is_labeled_mask
raw_hinge_loss = tf.maximum(0, 1 - y_true * logits)
return tf.reduce_mean(raw_hinge_loss * weights)


def infer(params, embeddings: np.ndarray):
"""Apply the model to embeddings."""
return np.dot(embeddings, params['beta']) + params['beta_bias']
Expand Down Expand Up @@ -105,19 +170,26 @@ def train_linear_classifier(
learning_rate: float,
weak_neg_weight: float,
num_train_steps: int,
):
loss: str = 'bce',
) -> tuple[LinearClassifier, dict[str, float]]:
"""Train a linear classifier."""
embedding_dim = data_manager.db.embedding_dimension()
num_classes = len(data_manager.get_target_labels())
lin_model = get_linear_model(embedding_dim, num_classes)
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
lin_model.compile(optimizer=optimizer, loss='binary_crossentropy')
if loss == 'hinge':
loss_fn = hinge_loss
elif loss == 'bce':
loss_fn = bce_loss
else:
raise ValueError(f'Unknown loss: {loss}')

@tf.function
def train_step(y_true, embeddings, is_labeled_mask):
with tf.GradientTape() as tape:
logits = lin_model(embeddings, training=True)
loss = bce_loss(y_true, logits, is_labeled_mask, weak_neg_weight)
loss = loss_fn(y_true, logits, is_labeled_mask, weak_neg_weight)
loss = tf.reduce_mean(loss)
grads = tape.gradient(loss, lin_model.trainable_variables)
optimizer.apply_gradients(zip(grads, lin_model.trainable_variables))
Expand Down Expand Up @@ -147,4 +219,61 @@ def train_step(y_true, embeddings, is_labeled_mask):
'beta_bias': lin_model.get_weights()[1],
}
eval_scores = eval_classifier(params, data_manager, eval_idxes)
return params, eval_scores

model_config = data_manager.db.get_metadata('model_config')
linear_classifier = LinearClassifier(
beta=params['beta'],
beta_bias=params['beta_bias'],
classes=data_manager.get_target_labels(),
embedding_model_config=model_config,
)
return linear_classifier, eval_scores


def write_inference_csv(
linear_classifier: LinearClassifier,
db: db_interface.HopliteDBInterface,
output_filepath: str,
threshold: float,
labels: Sequence[str] | None = None,
):
"""Write a CSV for all audio windows with logits above a threshold.
Args:
params: The parameters of the linear classifier.
class_list: The class list of labels associated with the classifier.
db: HopliteDBInterface to read embeddings from.
output_filepath: Path to write the CSV to.
threshold: Logits must be above this value to be written.
labels: If provided, only write logits for these labels. If None, write
logits for all labels.
Returns:
None
"""
idxes = db.get_embedding_ids()
if labels is None:
labels = linear_classifier.classes
label_ids = {cl: i for i, cl in enumerate(linear_classifier.classes)}
target_label_ids = np.array([label_ids[l] for l in labels])
logits_fn = lambda emb: linear_classifier(emb)[target_label_ids]
detection_count = 0
with open(output_filepath, 'w') as f:
f.write('idx,dataset_name,source_id,offset,label,logits\n')
for idx in tqdm.tqdm(idxes):
source = db.get_embedding_source(idx)
emb = db.get_embedding(idx)
logits = logits_fn(emb)
for a in np.argwhere(logits > threshold):
lbl = labels[a[0]]
row = [
idx,
source.dataset_name,
source.source_id,
source.offsets[0],
lbl,
logits[a],
]
f.write(','.join(map(str, row)) + '\n')
detection_count += 1
print(f'Wrote {detection_count} detections to {output_filepath}')
Loading

0 comments on commit 25ed26a

Please sign in to comment.