Skip to content

Commit

Permalink
Merge pull request #311 from fizyr/fpn-correction
Browse files Browse the repository at this point in the history
Improve COCO results (FPN + NMS changes).
  • Loading branch information
hgaiser authored Mar 3, 2018
2 parents f0bc909 + fa619b8 commit 05bd676
Show file tree
Hide file tree
Showing 10 changed files with 101 additions and 80 deletions.
24 changes: 12 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,18 @@ retinanet-train coco /path/to/MS/COCO
The pretrained MS COCO model can be downloaded [here](https://github.com/fizyr/keras-retinanet/releases/download/0.1/resnet50_coco_best_v1.2.2.h5). Results using the `cocoapi` are shown below (note: according to the paper, this configuration should achieve a mAP of 0.343).

```
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.325
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.513
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.342
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.149
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.354
Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.345
Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.533
Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=100 ] = 0.368
Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.189
Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.380
Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.465
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.288
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.437
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.464
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.263
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.510
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.623
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 1 ] = 0.301
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets= 10 ] = 0.482
Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.529
Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.364
Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.565
Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.666
```

For training on [OID](https://github.com/openimages/dataset), run:
Expand Down Expand Up @@ -118,7 +118,7 @@ from keras_retinanet.models.resnet import custom_objects
model = keras.models.load_model('/path/to/model.h5', custom_objects=custom_objects)
```

Execution time on NVIDIA Pascal Titan X is roughly 75msec for an image of shape `1000x600x3`.
Execution time on NVIDIA Pascal Titan X is roughly 75msec for an image of shape `1000x800x3`.

## CSV datasets
The `CSVGenerator` provides an easy way to define your own datasets.
Expand Down
8 changes: 4 additions & 4 deletions keras_retinanet/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@
import keras


def top_k(*args, **kwargs):
return tensorflow.nn.top_k(*args, **kwargs)


def resize_images(*args, **kwargs):
return tensorflow.image.resize_images(*args, **kwargs)

Expand All @@ -34,6 +30,10 @@ def range(*args, **kwargs):
return tensorflow.range(*args, **kwargs)


def scatter_nd(*args, **kwargs):
return tensorflow.scatter_nd(*args, **kwargs)


def gather_nd(*args, **kwargs):
return tensorflow.gather_nd(*args, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion keras_retinanet/bin/evaluate_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def main(args=None):
# create a generator for testing data
test_generator = CocoGenerator(
args.coco_path,
args.set,
args.set
)

evaluate_coco(test_generator, model, args.score_threshold)
Expand Down
61 changes: 39 additions & 22 deletions keras_retinanet/layers/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,43 +76,60 @@ def get_config(self):


class NonMaximumSuppression(keras.layers.Layer):
def __init__(self, nms_threshold=0.5, top_k=None, max_boxes=300, *args, **kwargs):
self.nms_threshold = nms_threshold
self.top_k = top_k
self.max_boxes = max_boxes
def __init__(self, nms_threshold=0.5, score_threshold=0.05, max_boxes=300, *args, **kwargs):
self.nms_threshold = nms_threshold
self.score_threshold = score_threshold
self.max_boxes = max_boxes
super(NonMaximumSuppression, self).__init__(*args, **kwargs)

def call(self, inputs, **kwargs):
boxes, classification, detections = inputs

# TODO: support batch size > 1.
boxes = boxes[0]
classification = classification[0]
detections = detections[0]
boxes = inputs[0][0]
classification = inputs[1][0]
other = [i[0] for i in inputs[2:]] # can be any user-specified additional data
indices = backend.range(keras.backend.shape(classification)[0])
selected_scores = []

# perform per class NMS
for c in range(int(classification.shape[1])):
scores = classification[:, c]

# threshold based on score
score_indices = backend.where(keras.backend.greater(scores, self.score_threshold))
score_indices = keras.backend.cast(score_indices, 'int32')
boxes_ = backend.gather_nd(boxes, score_indices)
scores = keras.backend.gather(scores, score_indices)[:, 0]

# perform NMS
nms_indices = backend.non_max_suppression(boxes_, scores, max_output_size=self.max_boxes, iou_threshold=self.nms_threshold)

# filter set of original indices
selected_indices = keras.backend.gather(score_indices, nms_indices)

# mask original classification column, setting all suppressed values to 0
scores = keras.backend.gather(scores, nms_indices)
scores = backend.scatter_nd(selected_indices, scores, keras.backend.shape(classification[:, c]))
scores = keras.backend.expand_dims(scores, axis=1)

scores = keras.backend.max(classification, axis=1)
selected_scores.append(scores)

# selecting best anchors theoretically improves speed at the cost of minor performance
if self.top_k:
scores, indices = backend.top_k(scores, self.top_k, sorted=False)
boxes = keras.backend.gather(boxes, indices)
classification = keras.backend.gather(classification, indices)
detections = keras.backend.gather(detections, indices)
# reconstruct the (suppressed) classification scores
classification = keras.backend.concatenate(selected_scores, axis=1)

indices = backend.non_max_suppression(boxes, scores, max_output_size=self.max_boxes, iou_threshold=self.nms_threshold)
# reconstruct into the expected output
detections = keras.backend.concatenate([boxes, classification] + other, axis=1)

detections = keras.backend.gather(detections, indices)
return keras.backend.expand_dims(detections, axis=0)

def compute_output_shape(self, input_shape):
return (input_shape[2][0], None, input_shape[2][2])
return (input_shape[0][0], input_shape[0][1], sum([i[2] for i in input_shape]))

def get_config(self):
config = super(NonMaximumSuppression, self).get_config()
config.update({
'nms_threshold' : self.nms_threshold,
'top_k' : self.top_k,
'max_boxes' : self.max_boxes,
'nms_threshold' : self.nms_threshold,
'score_threshold' : self.score_threshold,
'max_boxes' : self.max_boxes,
})

return config
Expand Down
12 changes: 7 additions & 5 deletions keras_retinanet/models/retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,15 @@ def default_regression_model(num_anchors, pyramid_feature_size=256, regression_f

def __create_pyramid_features(C3, C4, C5, feature_size=256):
# upsample C5 to get P5 from the FPN paper
P5 = keras.layers.Conv2D(feature_size, kernel_size=1, strides=1, padding='same', name='P5')(C5)
P5 = keras.layers.Conv2D(feature_size, kernel_size=1, strides=1, padding='same', name='C5_reduced')(C5)
P5_upsampled = layers.UpsampleLike(name='P5_upsampled')([P5, C4])
P5 = keras.layers.Conv2D(feature_size, kernel_size=3, strides=1, padding='same', name='P5')(P5)

# add P5 elementwise to C4
P4 = keras.layers.Conv2D(feature_size, kernel_size=1, strides=1, padding='same', name='C4_reduced')(C4)
P4 = keras.layers.Add(name='P4_merged')([P5_upsampled, P4])
P4 = keras.layers.Conv2D(feature_size, kernel_size=3, strides=1, padding='same', name='P4')(P4)
P4_upsampled = layers.UpsampleLike(name='P4_upsampled')([P4, C3])
P4 = keras.layers.Conv2D(feature_size, kernel_size=3, strides=1, padding='same', name='P4')(P4)

# add P4 elementwise to C3
P3 = keras.layers.Conv2D(feature_size, kernel_size=1, strides=1, padding='same', name='C3_reduced')(C3)
Expand Down Expand Up @@ -207,12 +208,13 @@ def retinanet_bbox(inputs, num_classes, nms=True, name='retinanet-bbox', *args,
classification = model.outputs[2]

# apply predicted regression to anchors
boxes = layers.RegressBoxes(name='boxes')([anchors, regression])
detections = keras.layers.Concatenate(axis=2)([boxes, classification] + model.outputs[3:])
boxes = layers.RegressBoxes(name='boxes')([anchors, regression])

# additionally apply non maximum suppression
if nms:
detections = layers.NonMaximumSuppression(name='nms')([boxes, classification, detections])
detections = layers.NonMaximumSuppression(name='nms')([boxes, classification] + model.outputs[3:])
else:
detections = keras.layers.Concatenate(axis=2)([boxes, classification] + model.outputs[3:])

# construct the model
return keras.models.Model(inputs=inputs, outputs=model.outputs[1:] + [detections], name=name)
4 changes: 2 additions & 2 deletions keras_retinanet/preprocessing/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def __init__(
batch_size=1,
group_method='ratio', # one of 'none', 'random', 'ratio'
shuffle_groups=True,
image_min_side=600,
image_max_side=1024,
image_min_side=800,
image_max_side=1333,
transform_parameters=None,
):
self.transform_generator = transform_generator
Expand Down
20 changes: 12 additions & 8 deletions keras_retinanet/utils/anchors.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,16 +171,20 @@ def bbox_transform(anchors, gt_boxes, mean=None, std=None):
elif not isinstance(std, np.ndarray):
raise ValueError('Expected std to be a np.ndarray, list or tuple. Received: {}'.format(type(std)))

anchor_widths = anchors[:, 2] - anchors[:, 0] + 1.0
anchor_heights = anchors[:, 3] - anchors[:, 1] + 1.0
anchor_widths = anchors[:, 2] - anchors[:, 0]
anchor_heights = anchors[:, 3] - anchors[:, 1]
anchor_ctr_x = anchors[:, 0] + 0.5 * anchor_widths
anchor_ctr_y = anchors[:, 1] + 0.5 * anchor_heights

gt_widths = gt_boxes[:, 2] - gt_boxes[:, 0] + 1.0
gt_heights = gt_boxes[:, 3] - gt_boxes[:, 1] + 1.0
gt_widths = gt_boxes[:, 2] - gt_boxes[:, 0]
gt_heights = gt_boxes[:, 3] - gt_boxes[:, 1]
gt_ctr_x = gt_boxes[:, 0] + 0.5 * gt_widths
gt_ctr_y = gt_boxes[:, 1] + 0.5 * gt_heights

# clip widths to 1
gt_widths = np.maximum(gt_widths, 1)
gt_heights = np.maximum(gt_heights, 1)

targets_dx = (gt_ctr_x - anchor_ctr_x) / anchor_widths
targets_dy = (gt_ctr_y - anchor_ctr_y) / anchor_heights
targets_dw = np.log(gt_widths / anchor_widths)
Expand All @@ -204,15 +208,15 @@ def compute_overlap(a, b):
-------
overlaps: (N, K) ndarray of overlap between boxes and query_boxes
"""
area = (b[:, 2] - b[:, 0] + 1) * (b[:, 3] - b[:, 1] + 1)
area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1])

iw = np.minimum(np.expand_dims(a[:, 2], axis=1), b[:, 2]) - np.maximum(np.expand_dims(a[:, 0], 1), b[:, 0]) + 1
ih = np.minimum(np.expand_dims(a[:, 3], axis=1), b[:, 3]) - np.maximum(np.expand_dims(a[:, 1], 1), b[:, 1]) + 1
iw = np.minimum(np.expand_dims(a[:, 2], axis=1), b[:, 2]) - np.maximum(np.expand_dims(a[:, 0], 1), b[:, 0])
ih = np.minimum(np.expand_dims(a[:, 3], axis=1), b[:, 3]) - np.maximum(np.expand_dims(a[:, 1], 1), b[:, 1])

iw = np.maximum(iw, 0)
ih = np.maximum(ih, 0)

ua = np.expand_dims((a[:, 2] - a[:, 0] + 1) * (a[:, 3] - a[:, 1] + 1), axis=1) + area - iw * ih
ua = np.expand_dims((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), axis=1) + area - iw * ih

ua = np.maximum(ua, np.finfo(float).eps)

Expand Down
29 changes: 13 additions & 16 deletions keras_retinanet/utils/coco_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def evaluate_coco(generator, model, threshold=0.05):
# start collecting results
results = []
image_ids = []
for i in range(generator.size()):
image = generator.load_image(i)
for index in range(generator.size()):
image = generator.load_image(index)
image = generator.preprocess_image(image)
image, scale = generator.resize_image(image)

Expand All @@ -50,26 +50,23 @@ def evaluate_coco(generator, model, threshold=0.05):
detections[:, :, 3] -= detections[:, :, 1]

# compute predicted labels and scores
for detection in detections[0, ...]:
positive_labels = np.where(detection[4:] > threshold)[0]

for i, j in np.transpose(np.where(detections[0, :, 4:] > threshold)):
# append detections for each positively labeled class
for label in positive_labels:
image_result = {
'image_id' : generator.image_ids[i],
'category_id' : generator.label_to_coco_label(label),
'score' : float(detection[4 + label]),
'bbox' : (detection[:4]).tolist(),
}
image_result = {
'image_id' : generator.image_ids[index],
'category_id' : generator.label_to_coco_label(j),
'score' : float(detections[0, i, 4 + j]),
'bbox' : (detections[0, i, :4]).tolist(),
}

# append detection to results
results.append(image_result)
# append detection to results
results.append(image_result)

# append image to list of processed images
image_ids.append(generator.image_ids[i])
image_ids.append(generator.image_ids[index])

# print progress
print('{}/{}'.format(i, generator.size()), end='\r')
print('{}/{}'.format(index, generator.size()), end='\r')

if not len(results):
return
Expand Down
2 changes: 1 addition & 1 deletion keras_retinanet/utils/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def apply_transform(matrix, image, params):
return output


def resize_image(img, min_side=600, max_side=1024):
def resize_image(img, min_side=800, max_side=1333):
(rows, cols, _) = img.shape

smallest_side = min(rows, cols)
Expand Down
19 changes: 10 additions & 9 deletions tests/layers/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,18 +100,19 @@ def test_simple(self):
]], dtype=keras.backend.floatx())
classification = keras.backend.variable(classification)

detections = np.array([[
other = np.array([[
[1, 2, 3],
[4, 5, 6],
]], dtype=keras.backend.floatx())
detections = keras.backend.variable(detections)
other = keras.backend.variable(other)

# compute output
actual = non_maximum_suppression_layer.call([boxes, classification, detections])
actual = non_maximum_suppression_layer.call([boxes, classification, other])
actual = keras.backend.eval(actual)

expected = np.array([[
[4, 5, 6],
[0, 0, 10, 10, 0, 0, 1, 2, 3],
[0, 0, 10, 10, 0, 1, 4, 5, 6],
]], dtype=keras.backend.floatx())

np.testing.assert_array_equal(actual, expected)
Expand Down Expand Up @@ -147,7 +148,7 @@ def test_mini_batch(self):
], dtype=keras.backend.floatx())
classification = keras.backend.variable(classification)

detections = np.array([
other = np.array([
[
[1, 2, 3],
[4, 5, 6],
Expand All @@ -157,18 +158,18 @@ def test_mini_batch(self):
[10, 11, 12],
],
], dtype=keras.backend.floatx())
detections = keras.backend.variable(detections)
other = keras.backend.variable(other)

# compute output
actual = non_maximum_suppression_layer.call([boxes, classification, detections])
actual = non_maximum_suppression_layer.call([boxes, classification, other])
actual = keras.backend.eval(actual)

expected = np.array([
[
[4, 5, 6],
[0, 0, 10, 10, 0, 1, 4, 5, 6],
],
[
[7, 8, 9],
[100, 100, 150, 150, 0, 1, 7, 8, 9],
],
], dtype=keras.backend.floatx())

Expand Down

0 comments on commit 05bd676

Please sign in to comment.