Skip to content

Commit

Permalink
Add support for segmentation mask in BaseImageAugmentationLayer (#748)
Browse files Browse the repository at this point in the history
* Add segmap to base image augmentation layer

* Address Scott review comments

* More review comments
  • Loading branch information
ianstenbit authored Aug 30, 2022
1 parent 349eadc commit b380c08
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 11 deletions.
48 changes: 41 additions & 7 deletions keras_cv/layers/preprocessing/base_image_augmentation_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
BOUNDING_BOXES = "bounding_boxes"
KEYPOINTS = "keypoints"
RAGGED_BOUNDING_BOXES = "ragged_bounding_boxes"
SEGMENTATION_MASK = "segmentation_mask"
IS_DICT = "is_dict"
USE_TARGETS = "use_targets"

Expand Down Expand Up @@ -143,7 +144,7 @@ def augment_image(self, image, transformation, **kwargs):
`layer.call()`.
transformation: The transformation object produced by
`get_random_transformation`. Used to coordinate the randomness
between image, label and bounding box.
between image, label, bounding box, keypoints, and segmentation mask.
Returns:
output 3D tensor, which will be forward to `layer.call()`.
Expand All @@ -157,7 +158,7 @@ def augment_label(self, label, transformation, **kwargs):
label: 1D label to the layer. Forwarded from `layer.call()`.
transformation: The transformation object produced by
`get_random_transformation`. Used to coordinate the randomness
between image, label and bounding box.
between image, label, bounding box, keypoints, and segmentation mask.
Returns:
output 1D tensor, which will be forward to `layer.call()`.
Expand All @@ -171,7 +172,7 @@ def augment_target(self, target, transformation, **kwargs):
target: 1D label to the layer. Forwarded from `layer.call()`.
transformation: The transformation object produced by
`get_random_transformation`. Used to coordinate the randomness
between image, label and bounding box.
between image, label, bounding box, keypoints, and segmentation mask.
Returns:
output 1D tensor, which will be forward to `layer.call()`.
Expand All @@ -188,7 +189,7 @@ def augment_bounding_boxes(self, bounding_boxes, transformation, **kwargs):
`call()`.
transformation: The transformation object produced by
`get_random_transformation`. Used to coordinate the randomness
between image, label and bounding box.
between image, label, bounding box, keypoints, and segmentation mask.
Returns:
output 2D tensor, which will be forward to `layer.call()`.
Expand All @@ -203,15 +204,36 @@ def augment_keypoints(self, keypoints, transformation, **kwargs):
`layer.call()`.
transformation: The transformation object produced by
`get_random_transformation`. Used to coordinate the randomness
between image, label and bounding box.
between image, label, bounding box, keypoints, and segmentation mask.
Returns:
output 2D tensor, which will be forward to `layer.call()`.
"""
raise NotImplementedError()

def augment_segmentation_mask(self, segmentation_mask, transformation, **kwargs):
"""Augment a single image's segmentation mask during training.
Args:
segmentation_mask: 3D segmentation mask input tensor to the layer.
This should generally have the shape [H, W, 1], or in some cases [H, W, C] for multilabeled data.
Forwarded from `layer.call()`.
transformation: The transformation object produced by
`get_random_transformation`. Used to coordinate the randomness
between image, label, bounding box, keypoints, and segmentation mask.
Returns:
output 3D tensor containing the augmented segmentation mask, which will be forward to `layer.call()`.
"""
raise NotImplementedError()

def get_random_transformation(
self, image=None, label=None, bounding_boxes=None, keypoints=None
self,
image=None,
label=None,
bounding_boxes=None,
keypoints=None,
segmentation_mask=None,
):
"""Produce random transformation config for one single input.
Expand All @@ -222,6 +244,7 @@ def get_random_transformation(
image: 3D image tensor from inputs.
label: optional 1D label tensor from inputs.
bounding_box: optional 2D bounding boxes tensor from inputs.
segmentation_mask: optional 3D segmentation mask tensor from inputs.
Returns:
Any type of object, which will be forwarded to `augment_image`,
Expand Down Expand Up @@ -253,8 +276,13 @@ def _augment(self, inputs):
label = inputs.get(LABELS, None)
bounding_boxes = inputs.get(BOUNDING_BOXES, None)
keypoints = inputs.get(KEYPOINTS, None)
segmentation_mask = inputs.get(SEGMENTATION_MASK, None)
transformation = self.get_random_transformation(
image=image, label=label, bounding_boxes=bounding_boxes, keypoints=keypoints
image=image,
label=label,
bounding_boxes=bounding_boxes,
keypoints=keypoints,
segmentation_mask=segmentation_mask,
)
image = self.augment_image(
image,
Expand Down Expand Up @@ -288,6 +316,12 @@ def _augment(self, inputs):
image=image,
)
result[KEYPOINTS] = keypoints
if segmentation_mask is not None:
segmentation_mask = self.augment_segmentation_mask(
segmentation_mask,
transformation=transformation,
)
result[SEGMENTATION_MASK] = segmentation_mask

# preserve any additional inputs unmodified by this layer.
for key in inputs.keys() - result.keys():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def augment_bounding_boxes(self, bounding_boxes, transformation, **kwargs):
def augment_keypoints(self, keypoints, transformation, **kwargs):
return keypoints + transformation

def augment_segmentation_mask(self, segmentation_mask, transformation, **kwargs):
return segmentation_mask + transformation


class VectorizeDisabledLayer(BaseImageAugmentationLayer):
def __init__(self, **kwargs):
Expand Down Expand Up @@ -148,15 +151,22 @@ def test_augment_image_and_localization_data(self):
images = np.random.random(size=(8, 8, 3)).astype("float32")
bounding_boxes = np.random.random(size=(3, 5)).astype("float32")
keypoints = np.random.random(size=(3, 5, 2)).astype("float32")
segmentation_mask = np.random.random(size=(8, 8, 1)).astype("float32")

output = add_layer(
{"images": images, "bounding_boxes": bounding_boxes, "keypoints": keypoints}
{
"images": images,
"bounding_boxes": bounding_boxes,
"keypoints": keypoints,
"segmentation_mask": segmentation_mask,
}
)

expected_output = {
"images": images + 2.0,
"bounding_boxes": bounding_boxes + 2.0,
"keypoints": keypoints + 2.0,
"segmentation_mask": segmentation_mask + 2.0,
}
self.assertAllClose(output, expected_output)

Expand All @@ -165,53 +175,78 @@ def test_augment_batch_image_and_localization_data(self):
images = np.random.random(size=(2, 8, 8, 3)).astype("float32")
bounding_boxes = np.random.random(size=(2, 3, 5)).astype("float32")
keypoints = np.random.random(size=(2, 3, 5, 2)).astype("float32")
segmentation_mask = np.random.random(size=(2, 8, 8, 1)).astype("float32")

output = add_layer(
{"images": images, "bounding_boxes": bounding_boxes, "keypoints": keypoints}
{
"images": images,
"bounding_boxes": bounding_boxes,
"keypoints": keypoints,
"segmentation_mask": segmentation_mask,
}
)

bounding_boxes_diff = output["bounding_boxes"] - bounding_boxes
keypoints_diff = output["keypoints"] - keypoints
segmentation_mask_diff = output["segmentation_mask"] - segmentation_mask
self.assertNotAllClose(bounding_boxes_diff[0], bounding_boxes_diff[1])
self.assertNotAllClose(keypoints_diff[0], keypoints_diff[1])
self.assertNotAllClose(segmentation_mask_diff[0], segmentation_mask_diff[1])

@tf.function
def in_tf_function(inputs):
return add_layer(inputs)

output = in_tf_function(
{"images": images, "bounding_boxes": bounding_boxes, "keypoints": keypoints}
{
"images": images,
"bounding_boxes": bounding_boxes,
"keypoints": keypoints,
"segmentation_mask": segmentation_mask,
}
)

bounding_boxes_diff = output["bounding_boxes"] - bounding_boxes
keypoints_diff = output["keypoints"] - keypoints
segmentation_mask_diff = output["segmentation_mask"] - segmentation_mask
self.assertNotAllClose(bounding_boxes_diff[0], bounding_boxes_diff[1])
self.assertNotAllClose(keypoints_diff[0], keypoints_diff[1])
self.assertNotAllClose(segmentation_mask_diff[0], segmentation_mask_diff[1])

def test_augment_all_data_in_tf_function(self):
add_layer = RandomAddLayer()
images = np.random.random(size=(2, 8, 8, 3)).astype("float32")
bounding_boxes = np.random.random(size=(2, 3, 5)).astype("float32")
keypoints = np.random.random(size=(2, 3, 5, 2)).astype("float32")
segmentation_mask = np.random.random(size=(2, 8, 8, 1)).astype("float32")

@tf.function
def in_tf_function(inputs):
return add_layer(inputs)

output = in_tf_function(
{"images": images, "bounding_boxes": bounding_boxes, "keypoints": keypoints}
{
"images": images,
"bounding_boxes": bounding_boxes,
"keypoints": keypoints,
"segmentation_mask": segmentation_mask,
}
)

bounding_boxes_diff = output["bounding_boxes"] - bounding_boxes
keypoints_diff = output["keypoints"] - keypoints
segmentation_mask_diff = output["segmentation_mask"] - segmentation_mask
self.assertNotAllClose(bounding_boxes_diff[0], bounding_boxes_diff[1])
self.assertNotAllClose(keypoints_diff[0], keypoints_diff[1])
self.assertNotAllClose(segmentation_mask_diff[0], segmentation_mask_diff[1])

def test_raise_error_missing_class_id(self):
add_layer = RandomAddLayer()
images = np.random.random(size=(2, 8, 8, 3)).astype("float32")
bounding_boxes = np.random.random(size=(2, 3, 4)).astype("float32")
keypoints = np.random.random(size=(2, 3, 5, 2)).astype("float32")
segmentation_mask = np.random.random(size=(2, 8, 8, 1)).astype("float32")

with self.assertRaisesRegex(
ValueError,
"Bounding boxes are missing class_id. If you would like to pad the "
Expand All @@ -223,5 +258,6 @@ def test_raise_error_missing_class_id(self):
"images": images,
"bounding_boxes": bounding_boxes,
"keypoints": keypoints,
"segmentation_mask": segmentation_mask,
}
)

0 comments on commit b380c08

Please sign in to comment.