High level network definitions with pre-trained weights in TensorFlow (tested with 2.1.0 >=
TF >= 1.4.0
).
- Applicability. Many people already have their own ML workflows, and want to put a new model on their workflows. TensorNets can be easily plugged together because it is designed as simple functional interfaces without custom classes.
- Manageability. Models are written in
tf.contrib.layers
, which is lightweight like PyTorch and Keras, and allows for ease of accessibility to every weight and end-point. Also, it is easy to deploy and expand a collection of pre-processing and pre-trained weights. - Readability. With recent TensorFlow APIs, more factoring and less indenting can be possible. For example, all the inception variants are implemented as about 500 lines of code in TensorNets while 2000+ lines in official TensorFlow models.
- Reproducibility. You can always reproduce the original results with simple APIs including feature extractions. Furthermore, you don't need to care about a version of TensorFlow beacuse compatibilities with various releases of TensorFlow have been checked with Travis.
You can install TensorNets from PyPI (pip install tensornets
) or directly from GitHub (pip install git+https://github.com/taehoonlee/tensornets.git
).
Each network (see full list) is not a custom class but a function that takes and returns tf.Tensor
as its input and output. Here is an example of ResNet50
:
import tensorflow as tf
# import tensorflow.compat.v1 as tf # for TF 2
import tensornets as nets
# tf.disable_v2_behavior() # for TF 2
inputs = tf.placeholder(tf.float32, [None, 224, 224, 3])
model = nets.ResNet50(inputs)
assert isinstance(model, tf.Tensor)
You can load an example image by using utils.load_img
returning a np.ndarray
as the NHWC format:
img = nets.utils.load_img('cat.png', target_size=256, crop_size=224)
assert img.shape == (1, 224, 224, 3)
Once your network is created, you can run with regular TensorFlow APIs π because all the networks in TensorNets always return tf.Tensor
. Using pre-trained weights and pre-processing are as easy as pretrained()
and preprocess()
to reproduce the original results:
with tf.Session() as sess:
img = model.preprocess(img) # equivalent to img = nets.preprocess(model, img)
sess.run(model.pretrained()) # equivalent to nets.pretrained(model)
preds = sess.run(model, {inputs: img})
You can see the most probable classes:
print(nets.utils.decode_predictions(preds, top=2)[0])
[(u'n02124075', u'Egyptian_cat', 0.28067636), (u'n02127052', u'lynx', 0.16826575)]
You can also easily obtain values of intermediate layers with middles()
and outputs()
:
with tf.Session() as sess:
img = model.preprocess(img)
sess.run(model.pretrained())
middles = sess.run(model.middles(), {inputs: img})
outputs = sess.run(model.outputs(), {inputs: img})
model.print_middles()
assert middles[0].shape == (1, 56, 56, 256)
assert middles[-1].shape == (1, 7, 7, 2048)
model.print_outputs()
assert sum(sum((outputs[-1] - preds) ** 2)) < 1e-8
With load()
and save()
, your weight values can be restorable:
with tf.Session() as sess:
model.init()
# ... your training ...
model.save('test.npz')
with tf.Session() as sess:
model.load('test.npz')
# ... your deployment ...
TensorNets enables us to deploy well-known architectures and benchmark those results faster β‘οΈ. For more information, you can check out the lists of utilities, examples, and architectures.
Each object detection model can be coupled with any network in TensorNets (see performance) and takes two arguments: a placeholder and a function acting as a stem layer. Here is an example of YOLOv2
for PASCAL VOC:
import tensorflow as tf
import tensornets as nets
inputs = tf.placeholder(tf.float32, [None, 416, 416, 3])
model = nets.YOLOv2(inputs, nets.Darknet19)
img = nets.utils.load_img('cat.png')
with tf.Session() as sess:
sess.run(model.pretrained())
preds = sess.run(model, {inputs: model.preprocess(img)})
boxes = model.get_boxes(preds, img.shape[1:3])
Like other models, a detection model also returns tf.Tensor
as its output. You can see the bounding box predictions (x1, y1, x2, y2, score)
by using model.get_boxes(model_output, original_img_shape)
and visualize the results:
from tensornets.datasets import voc
print("%s: %s" % (voc.classnames[7], boxes[7][0])) # 7 is cat
import numpy as np
import matplotlib.pyplot as plt
box = boxes[7][0]
plt.imshow(img[0].astype(np.uint8))
plt.gca().add_patch(plt.Rectangle(
(box[0], box[1]), box[2] - box[0], box[3] - box[1],
fill=False, edgecolor='r', linewidth=2))
plt.show()
More detection examples such as FasterRCNN on VOC2007 are here π. Note that:
-
APIs of detection models are slightly different:
YOLOv3
:sess.run(model.preds, {inputs: img})
,YOLOv2
:sess.run(model, {inputs: img})
,FasterRCNN
:sess.run(model, {inputs: img, model.scales: scale})
,
-
FasterRCNN
requiresroi_pooling
:git clone https://github.com/deepsense-io/roi-pooling && cd roi-pooling && vi roi_pooling/Makefile
and edit according to here,python setup.py install
.
Besides pretrained()
and preprocess()
, the output tf.Tensor
provides the following useful methods:
logits
: returns thetf.Tensor
logits (the values before the softmax),middles()
(=get_middles()
): returns a list of all the representativetf.Tensor
end-points,outputs()
(=get_outputs()
): returns a list of all thetf.Tensor
end-points,weights()
(=get_weights()
): returns a list of all thetf.Tensor
weight matrices,summary()
(=print_summary()
): prints the numbers of layers, weight matrices, and parameters,print_middles()
: prints all the representative end-points,print_outputs()
: prints all the end-points,print_weights()
: prints all the weight matrices.
Example outputs of print methods are:
>>> model.print_middles()
Scope: resnet50
conv2/block1/out:0 (?, 56, 56, 256)
conv2/block2/out:0 (?, 56, 56, 256)
conv2/block3/out:0 (?, 56, 56, 256)
conv3/block1/out:0 (?, 28, 28, 512)
conv3/block2/out:0 (?, 28, 28, 512)
conv3/block3/out:0 (?, 28, 28, 512)
conv3/block4/out:0 (?, 28, 28, 512)
conv4/block1/out:0 (?, 14, 14, 1024)
...
>>> model.print_outputs()
Scope: resnet50
conv1/pad:0 (?, 230, 230, 3)
conv1/conv/BiasAdd:0 (?, 112, 112, 64)
conv1/bn/batchnorm/add_1:0 (?, 112, 112, 64)
conv1/relu:0 (?, 112, 112, 64)
pool1/pad:0 (?, 114, 114, 64)
pool1/MaxPool:0 (?, 56, 56, 64)
conv2/block1/0/conv/BiasAdd:0 (?, 56, 56, 256)
conv2/block1/0/bn/batchnorm/add_1:0 (?, 56, 56, 256)
conv2/block1/1/conv/BiasAdd:0 (?, 56, 56, 64)
conv2/block1/1/bn/batchnorm/add_1:0 (?, 56, 56, 64)
conv2/block1/1/relu:0 (?, 56, 56, 64)
...
>>> model.print_weights()
Scope: resnet50
conv1/conv/weights:0 (7, 7, 3, 64)
conv1/conv/biases:0 (64,)
conv1/bn/beta:0 (64,)
conv1/bn/gamma:0 (64,)
conv1/bn/moving_mean:0 (64,)
conv1/bn/moving_variance:0 (64,)
conv2/block1/0/conv/weights:0 (1, 1, 64, 256)
conv2/block1/0/conv/biases:0 (256,)
conv2/block1/0/bn/beta:0 (256,)
conv2/block1/0/bn/gamma:0 (256,)
...
>>> model.summary()
Scope: resnet50
Total layers: 54
Total weights: 320
Total parameters: 25,636,712
- Comparison of different networks:
inputs = tf.placeholder(tf.float32, [None, 224, 224, 3])
models = [
nets.MobileNet75(inputs),
nets.MobileNet100(inputs),
nets.SqueezeNet(inputs),
]
img = utils.load_img('cat.png', target_size=256, crop_size=224)
imgs = nets.preprocess(models, img)
with tf.Session() as sess:
nets.pretrained(models)
for (model, img) in zip(models, imgs):
preds = sess.run(model, {inputs: img})
print(utils.decode_predictions(preds, top=2)[0])
- Transfer learning:
inputs = tf.placeholder(tf.float32, [None, 224, 224, 3])
outputs = tf.placeholder(tf.float32, [None, 50])
model = nets.DenseNet169(inputs, is_training=True, classes=50)
loss = tf.losses.softmax_cross_entropy(outputs, model.logits)
train = tf.train.AdamOptimizer(learning_rate=1e-5).minimize(loss)
with tf.Session() as sess:
nets.pretrained(model)
for (x, y) in your_NumPy_data: # the NHWC and one-hot format
sess.run(train, {inputs: x, outputs: y})
- Using multi-GPU:
inputs = tf.placeholder(tf.float32, [None, 224, 224, 3])
models = []
with tf.device('gpu:0'):
models.append(nets.ResNeXt50(inputs))
with tf.device('gpu:1'):
models.append(nets.DenseNet201(inputs))
from tensornets.preprocess import fb_preprocess
img = utils.load_img('cat.png', target_size=256, crop_size=224)
img = fb_preprocess(img)
with tf.Session() as sess:
nets.pretrained(models)
preds = sess.run(models, {inputs: img})
for pred in preds:
print(utils.decode_predictions(pred, top=2)[0])
- The top-k accuracies were obtained with TensorNets on ImageNet validation set and may slightly differ from the original ones.
- Input: input size fed into models
- Top-1: single center crop, top-1 accuracy
- Top-5: single center crop, top-5 accuracy
- MAC: rounded the number of float operations by using tf.profiler
- Size: rounded the number of parameters (w/ fully-connected layers)
- Stem: rounded the number of parameters (w/o fully-connected layers)
- The computation times were measured on NVIDIA Tesla P100 (3584 cores, 16 GB global memory) with cuDNN 6.0 and CUDA 8.0.
- Speed: milliseconds for inferences of 100 images
- The summary plot is generated by this script.
Input | Top-1 | Top-5 | MAC | Size | Stem | Speed | References | |
---|---|---|---|---|---|---|---|---|
ResNet50 | 224 | 74.874 | 92.018 | 51.0M | 25.6M | 23.6M | 195.4 | [paper] [tf-slim] [torch-fb] [caffe] [keras] |
ResNet101 | 224 | 76.420 | 92.786 | 88.9M | 44.7M | 42.7M | 311.7 | [paper] [tf-slim] [torch-fb] [caffe] |
ResNet152 | 224 | 76.604 | 93.118 | 120.1M | 60.4M | 58.4M | 439.1 | [paper] [tf-slim] [torch-fb] [caffe] |
ResNet50v2 | 299 | 75.960 | 93.034 | 51.0M | 25.6M | 23.6M | 209.7 | [paper] [tf-slim] [torch-fb] |
ResNet101v2 | 299 | 77.234 | 93.816 | 88.9M | 44.7M | 42.6M | 326.2 | [paper] [tf-slim] [torch-fb] |
ResNet152v2 | 299 | 78.032 | 94.162 | 120.1M | 60.4M | 58.3M | 455.2 | [paper] [tf-slim] [torch-fb] |
ResNet200v2 | 224 | 78.286 | 94.152 | 129.0M | 64.9M | 62.9M | 618.3 | [paper] [tf-slim] [torch-fb] |
ResNeXt50c32 | 224 | 77.740 | 93.810 | 49.9M | 25.1M | 23.0M | 267.4 | [paper] [torch-fb] |
ResNeXt101c32 | 224 | 78.730 | 94.294 | 88.1M | 44.3M | 42.3M | 427.9 | [paper] [torch-fb] |
ResNeXt101c64 | 224 | 79.494 | 94.592 | 0.0M | 83.7M | 81.6M | 877.8 | [paper] [torch-fb] |
WideResNet50 | 224 | 78.018 | 93.934 | 137.6M | 69.0M | 66.9M | 358.1 | [paper] [torch] |
Inception1 | 224 | 66.840 | 87.676 | 14.0M | 7.0M | 6.0M | 165.1 | [paper] [tf-slim] [caffe-zoo] |
Inception2 | 224 | 74.680 | 92.156 | 22.3M | 11.2M | 10.2M | 134.3 | [paper] [tf-slim] |
Inception3 | 299 | 77.946 | 93.758 | 47.6M | 23.9M | 21.8M | 314.6 | [paper] [tf-slim] [keras] |
Inception4 | 299 | 80.120 | 94.978 | 85.2M | 42.7M | 41.2M | 582.1 | [paper] [tf-slim] |
InceptionResNet2 | 299 | 80.256 | 95.252 | 111.5M | 55.9M | 54.3M | 656.8 | [paper] [tf-slim] |
NASNetAlarge | 331 | 82.498 | 96.004 | 186.2M | 93.5M | 89.5M | 2081 | [paper] [tf-slim] |
NASNetAmobile | 224 | 74.366 | 91.854 | 15.3M | 7.7M | 6.7M | 165.8 | [paper] [tf-slim] |
PNASNetlarge | 331 | 82.634 | 96.050 | 171.8M | 86.2M | 81.9M | 1978 | [paper] [tf-slim] |
VGG16 | 224 | 71.268 | 90.050 | 276.7M | 138.4M | 14.7M | 348.4 | [paper] [keras] |
VGG19 | 224 | 71.256 | 89.988 | 287.3M | 143.7M | 20.0M | 399.8 | [paper] [keras] |
DenseNet121 | 224 | 74.972 | 92.258 | 15.8M | 8.1M | 7.0M | 202.9 | [paper] [torch] |
DenseNet169 | 224 | 76.176 | 93.176 | 28.0M | 14.3M | 12.6M | 219.1 | [paper] [torch] |
DenseNet201 | 224 | 77.320 | 93.620 | 39.6M | 20.2M | 18.3M | 272.0 | [paper] [torch] |
MobileNet25 | 224 | 51.582 | 75.792 | 0.9M | 0.5M | 0.2M | 34.46 | [paper] [tf-slim] |
MobileNet50 | 224 | 64.292 | 85.624 | 2.6M | 1.3M | 0.8M | 52.46 | [paper] [tf-slim] |
MobileNet75 | 224 | 68.412 | 88.242 | 5.1M | 2.6M | 1.8M | 70.11 | [paper] [tf-slim] |
MobileNet100 | 224 | 70.424 | 89.504 | 8.4M | 4.3M | 3.2M | 83.41 | [paper] [tf-slim] |
MobileNet35v2 | 224 | 60.086 | 82.432 | 3.3M | 1.7M | 0.4M | 57.04 | [paper] [tf-slim] |
MobileNet50v2 | 224 | 65.194 | 86.062 | 3.9M | 2.0M | 0.7M | 64.35 | [paper] [tf-slim] |
MobileNet75v2 | 224 | 69.532 | 89.176 | 5.2M | 2.7M | 1.4M | 88.68 | [paper] [tf-slim] |
MobileNet100v2 | 224 | 71.336 | 90.142 | 6.9M | 3.5M | 2.3M | 93.82 | [paper] [tf-slim] |
MobileNet130v2 | 224 | 74.680 | 92.122 | 10.7M | 5.4M | 3.8M | 130.4 | [paper] [tf-slim] |
MobileNet140v2 | 224 | 75.230 | 92.422 | 12.1M | 6.2M | 4.4M | 132.9 | [paper] [tf-slim] |
75v3large | 224 | 73.754 | 91.618 | 7.9M | 4.0M | 2.7M | 79.73 | [paper] [tf-slim] |
100v3large | 224 | 75.790 | 92.840 | 27.3M | 5.5M | 4.2M | 94.71 | [paper] [tf-slim] |
100v3largemini | 224 | 72.706 | 90.930 | 7.8M | 3.9M | 2.7M | 70.57 | [paper] [tf-slim] |
75v3small | 224 | 66.138 | 86.534 | 4.1M | 2.1M | 1.0M | 37.78 | [paper] [tf-slim] |
100v3small | 224 | 68.318 | 87.942 | 5.1M | 2.6M | 1.5M | 42.00 | [paper] [tf-slim] |
100v3smallmini | 224 | 63.440 | 84.646 | 4.1M | 2.1M | 1.0M | 29.65 | [paper] [tf-slim] |
EfficientNetB0 | 224 | 77.012 | 93.338 | 26.2M | 5.3M | 4.0M | 147.1 | [paper] [tf-tpu] |
EfficientNetB1 | 240 | 79.040 | 94.284 | 15.4M | 7.9M | 6.6M | 217.3 | [paper] [tf-tpu] |
EfficientNetB2 | 260 | 80.064 | 94.862 | 18.1M | 9.2M | 7.8M | 296.4 | [paper] [tf-tpu] |
EfficientNetB3 | 300 | 81.384 | 95.586 | 24.2M | 12.3M | 10.8M | 482.7 | [paper] [tf-tpu] |
EfficientNetB4 | 380 | 82.588 | 96.094 | 38.4M | 19.5M | 17.7M | 959.5 | [paper] [tf-tpu] |
EfficientNetB5 | 456 | 83.496 | 96.590 | 60.4M | 30.6M | 28.5M | 1872 | [paper] [tf-tpu] |
EfficientNetB6 | 528 | 83.772 | 96.762 | 85.5M | 43.3M | 41.0M | 3503 | [paper] [tf-tpu] |
EfficientNetB7 | 600 | 84.088 | 96.740 | 131.9M | 66.7M | 64.1M | 6149 | [paper] [tf-tpu] |
SqueezeNet | 224 | 54.434 | 78.040 | 2.5M | 1.2M | 0.7M | 71.43 | [paper] [caffe] |
- The object detection models can be coupled with any network but mAPs could be measured only for the models with pre-trained weights. Note that:
YOLOv3VOC
was trained by taehoonlee with this recipe modified asmax_batches=70000, steps=40000,60000
,YOLOv2VOC
is equivalent toYOLOv2(inputs, Darknet19)
,TinyYOLOv2VOC
:TinyYOLOv2(inputs, TinyDarknet19)
,FasterRCNN_ZF_VOC
:FasterRCNN(inputs, ZF)
,FasterRCNN_VGG16_VOC
:FasterRCNN(inputs, VGG16, stem_out='conv5/3')
.
- The mAPs were obtained with TensorNets and may slightly differ from the original ones. The test input sizes were the numbers reported as the best in the papers:
YOLOv3
,YOLOv2
: 416x416FasterRCNN
: min_shorter_side=600, max_longer_side=1000
- The computation times were measured on NVIDIA Tesla P100 (3584 cores, 16 GB global memory) with cuDNN 6.0 and CUDA 8.0.
- Size: rounded the number of parameters
- Speed: milliseconds only for network inferences of a 416x416 or 608x608 single image
- FPS: 1000 / speed
PASCAL VOC2007 test | mAP | Size | Speed | FPS | References |
---|---|---|---|---|---|
YOLOv3VOC (416) | 0.7423 | 62M | 24.09 | 41.51 | [paper] [darknet] [darkflow] |
YOLOv2VOC (416) | 0.7320 | 51M | 14.75 | 67.80 | [paper] [darknet] [darkflow] |
TinyYOLOv2VOC (416) | 0.5303 | 16M | 6.534 | 153.0 | [paper] [darknet] [darkflow] |
FasterRCNN_ZF_VOC | 0.4466 | 59M | 241.4 | 3.325 | [paper] [caffe] [roi-pooling] |
FasterRCNN_VGG16_VOC | 0.6872 | 137M | 300.7 | 4.143 | [paper] [caffe] [roi-pooling] |
MS COCO val2014 | mAP | Size | Speed | FPS | References |
---|---|---|---|---|---|
YOLOv3COCO (608) | 0.6016 | 62M | 60.66 | 16.49 | [paper] [darknet] [darkflow] |
YOLOv3COCO (416) | 0.6028 | 62M | 40.23 | 24.85 | [paper] [darknet] [darkflow] |
YOLOv2COCO (608) | 0.5189 | 51M | 45.88 | 21.80 | [paper] [darknet] [darkflow] |
YOLOv2COCO (416) | 0.4922 | 51M | 21.66 | 46.17 | [paper] [darknet] [darkflow] |
- The six variants of MobileNetv3 are released, 12 Mar 2020.
- The eight variants of EfficientNet are released, 28 Jan 2020.
- It is available to use TensorNets on TF 2, 23 Jan 2020.
- MS COCO utils are released, 9 Jul 2018.
- PNASNetlarge is released, 12 May 2018.
- The six variants of MobileNetv2 are released, 5 May 2018.
- YOLOv3 for COCO and VOC are released, 4 April 2018.
- Generic object detection models for YOLOv2 and FasterRCNN are released, 26 March 2018.
- Add training codes.
- Add image classification models.
- PolyNet: A Pursuit of Structural Diversity in Very Deep Networks, CVPR 2017, Top-5 4.25%
- Squeeze-and-Excitation Networks, CVPR 2018, Top-5 3.79%
- GPipe: Efficient Training of Giant Neural Networks using Pipeline Parallelism, arXiv 2018, Top-5 3.0%
- Add object detection models (MaskRCNN, SSD).
- Add image segmentation models (FCN, UNet).
- Add image datasets (OpenImages).
- Add style transfer examples which can be coupled with any network in TensorNets.
- Add speech and language models with representative datasets (WaveNet, ByteNet).