Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

batch evaluation #89

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions lib/layer_utils/anchor_target_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from __future__ import print_function

import os
from model.config import cfg
from lib.model.config import cfg
import numpy as np
import numpy.random as npr
from utils.bbox import bbox_overlaps
from model.bbox_transform import bbox_transform
from lib.utils.bbox import bbox_overlaps
from lib.model.bbox_transform import bbox_transform
import torch

def anchor_target_layer(rpn_cls_score, gt_boxes, im_info, _feat_stride, all_anchors, num_anchors):
Expand Down Expand Up @@ -160,4 +160,4 @@ def _compute_targets(ex_rois, gt_rois):
assert ex_rois.shape[1] == 4
assert gt_rois.shape[1] == 5

return bbox_transform(torch.from_numpy(ex_rois), torch.from_numpy(gt_rois[:, :4])).numpy()
return bbox_transform(torch.from_numpy(ex_rois), torch.from_numpy(gt_rois[:, :4])).numpy()
48 changes: 45 additions & 3 deletions lib/layer_utils/proposal_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from __future__ import print_function

import numpy as np
from model.config import cfg
from model.bbox_transform import bbox_transform_inv, clip_boxes
from model.nms_wrapper import nms
from lib.model.config import cfg
from lib.model.bbox_transform import bbox_transform_inv, clip_boxes, bbox_transform_inv_batch
from lib.model.nms_wrapper import nms,nms_batch

import torch

Expand Down Expand Up @@ -52,3 +52,45 @@ def proposal_layer(rpn_cls_prob, rpn_bbox_pred, im_info, cfg_key, _feat_stride,
blob = torch.cat((batch_inds, proposals), 1)

return blob, scores


def proposal_layer_batch(rpn_cls_prob, rpn_bbox_pred, im_info, cfg_key, _feat_stride, anchors, num_anchors, network_device):
"""A simplified version compared to fast/er RCNN
For details please see the technical report
"""
if type(cfg_key) == bytes:
cfg_key = cfg_key.decode('utf-8')
pre_nms_topN = cfg[cfg_key].RPN_PRE_NMS_TOP_N
post_nms_topN = cfg[cfg_key].RPN_POST_NMS_TOP_N
nms_thresh = cfg[cfg_key].RPN_NMS_THRESH

# Get the scores and bounding boxes
scores = rpn_cls_prob[:, :, :, num_anchors:]
rpn_bbox_pred = rpn_bbox_pred.view((rpn_bbox_pred.size(0),-1, 4))
scores = scores.contiguous().view(scores.size(0),-1, 1)
proposals = bbox_transform_inv_batch(anchors, rpn_bbox_pred)#here bug
proposals = list(map(lambda x : clip_boxes(x, im_info[:2]),proposals))

blobs, scoress = [],[]
for i in range(scores.size(0)):
# Pick the top region proposals
score, order = scores[i].view(-1).sort(descending=True)
if pre_nms_topN > 0:
order = order[:pre_nms_topN]
score = score[:pre_nms_topN].view(-1, 1)
proposal = proposals[i][order.data, :]

# Non-maximal suppression
keep = nms_batch(torch.cat((proposal, score), 1).data, nms_thresh, network_device)
# Pick th top region proposals after NMS
if post_nms_topN > 0:
keep = keep[:post_nms_topN]
proposal = proposal[keep]
score = score[keep]

# Only support single image as input
batch_inds = proposal.new_zeros(proposal.size(0), 1)
blob = torch.cat((batch_inds, proposal), 1)
blobs.append(blob)
scoress.append(score)
return blobs, scoress
6 changes: 3 additions & 3 deletions lib/layer_utils/proposal_target_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@

import numpy as np
import numpy.random as npr
from model.config import cfg
from model.bbox_transform import bbox_transform
from utils.bbox import bbox_overlaps
from lib.model.config import cfg
from lib.model.bbox_transform import bbox_transform
from lib.utils.bbox import bbox_overlaps


import torch
Expand Down
51 changes: 49 additions & 2 deletions lib/layer_utils/proposal_top_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,56 @@
from __future__ import print_function

import numpy as np
from model.config import cfg
from model.bbox_transform import bbox_transform_inv, clip_boxes
from lib.model.config import cfg
from lib.model.bbox_transform import bbox_transform_inv, clip_boxes
import numpy.random as npr

import torch
import signal

def proposal_top_layer_batch(rpn_cls_probs, rpn_bbox_preds, im_info, _feat_stride, anchors, num_anchors, network_device):
"""A layer that just selects the top region proposals
without using non-maximal suppression,
For details please see the technical report
"""
rpn_top_n = cfg.TEST.RPN_TOP_N

scores = rpn_cls_probs[:, :, :, num_anchors:]

rpn_bbox_preds = rpn_bbox_preds.view(rpn_bbox_preds.size(0),-1, 4)
scores = scores.contiguous().view(rpn_bbox_preds.size(0),-1, 1)

blobs, scoress = [],[]
for i in range(scores.size(0)):
score = scores[i]
length = score.size(0)
if length < rpn_top_n:
# Random selection, maybe unnecessary and loses good proposals
# But such case rarely happens
top_inds = torch.from_numpy(npr.choice(length, size=rpn_top_n, replace=True)).long().cuda(network_device)
else:
top_inds = score.sort(0, descending=True)[1]
top_inds = top_inds[:rpn_top_n]
top_inds = top_inds.view(rpn_top_n)

# Do the selection here
anchor = anchors[top_inds, :].contiguous()
rpn_bbox_pred = rpn_bbox_preds[i][top_inds, :].contiguous()
score = score[top_inds].contiguous()

# Convert anchors into proposals via bbox transformations
proposal = bbox_transform_inv(anchor, rpn_bbox_pred)
# Clip predicted boxes to image
proposal = clip_boxes(proposal, im_info[:2])

# Output rois blob
# Our RPN implementation only supports a single input image, so all
# batch inds are 0
batch_inds = proposal.new_zeros(proposal.size(0), 1)
blob = torch.cat([batch_inds, proposal], 1)
blobs.append(blob)
scoress.append(score)
return blobs, scoress

def proposal_top_layer(rpn_cls_prob, rpn_bbox_pred, im_info, _feat_stride, anchors, num_anchors):
"""A layer that just selects the top region proposals
Expand Down Expand Up @@ -43,10 +88,12 @@ def proposal_top_layer(rpn_cls_prob, rpn_bbox_pred, im_info, _feat_stride, ancho

# Convert anchors into proposals via bbox transformations
proposals = bbox_transform_inv(anchors, rpn_bbox_pred)
#proposals = torch.zeros((5000,4))

# Clip predicted boxes to image
proposals = clip_boxes(proposals, im_info[:2])


# Output rois blob
# Our RPN implementation only supports a single input image, so all
# batch inds are 0
Expand Down
Binary file not shown.
5 changes: 4 additions & 1 deletion lib/layer_utils/roi_pooling/_ext/roi_pooling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
def _import_symbols(locals):
for symbol in dir(_lib):
fn = getattr(_lib, symbol)
locals[symbol] = _wrap_function(fn, _ffi)
if callable(fn):
locals[symbol] = _wrap_function(fn, _ffi)
else:
locals[symbol] = fn
__all__.append(symbol)

_import_symbols(locals())
Binary file not shown.
2 changes: 1 addition & 1 deletion lib/layer_utils/snippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from __future__ import print_function

import numpy as np
from layer_utils.generate_anchors import generate_anchors
from lib.layer_utils.generate_anchors import generate_anchors

def generate_anchors_pre(height, width, feat_stride, anchor_scales=(8,16,32), anchor_ratios=(0.5,1,2)):
""" A wrapper function to generate anchors given different scales
Expand Down
6 changes: 3 additions & 3 deletions lib/make.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,21 @@ cd layer_utils/roi_pooling/src/cuda
echo "Compiling roi_pooling kernels by nvcc..."
nvcc -c -o roi_pooling_kernel.cu.o roi_pooling_kernel.cu -x cu -Xcompiler -fPIC $CUDA_ARCH
cd ../../
python build.py
python3 build.py
cd ../../

# Build RoIAlign
cd layer_utils/roi_align/src/cuda
echo 'Compiling crop_and_resize kernels by nvcc...'
nvcc -c -o crop_and_resize_kernel.cu.o crop_and_resize_kernel.cu -x cu -Xcompiler -fPIC $CUDA_ARCH
cd ../../
python build.py
python3 build.py
cd ../../

# Build NMS
cd nms/src/cuda
echo "Compiling nms kernels by nvcc..."
nvcc -c -o nms_kernel.cu.o nms_kernel.cu -x cu -Xcompiler -fPIC $CUDA_ARCH
cd ../../
python build.py
python3 build.py
cd ../
31 changes: 30 additions & 1 deletion lib/model/bbox_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,34 @@ def bbox_transform(ex_rois, gt_rois):
return targets


def bbox_transform_inv_batch(boxes, deltas):
# Input should be both tensor or both Variable and on the same device
if len(boxes) == 0:
return deltas.detach() * 0

widths = boxes[:, 2] - boxes[:, 0] + 1.0
heights = boxes[:, 3] - boxes[:, 1] + 1.0
ctr_x = boxes[:, 0] + 0.5 * widths
ctr_y = boxes[:, 1] + 0.5 * heights

dx = deltas[:, :, 0::4]
dy = deltas[:, :, 1::4]
dw = deltas[:, :, 2::4]
dh = deltas[:, :, 3::4]

pred_ctr_x = dx * widths.repeat(deltas.size(0),1).unsqueeze(-1) + ctr_x.repeat(deltas.size(0),1).unsqueeze(-1)
pred_ctr_y = dy * heights.repeat(deltas.size(0),1).unsqueeze(-1) + ctr_y.repeat(deltas.size(0),1).unsqueeze(-1)
pred_w = torch.exp(dw) * widths.repeat(deltas.size(0),1).unsqueeze(-1)
pred_h = torch.exp(dh) * heights.repeat(deltas.size(0),1).unsqueeze(-1)

pred_boxes = torch.cat(\
[_.unsqueeze(2) for _ in [pred_ctr_x - 0.5 * pred_w,\
pred_ctr_y - 0.5 * pred_h,\
pred_ctr_x + 0.5 * pred_w,\
pred_ctr_y + 0.5 * pred_h]], 2).view(deltas.size(0),len(boxes), -1)

return pred_boxes

def bbox_transform_inv(boxes, deltas):
# Input should be both tensor or both Variable and on the same device
if len(boxes) == 0:
Expand All @@ -46,7 +74,7 @@ def bbox_transform_inv(boxes, deltas):
dy = deltas[:, 1::4]
dw = deltas[:, 2::4]
dh = deltas[:, 3::4]

pred_ctr_x = dx * widths.unsqueeze(1) + ctr_x.unsqueeze(1)
pred_ctr_y = dy * heights.unsqueeze(1) + ctr_y.unsqueeze(1)
pred_w = torch.exp(dw) * widths.unsqueeze(1)
Expand All @@ -61,6 +89,7 @@ def bbox_transform_inv(boxes, deltas):
return pred_boxes



def clip_boxes(boxes, im_shape):
"""
Clip boxes to image boundaries.
Expand Down
14 changes: 7 additions & 7 deletions lib/model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
# Whether to double the learning rate for bias
__C.TRAIN.DOUBLE_BIAS = True

# Whether to initialize the weights with truncated normal distribution
# Whether to initialize the weights with truncated normal distribution
__C.TRAIN.TRUNCATED = False

# Whether to have weight decay on bias as well
Expand All @@ -50,7 +50,7 @@

# Whether to use aspect-ratio grouping of training images, introduced merely for saving
# GPU memory
__C.TRAIN.ASPECT_GROUPING = False
__C.TRAIN.ASPECT_GROUPING = True

# The number of snapshots kept, older ones are deleted to save space
__C.TRAIN.SNAPSHOT_KEPT = 3
Expand Down Expand Up @@ -139,7 +139,7 @@
__C.TRAIN.RPN_BATCHSIZE = 256

# NMS threshold used on RPN proposals
__C.TRAIN.RPN_NMS_THRESH = 0.7
__C.TRAIN.RPN_NMS_THRESH = 0.8

# Number of top scoring boxes to keep before apply NMS to RPN proposals
__C.TRAIN.RPN_PRE_NMS_TOP_N = 12000
Expand All @@ -155,7 +155,7 @@
# Set to -1.0 to use uniform example weighting
__C.TRAIN.RPN_POSITIVE_WEIGHT = -1.0

# Whether to use all ground truth bounding boxes for training,
# Whether to use all ground truth bounding boxes for training,
# For COCO, setting USE_ALL_GT to False will exclude boxes that are flagged as ''iscrowd''
__C.TRAIN.USE_ALL_GT = True

Expand Down Expand Up @@ -213,8 +213,8 @@

__C.RESNET = edict()

# Option to set if max-pooling is appended after crop_and_resize.
# if true, the region will be resized to a square of 2xPOOLING_SIZE,
# Option to set if max-pooling is appended after crop_and_resize.
# if true, the region will be resized to a square of 2xPOOLING_SIZE,
# then 2x2 max-pooling is applied; otherwise the region will be directly
# resized to a square of POOLING_SIZE
__C.RESNET.MAX_POOL = False
Expand Down Expand Up @@ -267,7 +267,7 @@
__C.EXP_DIR = 'default'

# Use GPU implementation of non-maximum suppression
__C.USE_GPU_NMS = True
__C.USE_GPU_NMS = False

# Default pooling mode
__C.POOLING_MODE = 'crop'
Expand Down
8 changes: 7 additions & 1 deletion lib/model/nms_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,15 @@
from __future__ import division
from __future__ import print_function

from nms.pth_nms import pth_nms
from lib.nms.pth_nms import pth_nms
from lib.nms.pth_nms import pth_nms_batch


def nms_batch(dets, thresh, network_device):
"""Dispatch to either CPU or GPU NMS implementations.
Accept dets as tensor"""
return pth_nms_batch(dets, thresh, network_device)

def nms(dets, thresh):
"""Dispatch to either CPU or GPU NMS implementations.
Accept dets as tensor"""
Expand Down
Loading