-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
2 changed files
with
334 additions
and
0 deletions.
There are no files selected for viewing
163 changes: 163 additions & 0 deletions
163
src/skai/experiment/zero_shot_eval/vlm_zero_shot_lib.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,163 @@ | ||
"""Library for using Vision-Language for zero-shot evaluation of damage.""" | ||
|
||
import abc | ||
import string | ||
from typing import Iterator | ||
|
||
import jax | ||
import numpy as np | ||
|
||
|
||
def _batch_array( | ||
array: np.ndarray, batch_size: int = 32 | ||
) -> Iterator[np.ndarray]: | ||
"""Batch a numpy array. | ||
Args: | ||
array: The input numpy array to be batched with at least one dimension. | ||
batch_size: The size of the batch for which the array will be batched. | ||
Yields: | ||
A sequence of batched numpy arrays each with batch size of batch_size. | ||
If number of arrays is not dividable then the last array will have a | ||
batch size of number_of_arrays % batch_size. | ||
""" | ||
for i in range(0, array.shape[0], batch_size): | ||
yield array[i : i + batch_size] | ||
|
||
|
||
def _expand_labels_with_contexts(labels: list[str], contexts: list[str]): | ||
"""Expand labels with contexts. | ||
Args: | ||
labels: List of label descriptions i.e damaged buildings. | ||
contexts: List of contexts that can be used to format the labels, i.e "This | ||
a satelite image of {}" | ||
Returns: | ||
A list of labels that are formatted with the contexts. | ||
""" | ||
|
||
def _is_formattable(context: str) -> bool: | ||
if not context: | ||
return False | ||
fieldnames = [val[1] for val in string.Formatter().parse(context)] | ||
if fieldnames[0] is None: | ||
return False | ||
return True | ||
|
||
expanded_labels = [] | ||
for context in contexts: | ||
if not _is_formattable(context): | ||
expanded_labels.extend(labels) | ||
continue | ||
for label in labels: | ||
expanded_labels.append(context.format(label)) | ||
|
||
expanded_labels = list(set(expanded_labels)) | ||
return expanded_labels | ||
|
||
|
||
class VLM(abc.ABC): | ||
"""Provide a generic interface for Vision Language models.""" | ||
|
||
def __init__(self): | ||
self.label_embeddings = None | ||
|
||
@abc.abstractmethod | ||
def tokenize(self, texts: list[str]) -> np.ndarray: | ||
raise NotImplementedError() | ||
|
||
@abc.abstractmethod | ||
def encode_tokens(self, tokens: np.ndarray) -> np.ndarray: | ||
raise NotImplementedError() | ||
|
||
@abc.abstractmethod | ||
def preprocess_images(self, images: np.ndarray) -> np.ndarray: | ||
raise NotImplementedError() | ||
|
||
@abc.abstractmethod | ||
def encode_images(self, images: np.ndarray) -> np.ndarray: | ||
raise NotImplementedError() | ||
|
||
@abc.abstractmethod | ||
def get_temperature(self) -> float: | ||
raise NotImplementedError() | ||
|
||
def set_label_embeddings( | ||
self, | ||
positive_labels: list[str], | ||
negative_labels: list[str], | ||
contexts: list[str], | ||
): | ||
"""Set label embeddings. | ||
Args: | ||
positive_labels: List of label descriptions of the spositive class i.e | ||
damaged buildings. | ||
negative_labels: List of label descriptions of the negative class i.e | ||
damaged buildings. | ||
contexts: List of contexts that can be used to format the labels, i.e | ||
"This a satelite image of {}" | ||
""" | ||
|
||
def _get_embedding(labels, contexts): | ||
expanded_labels = _expand_labels_with_contexts(labels, contexts) | ||
embeddings = [] | ||
for labels_batch in _batch_array(np.array(expanded_labels)): | ||
tokens = self.tokenize(labels_batch.tolist()) | ||
embeddings.append(self.encode_tokens(tokens)) | ||
embeddings = np.concatenate(embeddings, axis=0) | ||
embedding = np.mean(embeddings, axis=0) | ||
return embedding | ||
|
||
negative_embedding = _get_embedding(negative_labels, contexts) | ||
positive_embedding = _get_embedding(positive_labels, contexts) | ||
self.label_embeddings = np.stack( | ||
[positive_embedding, negative_embedding], axis=0 | ||
) | ||
|
||
def predict(self, images: np.ndarray) -> np.ndarray: | ||
"""Generate probability scores for a batch of images. | ||
Args: | ||
images: A batch of images. | ||
Returns: | ||
A 2-D array of shape (batch, 2), where the second dimension contains the | ||
positive and negative class probabilities respectively. | ||
Raises: | ||
ValueError if the connected device is not TPU or labels are not set. | ||
""" | ||
if self.label_embeddings is None: | ||
raise ValueError('Label embeddings are not set.') | ||
|
||
if jax.lib.xla_bridge.get_backend().platform != 'tpu': | ||
raise ValueError('Not connected to TPU.') | ||
|
||
images = self.preprocess_images(images) | ||
|
||
batch_size, image_size, _, _ = images.shape | ||
num_images_to_augment = 0 | ||
if batch_size % jax.local_device_count() != 0: | ||
num_images_to_augment = jax.local_device_count() - ( | ||
batch_size % jax.local_device_count() | ||
) | ||
images_to_augment = np.zeros((num_images_to_augment,) + images.shape[1:]) | ||
images = np.concatenate([images, images_to_augment], axis=0) | ||
|
||
images = images.reshape( | ||
jax.local_device_count(), | ||
(batch_size + num_images_to_augment) // jax.local_device_count(), | ||
image_size, | ||
image_size, | ||
3, | ||
) | ||
images_embd = self.encode_images(images) | ||
images_embd = images_embd.reshape(batch_size + num_images_to_augment, -1)[ | ||
:batch_size, : | ||
] | ||
sims = images_embd @ self.label_embeddings.T * self.get_temperature() | ||
probability_scores = np.array(jax.nn.softmax(sims, axis=-1)) | ||
return probability_scores |
171 changes: 171 additions & 0 deletions
171
src/skai/experiment/zero_shot_eval/vlm_zero_shot_lib_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
"""Tests for vlm_zero_shot_lib.""" | ||
|
||
import dataclasses | ||
import sys | ||
from absl.testing import parameterized | ||
import jax | ||
import mock | ||
import numpy as np | ||
from skai.experiment.zero_shot_eval import vlm_zero_shot_lib | ||
from google3.testing.pybase import googletest | ||
|
||
DIVISIBLE_ARRAY = np.random.randint(size=(64, 224, 224, 3), low=0, high=255) | ||
INDIVISIBLE_ARRAY = np.random.randint(size=(73, 224, 224, 3), low=0, high=255) | ||
DIVISIBLE_ONE_DIM_ARRAY = np.random.randint(size=(64,), low=0, high=255) | ||
INDIVISIBLE_ONE_DIM_ARRAY = np.random.randint(size=(73,), low=0, high=255) | ||
PACKAGE = "skai.experiment.zero_shot_eval.vlm_zero_shot_lib." | ||
|
||
|
||
MAX_INT = sys.maxsize | ||
|
||
|
||
class _FakeVLM(vlm_zero_shot_lib.VLM): | ||
|
||
def __init__(self, embd_size: int): | ||
super().__init__() | ||
self.embd_size = embd_size | ||
|
||
def tokenize(self, texts: list[str]) -> np.ndarray: | ||
return np.random.randn(len(texts), self.embd_size) | ||
|
||
def encode_tokens(self, tokens: np.ndarray) -> np.ndarray: | ||
return np.random.randn(*tokens.shape[:-1], self.embd_size) | ||
|
||
def preprocess_images(self, images: np.ndarray) -> np.ndarray: | ||
return np.random.randn(*images.shape) | ||
|
||
def encode_images(self, images: np.ndarray) -> np.ndarray: | ||
return np.random.randn(*images.shape[:-3], self.embd_size) | ||
|
||
def get_temperature(self) -> float: | ||
return 10.0 | ||
|
||
|
||
class VlmZeroShotLibTest(parameterized.TestCase): | ||
|
||
@parameterized.named_parameters( | ||
("divisible arrays", DIVISIBLE_ARRAY, 32), | ||
("indivisible arrays", INDIVISIBLE_ARRAY, 32), | ||
("divisible array with one dimension", DIVISIBLE_ONE_DIM_ARRAY, 32), | ||
("indivisible array with one dimension", INDIVISIBLE_ONE_DIM_ARRAY, 32), | ||
("number of arrays less than batch size", DIVISIBLE_ARRAY, MAX_INT), | ||
) | ||
def test_equality_batch_array(self, array, batch_size): | ||
batched_arrays = vlm_zero_shot_lib._batch_array(array, batch_size) | ||
stacked_arrays = np.concatenate(list(batched_arrays), axis=0) | ||
np.testing.assert_array_equal(stacked_arrays, array) | ||
|
||
@parameterized.named_parameters( | ||
( | ||
"formattable contexts", | ||
["damaged building", "damaged roof"], | ||
["This is a satellite image of {}", "This is top down view of {}"], | ||
[ | ||
"This is a satellite image of damaged building", | ||
"This is top down view of damaged building", | ||
"This is a satellite image of damaged roof", | ||
"This is top down view of damaged roof", | ||
], | ||
), | ||
( | ||
"unformattable contexts", | ||
["damaged building", "damaged roof"], | ||
["This is a satellite image of", "This is top down view of"], | ||
["damaged building", "damaged roof"], | ||
), | ||
( | ||
"Partially formattable contexts", | ||
["damaged building", "damaged roof"], | ||
["This is a satellite image of {}", "This is top down view of"], | ||
[ | ||
"This is a satellite image of damaged building", | ||
"damaged building", | ||
"This is a satellite image of damaged roof", | ||
"damaged roof", | ||
], | ||
), | ||
) | ||
def test_expand_labels_with_contexts( | ||
self, labels, contexts, expected_labels_with_contexts | ||
): | ||
labels_with_contexts = vlm_zero_shot_lib._expand_labels_with_contexts( | ||
labels, contexts | ||
) | ||
self.assertCountEqual(labels_with_contexts, expected_labels_with_contexts) | ||
|
||
@parameterized.named_parameters( | ||
( | ||
"formattable contexts", | ||
["damaged building", "damaged roof"], | ||
["undamaged buidling", "intact houses"], | ||
["This is a satellite image of {}", "This is top down view of {}"], | ||
), | ||
( | ||
"unformattable contexts", | ||
["damaged building", "damaged roof"], | ||
["undamaged buidling", "intact houses"], | ||
["This is a satellite image of", "This is top down view of"], | ||
), | ||
( | ||
"Partially formattable contexts", | ||
["damaged building", "damaged roof"], | ||
["undamaged buidling", "intact houses"], | ||
["This is a satellite image of {}", "This is top down view of"], | ||
), | ||
) | ||
def test_vlm_set_label_embeddings( | ||
self, positive_labels, negative_labels, contexts | ||
): | ||
embd_size = 1024 | ||
vlm = _FakeVLM(embd_size) | ||
vlm.set_label_embeddings(positive_labels, negative_labels, contexts) | ||
self.assertEqual(list(vlm.label_embeddings.shape), [2, 1024]) | ||
|
||
@parameterized.named_parameters( | ||
( | ||
"divisible arrays", | ||
DIVISIBLE_ARRAY, | ||
["damaged building"], | ||
["intact houses"], | ||
["This is top down view of {}"], | ||
), | ||
( | ||
"indivisible arrays", | ||
INDIVISIBLE_ARRAY, | ||
["damaged building"], | ||
["intact houses"], | ||
["This is top down view of {}"], | ||
), | ||
( | ||
"unformattable contexts", | ||
DIVISIBLE_ARRAY, | ||
["damaged building", "damaged roof"], | ||
["undamaged buidling", "intact houses"], | ||
["This is a satellite image of", "This is top down view of"], | ||
) | ||
) | ||
@mock.patch(PACKAGE + "jax") | ||
def test_vlm_predict( | ||
self, images, positive_labels, negative_labels, contexts, mocked_jax | ||
): | ||
@dataclasses.dataclass(frozen=True) | ||
class _MockedClient: | ||
platform: str | ||
|
||
mocked_jax.local_device_count.return_value = 8 | ||
mocked_jax.nn.softmax.return_value = jax.nn.softmax( | ||
np.random.randn(images.shape[0], 2) | ||
) | ||
mocked_jax.lib.xla_bridge.get_backend.return_value = _MockedClient( | ||
platform="tpu" | ||
) | ||
|
||
vlm = _FakeVLM(1024) | ||
vlm.set_label_embeddings(positive_labels, negative_labels, contexts) | ||
scores = vlm.predict(images) | ||
|
||
self.assertEqual(scores.shape, (images.shape[0], 2)) | ||
|
||
|
||
if __name__ == "__main__": | ||
googletest.main() |