Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 584622267
  • Loading branch information
Mohamedelfatih Mohamedkhair authored and copybara-github committed Nov 30, 2023
1 parent 949e305 commit 427b522
Show file tree
Hide file tree
Showing 2 changed files with 334 additions and 0 deletions.
163 changes: 163 additions & 0 deletions src/skai/experiment/zero_shot_eval/vlm_zero_shot_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
"""Library for using Vision-Language for zero-shot evaluation of damage."""

import abc
import string
from typing import Iterator

import jax
import numpy as np


def _batch_array(
array: np.ndarray, batch_size: int = 32
) -> Iterator[np.ndarray]:
"""Batch a numpy array.
Args:
array: The input numpy array to be batched with at least one dimension.
batch_size: The size of the batch for which the array will be batched.
Yields:
A sequence of batched numpy arrays each with batch size of batch_size.
If number of arrays is not dividable then the last array will have a
batch size of number_of_arrays % batch_size.
"""
for i in range(0, array.shape[0], batch_size):
yield array[i : i + batch_size]


def _expand_labels_with_contexts(labels: list[str], contexts: list[str]):
"""Expand labels with contexts.
Args:
labels: List of label descriptions i.e damaged buildings.
contexts: List of contexts that can be used to format the labels, i.e "This
a satelite image of {}"
Returns:
A list of labels that are formatted with the contexts.
"""

def _is_formattable(context: str) -> bool:
if not context:
return False
fieldnames = [val[1] for val in string.Formatter().parse(context)]
if fieldnames[0] is None:
return False
return True

expanded_labels = []
for context in contexts:
if not _is_formattable(context):
expanded_labels.extend(labels)
continue
for label in labels:
expanded_labels.append(context.format(label))

expanded_labels = list(set(expanded_labels))
return expanded_labels


class VLM(abc.ABC):
"""Provide a generic interface for Vision Language models."""

def __init__(self):
self.label_embeddings = None

@abc.abstractmethod
def tokenize(self, texts: list[str]) -> np.ndarray:
raise NotImplementedError()

@abc.abstractmethod
def encode_tokens(self, tokens: np.ndarray) -> np.ndarray:
raise NotImplementedError()

@abc.abstractmethod
def preprocess_images(self, images: np.ndarray) -> np.ndarray:
raise NotImplementedError()

@abc.abstractmethod
def encode_images(self, images: np.ndarray) -> np.ndarray:
raise NotImplementedError()

@abc.abstractmethod
def get_temperature(self) -> float:
raise NotImplementedError()

def set_label_embeddings(
self,
positive_labels: list[str],
negative_labels: list[str],
contexts: list[str],
):
"""Set label embeddings.
Args:
positive_labels: List of label descriptions of the spositive class i.e
damaged buildings.
negative_labels: List of label descriptions of the negative class i.e
damaged buildings.
contexts: List of contexts that can be used to format the labels, i.e
"This a satelite image of {}"
"""

def _get_embedding(labels, contexts):
expanded_labels = _expand_labels_with_contexts(labels, contexts)
embeddings = []
for labels_batch in _batch_array(np.array(expanded_labels)):
tokens = self.tokenize(labels_batch.tolist())
embeddings.append(self.encode_tokens(tokens))
embeddings = np.concatenate(embeddings, axis=0)
embedding = np.mean(embeddings, axis=0)
return embedding

negative_embedding = _get_embedding(negative_labels, contexts)
positive_embedding = _get_embedding(positive_labels, contexts)
self.label_embeddings = np.stack(
[positive_embedding, negative_embedding], axis=0
)

def predict(self, images: np.ndarray) -> np.ndarray:
"""Generate probability scores for a batch of images.
Args:
images: A batch of images.
Returns:
A 2-D array of shape (batch, 2), where the second dimension contains the
positive and negative class probabilities respectively.
Raises:
ValueError if the connected device is not TPU or labels are not set.
"""
if self.label_embeddings is None:
raise ValueError('Label embeddings are not set.')

if jax.lib.xla_bridge.get_backend().platform != 'tpu':
raise ValueError('Not connected to TPU.')

images = self.preprocess_images(images)

batch_size, image_size, _, _ = images.shape
num_images_to_augment = 0
if batch_size % jax.local_device_count() != 0:
num_images_to_augment = jax.local_device_count() - (
batch_size % jax.local_device_count()
)
images_to_augment = np.zeros((num_images_to_augment,) + images.shape[1:])
images = np.concatenate([images, images_to_augment], axis=0)

images = images.reshape(
jax.local_device_count(),
(batch_size + num_images_to_augment) // jax.local_device_count(),
image_size,
image_size,
3,
)
images_embd = self.encode_images(images)
images_embd = images_embd.reshape(batch_size + num_images_to_augment, -1)[
:batch_size, :
]
sims = images_embd @ self.label_embeddings.T * self.get_temperature()
probability_scores = np.array(jax.nn.softmax(sims, axis=-1))
return probability_scores
171 changes: 171 additions & 0 deletions src/skai/experiment/zero_shot_eval/vlm_zero_shot_lib_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""Tests for vlm_zero_shot_lib."""

import dataclasses
import sys
from absl.testing import parameterized
import jax
import mock
import numpy as np
from skai.experiment.zero_shot_eval import vlm_zero_shot_lib
from google3.testing.pybase import googletest

DIVISIBLE_ARRAY = np.random.randint(size=(64, 224, 224, 3), low=0, high=255)
INDIVISIBLE_ARRAY = np.random.randint(size=(73, 224, 224, 3), low=0, high=255)
DIVISIBLE_ONE_DIM_ARRAY = np.random.randint(size=(64,), low=0, high=255)
INDIVISIBLE_ONE_DIM_ARRAY = np.random.randint(size=(73,), low=0, high=255)
PACKAGE = "skai.experiment.zero_shot_eval.vlm_zero_shot_lib."


MAX_INT = sys.maxsize


class _FakeVLM(vlm_zero_shot_lib.VLM):

def __init__(self, embd_size: int):
super().__init__()
self.embd_size = embd_size

def tokenize(self, texts: list[str]) -> np.ndarray:
return np.random.randn(len(texts), self.embd_size)

def encode_tokens(self, tokens: np.ndarray) -> np.ndarray:
return np.random.randn(*tokens.shape[:-1], self.embd_size)

def preprocess_images(self, images: np.ndarray) -> np.ndarray:
return np.random.randn(*images.shape)

def encode_images(self, images: np.ndarray) -> np.ndarray:
return np.random.randn(*images.shape[:-3], self.embd_size)

def get_temperature(self) -> float:
return 10.0


class VlmZeroShotLibTest(parameterized.TestCase):

@parameterized.named_parameters(
("divisible arrays", DIVISIBLE_ARRAY, 32),
("indivisible arrays", INDIVISIBLE_ARRAY, 32),
("divisible array with one dimension", DIVISIBLE_ONE_DIM_ARRAY, 32),
("indivisible array with one dimension", INDIVISIBLE_ONE_DIM_ARRAY, 32),
("number of arrays less than batch size", DIVISIBLE_ARRAY, MAX_INT),
)
def test_equality_batch_array(self, array, batch_size):
batched_arrays = vlm_zero_shot_lib._batch_array(array, batch_size)
stacked_arrays = np.concatenate(list(batched_arrays), axis=0)
np.testing.assert_array_equal(stacked_arrays, array)

@parameterized.named_parameters(
(
"formattable contexts",
["damaged building", "damaged roof"],
["This is a satellite image of {}", "This is top down view of {}"],
[
"This is a satellite image of damaged building",
"This is top down view of damaged building",
"This is a satellite image of damaged roof",
"This is top down view of damaged roof",
],
),
(
"unformattable contexts",
["damaged building", "damaged roof"],
["This is a satellite image of", "This is top down view of"],
["damaged building", "damaged roof"],
),
(
"Partially formattable contexts",
["damaged building", "damaged roof"],
["This is a satellite image of {}", "This is top down view of"],
[
"This is a satellite image of damaged building",
"damaged building",
"This is a satellite image of damaged roof",
"damaged roof",
],
),
)
def test_expand_labels_with_contexts(
self, labels, contexts, expected_labels_with_contexts
):
labels_with_contexts = vlm_zero_shot_lib._expand_labels_with_contexts(
labels, contexts
)
self.assertCountEqual(labels_with_contexts, expected_labels_with_contexts)

@parameterized.named_parameters(
(
"formattable contexts",
["damaged building", "damaged roof"],
["undamaged buidling", "intact houses"],
["This is a satellite image of {}", "This is top down view of {}"],
),
(
"unformattable contexts",
["damaged building", "damaged roof"],
["undamaged buidling", "intact houses"],
["This is a satellite image of", "This is top down view of"],
),
(
"Partially formattable contexts",
["damaged building", "damaged roof"],
["undamaged buidling", "intact houses"],
["This is a satellite image of {}", "This is top down view of"],
),
)
def test_vlm_set_label_embeddings(
self, positive_labels, negative_labels, contexts
):
embd_size = 1024
vlm = _FakeVLM(embd_size)
vlm.set_label_embeddings(positive_labels, negative_labels, contexts)
self.assertEqual(list(vlm.label_embeddings.shape), [2, 1024])

@parameterized.named_parameters(
(
"divisible arrays",
DIVISIBLE_ARRAY,
["damaged building"],
["intact houses"],
["This is top down view of {}"],
),
(
"indivisible arrays",
INDIVISIBLE_ARRAY,
["damaged building"],
["intact houses"],
["This is top down view of {}"],
),
(
"unformattable contexts",
DIVISIBLE_ARRAY,
["damaged building", "damaged roof"],
["undamaged buidling", "intact houses"],
["This is a satellite image of", "This is top down view of"],
)
)
@mock.patch(PACKAGE + "jax")
def test_vlm_predict(
self, images, positive_labels, negative_labels, contexts, mocked_jax
):
@dataclasses.dataclass(frozen=True)
class _MockedClient:
platform: str

mocked_jax.local_device_count.return_value = 8
mocked_jax.nn.softmax.return_value = jax.nn.softmax(
np.random.randn(images.shape[0], 2)
)
mocked_jax.lib.xla_bridge.get_backend.return_value = _MockedClient(
platform="tpu"
)

vlm = _FakeVLM(1024)
vlm.set_label_embeddings(positive_labels, negative_labels, contexts)
scores = vlm.predict(images)

self.assertEqual(scores.shape, (images.shape[0], 2))


if __name__ == "__main__":
googletest.main()

0 comments on commit 427b522

Please sign in to comment.