From 427b52219f02eaf8cb6b6e9a5440035485288048 Mon Sep 17 00:00:00 2001 From: Mohamedelfatih Mohamedkhair Date: Wed, 22 Nov 2023 07:25:11 -0800 Subject: [PATCH] No public description PiperOrigin-RevId: 584622267 --- .../zero_shot_eval/vlm_zero_shot_lib.py | 163 +++++++++++++++++ .../zero_shot_eval/vlm_zero_shot_lib_test.py | 171 ++++++++++++++++++ 2 files changed, 334 insertions(+) create mode 100644 src/skai/experiment/zero_shot_eval/vlm_zero_shot_lib.py create mode 100644 src/skai/experiment/zero_shot_eval/vlm_zero_shot_lib_test.py diff --git a/src/skai/experiment/zero_shot_eval/vlm_zero_shot_lib.py b/src/skai/experiment/zero_shot_eval/vlm_zero_shot_lib.py new file mode 100644 index 00000000..5d2d975a --- /dev/null +++ b/src/skai/experiment/zero_shot_eval/vlm_zero_shot_lib.py @@ -0,0 +1,163 @@ +"""Library for using Vision-Language for zero-shot evaluation of damage.""" + +import abc +import string +from typing import Iterator + +import jax +import numpy as np + + +def _batch_array( + array: np.ndarray, batch_size: int = 32 +) -> Iterator[np.ndarray]: + """Batch a numpy array. + + Args: + array: The input numpy array to be batched with at least one dimension. + batch_size: The size of the batch for which the array will be batched. + + Yields: + A sequence of batched numpy arrays each with batch size of batch_size. + If number of arrays is not dividable then the last array will have a + batch size of number_of_arrays % batch_size. + """ + for i in range(0, array.shape[0], batch_size): + yield array[i : i + batch_size] + + +def _expand_labels_with_contexts(labels: list[str], contexts: list[str]): + """Expand labels with contexts. + + Args: + labels: List of label descriptions i.e damaged buildings. + contexts: List of contexts that can be used to format the labels, i.e "This + a satelite image of {}" + + Returns: + A list of labels that are formatted with the contexts. + """ + + def _is_formattable(context: str) -> bool: + if not context: + return False + fieldnames = [val[1] for val in string.Formatter().parse(context)] + if fieldnames[0] is None: + return False + return True + + expanded_labels = [] + for context in contexts: + if not _is_formattable(context): + expanded_labels.extend(labels) + continue + for label in labels: + expanded_labels.append(context.format(label)) + + expanded_labels = list(set(expanded_labels)) + return expanded_labels + + +class VLM(abc.ABC): + """Provide a generic interface for Vision Language models.""" + + def __init__(self): + self.label_embeddings = None + + @abc.abstractmethod + def tokenize(self, texts: list[str]) -> np.ndarray: + raise NotImplementedError() + + @abc.abstractmethod + def encode_tokens(self, tokens: np.ndarray) -> np.ndarray: + raise NotImplementedError() + + @abc.abstractmethod + def preprocess_images(self, images: np.ndarray) -> np.ndarray: + raise NotImplementedError() + + @abc.abstractmethod + def encode_images(self, images: np.ndarray) -> np.ndarray: + raise NotImplementedError() + + @abc.abstractmethod + def get_temperature(self) -> float: + raise NotImplementedError() + + def set_label_embeddings( + self, + positive_labels: list[str], + negative_labels: list[str], + contexts: list[str], + ): + """Set label embeddings. + + Args: + positive_labels: List of label descriptions of the spositive class i.e + damaged buildings. + negative_labels: List of label descriptions of the negative class i.e + damaged buildings. + contexts: List of contexts that can be used to format the labels, i.e + "This a satelite image of {}" + """ + + def _get_embedding(labels, contexts): + expanded_labels = _expand_labels_with_contexts(labels, contexts) + embeddings = [] + for labels_batch in _batch_array(np.array(expanded_labels)): + tokens = self.tokenize(labels_batch.tolist()) + embeddings.append(self.encode_tokens(tokens)) + embeddings = np.concatenate(embeddings, axis=0) + embedding = np.mean(embeddings, axis=0) + return embedding + + negative_embedding = _get_embedding(negative_labels, contexts) + positive_embedding = _get_embedding(positive_labels, contexts) + self.label_embeddings = np.stack( + [positive_embedding, negative_embedding], axis=0 + ) + + def predict(self, images: np.ndarray) -> np.ndarray: + """Generate probability scores for a batch of images. + + Args: + images: A batch of images. + + Returns: + A 2-D array of shape (batch, 2), where the second dimension contains the + positive and negative class probabilities respectively. + + Raises: + ValueError if the connected device is not TPU or labels are not set. + """ + if self.label_embeddings is None: + raise ValueError('Label embeddings are not set.') + + if jax.lib.xla_bridge.get_backend().platform != 'tpu': + raise ValueError('Not connected to TPU.') + + images = self.preprocess_images(images) + + batch_size, image_size, _, _ = images.shape + num_images_to_augment = 0 + if batch_size % jax.local_device_count() != 0: + num_images_to_augment = jax.local_device_count() - ( + batch_size % jax.local_device_count() + ) + images_to_augment = np.zeros((num_images_to_augment,) + images.shape[1:]) + images = np.concatenate([images, images_to_augment], axis=0) + + images = images.reshape( + jax.local_device_count(), + (batch_size + num_images_to_augment) // jax.local_device_count(), + image_size, + image_size, + 3, + ) + images_embd = self.encode_images(images) + images_embd = images_embd.reshape(batch_size + num_images_to_augment, -1)[ + :batch_size, : + ] + sims = images_embd @ self.label_embeddings.T * self.get_temperature() + probability_scores = np.array(jax.nn.softmax(sims, axis=-1)) + return probability_scores diff --git a/src/skai/experiment/zero_shot_eval/vlm_zero_shot_lib_test.py b/src/skai/experiment/zero_shot_eval/vlm_zero_shot_lib_test.py new file mode 100644 index 00000000..2e72ea3e --- /dev/null +++ b/src/skai/experiment/zero_shot_eval/vlm_zero_shot_lib_test.py @@ -0,0 +1,171 @@ +"""Tests for vlm_zero_shot_lib.""" + +import dataclasses +import sys +from absl.testing import parameterized +import jax +import mock +import numpy as np +from skai.experiment.zero_shot_eval import vlm_zero_shot_lib +from google3.testing.pybase import googletest + +DIVISIBLE_ARRAY = np.random.randint(size=(64, 224, 224, 3), low=0, high=255) +INDIVISIBLE_ARRAY = np.random.randint(size=(73, 224, 224, 3), low=0, high=255) +DIVISIBLE_ONE_DIM_ARRAY = np.random.randint(size=(64,), low=0, high=255) +INDIVISIBLE_ONE_DIM_ARRAY = np.random.randint(size=(73,), low=0, high=255) +PACKAGE = "skai.experiment.zero_shot_eval.vlm_zero_shot_lib." + + +MAX_INT = sys.maxsize + + +class _FakeVLM(vlm_zero_shot_lib.VLM): + + def __init__(self, embd_size: int): + super().__init__() + self.embd_size = embd_size + + def tokenize(self, texts: list[str]) -> np.ndarray: + return np.random.randn(len(texts), self.embd_size) + + def encode_tokens(self, tokens: np.ndarray) -> np.ndarray: + return np.random.randn(*tokens.shape[:-1], self.embd_size) + + def preprocess_images(self, images: np.ndarray) -> np.ndarray: + return np.random.randn(*images.shape) + + def encode_images(self, images: np.ndarray) -> np.ndarray: + return np.random.randn(*images.shape[:-3], self.embd_size) + + def get_temperature(self) -> float: + return 10.0 + + +class VlmZeroShotLibTest(parameterized.TestCase): + + @parameterized.named_parameters( + ("divisible arrays", DIVISIBLE_ARRAY, 32), + ("indivisible arrays", INDIVISIBLE_ARRAY, 32), + ("divisible array with one dimension", DIVISIBLE_ONE_DIM_ARRAY, 32), + ("indivisible array with one dimension", INDIVISIBLE_ONE_DIM_ARRAY, 32), + ("number of arrays less than batch size", DIVISIBLE_ARRAY, MAX_INT), + ) + def test_equality_batch_array(self, array, batch_size): + batched_arrays = vlm_zero_shot_lib._batch_array(array, batch_size) + stacked_arrays = np.concatenate(list(batched_arrays), axis=0) + np.testing.assert_array_equal(stacked_arrays, array) + + @parameterized.named_parameters( + ( + "formattable contexts", + ["damaged building", "damaged roof"], + ["This is a satellite image of {}", "This is top down view of {}"], + [ + "This is a satellite image of damaged building", + "This is top down view of damaged building", + "This is a satellite image of damaged roof", + "This is top down view of damaged roof", + ], + ), + ( + "unformattable contexts", + ["damaged building", "damaged roof"], + ["This is a satellite image of", "This is top down view of"], + ["damaged building", "damaged roof"], + ), + ( + "Partially formattable contexts", + ["damaged building", "damaged roof"], + ["This is a satellite image of {}", "This is top down view of"], + [ + "This is a satellite image of damaged building", + "damaged building", + "This is a satellite image of damaged roof", + "damaged roof", + ], + ), + ) + def test_expand_labels_with_contexts( + self, labels, contexts, expected_labels_with_contexts + ): + labels_with_contexts = vlm_zero_shot_lib._expand_labels_with_contexts( + labels, contexts + ) + self.assertCountEqual(labels_with_contexts, expected_labels_with_contexts) + + @parameterized.named_parameters( + ( + "formattable contexts", + ["damaged building", "damaged roof"], + ["undamaged buidling", "intact houses"], + ["This is a satellite image of {}", "This is top down view of {}"], + ), + ( + "unformattable contexts", + ["damaged building", "damaged roof"], + ["undamaged buidling", "intact houses"], + ["This is a satellite image of", "This is top down view of"], + ), + ( + "Partially formattable contexts", + ["damaged building", "damaged roof"], + ["undamaged buidling", "intact houses"], + ["This is a satellite image of {}", "This is top down view of"], + ), + ) + def test_vlm_set_label_embeddings( + self, positive_labels, negative_labels, contexts + ): + embd_size = 1024 + vlm = _FakeVLM(embd_size) + vlm.set_label_embeddings(positive_labels, negative_labels, contexts) + self.assertEqual(list(vlm.label_embeddings.shape), [2, 1024]) + + @parameterized.named_parameters( + ( + "divisible arrays", + DIVISIBLE_ARRAY, + ["damaged building"], + ["intact houses"], + ["This is top down view of {}"], + ), + ( + "indivisible arrays", + INDIVISIBLE_ARRAY, + ["damaged building"], + ["intact houses"], + ["This is top down view of {}"], + ), + ( + "unformattable contexts", + DIVISIBLE_ARRAY, + ["damaged building", "damaged roof"], + ["undamaged buidling", "intact houses"], + ["This is a satellite image of", "This is top down view of"], + ) + ) + @mock.patch(PACKAGE + "jax") + def test_vlm_predict( + self, images, positive_labels, negative_labels, contexts, mocked_jax + ): + @dataclasses.dataclass(frozen=True) + class _MockedClient: + platform: str + + mocked_jax.local_device_count.return_value = 8 + mocked_jax.nn.softmax.return_value = jax.nn.softmax( + np.random.randn(images.shape[0], 2) + ) + mocked_jax.lib.xla_bridge.get_backend.return_value = _MockedClient( + platform="tpu" + ) + + vlm = _FakeVLM(1024) + vlm.set_label_embeddings(positive_labels, negative_labels, contexts) + scores = vlm.predict(images) + + self.assertEqual(scores.shape, (images.shape[0], 2)) + + +if __name__ == "__main__": + googletest.main()