Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save models and write inference results. Also fix a flipped sign in BCE loss, and fix some small problems in agile modeling notebook. #2

Merged
merged 1 commit into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 61 additions & 24 deletions hoplite/agile/2_agile_modeling_v2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,19 @@
"source": [
"#@title Imports. { vertical-output: true }\n",
"\n",
"import numpy as np\n",
"import os\n",
"\n",
"from matplotlib import pyplot as plt\n",
"import numpy as np\n",
"\n",
"from hoplite.agile import audio_loader\n",
"from hoplite.agile import classifier\n",
"from hoplite.agile import classifier_data\n",
"from hoplite.agile import embedding_display\n",
"from hoplite.agile import source_info\n",
"from hoplite.db import brutalism\n",
"from hoplite.db import score_functions\n",
"from hoplite.db import search_results\n",
"from hoplite.db import sqlite_usearch_impl\n",
"from hoplite.zoo import model_configs\n"
]
Expand All @@ -40,13 +44,18 @@
"\n",
"db = sqlite_usearch_impl.SQLiteUsearchDB.create(db_path)\n",
"db_model_config = db.get_metadata('model_config')\n",
"embed_config = db.get_metadata('embed_config')\n",
"embed_config = db.get_metadata('audio_sources')\n",
"model_class = model_configs.MODEL_CLASS_MAP[db_model_config.model_key]\n",
"embedding_model = model_class.from_config(db_model_config.model_config)\n",
"\n",
"audio_sources = source_info.AudioSources.from_config_dict(\n",
" embed_config.audio_sources)\n",
"if hasattr(embedding_model, 'window_size_s'):\n",
" window_size_s = embedding_model.window_size_s\n",
"else:\n",
" window_size_s = 5.0\n",
"audio_filepath_loader = audio_loader.make_filepath_loader(\n",
" audio_globs=embed_config.audio_globs,\n",
" window_size_s=embedding_model.window_size_s,\n",
" audio_sources=audio_sources,\n",
" window_size_s=window_size_s,\n",
" sample_rate_hz=embedding_model.sample_rate,\n",
")"
]
Expand Down Expand Up @@ -94,28 +103,35 @@
"num_results = 50 #@param\n",
"query_embedding = embedding_model.embed(\n",
" query.get_audio_window()).embeddings[0, 0]\n",
"#@markdown Number of (randomly selected) database entries to search over.\n",
"sample_size = 1_000_000 #@param\n",
"\n",
"#@markdown If checked, search for examples\n",
"#@markdown near a particular target score.\n",
"target_sampling = False #@param {type: 'boolean'}\n",
"\n",
"#@markdown When target sampling, target this score.\n",
"target_score = -1.0 #@param\n",
"\n",
"if not target_sampling:\n",
" target_score = None\n",
"score_fn = score_functions.get_score_fn('dot', target_score=target_score)\n",
"results, all_scores = brutalism.threaded_brute_search(\n",
" db, query_embedding, num_results, score_fn=score_fn,\n",
" sample_size=sample_size)\n",
"\n",
"# TODO(tomdenton): Better histogram when target sampling.\n",
"_ = plt.hist(all_scores, bins=100)\n",
"hit_scores = [r.sort_score for r in results.search_results]\n",
"plt.scatter(hit_scores, np.zeros_like(hit_scores), marker='|',\n",
" color='r', alpha=0.5)\n"
"#@markdown If True, search the full DB. Otherwise, use approximate\n",
"#@markdown nearest-neighbor search.\n",
"exact_search = False #@param {type: 'boolean'}\n",
"\n",
"if exact_search:\n",
" target_score = None\n",
" score_fn = score_functions.get_score_fn('dot', target_score=target_score)\n",
" results, all_scores = brutalism.threaded_brute_search(\n",
" db, query_embedding, num_results, score_fn=score_fn)\n",
" # TODO(tomdenton): Better histogram when target sampling.\n",
" _ = plt.hist(all_scores, bins=100)\n",
" hit_scores = [r.sort_score for r in results.search_results]\n",
" plt.scatter(hit_scores, np.zeros_like(hit_scores), marker='|',\n",
" color='r', alpha=0.5)\n",
"else:\n",
" ann_matches = db.ui.search(query_embedding, count=num_results)\n",
" results = search_results.TopKSearchResults(top_k=num_results)\n",
" for k, d in zip(ann_matches.keys, ann_matches.distances):\n",
" results.update(search_results.SearchResult(k, d))\n"
]
},
{
Expand Down Expand Up @@ -175,7 +191,7 @@
"#@markdown Set of labels to classify. If None, auto-populated from the DB.\n",
"target_labels = None #@param\n",
"\n",
"#@markdown Classifier traning params. These should not require tuning.\n",
"#@markdown Classifier traning hyperparams. These should not require tuning.\n",
"learning_rate = 1e-3 #@param\n",
"weak_neg_weight = 0.05 #@param\n",
"l2_mu = 0.000 #@param\n",
Expand All @@ -196,21 +212,22 @@
" rng=np.random.default_rng(seed=5))\n",
"print('Training for target labels : ')\n",
"print(data_manager.get_target_labels())\n",
"params, eval_scores = classifier.train_linear_classifier(\n",
"linear_classifier, eval_scores = classifier.train_linear_classifier(\n",
" data_manager=data_manager,\n",
" learning_rate=learning_rate,\n",
" weak_neg_weight=weak_neg_weight,\n",
" l2_mu=l2_mu,\n",
" num_train_steps=num_steps,\n",
" loss_name=loss_fn_name,\n",
")\n",
"print('\\n' + '-' * 80)\n",
"top1 = eval_scores['top1_acc']\n",
"print(f'top-1 {top1:.3f}')\n",
"rocauc = eval_scores['roc_auc']\n",
"print(f'roc_auc {rocauc:.3f}')\n",
"cmap = eval_scores['cmap']\n",
"print(f'cmap {cmap:.3f}')"
"print(f'cmap {cmap:.3f}')\n",
"\n",
"# Save linear classifier.\n",
"linear_classifier.save(os.path.join(db_path, 'agile_classifier_v2.pt'))"
]
},
{
Expand All @@ -228,8 +245,8 @@
"num_results = 50 #@param\n",
"\n",
"target_label_idx = data_manager.get_target_labels().index(target_label)\n",
"class_query = params['beta'][:, target_label_idx]\n",
"bias = params['beta_bias'][target_label_idx]\n",
"class_query = linear_classifier.beta[:, target_label_idx]\n",
"bias = linear_classifier.beta_bias[target_label_idx]\n",
"\n",
"#@markdown Number of (randomly selected) database entries to search over.\n",
"sample_size = 1_000_000 #@param\n",
Expand Down Expand Up @@ -289,6 +306,26 @@
"print('\\nnew_lbls: ', new_lbls)\n",
"print('\\nprev_lbls: ', prev_lbls)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kBH-2kz4SaS2"
},
"outputs": [],
"source": [
"#@title Run inference with trained classifier. { vertical-output: true }\n",
"\n",
"output_csv_filepath = '' #@param {type:'string'}\n",
"logit_threshold = 1.0 #@param\n",
"# Set labels to a tuple of desired labels if you want to run inference on a\n",
"# subset of the labels.\n",
"labels = None #@param\n",
"\n",
"classifier.write_inference_csv(\n",
" linear_classifier, db, output_csv_filepath, logit_threshold, labels=labels)\n"
]
}
],
"metadata": {
Expand Down
10 changes: 5 additions & 5 deletions hoplite/agile/audio_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@

from etils import epath
from hoplite import audio_io
from hoplite.agile import source_info
import numpy as np


def make_filepath_loader(
audio_globs: dict[str, tuple[str, str]],
audio_sources: source_info.AudioSources,
sample_rate_hz: int = 32000,
window_size_s: float = 5.0,
dtype: str = 'float32',
Expand All @@ -34,8 +35,7 @@ def make_filepath_loader(
Note that if multiple globs match a given source ID, the first match is used.

Args:
audio_globs: Mapping from dataset name to pairs of `(root directory, file
glob)`. (See `embed.EmbedConfig` for details.)
audio_sources: Embedding audio sources.
sample_rate_hz: Sample rate of the audio.
window_size_s: Window size of the audio.
dtype: Data type of the audio.
Expand All @@ -49,8 +49,8 @@ def make_filepath_loader(

def loader(source_id: str, offset_s: float) -> np.ndarray:
found_path = None
for base_path, _ in audio_globs.values():
path = epath.Path(base_path) / source_id
for audio_source in audio_sources.audio_globs:
path = epath.Path(audio_source.base_path) / source_id
if path.exists():
found_path = path
break
Expand Down
139 changes: 134 additions & 5 deletions hoplite/agile/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,64 @@

"""Functions for training and applying a linear classifier."""

from typing import Any
import base64
import dataclasses
import json
from typing import Any, Sequence

from hoplite.agile import classifier_data
from hoplite.agile import metrics
from hoplite.db import interface as db_interface
from hoplite.taxonomy import namespace
from ml_collections import config_dict
import numpy as np
import tensorflow as tf
import tqdm


@dataclasses.dataclass
class LinearClassifier:
"""Wrapper for linear classifier params and metadata."""

beta: np.ndarray
beta_bias: np.ndarray
classes: tuple[str, ...]
embedding_model_config: Any

def __call__(self, embeddings: np.ndarray):
return np.dot(embeddings, self.beta) + self.beta_bias

def save(self, path: str):
"""Save the classifier to a path."""
cfg = config_dict.ConfigDict()
cfg.model_config = self.embedding_model_config
cfg.classes = self.classes
# Convert numpy arrays to base64 encoded blobs.
beta_bytes = base64.b64encode(np.float32(self.beta).tobytes()).decode(
'ascii'
)
beta_bias_bytes = base64.b64encode(
np.float32(self.beta_bias).tobytes()
).decode('ascii')
cfg.beta = beta_bytes
cfg.beta_bias = beta_bias_bytes
with open(path, 'w') as f:
f.write(cfg.to_json())

@classmethod
def load(cls, path: str):
"""Load a classifier from a path."""
with open(path, 'r') as f:
cfg_json = json.loads(f.read())
cfg = config_dict.ConfigDict(cfg_json)
classes = cfg.classes
beta = np.frombuffer(base64.b64decode(cfg.beta), dtype=np.float32)
beta = np.reshape(beta, (-1, len(classes)))
beta_bias = np.frombuffer(base64.b64decode(cfg.beta_bias), dtype=np.float32)
embedding_model_config = cfg.model_config
return cls(beta, beta_bias, classes, embedding_model_config)


def get_linear_model(embedding_dim: int, num_classes: int) -> tf.keras.Model:
"""Create a simple linear Keras model."""
model = tf.keras.Sequential([
Expand All @@ -43,12 +92,28 @@ def bce_loss(
y_true = tf.cast(y_true, dtype=logits.dtype)
log_p = tf.math.log_sigmoid(logits)
log_not_p = tf.math.log_sigmoid(-logits)
raw_bce = -y_true * log_p + (1.0 - y_true) * log_not_p
# optax sigmoid_binary_cross_entropy:
# -labels * log_p - (1.0 - labels) * log_not_p
raw_bce = -y_true * log_p - (1.0 - y_true) * log_not_p
is_labeled_mask = tf.cast(is_labeled_mask, dtype=logits.dtype)
weights = (1.0 - is_labeled_mask) * weak_neg_weight + is_labeled_mask
return tf.reduce_mean(raw_bce * weights)


def hinge_loss(
y_true: tf.Tensor,
logits: tf.Tensor,
is_labeled_mask: tf.Tensor,
weak_neg_weight: float,
) -> tf.Tensor:
"""Weighted SVM hinge loss."""
# Convert multihot to +/- 1 labels.
y_true = 2 * y_true - 1
weights = (1.0 - is_labeled_mask) * weak_neg_weight + is_labeled_mask
raw_hinge_loss = tf.maximum(0, 1 - y_true * logits)
return tf.reduce_mean(raw_hinge_loss * weights)


def infer(params, embeddings: np.ndarray):
"""Apply the model to embeddings."""
return np.dot(embeddings, params['beta']) + params['beta_bias']
Expand Down Expand Up @@ -105,19 +170,26 @@ def train_linear_classifier(
learning_rate: float,
weak_neg_weight: float,
num_train_steps: int,
):
loss: str = 'bce',
) -> tuple[LinearClassifier, dict[str, float]]:
"""Train a linear classifier."""
embedding_dim = data_manager.db.embedding_dimension()
num_classes = len(data_manager.get_target_labels())
lin_model = get_linear_model(embedding_dim, num_classes)
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
lin_model.compile(optimizer=optimizer, loss='binary_crossentropy')
if loss == 'hinge':
loss_fn = hinge_loss
elif loss == 'bce':
loss_fn = bce_loss
else:
raise ValueError(f'Unknown loss: {loss}')

@tf.function
def train_step(y_true, embeddings, is_labeled_mask):
with tf.GradientTape() as tape:
logits = lin_model(embeddings, training=True)
loss = bce_loss(y_true, logits, is_labeled_mask, weak_neg_weight)
loss = loss_fn(y_true, logits, is_labeled_mask, weak_neg_weight)
loss = tf.reduce_mean(loss)
grads = tape.gradient(loss, lin_model.trainable_variables)
optimizer.apply_gradients(zip(grads, lin_model.trainable_variables))
Expand Down Expand Up @@ -147,4 +219,61 @@ def train_step(y_true, embeddings, is_labeled_mask):
'beta_bias': lin_model.get_weights()[1],
}
eval_scores = eval_classifier(params, data_manager, eval_idxes)
return params, eval_scores

model_config = data_manager.db.get_metadata('model_config')
linear_classifier = LinearClassifier(
beta=params['beta'],
beta_bias=params['beta_bias'],
classes=data_manager.get_target_labels(),
embedding_model_config=model_config,
)
return linear_classifier, eval_scores


def write_inference_csv(
linear_classifier: LinearClassifier,
db: db_interface.HopliteDBInterface,
output_filepath: str,
threshold: float,
labels: Sequence[str] | None = None,
):
"""Write a CSV for all audio windows with logits above a threshold.

Args:
params: The parameters of the linear classifier.
class_list: The class list of labels associated with the classifier.
db: HopliteDBInterface to read embeddings from.
output_filepath: Path to write the CSV to.
threshold: Logits must be above this value to be written.
labels: If provided, only write logits for these labels. If None, write
logits for all labels.

Returns:
None
"""
idxes = db.get_embedding_ids()
if labels is None:
labels = linear_classifier.classes
label_ids = {cl: i for i, cl in enumerate(linear_classifier.classes)}
target_label_ids = np.array([label_ids[l] for l in labels])
logits_fn = lambda emb: linear_classifier(emb)[target_label_ids]
detection_count = 0
with open(output_filepath, 'w') as f:
f.write('idx,dataset_name,source_id,offset,label,logits\n')
for idx in tqdm.tqdm(idxes):
source = db.get_embedding_source(idx)
emb = db.get_embedding(idx)
logits = logits_fn(emb)
for a in np.argwhere(logits > threshold):
lbl = labels[a[0]]
row = [
idx,
source.dataset_name,
source.source_id,
source.offsets[0],
lbl,
logits[a],
]
f.write(','.join(map(str, row)) + '\n')
detection_count += 1
print(f'Wrote {detection_count} detections to {output_filepath}')
Loading
Loading