Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to Keras Mish implementation for TfLite compatibility #60

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from setuptools import find_packages, setup


with open("README.md", "r") as f:
long_description = f.read()

Expand All @@ -19,7 +18,8 @@
python_requires=">=3.6",
entry_points={
"console_scripts": [
"convert-darknet-weights = tf2_yolov4.tools.convert_darknet_weights:convert_darknet_weights"
"convert-darknet-weights = tf2_yolov4.tools.convert_darknet_weights:convert_darknet_weights",
"convert-tflite = tf2_yolov4.tools.convert_tflite:convert_tflite",
]
},
classifiers=[
Expand Down
12 changes: 12 additions & 0 deletions tests/tools/test_convert_tflite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from tf2_yolov4.tools.convert_tflite import create_tflite_model

HEIGHT, WIDTH = (640, 960)


def test_import_convert_tflite_script_does_not_fail():
from tf2_yolov4.tools.convert_tflite import convert_tflite


def test_create_tflite_model_returns_correct_type(yolov4_inference):
tflite_model = create_tflite_model(yolov4_inference)
assert isinstance(tflite_model, bytes)
5 changes: 5 additions & 0 deletions tf2_yolov4/activations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Activations layers"""

from tf2_yolov4.activations.mish import Mish

__all__ = ["Mish"]
24 changes: 24 additions & 0 deletions tf2_yolov4/activations/mish.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Tensorflow-Keras Implementation of Mish
Source: https://github.com/digantamisra98/Mish/blob/master/Mish/TFKeras/mish.py
"""
import tensorflow as tf
from tensorflow.keras.layers import Layer


class Mish(Layer):
"""
Mish Activation Function.
.. math::
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + e^{x}))
Shape:
- Input: Arbitrary. Use the keyword argument `input_shape`
(tuple of integers, does not include the samples axis)
when using this layer as the first layer in a model.
- Output: Same shape as the input.
Examples:
>>> X = Mish()(X_input)
"""

def call(self, inputs, **kwargs):
return inputs * tf.math.tanh(tf.math.softplus(inputs))
5 changes: 3 additions & 2 deletions tf2_yolov4/layers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Common layer architecture such as Conv->BN->Mish or Conv->BN->LeakyReLU"""
import tensorflow as tf
import tensorflow_addons as tfa

from tf2_yolov4.activations import Mish


def conv_bn(
Expand Down Expand Up @@ -41,6 +42,6 @@ def conv_bn(
if activation == "leaky_relu":
x = tf.keras.layers.LeakyReLU(alpha=0.1)(x)
elif activation == "mish":
x = tfa.activations.mish(x)
x = Mish()(x)

return x
68 changes: 68 additions & 0 deletions tf2_yolov4/tools/convert_tflite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""
Script used to create a TfLite YOLOv4 model from previously trained weights.
"""

import click
import tensorflow as tf

from tf2_yolov4.anchors import YOLOV4_ANCHORS
from tf2_yolov4.model import YOLOv4

HEIGHT, WIDTH = (640, 960)

TFLITE_MODEL_PATH = "yolov4.tflite"


def create_tflite_model(model):
"""Converts a YOLOv4 model to a TfLite model

Args:
model (tensorflow.python.keras.engine.training.Model): YOLOv4 model

Returns:
(bytes): a binary TfLite model
"""
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS,
tf.lite.OpsSet.SELECT_TF_OPS,
]

converter.allow_custom_ops = True
return converter.convert()


@click.command()
@click.option("--num_classes", default=80, help="Number of classes")
@click.option(
"--weights_path", default=None, help="Path to .h5 file with model weights"
)
def convert_tflite(num_classes, weights_path):
"""Creates a .tflite file with a trained YOLOv4 model

Args:
num_classes (int): Number of classes
weights_path (str, optional): Path to .h5 pre-trained weights file
"""
model = YOLOv4(
input_shape=(HEIGHT, WIDTH, 3),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Height and width are parametrizable, is it an argument stored in the tflite model or is it just used for the conversion? We want to make sure users can proceed the inference on any size

anchors=YOLOV4_ANCHORS,
num_classes=num_classes,
training=False,
yolo_max_boxes=100,
yolo_iou_threshold=0.4,
yolo_score_threshold=0.1,
)

if weights_path:
model.load_weights(weights_path)

tflite_model = create_tflite_model(model)

with tf.io.gfile.GFile(TFLITE_MODEL_PATH, "wb") as file:
file.write(tflite_model)


if __name__ == "__main__":
convert_tflite()