diff --git a/.kokoro/github/ubuntu/gpu/build.sh b/.kokoro/github/ubuntu/gpu/build.sh index 1ffd49b7af..cd14821d67 100644 --- a/.kokoro/github/ubuntu/gpu/build.sh +++ b/.kokoro/github/ubuntu/gpu/build.sh @@ -65,6 +65,7 @@ then keras_cv/models/object_detection/retinanet \ keras_cv/models/object_detection/yolo_v8 \ keras_cv/models/object_detection_3d \ + keras_cv/models/object_detection/faster_rcnn \ keras_cv/models/segmentation \ keras_cv/models/stable_diffusion else @@ -79,6 +80,7 @@ else keras_cv/models/object_detection/retinanet \ keras_cv/models/object_detection/yolo_v8 \ keras_cv/models/object_detection_3d \ + keras_cv/models/object_detection/faster_rcnn \ keras_cv/models/segmentation \ keras_cv/models/stable_diffusion fi \ No newline at end of file diff --git a/demo.py b/demo.py new file mode 100644 index 0000000000..4f156cc214 --- /dev/null +++ b/demo.py @@ -0,0 +1,76 @@ +import os + +os.environ['CUDA_VISIBLE_DEVICES'] = '-1' + +import keras + +import keras_cv +from keras_cv.models import FasterRCNN + +batch_size = 1 +image_shape = (512, 512, 3) + +images = keras.ops.ones((batch_size,) + image_shape) +labels = { + "boxes": keras.ops.array( + [ + [ + [0, 0, 100, 100], + [100, 100, 200, 200], + [300, 300, 100, 100], + ] + ], + dtype="float32", + ), + "classes": keras.ops.array([[1, 1, 1]], dtype="float32"), +} + +# Initialize the model +model = FasterRCNN( + batch_size=batch_size, + num_classes=2, + bounding_box_format="xywh", + backbone=keras_cv.models.ResNet50Backbone.from_preset( + "resnet50_imagenet", + input_shape=image_shape, + ), +) + +# Call the model +outputs = model(images) +print("outputs") +for key, value in outputs.items(): + print(f"{key}: {value.shape}") + +# Compile the model +model.compile( + optimizer=keras.optimizers.Adam(), + box_loss=keras.losses.Huber(), + classification_loss=keras.losses.CategoricalCrossentropy(), + rpn_box_loss=keras.losses.Huber(), + rpn_classification_loss=keras.losses.BinaryCrossentropy(from_logits=True), +) + +# Compute Loss from the model +loss = model.compute_loss(x=images, y=labels, y_pred=None, sample_weight=None) +print(loss) + +# Train step +xs = keras.ops.ones((1, 512, 512, 3), "float32") +ys = { + "classes": keras.ops.array([[1, 1, 1]], dtype="float32"), + "boxes": keras.ops.array( + [ + [ + [0, 0, 100, 100], + [100, 100, 200, 200], + [300, 300, 100, 100], + ] + ], + dtype="float32", + ), +} +import tensorflow as tf +ds = tf.data.Dataset.from_tensor_slices((xs, ys)) +ds = ds.batch(1, drop_remainder=True) +model.fit(ds, epochs=1) \ No newline at end of file diff --git a/keras_cv/bounding_box/utils.py b/keras_cv/bounding_box/utils.py index fd85f2b893..dc0978a259 100644 --- a/keras_cv/bounding_box/utils.py +++ b/keras_cv/bounding_box/utils.py @@ -141,7 +141,7 @@ def _clip_boxes(boxes, box_format, image_shape): if isinstance(image_shape, list) or isinstance(image_shape, tuple): height, width, _ = image_shape - max_length = [height, width, height, width] + max_length = ops.stack([height, width, height, width], axis=-1) else: image_shape = ops.cast(image_shape, dtype=boxes.dtype) height = image_shape[0] diff --git a/keras_cv/layers/object_detection/roi_align.py b/keras_cv/layers/object_detection/roi_align.py index 7821311af7..5b8ed109d3 100644 --- a/keras_cv/layers/object_detection/roi_align.py +++ b/keras_cv/layers/object_detection/roi_align.py @@ -25,31 +25,26 @@ def _feature_bilinear_interpolation(features, kernel_y, kernel_x): The RoIAlign feature f can be computed by bilinear interpolation of four neighboring feature points f0, f1, f2, and f3. + f(y, x) = [hy, ly] * [[f00, f01], * [hx, lx]^T - [f10, f11]] + [f10, f11]] f(y, x) = (hy*hx)f00 + (hy*lx)f01 + (ly*hx)f10 + (lx*ly)f11 f(y, x) = w00*f00 + w01*f01 + w10*f10 + w11*f11 kernel_y = [hy, ly] kernel_x = [hx, lx] Args: - features: The features are in shape of [batch_size, num_boxes, - output_size * 2, output_size * 2, num_filters]. - kernel_y: Tensor of size [batch_size, boxes, output_size, 2, 1]. - kernel_x: Tensor of size [batch_size, boxes, output_size, 2, 1]. + features: The features are in shape of [batch_size, num_boxes, output_size * + 2, output_size * 2, num_filters]. + kernel_y: Tensor of size [batch_size, boxes, output_size, 2, 1]. + kernel_x: Tensor of size [batch_size, boxes, output_size, 2, 1]. Returns: - A 5-D tensor representing feature crop of shape - [batch_size, num_boxes, output_size, output_size, num_filters]. - """ - features_shape = ops.shape(features) - batch_size, num_boxes, output_size, num_filters = ( - features_shape[0], - features_shape[1], - features_shape[2], - features_shape[4], - ) + A 5-D tensor representing feature crop of shape + [batch_size, num_boxes, output_size, output_size, num_filters]. + """ + (batch_size, num_boxes, output_size, _, num_filters) = ops.shape(features) output_size = output_size // 2 kernel_y = ops.reshape( kernel_y, [batch_size, num_boxes, output_size * 2, 1] @@ -69,48 +64,38 @@ def _feature_bilinear_interpolation(features, kernel_y, kernel_x): features, [batch_size * num_boxes, output_size * 2, output_size * 2, num_filters], ) - features = ops.nn.average_pool( - features, [1, 2, 2, 1], [1, 2, 2, 1], "VALID" - ) + features = ops.average_pool(features, [1, 2, 2, 1], [1, 2, 2, 1], "VALID") features = ops.reshape( features, [batch_size, num_boxes, output_size, output_size, num_filters] ) return features -def _compute_grid_positions( - boxes, - boundaries, - output_size, - sample_offset, -): - """ - Computes the grid position w.r.t. the corresponding feature map. +def _compute_grid_positions(boxes, boundaries, output_size, sample_offset): + """Compute the grid position w.r.t. + + the corresponding feature map. Args: - boxes: a 3-D tensor of shape [batch_size, num_boxes, 4] encoding the + boxes: a 3-D tensor of shape [batch_size, num_boxes, 4] encoding the information of each box w.r.t. the corresponding feature map. boxes[:, :, 0:2] are the grid position in (y, x) (float) of the top-left corner of each box. boxes[:, :, 2:4] are the box sizes in (h, w) (float) - in terms of the number of pixels of the corresponding feature map - size. - boundaries: a 3-D tensor of shape [batch_size, num_boxes, 2] representing + in terms of the number of pixels of the corresponding feature map size. + boundaries: a 3-D tensor of shape [batch_size, num_boxes, 2] representing the boundary (in (y, x)) of the corresponding feature map for each box. - Any resampled grid points that go beyond the boundary will be clipped. - output_size: a scalar indicating the output crop size. - sample_offset: a float number in [0, 1] indicates the subpixel sample - offset from grid point. + Any resampled grid points that go beyond the bounary will be clipped. + output_size: a scalar indicating the output crop size. + sample_offset: a float number in [0, 1] indicates the subpixel sample offset + from grid point. Returns: - kernel_y: Tensor of size [batch_size, boxes, output_size, 2, 1]. - kernel_x: Tensor of size [batch_size, boxes, output_size, 2, 1]. - box_grid_y0y1: Tensor of size [batch_size, boxes, output_size, 2] - box_grid_x0x1: Tensor of size [batch_size, boxes, output_size, 2] + kernel_y: Tensor of size [batch_size, boxes, output_size, 2, 1]. + kernel_x: Tensor of size [batch_size, boxes, output_size, 2, 1]. + box_grid_y0y1: Tensor of size [batch_size, boxes, output_size, 2] + box_grid_x0x1: Tensor of size [batch_size, boxes, output_size, 2] """ - boxes_shape = ops.shape(boxes) - batch_size, num_boxes = boxes_shape[0], boxes_shape[1] - if batch_size is None: - batch_size = ops.shape(boxes)[0] + batch_size, num_boxes, _ = ops.shape(boxes) box_grid_x = [] box_grid_y = [] for i in range(output_size): @@ -125,12 +110,8 @@ def _compute_grid_positions( box_grid_y0 = ops.floor(box_grid_y) box_grid_x0 = ops.floor(box_grid_x) - box_grid_x0 = ops.maximum( - ops.cast(0.0, dtype=box_grid_x0.dtype), box_grid_x0 - ) - box_grid_y0 = ops.maximum( - ops.cast(0.0, dtype=box_grid_y0.dtype), box_grid_y0 - ) + box_grid_x0 = ops.maximum(0.0, box_grid_x0) + box_grid_y0 = ops.maximum(0.0, box_grid_y0) box_grid_x0 = ops.minimum( box_grid_x0, ops.expand_dims(boundaries[:, :, 1], -1) @@ -168,51 +149,33 @@ def _compute_grid_positions( def multilevel_crop_and_resize( - features, - boxes, - output_size: int = 7, - sample_offset: float = 0.5, + features, boxes, output_size=7, sample_offset=0.5 ): - """ - Crop and resize on multilevel feature pyramid. + """Crop and resize on multilevel feature pyramid. Generate the (output_size, output_size) set of pixels for each input box by first locating the box into the correct feature level, and then cropping - and resizing it using the corresponding feature map of that level. + and resizing it using the correspoding feature map of that level. Args: - features: A dictionary with key as pyramid level and value as features. - The pyramid level keys need to be represented by strings like so: - "P2", "P3", "P4", and so on. - The features are in shape of [batch_size, height_l, width_l, - num_filters]. - boxes: A 3-D Tensor of shape [batch_size, num_boxes, 4]. Each row - represents a box with [y1, x1, y2, x2] in un-normalized coordinates. - output_size: A scalar to indicate the output crop size. - sample_offset: a float number in [0, 1] indicates the subpixel sample - offset from grid point. + features: A dictionary with key as pyramid level and value as features. The + features are in shape of [batch_size, height_l, width_l, num_filters]. + boxes: A 3-D Tensor of shape [batch_size, num_boxes, 4]. Each row represents + a box with [y1, x1, y2, x2] in un-normalized coordinates. + output_size: A scalar to indicate the output crop size. Returns: - A 5-D tensor representing feature crop of shape - [batch_size, num_boxes, output_size, output_size, num_filters]. + A 5-D tensor representing feature crop of shape + [batch_size, num_boxes, output_size, output_size, num_filters]. """ - - levels_str = list(features.keys()) - # Levels are represented by strings with a prefix "P" to represent - # pyramid levels. The integer level can be obtained by looking at - # the value that follows the "P". - levels = [int(level_str[1:]) for level_str in levels_str] + levels = list(features.keys()) + levels = [int(level[1:]) for level in levels] min_level = min(levels) max_level = max(levels) - features_shape = ops.shape(features[f"P{min_level}"]) - batch_size, max_feature_height, max_feature_width, num_filters = ( - features_shape[0], - features_shape[1], - features_shape[2], - features_shape[3], + batch_size, max_feature_height, max_feature_width, num_filters = ops.shape( + features[f"P{min_level}"] ) - - num_boxes = ops.shape(boxes)[1] + _, num_boxes, _ = ops.shape(boxes) # Stack feature pyramid into a features_all of shape # [batch_size, levels, height, width, num_filters]. @@ -223,14 +186,14 @@ def multilevel_crop_and_resize( shape = ops.shape(features[f"P{level}"]) feature_heights.append(shape[1]) feature_widths.append(shape[2]) - # Concat tensor of [batch_size, height_l * width_l, num_filters] for - # each level. + # Concat tensor of [batch_size, height_l * width_l, num_filters] for each + # levels. features_all.append( ops.reshape(features[f"P{level}"], [batch_size, -1, num_filters]) ) - features_r2 = ops.reshape( - ops.concatenate(features_all, 1), [-1, num_filters] - ) + features_r2 = ops.reshape( + ops.concatenate(features_all, 1), [-1, num_filters] + ) # Calculate height_l * width_l for each level. level_dim_sizes = [ @@ -242,26 +205,15 @@ def multilevel_crop_and_resize( for i in range(len(feature_widths) - 1): level_dim_offsets.append(level_dim_offsets[i] + level_dim_sizes[i]) batch_dim_size = level_dim_offsets[-1] + level_dim_sizes[-1] - level_dim_offsets = ( - ops.ones_like(level_dim_offsets, dtype="int32") * level_dim_offsets - ) - height_dim_sizes = ( - ops.ones_like(feature_widths, dtype="int32") * feature_widths - ) + level_dim_offsets = ops.array(level_dim_offsets, dtype="int32") + height_dim_sizes = ops.array(feature_widths, dtype="int32") # Assigns boxes to the right level. box_width = boxes[:, :, 3] - boxes[:, :, 1] box_height = boxes[:, :, 2] - boxes[:, :, 0] - areas_sqrt = ops.sqrt( - ops.cast(box_height, "float32") * ops.cast(box_width, "float32") - ) - - # following the FPN paper to divide by 224. + areas_sqrt = ops.sqrt(box_height * box_width) levels = ops.cast( - ops.floor_divide( - ops.log(ops.divide(areas_sqrt, 224.0)), - ops.log(2.0), - ) + ops.floor_divide(ops.log(ops.divide(areas_sqrt, 224.0)), ops.log(2.0)) + 4.0, dtype="int32", ) @@ -270,7 +222,7 @@ def multilevel_crop_and_resize( # Projects box location and sizes to corresponding feature levels. scale_to_level = ops.cast( - ops.pow(2.0, ops.cast(levels, "float32")), + ops.power(ops.array(2.0), ops.cast(levels, "float32")), dtype=boxes.dtype, ) boxes /= ops.expand_dims(scale_to_level, axis=2) @@ -287,7 +239,7 @@ def multilevel_crop_and_resize( # Maps levels to [0, max_level-min_level]. levels -= min_level - level_strides = ops.pow([[2.0]], ops.cast(levels, "float32")) + level_strides = ops.power([[2.0]], ops.cast(levels, "float32")) boundary = ops.cast( ops.concatenate( [ @@ -308,12 +260,9 @@ def multilevel_crop_and_resize( ) # Compute grid positions. - ( - kernel_y, - kernel_x, - box_gridy0y1, - box_gridx0x1, - ) = _compute_grid_positions(boxes, boundary, output_size, sample_offset) + kernel_y, kernel_x, box_gridy0y1, box_gridx0x1 = _compute_grid_positions( + boxes, boundary, output_size, sample_offset=sample_offset + ) x_indices = ops.cast( ops.reshape(box_gridx0x1, [batch_size, num_boxes, output_size * 2]), @@ -333,8 +282,7 @@ def multilevel_crop_and_resize( # Get level offset for each box. Each box belongs to one level. levels_offset = ops.tile( ops.reshape( - ops.take(level_dim_offsets, levels), - [batch_size, num_boxes, 1, 1], + ops.take(level_dim_offsets, levels), [batch_size, num_boxes, 1, 1] ), [1, 1, output_size * 2, output_size * 2], ) @@ -354,17 +302,11 @@ def multilevel_crop_and_resize( [-1], ) - # TODO(tanzhenyu): replace tf.gather with tf.gather_nd and try to get - # similar performance. + # TODO(tanzhenyu): replace tf.gather with tf.gather_nd and try to get similar + # performance. features_per_box = ops.reshape( - ops.take(features_r2, indices), - [ - batch_size, - num_boxes, - output_size * 2, - output_size * 2, - num_filters, - ], + ops.take(features_r2, indices, axis=0), + [batch_size, num_boxes, output_size * 2, output_size * 2, num_filters], ) # Bilinear interpolation. @@ -397,7 +339,6 @@ def __init__( sample_offset: A `float` in [0, 1] of the subpixel sample offset. **kwargs: Additional keyword arguments passed to Layer. """ - # assert_tf_keras("keras_cv.layers._ROIAligner") self._config_dict = { "bounding_box_format": bounding_box_format, "crop_size": target_size, diff --git a/keras_cv/layers/object_detection/roi_generator.py b/keras_cv/layers/object_detection/roi_generator.py index 6f035a7045..dcad5dac25 100644 --- a/keras_cv/layers/object_detection/roi_generator.py +++ b/keras_cv/layers/object_detection/roi_generator.py @@ -145,7 +145,7 @@ def per_level_gen(boxes, scores): # If so, remove the last dimension to make it 2D if len(scores_shape) == 3: scores = ops.squeeze(scores, axis=-1) - _, num_boxes = scores_shape + num_boxes = scores_shape[1] level_pre_nms_topk = min(num_boxes, pre_nms_topk) level_post_nms_topk = min(num_boxes, post_nms_topk) scores, sorted_indices = ops.top_k( diff --git a/keras_cv/models/__init__.py b/keras_cv/models/__init__.py index 77c3ad33d9..cf6a8188fa 100644 --- a/keras_cv/models/__init__.py +++ b/keras_cv/models/__init__.py @@ -185,6 +185,7 @@ from keras_cv.models.backbones.vit_det.vit_det_backbone import ViTDetBackbone from keras_cv.models.classification.image_classifier import ImageClassifier from keras_cv.models.feature_extractor.clip import CLIP +from keras_cv.models.object_detection.faster_rcnn.faster_rcnn import FasterRCNN from keras_cv.models.object_detection.retinanet.retinanet import RetinaNet from keras_cv.models.object_detection.yolo_v8.yolo_v8_backbone import ( YOLOV8Backbone, diff --git a/keras_cv/models/legacy/__init__.py b/keras_cv/models/legacy/__init__.py index 419ae34b31..36fc9bd43f 100644 --- a/keras_cv/models/legacy/__init__.py +++ b/keras_cv/models/legacy/__init__.py @@ -27,9 +27,6 @@ from keras_cv.models.legacy.mlp_mixer import MLPMixerB16 from keras_cv.models.legacy.mlp_mixer import MLPMixerB32 from keras_cv.models.legacy.mlp_mixer import MLPMixerL16 -from keras_cv.models.legacy.object_detection.faster_rcnn.faster_rcnn import ( - FasterRCNN, -) from keras_cv.models.legacy.regnet import RegNetX002 from keras_cv.models.legacy.regnet import RegNetX004 from keras_cv.models.legacy.regnet import RegNetX006 diff --git a/keras_cv/models/legacy/object_detection/faster_rcnn/faster_rcnn_test.py b/keras_cv/models/legacy/object_detection/faster_rcnn/faster_rcnn_test.py deleted file mode 100644 index b8930af944..0000000000 --- a/keras_cv/models/legacy/object_detection/faster_rcnn/faster_rcnn_test.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2022 The KerasCV Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import pytest -import tensorflow as tf -from absl.testing import parameterized -from tensorflow import keras -from tensorflow.keras import optimizers - -from keras_cv.models import ResNet18V2Backbone -from keras_cv.models.legacy.object_detection.faster_rcnn.faster_rcnn import ( - FasterRCNN, -) -from keras_cv.models.object_detection.__test_utils__ import ( - _create_bounding_box_dataset, -) -from keras_cv.tests.test_case import TestCase - - -class FasterRCNNTest(TestCase): - # TODO(ianstenbit): Make FasterRCNN support shapes that are not multiples - # of 128, perhaps by adding a flag to the anchor generator for whether to - # include anchors centered outside of the image. (RetinaNet does use those, - # while FasterRCNN doesn't). For more context on why this is the case, see - # https://github.com/keras-team/keras-cv/pull/1882 - @parameterized.parameters( - ((2, 640, 384, 3),), - ((2, 512, 512, 3),), - ((2, 128, 128, 3),), - ) - def test_faster_rcnn_infer(self, batch_shape): - model = FasterRCNN( - num_classes=80, - bounding_box_format="xyxy", - backbone=ResNet18V2Backbone(), - ) - images = tf.random.normal(batch_shape) - outputs = model(images, training=False) - # 1000 proposals in inference - self.assertAllEqual([2, 1000, 81], outputs[1].shape) - self.assertAllEqual([2, 1000, 4], outputs[0].shape) - - @parameterized.parameters( - ((2, 640, 384, 3),), - ((2, 512, 512, 3),), - ((2, 128, 128, 3),), - ) - def test_faster_rcnn_train(self, batch_shape): - model = FasterRCNN( - num_classes=80, - bounding_box_format="xyxy", - backbone=ResNet18V2Backbone(), - ) - images = tf.random.normal(batch_shape) - outputs = model(images, training=True) - self.assertAllEqual([2, 1000, 81], outputs[1].shape) - self.assertAllEqual([2, 1000, 4], outputs[0].shape) - - def test_invalid_compile(self): - model = FasterRCNN( - num_classes=80, - bounding_box_format="yxyx", - backbone=ResNet18V2Backbone(), - ) - with self.assertRaisesRegex(ValueError, "only accepts"): - model.compile(rpn_box_loss="binary_crossentropy") - with self.assertRaisesRegex(ValueError, "only accepts"): - model.compile( - rpn_classification_loss=keras.losses.BinaryCrossentropy( - from_logits=False - ) - ) - - @pytest.mark.large # Fit is slow, so mark these large. - def test_faster_rcnn_with_dictionary_input_format(self): - faster_rcnn = FasterRCNN( - num_classes=20, - bounding_box_format="xywh", - backbone=ResNet18V2Backbone(), - ) - - images, boxes = _create_bounding_box_dataset("xywh") - dataset = tf.data.Dataset.from_tensor_slices( - {"images": images, "bounding_boxes": boxes} - ).batch(5, drop_remainder=True) - - faster_rcnn.compile( - optimizer=optimizers.Adam(), - box_loss="Huber", - classification_loss="SparseCategoricalCrossentropy", - rpn_box_loss="Huber", - rpn_classification_loss="BinaryCrossentropy", - ) - - faster_rcnn.fit(dataset, epochs=1) - faster_rcnn.evaluate(dataset) diff --git a/keras_cv/models/legacy/object_detection/__init__.py b/keras_cv/models/object_detection/faster_rcnn/__init__.py similarity index 65% rename from keras_cv/models/legacy/object_detection/__init__.py rename to keras_cv/models/object_detection/faster_rcnn/__init__.py index 65be099991..d5f9e37b30 100644 --- a/keras_cv/models/legacy/object_detection/__init__.py +++ b/keras_cv/models/object_detection/faster_rcnn/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2022 The KerasCV Authors +# Copyright 2023 The KerasCV Authors # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -11,3 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from keras_cv.models.object_detection.faster_rcnn.feature_pyramid import ( + FeaturePyramid, +) +from keras_cv.models.object_detection.faster_rcnn.rcnn_head import RCNNHead +from keras_cv.models.object_detection.faster_rcnn.rpn_head import RPNHead diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py new file mode 100644 index 0000000000..5079c06c70 --- /dev/null +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn.py @@ -0,0 +1,402 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tree + +from keras_cv.api_export import keras_cv_export +from keras_cv.backend import keras +from keras_cv.bounding_box.converters import _decode_deltas_to_boxes +from keras_cv.bounding_box.utils import _clip_boxes +from keras_cv.layers.object_detection.anchor_generator import AnchorGenerator +from keras_cv.layers.object_detection.box_matcher import BoxMatcher +from keras_cv.layers.object_detection.roi_align import _ROIAligner +from keras_cv.layers.object_detection.roi_generator import ROIGenerator +from keras_cv.layers.object_detection.roi_sampler import _ROISampler +from keras_cv.layers.object_detection.rpn_label_encoder import _RpnLabelEncoder +from keras_cv.models.object_detection.__internal__ import unpack_input +from keras_cv.models.object_detection.faster_rcnn import FeaturePyramid +from keras_cv.models.object_detection.faster_rcnn import RCNNHead +from keras_cv.models.object_detection.faster_rcnn import RPNHead +from keras_cv.models.task import Task +from keras_cv.utils.train import get_feature_extractor + +BOX_VARIANCE = [0.1, 0.1, 0.2, 0.2] + + +# TODO(tanzheny): add more configurations +@keras_cv_export("keras_cv.models.FasterRCNN") +class FasterRCNN(Task): + def __init__( + self, + batch_size, + backbone, + num_classes, + bounding_box_format, + anchor_generator=None, + feature_pyramid=None, + rcnn_head=None, + label_encoder=None, + *args, + **kwargs, + ): + + # 1. Create the Input Layer + extractor_levels = ["P2", "P3", "P4", "P5"] + extractor_layer_names = [ + backbone.pyramid_level_inputs[i] for i in extractor_levels + ] + feature_extractor = get_feature_extractor( + backbone, extractor_layer_names, extractor_levels + ) + feature_pyramid = feature_pyramid or FeaturePyramid( + name="feature_pyramid" + ) + image_shape = feature_extractor.input_shape[ + 1: + ] # exclude the batch size + images = keras.layers.Input( + image_shape, + batch_size=batch_size, + name="images", + ) + + # 2. Create the anchors + scales = [2**x for x in [0]] + aspect_ratios = [0.5, 1.0, 2.0] + anchor_generator = anchor_generator or AnchorGenerator( + bounding_box_format=bounding_box_format, + sizes={ + "P2": 32.0, + "P3": 64.0, + "P4": 128.0, + "P5": 256.0, + "P6": 512.0, + }, + scales=scales, + aspect_ratios=aspect_ratios, + strides={f"P{i}": 2**i for i in range(2, 7)}, + clip_boxes=True, + name="anchor_generator", + ) + # Note: `image_shape` should not be of NoneType + # Need to assert before this line + anchors = anchor_generator(image_shape=image_shape) + + ####################################################################### + # Call RPN + ####################################################################### + + # 3. Get the backbone outputs + backbone_outputs = feature_extractor(images) + feature_map = feature_pyramid(backbone_outputs) + + # 4. Get the Region Proposal Boxes and Scores + num_anchors_per_location = len(scales) * len(aspect_ratios) + rpn_head = RPNHead( + num_anchors_per_location=num_anchors_per_location, name="rpn_head" + ) + # [BS, num_anchors, 4], [BS, num_anchors, 1] + rpn_boxes, rpn_scores = rpn_head(feature_map) + + # 5. Decode the deltas to boxes + decoded_rpn_boxes = _decode_deltas_to_boxes( + anchors=anchors, + boxes_delta=rpn_boxes, + anchor_format=bounding_box_format, + box_format=bounding_box_format, + variance=BOX_VARIANCE, + ) + + # 6. Generate the Region of Interests + roi_generator = ROIGenerator( + bounding_box_format=bounding_box_format, + nms_score_threshold_train=float("-inf"), + nms_score_threshold_test=float("-inf"), + name="roi_generator", + ) + rois, _ = roi_generator(decoded_rpn_boxes, rpn_scores) + rois = _clip_boxes(rois, bounding_box_format, image_shape) + rpn_box_pred = keras.ops.concatenate(tree.flatten(rpn_boxes), axis=1) + rpn_cls_pred = keras.ops.concatenate(tree.flatten(rpn_scores), axis=1) + + ####################################################################### + # Call RCNN + ####################################################################### + + # 7. Pool the region of interests + roi_pooler = _ROIAligner(bounding_box_format="yxyx", name="roi_pooler") + feature_map = roi_pooler(features=feature_map, boxes=rois) + + # 8. Reshape the feature map [BS, H*W*K] + feature_map = keras.ops.reshape( + feature_map, + newshape=keras.ops.shape(rois)[:2] + (-1,), + ) + + # 9. Pass the feature map to RCNN head + # [BS, H*W*K, 4], [BS, H*W*K, num_classes + 1] + rcnn_head = rcnn_head or RCNNHead( + num_classes=num_classes, name="rcnn_head" + ) + box_pred, cls_pred = rcnn_head(feature_map=feature_map) + + # 10. Create the model using Functional API + inputs = {"images": images} + box_pred = keras.layers.Concatenate(axis=1, name="box")([box_pred]) + cls_pred = keras.layers.Concatenate(axis=1, name="classification")( + [cls_pred] + ) + rpn_box_pred = keras.layers.Concatenate(axis=1, name="rpn_box")( + [rpn_box_pred] + ) + rpn_cls_pred = keras.layers.Concatenate( + axis=1, name="rpn_classification" + )([rpn_cls_pred]) + outputs = { + "box": box_pred, + "classification": cls_pred, + "rpn_box": rpn_box_pred, + "rpn_classification": rpn_cls_pred, + } + + super().__init__(inputs=inputs, outputs=outputs, *args, **kwargs) + + # Define the model parameters + self.bounding_box_format = bounding_box_format + self.anchor_generator = anchor_generator + self.rpn_labeler = label_encoder or _RpnLabelEncoder( + anchor_format="yxyx", + ground_truth_box_format="yxyx", + positive_threshold=0.7, + negative_threshold=0.3, + samples_per_image=256, + positive_fraction=0.5, + box_variance=BOX_VARIANCE, + ) + self.feature_extractor = feature_extractor + self.feature_pyramid = feature_pyramid + self.roi_generator = roi_generator + self.rpn_head = rpn_head + self.box_matcher = BoxMatcher( + thresholds=[0.0, 0.5], match_values=[-2, -1, 1] + ) + self.roi_sampler = _ROISampler( + bounding_box_format="yxyx", + roi_matcher=self.box_matcher, + background_class=num_classes, + num_sampled_rois=512, + ) + self.roi_pooler = roi_pooler + self.rcnn_head = rcnn_head + + def compile( + self, + box_loss=None, + classification_loss=None, + rpn_box_loss=None, + rpn_classification_loss=None, + weight_decay=0.0001, + loss=None, + metrics=None, + **kwargs, + ): + if loss is not None: + raise ValueError( + "`FasterRCNN` does not accept a `loss` to `compile()`. " + "Instead, please pass `box_loss` and `classification_loss`. " + "`loss` will be ignored during training." + ) + box_loss = _parse_box_loss(box_loss) + classification_loss = _parse_classification_loss(classification_loss) + + rpn_box_loss = _parse_box_loss(rpn_box_loss) + rpn_classification_loss = _parse_classification_loss( + rpn_classification_loss + ) + + self.rpn_box_loss = rpn_box_loss + self.rpn_cls_loss = rpn_classification_loss + self.box_loss = box_loss + self.cls_loss = classification_loss + self.weight_decay = weight_decay + losses = { + "box": self.box_loss, + "classification": self.cls_loss, + "rpn_box": self.rpn_box_loss, + "rpn_classification": self.rpn_cls_loss, + } + self._has_user_metrics = metrics is not None and len(metrics) != 0 + self._user_metrics = metrics + super().compile(loss=losses, **kwargs) + + def compute_loss(self, x, y, y_pred, sample_weight, **kwargs): + # 1. Unpack the inputs + images = x + gt_boxes = y["boxes"] + if keras.ops.ndim(y["classes"]) != 2: + raise ValueError( + "Expected 'classes' to be a Tensor of rank 2. " + f"Got y['classes'].shape={keras.ops.shape(y['classes'])}." + ) + # TODO(tanzhenyu): remove this hack and perform broadcasting elsewhere + # gt_classes = keras.ops.expand_dims(y["classes"], axis=-1) + gt_classes = y["classes"] + + # Generate anchors + # image shape must not contain the batch size + local_batch = keras.ops.shape(images)[0] + image_shape = keras.ops.shape(images)[1:] + anchors = self.anchor_generator(image_shape=image_shape) + + # 2. Label with the anchors -- exclusive to compute_loss + ( + rpn_box_targets, + rpn_box_weights, + rpn_cls_targets, + rpn_cls_weights, + ) = self.rpn_labeler( + anchors_dict=keras.ops.concatenate( + tree.flatten(anchors), + axis=0, + ), + gt_boxes=gt_boxes, + gt_classes=gt_classes, + ) + + # 3. Computing the weights + rpn_box_weights /= ( + self.rpn_labeler.samples_per_image * local_batch * 0.25 + ) + rpn_cls_weights /= self.rpn_labeler.samples_per_image * local_batch + + ####################################################################### + # Call RPN + ####################################################################### + + backbone_outputs = self.feature_extractor(images) + feature_map = self.feature_pyramid(backbone_outputs) + + # [BS, num_anchors, 4], [BS, num_anchors, 1] + rpn_boxes, rpn_scores = self.rpn_head(feature_map) + + decoded_rpn_boxes = _decode_deltas_to_boxes( + anchors=anchors, + boxes_delta=rpn_boxes, + anchor_format=self.bounding_box_format, + box_format=self.bounding_box_format, + variance=BOX_VARIANCE, + ) + + rois, _ = self.roi_generator(decoded_rpn_boxes, rpn_scores) + rois = _clip_boxes(rois, self.bounding_box_format, image_shape) + rpn_box_pred = keras.ops.concatenate(tree.flatten(rpn_boxes), axis=1) + rpn_cls_pred = keras.ops.concatenate(tree.flatten(rpn_scores), axis=1) + + # 4. Stop gradient from flowing into the ROI -- exclusive to compute_loss + rois = keras.ops.stop_gradient(rois) + + # 5. Sample the ROIS -- exclusive to compute_loss -- exclusive to compute loss + ( + rois, + box_targets, + box_weights, + cls_targets, + cls_weights, + ) = self.roi_sampler(rois, gt_boxes, gt_classes) + + # 6. Box and class weights -- exclusive to compute loss + box_weights /= self.roi_sampler.num_sampled_rois * local_batch * 0.25 + cls_weights /= self.roi_sampler.num_sampled_rois * local_batch + + ####################################################################### + # Call RCNN + ####################################################################### + + feature_map = self.roi_pooler(features=feature_map, boxes=rois) + + # [BS, H*W*K] + feature_map = keras.ops.reshape( + feature_map, + newshape=keras.ops.shape(rois)[:2] + (-1,), + ) + + # [BS, H*W*K, 4], [BS, H*W*K, num_classes + 1] + box_pred, cls_pred = self.rcnn_head(feature_map=feature_map) + + y_true = { + "rpn_box": rpn_box_targets, + "rpn_classification": rpn_cls_targets, + "box": box_targets, + "classification": cls_targets, + } + y_pred = { + "rpn_box": rpn_box_pred, + "rpn_classification": rpn_cls_pred, + "box": box_pred, + "classification": cls_pred, + } + weights = { + "rpn_box": rpn_box_weights, + "rpn_classification": rpn_cls_weights, + "box": box_weights, + "classification": cls_weights, + } + + return super().compute_loss( + x=x, y=y_true, y_pred=y_pred, sample_weight=weights, **kwargs + ) + + def train_step(self, *args): + data = args[-1] + args = args[:-1] + x, y = unpack_input(data) + return super().train_step(*args, (x, y)) + + def test_step(self, *args): + data = args[-1] + args = args[:-1] + x, y = unpack_input(data) + return super().test_step(*args, (x, y)) + + +def _parse_box_loss(loss): + if not isinstance(loss, str): + # support arbitrary callables + return loss + + # case insensitive comparison + if loss.lower() == "smoothl1": + return keras.losses.SmoothL1Loss(l1_cutoff=1.0, reduction="sum") + if loss.lower() == "huber": + return keras.losses.Huber(reduction="sum") + + raise ValueError( + "Expected `box_loss` to be either a Keras Loss, " + f"callable, or the string 'SmoothL1', 'Huber'. Got loss={loss}." + ) + + +def _parse_classification_loss(loss): + if not isinstance(loss, str): + # support arbitrary callables + return loss + + # case insensitive comparison + if loss.lower() == "focal": + return keras.losses.FocalLoss(from_logits=True, reduction="sum") + + raise ValueError( + "Expected `classification_loss` to be either a Keras Loss, " + f"callable, or the string 'Focal'. Got loss={loss}." + ) diff --git a/keras_cv/models/legacy/object_detection/faster_rcnn/__init__.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_presets.py similarity index 94% rename from keras_cv/models/legacy/object_detection/faster_rcnn/__init__.py rename to keras_cv/models/object_detection/faster_rcnn/faster_rcnn_presets.py index 3992ffb59a..1c056e7835 100644 --- a/keras_cv/models/legacy/object_detection/faster_rcnn/__init__.py +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_presets.py @@ -11,3 +11,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""FastrerRCNN Task presets.""" diff --git a/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py new file mode 100644 index 0000000000..ee92e38d7f --- /dev/null +++ b/keras_cv/models/object_detection/faster_rcnn/faster_rcnn_test.py @@ -0,0 +1,347 @@ +# Copyright 2022 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np +import pytest +import tensorflow as tf +from absl.testing import parameterized + +import keras_cv +from keras_cv.backend import keras +from keras_cv.backend import ops +from keras_cv.models.backbones.test_backbone_presets import ( + test_backbone_presets, +) +from keras_cv.models.object_detection.__test_utils__ import ( + _create_bounding_box_dataset, +) +from keras_cv.models.object_detection.faster_rcnn.faster_rcnn import FasterRCNN +from keras_cv.tests.test_case import TestCase + + +class FasterRCNNTest(TestCase): + def test_faster_rcnn_construction(self): + faster_rcnn = FasterRCNN( + num_classes=80, + bounding_box_format="xyxy", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(512, 512, 3) + ), + ) + faster_rcnn.compile( + optimizer=keras.optimizers.Adam(), + box_loss="Huber", + classification_loss="SparseCategoricalCrossentropy", + rpn_box_loss="Huber", + rpn_classification_loss="BinaryCrossentropy", + ) + + def test_faster_rcnn_call(self): + faster_rcnn = keras_cv.models.FasterRCNN( + num_classes=80, + bounding_box_format="xywh", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(512, 512, 3) + ), + ) + images = np.random.uniform(size=(2, 512, 512, 3)) + _ = faster_rcnn(images) + _ = faster_rcnn.predict(images) + + def test_wrong_logits(self): + faster_rcnn = keras_cv.models.FasterRCNN( + num_classes=80, + bounding_box_format="xywh", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(512, 512, 3) + ), + ) + + with self.assertRaisesRegex( + ValueError, + "from_logits", + ): + faster_rcnn.compile( + optimizer=keras.optimizers.SGD(learning_rate=0.25), + box_loss=keras_cv.losses.SmoothL1Loss( + l1_cutoff=1.0, reduction="none" + ), + classification_loss=keras_cv.losses.FocalLoss( + from_logits=False, reduction="none" + ), + rpn_box_loss=keras_cv.losses.SmoothL1Loss( + l1_cutoff=1.0, reduction="none" + ), + rpn_classification_loss=keras_cv.losses.FocalLoss( + from_logits=False, reduction="none" + ), + ) + + def test_weights_contained_in_trainable_variables(self): + bounding_box_format = "xyxy" + faster_rcnn = keras_cv.models.FasterRCNN( + num_classes=80, + bounding_box_format=bounding_box_format, + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(512, 512, 3) + ), + ) + faster_rcnn.backbone.trainable = False + faster_rcnn.compile( + optimizer=keras.optimizers.Adam(), + box_loss="Huber", + classification_loss="SparseCategoricalCrossentropy", + rpn_box_loss="Huber", + rpn_classification_loss="BinaryCrossentropy", + ) + xs, ys = _create_bounding_box_dataset(bounding_box_format) + + # call once + _ = faster_rcnn(xs) + self.assertEqual(len(faster_rcnn.trainable_variables), 32) + + @pytest.mark.large # Fit is slow, so mark these large. + def test_no_nans(self): + faster_rcnn = keras_cv.models.FasterRCNN( + num_classes=80, + bounding_box_format="xyxy", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(512, 512, 3) + ), + ) + faster_rcnn.compile( + optimizer=keras.optimizers.Adam(), + box_loss="Huber", + classification_loss="SparseCategoricalCrossentropy", + rpn_box_loss="Huber", + rpn_classification_loss="BinaryCrossentropy", + ) + + # only a -1 box + xs = np.ones((1, 512, 512, 3), "float32") + ys = { + "classes": np.array([[-1]], "float32"), + "boxes": np.array([[[0, 0, 0, 0]]], "float32"), + } + ds = tf.data.Dataset.from_tensor_slices((xs, ys)) + ds = ds.repeat(2) + ds = ds.batch(2, drop_remainder=True) + faster_rcnn.fit(ds, epochs=1) + + weights = faster_rcnn.get_weights() + for weight in weights: + self.assertFalse(ops.any(ops.isnan(weight))) + + @pytest.mark.large # Fit is slow, so mark these large. + def test_weights_change(self): + faster_rcnn = keras_cv.models.FasterRCNN( + num_classes=80, + bounding_box_format="xyxy", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(512, 512, 3) + ), + ) + faster_rcnn.compile( + optimizer=keras.optimizers.Adam(), + box_loss="Huber", + classification_loss="SparseCategoricalCrossentropy", + rpn_box_loss="Huber", + rpn_classification_loss="BinaryCrossentropy", + ) + + images, boxes = _create_bounding_box_dataset("xyxy") + ds = tf.data.Dataset.from_tensor_slices( + {"images": images, "bounding_boxes": boxes} + ).batch(5, drop_remainder=True) + + # call once + _ = faster_rcnn(ops.ones((1, 512, 512, 3))) + original_fpn_weights = faster_rcnn.feature_pyramid.get_weights() + original_rpn_head_weights = faster_rcnn.rpn_head.get_weights() + original_rcnn_head_weights = faster_rcnn.rcnn_head.get_weights() + + faster_rcnn.fit(ds, epochs=1) + fpn_after_fit = faster_rcnn.feature_pyramid.get_weights() + rpn_head_after_fit_weights = faster_rcnn.rpn_head.get_weights() + rcnn_head_after_fit_weights = faster_rcnn.rcnn_head.get_weights() + + for w1, w2 in zip( + original_rcnn_head_weights, + rcnn_head_after_fit_weights, + ): + self.assertNotAllClose(w1, w2) + + for w1, w2 in zip( + original_rpn_head_weights, rpn_head_after_fit_weights + ): + self.assertNotAllClose(w1, w2) + + for w1, w2 in zip(original_fpn_weights, fpn_after_fit): + self.assertNotAllClose(w1, w2) + + @pytest.mark.large # Saving is slow, so mark these large. + def test_saved_model(self): + model = keras_cv.models.FasterRCNN( + num_classes=80, + bounding_box_format="xyxy", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(512, 512, 3) + ), + ) + input_batch = ops.ones(shape=(1, 512, 512, 3)) + model_output = model(input_batch) + save_path = os.path.join(self.get_temp_dir(), "faster_rcnn.keras") + model.save(save_path) + restored_model = keras.models.load_model(save_path) + + # Check we got the real object back. + self.assertIsInstance(restored_model, keras_cv.models.FasterRCNN) + + # Check that output matches. + restored_output = restored_model(input_batch) + self.assertAllClose( + tf.nest.map_structure(ops.convert_to_numpy, model_output), + tf.nest.map_structure(ops.convert_to_numpy, restored_output), + ) + + # TODO(ianstenbit): Make FasterRCNN support shapes that are not multiples + # of 128, perhaps by adding a flag to the anchor generator for whether to + # include anchors centered outside of the image. (RetinaNet does use those, + # while FasterRCNN doesn't). For more context on why this is the case, see + # https://github.com/keras-team/keras-cv/pull/1882 + @parameterized.parameters( + ((2, 640, 384, 3),), + ((2, 512, 512, 3),), + ((2, 128, 128, 3),), + ) + def test_faster_rcnn_infer(self, batch_shape): + model = FasterRCNN( + num_classes=80, + bounding_box_format="xyxy", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=batch_shape[1:] + ), + ) + images = ops.random.normal(batch_shape) + outputs = model(images, training=False) + # 1000 proposals in inference + self.assertAllEqual([2, 1000, 81], outputs["classification"].shape) + self.assertAllEqual([2, 1000, 4], outputs["box"].shape) + + @parameterized.parameters( + ((2, 640, 384, 3),), + ((2, 512, 512, 3),), + ((2, 128, 128, 3),), + ) + def test_faster_rcnn_train(self, batch_shape): + model = FasterRCNN( + num_classes=80, + bounding_box_format="xyxy", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=batch_shape[1:] + ), + ) + images = ops.random.normal(batch_shape) + outputs = model(images, training=True) + self.assertAllEqual([2, 1000, 81], outputs["classification"].shape) + self.assertAllEqual([2, 1000, 4], outputs["box"].shape) + + def test_invalid_compile(self): + model = FasterRCNN( + num_classes=80, + bounding_box_format="yxyx", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(512, 512, 3) + ), + ) + with self.assertRaisesRegex(ValueError, "only accepts"): + model.compile(rpn_box_loss="binary_crossentropy") + with self.assertRaisesRegex(ValueError, "only accepts"): + model.compile( + rpn_classification_loss=keras.losses.BinaryCrossentropy( + from_logits=False + ) + ) + + @pytest.mark.large # Fit is slow, so mark these large. + def test_faster_rcnn_with_dictionary_input_format(self): + faster_rcnn = FasterRCNN( + num_classes=20, + bounding_box_format="xywh", + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(512, 512, 3) + ), + ) + + images, boxes = _create_bounding_box_dataset("xywh") + dataset = tf.data.Dataset.from_tensor_slices( + {"images": images, "bounding_boxes": boxes} + ).batch(5, drop_remainder=True) + + faster_rcnn.compile( + optimizer=keras.optimizers.Adam(), + box_loss="Huber", + classification_loss="SparseCategoricalCrossentropy", + rpn_box_loss="Huber", + rpn_classification_loss="BinaryCrossentropy", + ) + + faster_rcnn.fit(dataset, epochs=1) + faster_rcnn.evaluate(dataset) + + # @pytest.mark.large # Fit is slow, so mark these large. + def test_fit_with_no_valid_gt_bbox(self): + bounding_box_format = "xywh" + faster_rcnn = FasterRCNN( + num_classes=20, + bounding_box_format=bounding_box_format, + backbone=keras_cv.models.ResNet18V2Backbone( + input_shape=(512, 512, 3) + ), + ) + + faster_rcnn.compile( + optimizer=keras.optimizers.Adam(), + box_loss="Huber", + classification_loss="SparseCategoricalCrossentropy", + rpn_box_loss="Huber", + rpn_classification_loss="BinaryCrossentropy", + ) + xs, ys = _create_bounding_box_dataset(bounding_box_format) + # Make all bounding_boxes invalid and filter out them + ys["classes"] = -np.ones_like(ys["classes"]) + + faster_rcnn.fit(x=xs, y=ys, epochs=1) + + +@pytest.mark.large +class FasterRCNNSmokeTest(TestCase): + @parameterized.named_parameters( + *[(preset, preset) for preset in test_backbone_presets] + ) + @pytest.mark.extra_large + def test_backbone_preset(self, preset): + model = keras_cv.models.FasterRCNN.from_preset( + preset, + num_classes=20, + bounding_box_format="xywh", + ) + xs, _ = _create_bounding_box_dataset(bounding_box_format="xywh") + output = model(xs) + + # 64 represents number of parameters in a box + # 5376 is the number of anchors for a 512x512 image + self.assertEqual(output["boxes"].shape, (xs.shape[0], 5376, 64)) diff --git a/keras_cv/models/object_detection/faster_rcnn/feature_pyramid.py b/keras_cv/models/object_detection/faster_rcnn/feature_pyramid.py new file mode 100644 index 0000000000..18648a4ccc --- /dev/null +++ b/keras_cv/models/object_detection/faster_rcnn/feature_pyramid.py @@ -0,0 +1,75 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_cv.api_export import keras_cv_export +from keras_cv.backend import keras + + +@keras_cv_export( + "keras_cv.models.faster_rcnn.FeaturePyramid", + package="keras_cv.models.faster_rcnn", +) +class FeaturePyramid(keras.layers.Layer): + """Builds the Feature Pyramid with the feature maps from the backbone.""" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.conv_c2_1x1 = keras.layers.Conv2D(256, 1, 1, "same") + self.conv_c3_1x1 = keras.layers.Conv2D(256, 1, 1, "same") + self.conv_c4_1x1 = keras.layers.Conv2D(256, 1, 1, "same") + self.conv_c5_1x1 = keras.layers.Conv2D(256, 1, 1, "same") + + self.conv_c2_3x3 = keras.layers.Conv2D(256, 3, 1, "same") + self.conv_c3_3x3 = keras.layers.Conv2D(256, 3, 1, "same") + self.conv_c4_3x3 = keras.layers.Conv2D(256, 3, 1, "same") + self.conv_c5_3x3 = keras.layers.Conv2D(256, 3, 1, "same") + self.conv_c6_3x3 = keras.layers.Conv2D(256, 3, 1, "same") + self.conv_c6_pool = keras.layers.MaxPool2D() + self.upsample_2x = keras.layers.UpSampling2D(2) + + def call(self, inputs, training=None): + c2_output = inputs["P2"] + c3_output = inputs["P3"] + c4_output = inputs["P4"] + c5_output = inputs["P5"] + + c6_output = self.conv_c6_pool(c5_output) + p6_output = c6_output + p5_output = self.conv_c5_1x1(c5_output) + p4_output = self.conv_c4_1x1(c4_output) + p3_output = self.conv_c3_1x1(c3_output) + p2_output = self.conv_c2_1x1(c2_output) + + p4_output = p4_output + self.upsample_2x(p5_output) + p3_output = p3_output + self.upsample_2x(p4_output) + p2_output = p2_output + self.upsample_2x(p3_output) + + p6_output = self.conv_c6_3x3(p6_output) + p5_output = self.conv_c5_3x3(p5_output) + p4_output = self.conv_c4_3x3(p4_output) + p3_output = self.conv_c3_3x3(p3_output) + p2_output = self.conv_c2_3x3(p2_output) + + return { + "P2": p2_output, + "P3": p3_output, + "P4": p4_output, + "P5": p5_output, + "P6": p6_output, + } + + def get_config(self): + config = {} + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/keras_cv/models/object_detection/faster_rcnn/rcnn_head.py b/keras_cv/models/object_detection/faster_rcnn/rcnn_head.py new file mode 100644 index 0000000000..4caec64076 --- /dev/null +++ b/keras_cv/models/object_detection/faster_rcnn/rcnn_head.py @@ -0,0 +1,71 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from keras_cv.api_export import keras_cv_export +from keras_cv.backend import keras + + +@keras_cv_export( + "keras_cv.models.faster_rcnn.RCNNHead", + package="keras_cv.models.faster_rcnn", +) +class RCNNHead(keras.layers.Layer): + def __init__( + self, + num_classes, + conv_dims=[], + fc_dims=[1024, 1024], + **kwargs, + ): + super().__init__(**kwargs) + self.num_classes = num_classes + self.conv_dims = conv_dims + self.fc_dims = fc_dims + self.convs = [] + for conv_dim in conv_dims: + layer = keras.layers.Conv2D( + filters=conv_dim, + kernel_size=3, + strides=1, + padding="same", + activation="relu", + ) + self.convs.append(layer) + self.fcs = [] + for fc_dim in fc_dims: + layer = keras.layers.Dense(units=fc_dim, activation="relu") + self.fcs.append(layer) + self.box_pred = keras.layers.Dense(units=4) + self.cls_score = keras.layers.Dense( + units=num_classes + 1, activation="softmax" + ) + + def call(self, feature_map, training=None): + x = feature_map + for conv in self.convs: + x = conv(x) + for fc in self.fcs: + x = fc(x) + rcnn_boxes = self.box_pred(x) + rcnn_scores = self.cls_score(x) + return rcnn_boxes, rcnn_scores + + def get_config(self): + config = { + "num_classes": self.num_classes, + "conv_dims": self.conv_dims, + "fc_dims": self.fc_dims, + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/keras_cv/models/object_detection/faster_rcnn/rpn_head.py b/keras_cv/models/object_detection/faster_rcnn/rpn_head.py new file mode 100644 index 0000000000..2e8f581dfb --- /dev/null +++ b/keras_cv/models/object_detection/faster_rcnn/rpn_head.py @@ -0,0 +1,113 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tree + +from keras_cv.api_export import keras_cv_export +from keras_cv.backend import keras +from keras_cv.backend import ops + + +@keras_cv_export( + "keras_cv.models.faster_rcnn.RPNHead", + package="keras_cv.models.faster_rcnn", +) +class RPNHead(keras.layers.Layer): + """A Keras layer implementing the RPN architecture. + + Region Proposal Networks (RPN) was first suggested in + [FasterRCNN](https://arxiv.org/abs/1506.01497). + This is an end to end trainable layer which proposes regions + for a detector (RCNN). + + Args: + num_achors_per_location: The number of anchors per location. + """ + + def __init__( + self, + num_anchors_per_location=3, + **kwargs, + ): + super().__init__(**kwargs) + self.num_anchors = num_anchors_per_location + + def build(self, input_shape): + if isinstance(input_shape, (dict, list, tuple)): + input_shape = tree.flatten(input_shape) + input_shape = input_shape[0:4] + filters = input_shape[-1] + self.conv = keras.layers.Conv2D( + filters=filters, + kernel_size=3, + strides=1, + padding="same", + activation="relu", + kernel_initializer="truncated_normal", + ) + self.objectness_logits = keras.layers.Conv2D( + filters=self.num_anchors * 1, + kernel_size=1, + strides=1, + padding="same", + kernel_initializer="truncated_normal", + ) + self.anchor_deltas = keras.layers.Conv2D( + filters=self.num_anchors * 4, + kernel_size=1, + strides=1, + padding="same", + kernel_initializer="truncated_normal", + ) + + def call(self, feature_map, training=None): + def call_single_level(f_map): + batch_size = ops.shape(f_map)[0] + # [BS, H, W, C] + t = self.conv(f_map) + # [BS, H, W, K] + rpn_scores = self.objectness_logits(t) + # [BS, H, W, K * 4] + rpn_boxes = self.anchor_deltas(t) + # [BS, H*W*K, 4] + rpn_boxes = ops.reshape(rpn_boxes, [batch_size, -1, 4]) + # [BS, H*W*K, 1] + rpn_scores = ops.reshape(rpn_scores, [batch_size, -1, 1]) + return rpn_boxes, rpn_scores + + if not isinstance(feature_map, (dict, list, tuple)): + return call_single_level(feature_map) + elif isinstance(feature_map, (list, tuple)): + rpn_boxes = [] + rpn_scores = [] + for f_map in feature_map: + rpn_box, rpn_score = call_single_level(f_map) + rpn_boxes.append(rpn_box) + rpn_scores.append(rpn_score) + return rpn_boxes, rpn_scores + else: + rpn_boxes = {} + rpn_scores = {} + for lvl, f_map in feature_map.items(): + rpn_box, rpn_score = call_single_level(f_map) + rpn_boxes[lvl] = rpn_box + rpn_scores[lvl] = rpn_score + return rpn_boxes, rpn_scores + + def get_config(self): + config = { + "num_anchors_per_location": self.num_anchors, + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items()))