diff --git a/.kokoro/README.md b/.kokoro/README.md new file mode 100644 index 0000000000..2c7724d988 --- /dev/null +++ b/.kokoro/README.md @@ -0,0 +1 @@ +CI to run on PR and merge to Master. \ No newline at end of file diff --git a/.kokoro/github/ubuntu/gpu/build.sh b/.kokoro/github/ubuntu/gpu/build.sh new file mode 100644 index 0000000000..e8fc6f5d75 --- /dev/null +++ b/.kokoro/github/ubuntu/gpu/build.sh @@ -0,0 +1,83 @@ +set -e +set -x + +cd "${KOKORO_ROOT}/" + +sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1 + +PYTHON_BINARY="/usr/bin/python3.9" + +"${PYTHON_BINARY}" -m venv venv +source venv/bin/activate +# Check the python version +python --version +python3 --version + +export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:" +# Check cuda +nvidia-smi +nvcc --version + +cd "src/github/keras-cv" +pip install -U pip setuptools + +if [ "${KERAS2:-0}" == "1" ] +then + echo "Keras2 detected." + pip install -r requirements-common.txt --progress-bar off + pip install tensorflow~=2.14 + pip install --extra-index-url https://download.pytorch.org/whl/cpu torch==2.1.0+cpu + pip install torchvision~=0.16.0 + pip install "jax[cpu]" + +elif [ "$KERAS_BACKEND" == "tensorflow" ] +then + echo "TensorFlow backend detected." + pip install -r requirements-tensorflow-cuda.txt --progress-bar off + +elif [ "$KERAS_BACKEND" == "jax" ] +then + echo "JAX backend detected." + pip install -r requirements-jax-cuda.txt --progress-bar off + +elif [ "$KERAS_BACKEND" == "torch" ] +then + echo "PyTorch backend detected." + pip install -r requirements-torch-cuda.txt --progress-bar off +fi + +pip install --no-deps -e "." --progress-bar off + +# Run Extra Large Tests for Continuous builds +if [ "${RUN_XLARGE:-0}" == "1" ] +then + pytest --check_gpu --run_large --run_extra_large --durations 0 \ + keras_cv/bounding_box \ + keras_cv/callbacks \ + keras_cv/losses \ + keras_cv/layers/object_detection \ + keras_cv/layers/preprocessing \ + keras_cv/models/backbones \ + keras_cv/models/classification \ + keras_cv/models/object_detection/retinanet \ + keras_cv/models/object_detection/yolo_v8 \ + keras_cv/models/object_detection_3d \ + keras_cv/models/segmentation \ + keras_cv/models/stable_diffusion \ + --cov=keras-cv +else + pytest --check_gpu --run_large --durations 0 \ + keras_cv/bounding_box \ + keras_cv/callbacks \ + keras_cv/losses \ + keras_cv/layers/object_detection \ + keras_cv/layers/preprocessing \ + keras_cv/models/backbones \ + keras_cv/models/classification \ + keras_cv/models/object_detection/retinanet \ + keras_cv/models/object_detection/yolo_v8 \ + keras_cv/models/object_detection_3d \ + keras_cv/models/segmentation \ + keras_cv/models/stable_diffusion \ + --cov=keras-cv +fi \ No newline at end of file diff --git a/.kokoro/github/ubuntu/gpu/jax/continuous.cfg b/.kokoro/github/ubuntu/gpu/jax/continuous.cfg new file mode 100644 index 0000000000..350fd02108 --- /dev/null +++ b/.kokoro/github/ubuntu/gpu/jax/continuous.cfg @@ -0,0 +1,18 @@ +build_file: "keras-cv/.kokoro/github/ubuntu/gpu/build.sh" + +action { + define_artifacts { + regex: "**/sponge_log.log" + regex: "**/sponge_log.xml" + } +} + +env_vars: { + key: "KERAS_BACKEND" + value: "jax" +} + +env_vars: { + key: "RUN_XLARGE" + value: "1" +} \ No newline at end of file diff --git a/.kokoro/github/ubuntu/gpu/jax/presubmit.cfg b/.kokoro/github/ubuntu/gpu/jax/presubmit.cfg new file mode 100644 index 0000000000..2aca2e95ff --- /dev/null +++ b/.kokoro/github/ubuntu/gpu/jax/presubmit.cfg @@ -0,0 +1,16 @@ +build_file: "keras-cv/.kokoro/github/ubuntu/gpu/build.sh" + +action { + define_artifacts { + regex: "**/sponge_log.log" + regex: "**/sponge_log.xml" + } +} + +env_vars: { + key: "KERAS_BACKEND" + value: "jax" +} + +# Set timeout to 60 mins from default 180 mins +timeout_mins: 60 \ No newline at end of file diff --git a/.kokoro/github/ubuntu/gpu/keras2/continuous.cfg b/.kokoro/github/ubuntu/gpu/keras2/continuous.cfg new file mode 100644 index 0000000000..361e35235b --- /dev/null +++ b/.kokoro/github/ubuntu/gpu/keras2/continuous.cfg @@ -0,0 +1,18 @@ +build_file: "keras-cv/.kokoro/github/ubuntu/gpu/build.sh" + +action { + define_artifacts { + regex: "**/sponge_log.log" + regex: "**/sponge_log.xml" + } +} + +env_vars: { + key: "KERAS2" + value: "1" +} + +env_vars: { + key: "RUN_XLARGE" + value: "1" +} \ No newline at end of file diff --git a/.kokoro/github/ubuntu/gpu/keras2/presubmit.cfg b/.kokoro/github/ubuntu/gpu/keras2/presubmit.cfg new file mode 100644 index 0000000000..d5caba18f7 --- /dev/null +++ b/.kokoro/github/ubuntu/gpu/keras2/presubmit.cfg @@ -0,0 +1,16 @@ +build_file: "keras-cv/.kokoro/github/ubuntu/gpu/build.sh" + +action { + define_artifacts { + regex: "**/sponge_log.log" + regex: "**/sponge_log.xml" + } +} + +env_vars: { + key: "KERAS2" + value: "1" +} + +# Set timeout to 60 mins from default 180 mins +timeout_mins: 60 \ No newline at end of file diff --git a/.kokoro/github/ubuntu/gpu/tensorflow/continuous.cfg b/.kokoro/github/ubuntu/gpu/tensorflow/continuous.cfg new file mode 100644 index 0000000000..9ed8200e71 --- /dev/null +++ b/.kokoro/github/ubuntu/gpu/tensorflow/continuous.cfg @@ -0,0 +1,18 @@ +build_file: "keras-cv/.kokoro/github/ubuntu/gpu/build.sh" + +action { + define_artifacts { + regex: "**/sponge_log.log" + regex: "**/sponge_log.xml" + } +} + +env_vars: { + key: "KERAS_BACKEND" + value: "tensorflow" +} + +env_vars: { + key: "RUN_XLARGE" + value: "1" +} \ No newline at end of file diff --git a/.kokoro/github/ubuntu/gpu/tensorflow/presubmit.cfg b/.kokoro/github/ubuntu/gpu/tensorflow/presubmit.cfg new file mode 100644 index 0000000000..f7e02e6efa --- /dev/null +++ b/.kokoro/github/ubuntu/gpu/tensorflow/presubmit.cfg @@ -0,0 +1,16 @@ +build_file: "keras-cv/.kokoro/github/ubuntu/gpu/build.sh" + +action { + define_artifacts { + regex: "**/sponge_log.log" + regex: "**/sponge_log.xml" + } +} + +env_vars: { + key: "KERAS_BACKEND" + value: "tensorflow" +} + +# Set timeout to 60 mins from default 180 mins +timeout_mins: 60 \ No newline at end of file diff --git a/.kokoro/github/ubuntu/gpu/torch/continuous.cfg b/.kokoro/github/ubuntu/gpu/torch/continuous.cfg new file mode 100644 index 0000000000..c3e118a6ef --- /dev/null +++ b/.kokoro/github/ubuntu/gpu/torch/continuous.cfg @@ -0,0 +1,18 @@ +build_file: "keras-cv/.kokoro/github/ubuntu/gpu/build.sh" + +action { + define_artifacts { + regex: "**/sponge_log.log" + regex: "**/sponge_log.xml" + } +} + +env_vars: { + key: "KERAS_BACKEND" + value: "torch" +} + +env_vars: { + key: "RUN_XLARGE" + value: "1" +} \ No newline at end of file diff --git a/.kokoro/github/ubuntu/gpu/torch/presubmit.cfg b/.kokoro/github/ubuntu/gpu/torch/presubmit.cfg new file mode 100644 index 0000000000..a96e865152 --- /dev/null +++ b/.kokoro/github/ubuntu/gpu/torch/presubmit.cfg @@ -0,0 +1,16 @@ +build_file: "keras-cv/.kokoro/github/ubuntu/gpu/build.sh" + +action { + define_artifacts { + regex: "**/sponge_log.log" + regex: "**/sponge_log.xml" + } +} + +env_vars: { + key: "KERAS_BACKEND" + value: "torch" +} + +# Set timeout to 60 mins from default 180 mins +timeout_mins: 60 \ No newline at end of file diff --git a/keras_cv/conftest.py b/keras_cv/conftest.py index b8be780c39..eaee5024b9 100644 --- a/keras_cv/conftest.py +++ b/keras_cv/conftest.py @@ -17,6 +17,7 @@ import tensorflow as tf from packaging import version +from keras_cv.backend import config as backend_config from keras_cv.backend.config import keras_3 @@ -33,9 +34,35 @@ def pytest_addoption(parser): default=False, help="run extra_large tests", ) + parser.addoption( + "--check_gpu", + action="store_true", + default=False, + help="fail if a gpu is not present", + ) def pytest_configure(config): + # Verify that device has GPU and detected by backend + if config.getoption("--check_gpu"): + found_gpu = False + backend = backend_config.backend() + if backend == "jax": + import jax + + try: + found_gpu = bool(jax.devices("gpu")) + except RuntimeError: + found_gpu = False + elif backend == "tensorflow": + found_gpu = bool(tf.config.list_logical_devices("GPU")) + elif backend == "torch": + import torch + + found_gpu = bool(torch.cuda.device_count()) + if not found_gpu: + pytest.fail(f"No GPUs discovered on the {backend} backend.") + config.addinivalue_line( "markers", "large: mark test as being slow or requiring a network" ) diff --git a/keras_cv/layers/preprocessing/equalization.py b/keras_cv/layers/preprocessing/equalization.py index cdc8808433..7dbeb636a6 100644 --- a/keras_cv/layers/preprocessing/equalization.py +++ b/keras_cv/layers/preprocessing/equalization.py @@ -12,17 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial + import tensorflow as tf from keras_cv.api_export import keras_cv_export -from keras_cv.layers.preprocessing.base_image_augmentation_layer import ( - BaseImageAugmentationLayer, +from keras_cv.layers.preprocessing.vectorized_base_image_augmentation_layer import ( # noqa: E501 + VectorizedBaseImageAugmentationLayer, ) from keras_cv.utils import preprocessing @keras_cv_export("keras_cv.layers.Equalization") -class Equalization(BaseImageAugmentationLayer): +class Equalization(VectorizedBaseImageAugmentationLayer): """Equalization performs histogram equalization on a channel-wise basis. Args: @@ -52,7 +54,7 @@ def __init__(self, value_range, bins=256, **kwargs): self.bins = bins self.value_range = value_range - def equalize_channel(self, image, channel_index): + def equalize_channel(self, images, channel_index): """equalize_channel performs histogram equalization on a single channel. Args: @@ -60,9 +62,25 @@ def equalize_channel(self, image, channel_index): with channels last channel_index: channel to equalize """ - image = image[..., channel_index] + is_single_image = tf.rank(images) == 4 and tf.shape(images)[0] == 1 + + images = images[..., channel_index] # Compute the histogram of the image channel. - histogram = tf.histogram_fixed_width(image, [0, 255], nbins=self.bins) + + # If the input is not a batch of images, directly using + # tf.histogram_fixed_width is much faster than using tf.vectorized_map + if is_single_image: + histogram = tf.histogram_fixed_width( + images, [0, 255], nbins=self.bins + ) + histogram = tf.expand_dims(histogram, axis=0) + else: + partial_hist = partial( + tf.histogram_fixed_width, value_range=[0, 255], nbins=self.bins + ) + histogram = tf.vectorized_map( + partial_hist, images, fallback_to_while_loop=True, warn=True + ) # For the purposes of computing the step, filter out the non-zeros. # Zeroes are replaced by a big number while calculating min to keep @@ -77,56 +95,94 @@ def equalize_channel(self, image, channel_index): ) step = ( - tf.reduce_sum(histogram) - tf.reduce_min(histogram_without_zeroes) + tf.reduce_sum(histogram, axis=-1) + - tf.reduce_min(histogram_without_zeroes, axis=-1) ) // (self.bins - 1) def build_mapping(histogram, step): + bacth_size = tf.shape(histogram)[0] + + # Replace where step is 0 with 1 to avoid division by 0. + # This doesn't change the result, because where step==0 the + # original image is returned + _step = tf.where( + tf.equal(step, 0), + 1, + step, + ) + _step = tf.expand_dims(_step, -1) + # Compute the cumulative sum, shifting by step // 2 # and then normalization by step. - lookup_table = (tf.cumsum(histogram) + (step // 2)) // step + lookup_table = ( + tf.cumsum(histogram, axis=-1) + (_step // 2) + ) // _step + # Shift lookup_table, prepending with 0. - lookup_table = tf.concat([[0], lookup_table[:-1]], 0) + lookup_table = tf.concat( + [tf.tile([[0]], [bacth_size, 1]), lookup_table[..., :-1]], + axis=1, + ) + # Clip the counts to be in range. This is done # in the C code for image.point. return tf.clip_by_value(lookup_table, 0, 255) # If step is zero, return the original image. Otherwise, build # lookup table from the full histogram and step and then index from it. - result = tf.cond( - tf.equal(step, 0), - lambda: image, - lambda: tf.gather(build_mapping(histogram, step), image), + # The lookup table is built for all images, + # regardless of the corresponding value of step. + result = tf.where( + tf.reshape(tf.equal(step, 0), (-1, 1, 1)), + images, + tf.gather( + build_mapping(histogram, step), images, batch_dims=1, axis=1 + ), ) return result - def augment_image(self, image, **kwargs): - image = preprocessing.transform_value_range( - image, self.value_range, (0, 255), dtype=self.compute_dtype + def augment_images(self, images, transformations=None, **kwargs): + images = preprocessing.transform_value_range( + images, self.value_range, (0, 255), dtype=self.compute_dtype ) - image = tf.cast(image, tf.int32) - image = tf.map_fn( - lambda channel: self.equalize_channel(image, channel), - tf.range(tf.shape(image)[-1]), + images = tf.cast(images, tf.int32) + + images = tf.map_fn( + lambda channel: self.equalize_channel(images, channel), + tf.range(tf.shape(images)[-1]), ) + images = tf.transpose(images, [1, 2, 3, 0]) - image = tf.transpose(image, [1, 2, 0]) - image = tf.cast(image, self.compute_dtype) - image = preprocessing.transform_value_range( - image, (0, 255), self.value_range, dtype=self.compute_dtype + images = tf.cast(images, self.compute_dtype) + images = preprocessing.transform_value_range( + images, (0, 255), self.value_range, dtype=self.compute_dtype ) - return image + return images def augment_bounding_boxes(self, bounding_boxes, **kwargs): return bounding_boxes - def augment_label(self, label, transformation=None, **kwargs): - return label + def augment_labels(self, labels, transformations=None, **kwargs): + return labels - def augment_segmentation_mask( - self, segmentation_mask, transformation, **kwargs + def augment_segmentation_masks( + self, segmentation_masks, transformations, **kwargs ): - return segmentation_mask + return segmentation_masks + + def augment_keypoints(self, keypoints, transformations, **kwargs): + return keypoints + + def augment_targets(self, targets, transformations, **kwargs): + return targets + + def augment_ragged_image(self, image, transformation, **kwargs): + image = tf.expand_dims(image, axis=0) + image = self.augment_images( + images=image, transformations=transformation, **kwargs + ) + return tf.squeeze(image, axis=0) def get_config(self): config = super().get_config() diff --git a/keras_cv/layers/preprocessing/random_color_degeneration.py b/keras_cv/layers/preprocessing/random_color_degeneration.py index f678a054b2..1d8206a3a2 100644 --- a/keras_cv/layers/preprocessing/random_color_degeneration.py +++ b/keras_cv/layers/preprocessing/random_color_degeneration.py @@ -16,14 +16,14 @@ from keras_cv.api_export import keras_cv_export from keras_cv.backend import keras -from keras_cv.layers.preprocessing.base_image_augmentation_layer import ( - BaseImageAugmentationLayer, +from keras_cv.layers.preprocessing.vectorized_base_image_augmentation_layer import ( # noqa: E501 + VectorizedBaseImageAugmentationLayer, ) from keras_cv.utils import preprocessing @keras_cv_export("keras_cv.layers.RandomColorDegeneration") -class RandomColorDegeneration(BaseImageAugmentationLayer): +class RandomColorDegeneration(VectorizedBaseImageAugmentationLayer): """Randomly performs the color degeneration operation on given images. The sharpness operation first converts an image to gray scale, then back to @@ -57,24 +57,39 @@ def __init__( ) self.seed = seed - def get_random_transformation(self, **kwargs): - return self.factor(dtype=self.compute_dtype) + def get_random_transformation_batch(self, batch_size, **kwargs): + return self.factor( + shape=(batch_size, 1, 1, 1), dtype=self.compute_dtype + ) - def augment_image(self, image, transformation=None, **kwargs): - degenerate = tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(image)) - result = preprocessing.blend(image, degenerate, transformation) + def augment_images(self, images, transformations=None, **kwargs): + degenerates = tf.image.grayscale_to_rgb( + tf.image.rgb_to_grayscale(images) + ) + result = preprocessing.blend(images, degenerates, transformations) return result def augment_bounding_boxes(self, bounding_boxes, **kwargs): return bounding_boxes - def augment_label(self, label, transformation=None, **kwargs): - return label + def augment_labels(self, labels, transformations=None, **kwargs): + return labels - def augment_segmentation_mask( - self, segmentation_mask, transformation, **kwargs + def augment_segmentation_masks( + self, segmentation_masks, transformations, **kwargs ): - return segmentation_mask + return segmentation_masks + + def augment_keypoints(self, keypoints, transformations, **kwargs): + return keypoints + + def augment_targets(self, targets, transformations, **kwargs): + return targets + + def augment_ragged_image(self, image, transformation, **kwargs): + return self.augment_images( + image, transformations=transformation, **kwargs + ) def get_config(self): config = super().get_config() diff --git a/keras_cv/layers/preprocessing/random_cutout.py b/keras_cv/layers/preprocessing/random_cutout.py index 30d4c31a84..8760630f5c 100644 --- a/keras_cv/layers/preprocessing/random_cutout.py +++ b/keras_cv/layers/preprocessing/random_cutout.py @@ -15,15 +15,18 @@ import tensorflow as tf from keras_cv.api_export import keras_cv_export -from keras_cv.layers.preprocessing.base_image_augmentation_layer import ( - BaseImageAugmentationLayer, +from keras_cv.layers.preprocessing.vectorized_base_image_augmentation_layer import ( # noqa: E501 + VectorizedBaseImageAugmentationLayer, ) from keras_cv.utils import fill_utils from keras_cv.utils import preprocessing +H_AXIS = -3 +W_AXIS = -2 + @keras_cv_export("keras_cv.layers.RandomCutout") -class RandomCutout(BaseImageAugmentationLayer): +class RandomCutout(VectorizedBaseImageAugmentationLayer): """Randomly cut out rectangles from images and fill them. Args: @@ -90,74 +93,132 @@ def __init__( f'or "constant". Got `fill_mode`={fill_mode}' ) - def _parse_bounds(self, factor): - if isinstance(factor, (tuple, list)): - return factor[0], factor[1] - else: - return type(factor)(0), factor - - def get_random_transformation(self, image=None, **kwargs): - center_x, center_y = self._compute_rectangle_position(image) - rectangle_height, rectangle_width = self._compute_rectangle_size(image) - return center_x, center_y, rectangle_height, rectangle_width + def get_random_transformation_batch(self, batch_size, images, **kwargs): + centers_x, centers_y = self._compute_rectangle_position(images) + rectangles_height, rectangles_width = self._compute_rectangle_size( + images + ) + return { + "centers_x": centers_x, + "centers_y": centers_y, + "rectangles_height": rectangles_height, + "rectangles_width": rectangles_width, + } - def augment_image(self, image, transformation=None, **kwargs): + def augment_images(self, images, transformations=None, **kwargs): """Apply random cutout.""" - inputs = tf.expand_dims(image, 0) - center_x, center_y, rectangle_height, rectangle_width = transformation - - rectangle_fill = self._compute_rectangle_fill(inputs) - inputs = fill_utils.fill_rectangle( - inputs, - center_x, - center_y, - rectangle_width, - rectangle_height, - rectangle_fill, + centers_x, centers_y = ( + transformations["centers_x"], + transformations["centers_y"], + ) + rectangles_height, rectangles_width = ( + transformations["rectangles_height"], + transformations["rectangles_width"], + ) + + rectangles_fill = self._compute_rectangle_fill(images) + images = fill_utils.fill_rectangle( + images, + centers_x, + centers_y, + rectangles_width, + rectangles_height, + rectangles_fill, ) - return inputs[0] + return images - def augment_label(self, label, transformation=None, **kwargs): - return label + def augment_bounding_boxes(self, bounding_boxes, **kwargs): + return bounding_boxes - def augment_segmentation_mask( - self, segmentation_masks, transformation=None, **kwargs + def augment_labels(self, labels, transformations=None, **kwargs): + return labels + + def augment_segmentation_masks( + self, segmentation_masks, transformations, **kwargs ): return segmentation_masks - def _compute_rectangle_position(self, inputs): - input_shape = tf.shape(inputs) - image_height, image_width = ( - input_shape[0], - input_shape[1], + def augment_keypoints(self, keypoints, transformations, **kwargs): + return keypoints + + def augment_targets(self, targets, transformations, **kwargs): + return targets + + def augment_ragged_image(self, image, transformation, **kwargs): + image = tf.expand_dims(image, axis=0) + centers_x, centers_y = ( + transformation["centers_x"], + transformation["centers_y"], + ) + rectangles_height, rectangles_width = ( + transformation["rectangles_height"], + transformation["rectangles_width"], + ) + transformation = { + "centers_x": tf.expand_dims(centers_x, axis=0), + "centers_y": tf.expand_dims(centers_y, axis=0), + "rectangles_height": tf.expand_dims(rectangles_height, axis=0), + "rectangles_width": tf.expand_dims(rectangles_width, axis=0), + } + image = self.augment_images( + images=image, transformations=transformation, **kwargs ) + return tf.squeeze(image, axis=0) + + def _get_image_shape(self, images): + if isinstance(images, tf.RaggedTensor): + heights = tf.reshape(images.row_lengths(), (-1,)) + widths = tf.reshape( + tf.reduce_max(images.row_lengths(axis=2), 1), (-1,) + ) + else: + batch_size = tf.shape(images)[0] + heights = tf.repeat(tf.shape(images)[H_AXIS], repeats=[batch_size]) + heights = tf.reshape(heights, shape=(-1,)) + widths = tf.repeat(tf.shape(images)[W_AXIS], repeats=[batch_size]) + widths = tf.reshape(widths, shape=(-1,)) + return tf.cast(heights, dtype=tf.int32), tf.cast(widths, dtype=tf.int32) + + def _compute_rectangle_position(self, inputs): + batch_size = tf.shape(inputs)[0] + heights, widths = self._get_image_shape(inputs) + + # generate values in float32 and then cast (i.e. round) to int32 because + # random.uniform do not support maxval broadcasting for integer types. + # Needed because maxval is a 1-D tensor to support ragged inputs. + + heights = tf.cast(heights, dtype=tf.float32) + widths = tf.cast(widths, dtype=tf.float32) + center_x = self._random_generator.uniform( - [1], 0, image_width, dtype=tf.int32 + (batch_size,), 0, widths, dtype=tf.float32 ) center_y = self._random_generator.uniform( - [1], 0, image_height, dtype=tf.int32 + (batch_size,), 0, heights, dtype=tf.float32 ) + + center_x = tf.cast(center_x, tf.int32) + center_y = tf.cast(center_y, tf.int32) + return center_x, center_y def _compute_rectangle_size(self, inputs): - input_shape = tf.shape(inputs) - image_height, image_width = ( - input_shape[0], - input_shape[1], - ) - height = self.height_factor() - width = self.width_factor() + batch_size = tf.shape(inputs)[0] + images_heights, images_widths = self._get_image_shape(inputs) - height = height * tf.cast(image_height, tf.float32) - width = width * tf.cast(image_width, tf.float32) + height = self.height_factor(shape=(batch_size,)) + width = self.width_factor(shape=(batch_size,)) + + height = height * tf.cast(images_heights, tf.float32) + width = width * tf.cast(images_widths, tf.float32) height = tf.cast(tf.math.ceil(height), tf.int32) width = tf.cast(tf.math.ceil(width), tf.int32) - height = tf.minimum(height, image_height) - width = tf.minimum(width, image_width) + height = tf.minimum(height, images_heights) + width = tf.minimum(width, images_heights) - return tf.expand_dims(height, axis=0), tf.expand_dims(width, axis=0) + return height, width def _compute_rectangle_fill(self, inputs): input_shape = tf.shape(inputs) @@ -167,16 +228,25 @@ def _compute_rectangle_fill(self, inputs): else: # gaussian noise fill_value = tf.random.normal(input_shape, dtype=self.compute_dtype) - + # rescale the random noise to the original image range + image_max = tf.reduce_max(inputs) + image_min = tf.reduce_min(inputs) + fill_max = tf.reduce_max(fill_value) + fill_min = tf.reduce_min(fill_value) + fill_value = (image_max - image_min) * (fill_value - fill_min) / ( + fill_max - fill_min + ) + image_min return fill_value def get_config(self): - config = { - "height_factor": self.height_factor, - "width_factor": self.width_factor, - "fill_mode": self.fill_mode, - "fill_value": self.fill_value, - "seed": self.seed, - } - base_config = super().get_config() - return dict(list(base_config.items()) + list(config.items())) + config = super().get_config() + config.update( + { + "height_factor": self.height_factor, + "width_factor": self.width_factor, + "fill_mode": self.fill_mode, + "fill_value": self.fill_value, + "seed": self.seed, + } + ) + return config