Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
0d109f0
Add config vars. to resize images
Aug 31, 2016
0b7439a
Make anchor ratio to be configurable in a prototxt
Aug 31, 2016
46405a1
Add bounding box voting
Sep 1, 2016
450cc1d
Update README.md
Sep 1, 2016
6b76263
Update the link to Caffe submodule
kyehyeon Sep 2, 2016
df5cce8
Correct deprecated variables (param_str_ -> param_str)
kyehyeon Sep 2, 2016
de03c83
Add PVANET prototxts and a script for downloading caffemodels
kyehyeon Sep 2, 2016
abf4288
Update README
kyehyeon Sep 2, 2016
300e9e0
Update README
kyehyeon Sep 2, 2016
6e6f36c
Update README
kyehyeon Sep 2, 2016
568a08c
Update README
kyehyeon Sep 2, 2016
4321a60
Update submodule 'caffe'
Sep 2, 2016
4956132
Update README.md and downloading script
Sep 1, 2016
e9d8ae7
Update a script for downloading original caffemodels
kyehyeon Sep 13, 2016
c1803f7
Add PVANET ImageNet pretrained models, PVANET-lite models, and their …
kyehyeon Sep 19, 2016
4cbd40c
Update the download script for ImageNet pretrained models
kyehyeon Sep 19, 2016
ddf6040
Add Google Drive download links for PVANET models
kyehyeon Sep 28, 2016
d23bf1f
Update README.md
Sep 30, 2016
4799b28
Add training prototxts
Sep 30, 2016
7be7250
Fix a bug in training examples
Oct 11, 2016
5f037cd
Add a hotfix to enable 'average_loss'
Oct 21, 2016
b528853
Update README.md
Oct 21, 2016
39570aa
Add a tool to merge 'Conv-BN-Scale' into a single 'Conv' layer.
Nov 2, 2016
e8440f9
Update README.md
Nov 21, 2016
ab76d48
Update README.md (adding an arXiv link for EMDNN-accepted version)
Dec 8, 2016
4fc32d8
aUpdate README.md (adding BibTeX)
Dec 9, 2016
b07da75
Merge branch 'master' into develop
Dec 26, 2016
8957e2a
Add PVANet 9.1
Dec 26, 2016
6815a7f
Add weighted box scoring (vote heuristics)
Dec 27, 2016
04ba579
Update README.md
Dec 27, 2016
44a1428
Add a training example
Dec 27, 2016
69156c4
Add a prototxt for ImageNet classification (192x192 model)
Dec 28, 2016
f3a4773
Fix PVA9.1 prototxt (eliminated missing layers)
Dec 28, 2016
d5a41d6
Update README.md
Jan 5, 2017
2c93baa
Fix README.md
Jan 16, 2017
e590825
Update demo.py
ygren Jun 8, 2017
e23be55
Create README.md
ygren Jun 8, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
lib/build
lib/pycocotools/_mask.c
lib/pycocotools/_mask.so
*.caffemodel
6 changes: 3 additions & 3 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[submodule "caffe-fast-rcnn"]
[submodule "caffe-for-pva"]
path = caffe-fast-rcnn
url = https://github.com/rbgirshick/caffe-fast-rcnn.git
branch = fast-rcnn
url = https://github.com/sanghoon/caffe.git
branch = dev_pvanet
258 changes: 84 additions & 174 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,217 +1,127 @@
### Disclaimer
## PVANet: Lightweight Deep Neural Networks for Real-time Object Detection
by Sanghoon Hong, Byungseok Roh, Kye-hyeon Kim, Yeongjae Cheon, Minje Park (Intel Imaging and Camera Technology)
Presented in [EMDNN2016](http://allenai.org/plato/emdnn/), a NIPS2016 workshop ([arXiv link](https://arxiv.org/abs/1611.08588))

The official Faster R-CNN code (written in MATLAB) is available [here](https://github.com/ShaoqingRen/faster_rcnn).
If your goal is to reproduce the results in our NIPS 2015 paper, please use the [official code](https://github.com/ShaoqingRen/faster_rcnn).
### Introduction

This repository contains a Python *reimplementation* of the MATLAB code.
This Python implementation is built on a fork of [Fast R-CNN](https://github.com/rbgirshick/fast-rcnn).
There are slight differences between the two implementations.
In particular, this Python port
- is ~10% slower at test-time, because some operations execute on the CPU in Python layers (e.g., 220ms / image vs. 200ms / image for VGG16)
- gives similar, but not exactly the same, mAP as the MATLAB version
- is *not compatible* with models trained using the MATLAB code due to the minor implementation differences
- **includes approximate joint training** that is 1.5x faster than alternating optimization (for VGG16) -- see these [slides](https://www.dropbox.com/s/xtr4yd4i5e0vw8g/iccv15_tutorial_training_rbg.pdf?dl=0) for more information
This repository is a fork from [py-faster-rcnn](https://github.com/rbgirshick/py-faster-rcnn) and demonstrates the performance of PVANet.

# *Faster* R-CNN: Towards Real-Time Object Detection with Region Proposal Networks
You can refer to [py-faster-rcnn README.md](https://github.com/rbgirshick/py-faster-rcnn/blob/master/README.md) and [faster-rcnn README.md](https://github.com/ShaoqingRen/faster_rcnn/blob/master/README.md) for more information.

By Shaoqing Ren, Kaiming He, Ross Girshick, Jian Sun (Microsoft Research)
### Desclaimer

This Python implementation contains contributions from Sean Bell (Cornell) written during an MSR internship.
Please note that this repository doesn't contain our in-house codes used in the published article.
- This version of py-faster-rcnn is slower than our in-house runtime code (e.g. image pre-processing code written in Python)
- PVANet was trained by our in-house deep learning library, not by this implementation.
- There might be a tiny difference in VOC2012 test results, because some hidden parameters in py-faster-rcnn may be set differently with ours.

Please see the official [README.md](https://github.com/ShaoqingRen/faster_rcnn/blob/master/README.md) for more details.
### Citing PVANet

Faster R-CNN was initially described in an [arXiv tech report](http://arxiv.org/abs/1506.01497) and was subsequently published in NIPS 2015.

### License

Faster R-CNN is released under the MIT License (refer to the LICENSE file for details).

### Citing Faster R-CNN

If you find Faster R-CNN useful in your research, please consider citing:

@inproceedings{renNIPS15fasterrcnn,
Author = {Shaoqing Ren and Kaiming He and Ross Girshick and Jian Sun},
Title = {Faster {R-CNN}: Towards Real-Time Object Detection
with Region Proposal Networks},
Booktitle = {Advances in Neural Information Processing Systems ({NIPS})},
Year = {2015}
}

### Contents
1. [Requirements: software](#requirements-software)
2. [Requirements: hardware](#requirements-hardware)
3. [Basic installation](#installation-sufficient-for-the-demo)
4. [Demo](#demo)
5. [Beyond the demo: training and testing](#beyond-the-demo-installation-for-training-and-testing-models)
6. [Usage](#usage)

### Requirements: software

1. Requirements for `Caffe` and `pycaffe` (see: [Caffe installation instructions](http://caffe.berkeleyvision.org/installation.html))

**Note:** Caffe *must* be built with support for Python layers!

```make
# In your Makefile.config, make sure to have this line uncommented
WITH_PYTHON_LAYER := 1
# Unrelatedly, it's also recommended that you use CUDNN
USE_CUDNN := 1
```

You can download my [Makefile.config](http://www.cs.berkeley.edu/~rbg/fast-rcnn-data/Makefile.config) for reference.
2. Python packages you might not have: `cython`, `python-opencv`, `easydict`
3. [Optional] MATLAB is required for **official** PASCAL VOC evaluation only. The code now includes unofficial Python evaluation code.

### Requirements: hardware

1. For training smaller networks (ZF, VGG_CNN_M_1024) a good GPU (e.g., Titan, K20, K40, ...) with at least 3G of memory suffices
2. For training Fast R-CNN with VGG16, you'll need a K40 (~11G of memory)
3. For training the end-to-end version of Faster R-CNN with VGG16, 3G of GPU memory is sufficient (using CUDNN)

### Installation (sufficient for the demo)
If you want to cite this work in your publication:
```
@article{hong2016pvanet,
title={{PVANet}: Lightweight Deep Neural Networks for Real-time Object Detection},
author={Hong, Sanghoon and Roh, Byungseok and Kim, Kye-Hyeon and Cheon, Yeongjae and Park, Minje},
journal={arXiv preprint arXiv:1611.08588},
year={2016}
}
```

### Installation
1. Clone the Faster R-CNN repository
```Shell
# Make sure to clone with --recursive
git clone --recursive https://github.com/rbgirshick/py-faster-rcnn.git
```

2. We'll call the directory that you cloned Faster R-CNN into `FRCN_ROOT`

*Ignore notes 1 and 2 if you followed step 1 above.*

**Note 1:** If you didn't clone Faster R-CNN with the `--recursive` flag, then you'll need to manually clone the `caffe-fast-rcnn` submodule:
```Shell
git submodule update --init --recursive
# Make sure to clone with --recursive
git clone --recursive https://github.com/sanghoon/pva-faster-rcnn.git
```
**Note 2:** The `caffe-fast-rcnn` submodule needs to be on the `faster-rcnn` branch (or equivalent detached state). This will happen automatically *if you followed step 1 instructions*.

3. Build the Cython modules
2. We'll call the directory that you cloned Faster R-CNN into `FRCN_ROOT`. Build the Cython modules
```Shell
cd $FRCN_ROOT/lib
make
```

4. Build Caffe and pycaffe
3. Build Caffe and pycaffe
```Shell
cd $FRCN_ROOT/caffe-fast-rcnn
# Now follow the Caffe installation instructions here:
# http://caffe.berkeleyvision.org/installation.html
# For your Makefile.config:
# Uncomment `WITH_PYTHON_LAYER := 1`

# If you're experienced with Caffe and have all of the requirements installed
# and your Makefile.config in place, then simply do:
cp Makefile.config.example Makefile.config
make -j8 && make pycaffe
```

5. Download pre-computed Faster R-CNN detectors
4. Download PVANet detection model for VOC2007
```Shell
cd $FRCN_ROOT
./data/scripts/fetch_faster_rcnn_models.sh
./models/pvanet/download_voc2007.sh
```

This will populate the `$FRCN_ROOT/data` folder with `faster_rcnn_models`. See `data/README.md` for details.
These models were trained on VOC 2007 trainval.

### Demo

*After successfully completing [basic installation](#installation-sufficient-for-the-demo)*, you'll be ready to run the demo.

To run the demo
```Shell
cd $FRCN_ROOT
./tools/demo.py
```
The demo performs detection using a VGG16 network trained for detection on PASCAL VOC 2007.

### Beyond the demo: installation for training and testing models
1. Download the training, validation, test data and VOCdevkit

```Shell
wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar
wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar
wget http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCdevkit_08-Jun-2007.tar
```

2. Extract all of these tars into one directory named `VOCdevkit`

```Shell
tar xvf VOCtrainval_06-Nov-2007.tar
tar xvf VOCtest_06-Nov-2007.tar
tar xvf VOCdevkit_08-Jun-2007.tar
```

3. It should have this basic structure

```Shell
$VOCdevkit/ # development kit
$VOCdevkit/VOCcode/ # VOC utility code
$VOCdevkit/VOC2007 # image sets, annotations, etc.
# ... and several other directories ...
```

4. Create symlinks for the PASCAL VOC dataset

```Shell
cd $FRCN_ROOT/data
ln -s $VOCdevkit VOCdevkit2007
5. Download PVANet detection model for VOC2012 (published model)
```Shell
cd $FRCN_ROOT
./models/pvanet/download_voc_best.sh
```

6. (Optional) Download all available models (including pre-trained and compressed models)
```Shell
cd $FRCN_ROOT
./models/pvanet/download_all_models.sh
```
Using symlinks is a good idea because you will likely want to share the same PASCAL dataset installation between multiple projects.
5. [Optional] follow similar steps to get PASCAL VOC 2010 and 2012
6. [Optional] If you want to use COCO, please see some notes under `data/README.md`
7. Follow the next sections to download pre-trained ImageNet models

### Download pre-trained ImageNet models

Pre-trained ImageNet models can be downloaded for the three networks described in the paper: ZF and VGG16.

```Shell
cd $FRCN_ROOT
./data/scripts/fetch_imagenet_models.sh
```
VGG16 comes from the [Caffe Model Zoo](https://github.com/BVLC/caffe/wiki/Model-Zoo), but is provided here for your convenience.
ZF was trained at MSRA.
7. (Optional) Download ILSVRC2012 (ImageNet) classification model
```Shell
cd $FRCN_ROOT
./models/pvanet/download_imagenet_model.sh
```

### Usage
8. (Optional) If the scripts don't work, please download the models from ...

To train and test a Faster R-CNN detector using the **alternating optimization** algorithm from our NIPS 2015 paper, use `experiments/scripts/faster_rcnn_alt_opt.sh`.
Output is written underneath `$FRCN_ROOT/output`.
| Model | Google Drive |
| ------ | ---- |
| PVANet for VOC2007 | [link](https://drive.google.com/open?id=0Bw_6VpHzQoMVRGZOSEctOEVMLXc) |
| PVANet for VOC2012 | [link](https://drive.google.com/open?id=0Bw_6VpHzQoMVa3M0Zm5zNnEtQUE) |
| PVANet for VOC2012 (compressed) | [link](https://drive.google.com/open?id=0Bw_6VpHzQoMVZU1BdEJDZG5MVXM) |
| PVANet for ILSVRC2012 (ImageNet) | [link](https://drive.google.com/open?id=0Bw_6VpHzQoMVTjctVVhjMXo1X3c) |
| PVANet pre-trained | [link](https://drive.google.com/open?id=0Bw_6VpHzQoMVak5FVFBWU0Uyb3M) |

```Shell
cd $FRCN_ROOT
./experiments/scripts/faster_rcnn_alt_opt.sh [GPU_ID] [NET] [--set ...]
# GPU_ID is the GPU you want to train on
# NET in {ZF, VGG_CNN_M_1024, VGG16} is the network arch to use
# --set ... allows you to specify fast_rcnn.config options, e.g.
# --set EXP_DIR seed_rng1701 RNG_SEED 1701
```
### How to run the demo

("alt opt" refers to the alternating optimization training algorithm described in the NIPS paper.)
1. Download PASCAL VOC 2007 and 2012
-- Follow the instructions in [py-faster-rcnn README.md](https://github.com/rbgirshick/py-faster-rcnn#beyond-the-demo-installation-for-training-and-testing-models)

To train and test a Faster R-CNN detector using the **approximate joint training** method, use `experiments/scripts/faster_rcnn_end2end.sh`.
Output is written underneath `$FRCN_ROOT/output`.
2. PVANet on PASCAL VOC 2007
```Shell
cd $FRCN_ROOT
./tools/test_net.py --net models/pvanet/pva9.1/PVA9.1_ImgNet_COCO_VOC0712.caffemodel --def models/pvanet/pva9.1/faster_rcnn_train_test_21cls.pt --cfg models/pvanet/cfgs/submit_1019.yml --gpu 0
```

```Shell
cd $FRCN_ROOT
./experiments/scripts/faster_rcnn_end2end.sh [GPU_ID] [NET] [--set ...]
# GPU_ID is the GPU you want to train on
# NET in {ZF, VGG_CNN_M_1024, VGG16} is the network arch to use
# --set ... allows you to specify fast_rcnn.config options, e.g.
# --set EXP_DIR seed_rng1701 RNG_SEED 1701
```
3. PVANet (compressed)
```Shell
cd $FRCN_ROOT
./tools/test_net.py --net models/pvanet/pva9.1/PVA9.1_ImgNet_COCO_VOC0712plus_compressed.caffemodel --def models/pvanet/pva9.1/faster_rcnn_train_test_ft_rcnn_only_plus_comp.pt --cfg models/pvanet/cfgs/submit_1019.yml --gpu 0
```
4.Visualization:run the demo.py
./tools/demo.py --gpu 0 --def models/pvanet/comp/test.pt --net models/pvanet/comp/test.model

This method trains the RPN module jointly with the Fast R-CNN network, rather than alternating between training the two. It results in faster (~ 1.5x speedup) training times and similar detection accuracy. See these [slides](https://www.dropbox.com/s/xtr4yd4i5e0vw8g/iccv15_tutorial_training_rbg.pdf?dl=0) for more details.
### Expected results

Artifacts generated by the scripts in `tools` are written in this directory.
#### Mean Average Precision on VOC detection tasks

Trained Fast R-CNN networks are saved under:
| Model | VOC2007 mAP (%) | VOC2012 mAP (%) |
| --------- | ------- | ------- |
| PVANet+ (VOC2007) | **84.9** | N/A |
| PVANet+ (VOC2012) | *89.8* | **84.2** |
| PVANet+ (VOC2012 + compressed) | *87.8* | 83.7 |
- The training set for the VOC2012 model includes the VOC2007 test set. Therefore the accuracies on VOC2007 of the model are not meaningful; They're shown here just for reference

```
output/<experiment directory>/<dataset name>/
```
#### Validation error on ILSVRC2012

Test outputs are saved under:
| Input size | Top-1 error (%) | Top-5 error (%) |
| --- | --- | --- |
| 192x192 | 30.00 | N/A |
| 224x224 | 27.66 | 8.84 |
- We re-trained a 224x224 model from the '192x192' model as a base model.

```
output/<experiment directory>/<dataset name>/<network snapshot name>/
```
2 changes: 1 addition & 1 deletion caffe-fast-rcnn
Submodule caffe-fast-rcnn updated 198 files
13 changes: 13 additions & 0 deletions lib/fast_rcnn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
# Each scale is the pixel size of an image's shortest side
__C.TRAIN.SCALES = (600,)

# Resize test images so that its width and height are multiples of ...
__C.TRAIN.SCALE_MULTIPLE_OF = 1

# Max pixel size of the longest side of a scaled input image
__C.TRAIN.MAX_SIZE = 1000

Expand Down Expand Up @@ -134,6 +137,9 @@
# Each scale is the pixel size of an image's shortest side
__C.TEST.SCALES = (600,)

# Resize test images so that its width and height are multiples of ...
__C.TEST.SCALE_MULTIPLE_OF = 1

# Max pixel size of the longest side of a scaled input image
__C.TEST.MAX_SIZE = 1000

Expand Down Expand Up @@ -163,6 +169,13 @@
# Proposal height and width both need to be greater than RPN_MIN_SIZE (at orig image scale)
__C.TEST.RPN_MIN_SIZE = 16

# Apply bounding box voting
__C.TEST.BBOX_VOTE = False

# Apply box scoring heuristics
__C.TEST.BBOX_VOTE_N_WEIGHTED_SCORE = 1
__C.TEST.BBOX_VOTE_WEIGHT_EMPTY = 0.5


#
# MISC
Expand Down
Loading