Skip to content

Commit

Permalink
Add segmask support for random_flip (#775)
Browse files Browse the repository at this point in the history
* Add segmask support for random_flip

* Set seed for test case

* Undo RRC changes

* Make shared _flip_image method
  • Loading branch information
ianstenbit authored Sep 8, 2022
1 parent 02dda74 commit 305c083
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 12 deletions.
32 changes: 20 additions & 12 deletions keras_cv/layers/preprocessing/random_flip.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,7 @@ def augment_label(self, label, transformation, **kwargs):
return label

def augment_image(self, image, transformation, **kwargs):
flipped_output = tf.cond(
transformation["flip_horizontal"],
lambda: tf.image.flip_left_right(image),
lambda: image,
)
flipped_output = tf.cond(
transformation["flip_vertical"],
lambda: tf.image.flip_up_down(flipped_output),
lambda: flipped_output,
)
flipped_output.set_shape(image.shape)
return flipped_output
return RandomFlip._flip_image(image, transformation)

def get_random_transformation(self, **kwargs):
flip_horizontal = False
Expand All @@ -110,6 +99,20 @@ def get_random_transformation(self, **kwargs):
"flip_vertical": tf.cast(flip_vertical, dtype=tf.bool),
}

def _flip_image(image, transformation):
flipped_output = tf.cond(
transformation["flip_horizontal"],
lambda: tf.image.flip_left_right(image),
lambda: image,
)
flipped_output = tf.cond(
transformation["flip_vertical"],
lambda: tf.image.flip_up_down(flipped_output),
lambda: flipped_output,
)
flipped_output.set_shape(image.shape)
return flipped_output

def _flip_bounding_boxes_horizontal(bounding_boxes):
x1, x2, x3, x4, rest = tf.split(
bounding_boxes, [1, 1, 1, 1, bounding_boxes.shape[-1] - 4], axis=-1
Expand Down Expand Up @@ -186,6 +189,11 @@ def augment_bounding_boxes(
)
return bounding_boxes

def augment_segmentation_mask(
self, segmentation_mask, transformation=None, **kwargs
):
return RandomFlip._flip_image(segmentation_mask, transformation)

def compute_output_shape(self, input_shape):
return input_shape

Expand Down
22 changes: 22 additions & 0 deletions keras_cv/layers/preprocessing/random_flip_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,25 @@ def test_augment_bbox_batched_input(self):
)
expected_output = np.reshape(expected_output, (2, 2, 5))
self.assertAllClose(expected_output, output["bounding_boxes"])

def test_augment_segmentation_mask(self):
np.random.seed(1337)
image = np.random.random((1, 20, 20, 3)).astype(np.float32)
mask = np.random.randint(2, size=(1, 20, 20, 1)).astype(np.float32)

input = {"images": image, "segmentation_masks": mask}

# Flip both vertically and horizontally
mock_random = [0.6, 0.6]
layer = RandomFlip()

with unittest.mock.patch.object(
layer._random_generator,
"random_uniform",
side_effect=mock_random,
):
output = layer(input, training=True)

expected_mask = np.flip(np.flip(mask, axis=1), axis=2)

self.assertAllClose(expected_mask, output["segmentation_masks"])

0 comments on commit 305c083

Please sign in to comment.