Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

No public description #177

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 163 additions & 0 deletions src/skai/experiment/zero_shot_eval/vlm_zero_shot_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
"""Library for using Vision-Language for zero-shot evaluation of damage."""

import abc
import string
from typing import Iterator

import jax
import numpy as np


def _batch_array(
array: np.ndarray, batch_size: int = 32
) -> Iterator[np.ndarray]:
"""Batch a numpy array.

Args:
array: The input numpy array to be batched with at least one dimension.
batch_size: The size of the batch for which the array will be batched.

Yields:
A sequence of batched numpy arrays each with batch size of batch_size.
If number of arrays is not dividable then the last array will have a
batch size of number_of_arrays % batch_size.
"""
for i in range(0, array.shape[0], batch_size):
yield array[i : i + batch_size]


def _expand_labels_with_contexts(labels: list[str], contexts: list[str]):
"""Expand labels with contexts.

Args:
labels: List of label descriptions i.e damaged buildings.
contexts: List of contexts that can be used to format the labels, i.e "This
a satelite image of {}"

Returns:
A list of labels that are formatted with the contexts.
"""

def _is_formattable(context: str) -> bool:
if not context:
return False
fieldnames = [val[1] for val in string.Formatter().parse(context)]
if fieldnames[0] is None:
return False
return True

expanded_labels = []
for context in contexts:
if not _is_formattable(context):
expanded_labels.extend(labels)
continue
for label in labels:
expanded_labels.append(context.format(label))

expanded_labels = list(set(expanded_labels))
return expanded_labels


class VLM(abc.ABC):
"""Provide a generic interface for Vision Language models."""

def __init__(self):
self.label_embeddings = None

@abc.abstractmethod
def tokenize(self, texts: list[str]) -> np.ndarray:
raise NotImplementedError()

@abc.abstractmethod
def encode_tokens(self, tokens: np.ndarray) -> np.ndarray:
raise NotImplementedError()

@abc.abstractmethod
def preprocess_images(self, images: np.ndarray) -> np.ndarray:
raise NotImplementedError()

@abc.abstractmethod
def encode_images(self, images: np.ndarray) -> np.ndarray:
raise NotImplementedError()

@abc.abstractmethod
def get_temperature(self) -> float:
raise NotImplementedError()

def set_label_embeddings(
self,
positive_labels: list[str],
negative_labels: list[str],
contexts: list[str],
):
"""Set label embeddings.

Args:
positive_labels: List of label descriptions of the spositive class i.e
damaged buildings.
negative_labels: List of label descriptions of the negative class i.e
damaged buildings.
contexts: List of contexts that can be used to format the labels, i.e
"This a satelite image of {}"
"""

def _get_embedding(labels, contexts):
expanded_labels = _expand_labels_with_contexts(labels, contexts)
embeddings = []
for labels_batch in _batch_array(np.array(expanded_labels)):
tokens = self.tokenize(labels_batch.tolist())
embeddings.append(self.encode_tokens(tokens))
embeddings = np.concatenate(embeddings, axis=0)
embedding = np.mean(embeddings, axis=0)
return embedding

negative_embedding = _get_embedding(negative_labels, contexts)
positive_embedding = _get_embedding(positive_labels, contexts)
self.label_embeddings = np.stack(
[positive_embedding, negative_embedding], axis=0
)

def predict(self, images: np.ndarray) -> np.ndarray:
"""Generate probability scores for a batch of images.

Args:
images: A batch of images.

Returns:
A 2-D array of shape (batch, 2), where the second dimension contains the
positive and negative class probabilities respectively.

Raises:
ValueError if the connected device is not TPU or labels are not set.
"""
if self.label_embeddings is None:
raise ValueError('Label embeddings are not set.')

if jax.lib.xla_bridge.get_backend().platform != 'tpu':
raise ValueError('Not connected to TPU.')

images = self.preprocess_images(images)

batch_size, image_size, _, _ = images.shape
num_images_to_augment = 0
if batch_size % jax.local_device_count() != 0:
num_images_to_augment = jax.local_device_count() - (
batch_size % jax.local_device_count()
)
images_to_augment = np.zeros((num_images_to_augment,) + images.shape[1:])
images = np.concatenate([images, images_to_augment], axis=0)

images = images.reshape(
jax.local_device_count(),
(batch_size + num_images_to_augment) // jax.local_device_count(),
image_size,
image_size,
3,
)
images_embd = self.encode_images(images)
images_embd = images_embd.reshape(batch_size + num_images_to_augment, -1)[
:batch_size, :
]
sims = images_embd @ self.label_embeddings.T * self.get_temperature()
probability_scores = np.array(jax.nn.softmax(sims, axis=-1))
return probability_scores
171 changes: 171 additions & 0 deletions src/skai/experiment/zero_shot_eval/vlm_zero_shot_lib_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
"""Tests for vlm_zero_shot_lib."""

import dataclasses
import sys
from absl.testing import parameterized
import jax
import mock
import numpy as np
from skai.experiment.zero_shot_eval import vlm_zero_shot_lib
from google3.testing.pybase import googletest

DIVISIBLE_ARRAY = np.random.randint(size=(64, 224, 224, 3), low=0, high=255)
INDIVISIBLE_ARRAY = np.random.randint(size=(73, 224, 224, 3), low=0, high=255)
DIVISIBLE_ONE_DIM_ARRAY = np.random.randint(size=(64,), low=0, high=255)
INDIVISIBLE_ONE_DIM_ARRAY = np.random.randint(size=(73,), low=0, high=255)
PACKAGE = "skai.experiment.zero_shot_eval.vlm_zero_shot_lib."


MAX_INT = sys.maxsize


class _FakeVLM(vlm_zero_shot_lib.VLM):

def __init__(self, embd_size: int):
super().__init__()
self.embd_size = embd_size

def tokenize(self, texts: list[str]) -> np.ndarray:
return np.random.randn(len(texts), self.embd_size)

def encode_tokens(self, tokens: np.ndarray) -> np.ndarray:
return np.random.randn(*tokens.shape[:-1], self.embd_size)

def preprocess_images(self, images: np.ndarray) -> np.ndarray:
return np.random.randn(*images.shape)

def encode_images(self, images: np.ndarray) -> np.ndarray:
return np.random.randn(*images.shape[:-3], self.embd_size)

def get_temperature(self) -> float:
return 10.0


class VlmZeroShotLibTest(parameterized.TestCase):

@parameterized.named_parameters(
("divisible arrays", DIVISIBLE_ARRAY, 32),
("indivisible arrays", INDIVISIBLE_ARRAY, 32),
("divisible array with one dimension", DIVISIBLE_ONE_DIM_ARRAY, 32),
("indivisible array with one dimension", INDIVISIBLE_ONE_DIM_ARRAY, 32),
("number of arrays less than batch size", DIVISIBLE_ARRAY, MAX_INT),
)
def test_equality_batch_array(self, array, batch_size):
batched_arrays = vlm_zero_shot_lib._batch_array(array, batch_size)
stacked_arrays = np.concatenate(list(batched_arrays), axis=0)
np.testing.assert_array_equal(stacked_arrays, array)

@parameterized.named_parameters(
(
"formattable contexts",
["damaged building", "damaged roof"],
["This is a satellite image of {}", "This is top down view of {}"],
[
"This is a satellite image of damaged building",
"This is top down view of damaged building",
"This is a satellite image of damaged roof",
"This is top down view of damaged roof",
],
),
(
"unformattable contexts",
["damaged building", "damaged roof"],
["This is a satellite image of", "This is top down view of"],
["damaged building", "damaged roof"],
),
(
"Partially formattable contexts",
["damaged building", "damaged roof"],
["This is a satellite image of {}", "This is top down view of"],
[
"This is a satellite image of damaged building",
"damaged building",
"This is a satellite image of damaged roof",
"damaged roof",
],
),
)
def test_expand_labels_with_contexts(
self, labels, contexts, expected_labels_with_contexts
):
labels_with_contexts = vlm_zero_shot_lib._expand_labels_with_contexts(
labels, contexts
)
self.assertCountEqual(labels_with_contexts, expected_labels_with_contexts)

@parameterized.named_parameters(
(
"formattable contexts",
["damaged building", "damaged roof"],
["undamaged buidling", "intact houses"],
["This is a satellite image of {}", "This is top down view of {}"],
),
(
"unformattable contexts",
["damaged building", "damaged roof"],
["undamaged buidling", "intact houses"],
["This is a satellite image of", "This is top down view of"],
),
(
"Partially formattable contexts",
["damaged building", "damaged roof"],
["undamaged buidling", "intact houses"],
["This is a satellite image of {}", "This is top down view of"],
),
)
def test_vlm_set_label_embeddings(
self, positive_labels, negative_labels, contexts
):
embd_size = 1024
vlm = _FakeVLM(embd_size)
vlm.set_label_embeddings(positive_labels, negative_labels, contexts)
self.assertEqual(list(vlm.label_embeddings.shape), [2, 1024])

@parameterized.named_parameters(
(
"divisible arrays",
DIVISIBLE_ARRAY,
["damaged building"],
["intact houses"],
["This is top down view of {}"],
),
(
"indivisible arrays",
INDIVISIBLE_ARRAY,
["damaged building"],
["intact houses"],
["This is top down view of {}"],
),
(
"unformattable contexts",
DIVISIBLE_ARRAY,
["damaged building", "damaged roof"],
["undamaged buidling", "intact houses"],
["This is a satellite image of", "This is top down view of"],
)
)
@mock.patch(PACKAGE + "jax")
def test_vlm_predict(
self, images, positive_labels, negative_labels, contexts, mocked_jax
):
@dataclasses.dataclass(frozen=True)
class _MockedClient:
platform: str

mocked_jax.local_device_count.return_value = 8
mocked_jax.nn.softmax.return_value = jax.nn.softmax(
np.random.randn(images.shape[0], 2)
)
mocked_jax.lib.xla_bridge.get_backend.return_value = _MockedClient(
platform="tpu"
)

vlm = _FakeVLM(1024)
vlm.set_label_embeddings(positive_labels, negative_labels, contexts)
scores = vlm.predict(images)

self.assertEqual(scores.shape, (images.shape[0], 2))


if __name__ == "__main__":
googletest.main()
Loading