From 305c083b986f2c7e795bd9a579cc8d364b8d1800 Mon Sep 17 00:00:00 2001 From: Ian Stenbit <3072903+ianstenbit@users.noreply.github.com> Date: Thu, 8 Sep 2022 13:03:18 -0600 Subject: [PATCH] Add segmask support for random_flip (#775) * Add segmask support for random_flip * Set seed for test case * Undo RRC changes * Make shared _flip_image method --- keras_cv/layers/preprocessing/random_flip.py | 32 ++++++++++++------- .../layers/preprocessing/random_flip_test.py | 22 +++++++++++++ 2 files changed, 42 insertions(+), 12 deletions(-) diff --git a/keras_cv/layers/preprocessing/random_flip.py b/keras_cv/layers/preprocessing/random_flip.py index f3ad5dea21..afa42247db 100644 --- a/keras_cv/layers/preprocessing/random_flip.py +++ b/keras_cv/layers/preprocessing/random_flip.py @@ -85,18 +85,7 @@ def augment_label(self, label, transformation, **kwargs): return label def augment_image(self, image, transformation, **kwargs): - flipped_output = tf.cond( - transformation["flip_horizontal"], - lambda: tf.image.flip_left_right(image), - lambda: image, - ) - flipped_output = tf.cond( - transformation["flip_vertical"], - lambda: tf.image.flip_up_down(flipped_output), - lambda: flipped_output, - ) - flipped_output.set_shape(image.shape) - return flipped_output + return RandomFlip._flip_image(image, transformation) def get_random_transformation(self, **kwargs): flip_horizontal = False @@ -110,6 +99,20 @@ def get_random_transformation(self, **kwargs): "flip_vertical": tf.cast(flip_vertical, dtype=tf.bool), } + def _flip_image(image, transformation): + flipped_output = tf.cond( + transformation["flip_horizontal"], + lambda: tf.image.flip_left_right(image), + lambda: image, + ) + flipped_output = tf.cond( + transformation["flip_vertical"], + lambda: tf.image.flip_up_down(flipped_output), + lambda: flipped_output, + ) + flipped_output.set_shape(image.shape) + return flipped_output + def _flip_bounding_boxes_horizontal(bounding_boxes): x1, x2, x3, x4, rest = tf.split( bounding_boxes, [1, 1, 1, 1, bounding_boxes.shape[-1] - 4], axis=-1 @@ -186,6 +189,11 @@ def augment_bounding_boxes( ) return bounding_boxes + def augment_segmentation_mask( + self, segmentation_mask, transformation=None, **kwargs + ): + return RandomFlip._flip_image(segmentation_mask, transformation) + def compute_output_shape(self, input_shape): return input_shape diff --git a/keras_cv/layers/preprocessing/random_flip_test.py b/keras_cv/layers/preprocessing/random_flip_test.py index 388e3b7a8d..6db4b17e29 100644 --- a/keras_cv/layers/preprocessing/random_flip_test.py +++ b/keras_cv/layers/preprocessing/random_flip_test.py @@ -134,3 +134,25 @@ def test_augment_bbox_batched_input(self): ) expected_output = np.reshape(expected_output, (2, 2, 5)) self.assertAllClose(expected_output, output["bounding_boxes"]) + + def test_augment_segmentation_mask(self): + np.random.seed(1337) + image = np.random.random((1, 20, 20, 3)).astype(np.float32) + mask = np.random.randint(2, size=(1, 20, 20, 1)).astype(np.float32) + + input = {"images": image, "segmentation_masks": mask} + + # Flip both vertically and horizontally + mock_random = [0.6, 0.6] + layer = RandomFlip() + + with unittest.mock.patch.object( + layer._random_generator, + "random_uniform", + side_effect=mock_random, + ): + output = layer(input, training=True) + + expected_mask = np.flip(np.flip(mask, axis=1), axis=2) + + self.assertAllClose(expected_mask, output["segmentation_masks"])